Skip to main content

prax_sqlx/
engine.rs

1//! SQLx query engine implementation.
2
3use crate::config::{DatabaseBackend, SqlxConfig};
4use crate::error::SqlxResult;
5use crate::pool::SqlxPool;
6use crate::row::SqlxRow;
7use crate::types::quote_identifier;
8use prax_query::QueryResult;
9use prax_query::filter::FilterValue;
10use prax_query::traits::{BoxFuture, Model, QueryEngine};
11use sqlx::Row;
12use std::sync::Arc;
13use tracing::debug;
14
15/// SQLx-based query engine for Prax.
16///
17/// This engine provides compile-time checked queries through SQLx,
18/// supporting PostgreSQL, MySQL, and SQLite.
19///
20/// # Example
21///
22/// ```rust,ignore
23/// use prax_sqlx::{SqlxEngine, SqlxConfig};
24///
25/// let config = SqlxConfig::from_url("postgres://localhost/mydb")?;
26/// let engine = SqlxEngine::new(config).await?;
27///
28/// // Execute queries
29/// let count = engine.count_table("users", None).await?;
30/// ```
31#[derive(Clone)]
32pub struct SqlxEngine {
33    pool: Arc<SqlxPool>,
34    backend: DatabaseBackend,
35}
36
37impl SqlxEngine {
38    /// Create a new SQLx engine from configuration.
39    pub async fn new(config: SqlxConfig) -> SqlxResult<Self> {
40        let backend = config.backend;
41        let pool = SqlxPool::connect(&config).await?;
42        Ok(Self {
43            pool: Arc::new(pool),
44            backend,
45        })
46    }
47
48    /// Create a new engine from an existing pool.
49    pub fn from_pool(pool: SqlxPool) -> Self {
50        let backend = pool.backend();
51        Self {
52            pool: Arc::new(pool),
53            backend,
54        }
55    }
56
57    /// Get the database backend type.
58    pub fn backend(&self) -> DatabaseBackend {
59        self.backend
60    }
61
62    /// Get the connection pool.
63    pub fn pool(&self) -> &SqlxPool {
64        &self.pool
65    }
66
67    /// Close the engine and all connections.
68    pub async fn close(&self) {
69        self.pool.close().await;
70    }
71
72    // ==================== Low-Level Query Methods ====================
73
74    /// Execute a raw SQL query and return multiple rows.
75    pub async fn raw_query_many(
76        &self,
77        sql: &str,
78        params: &[FilterValue],
79    ) -> SqlxResult<Vec<SqlxRow>> {
80        debug!(sql = %sql, "Executing raw_query_many");
81
82        match &*self.pool {
83            #[cfg(feature = "postgres")]
84            SqlxPool::Postgres(pool) => {
85                let mut query = sqlx::query(sql);
86                for param in params {
87                    query = bind_pg_param(query, param);
88                }
89                let rows = query.fetch_all(pool).await?;
90                Ok(rows.into_iter().map(SqlxRow::Postgres).collect())
91            }
92            #[cfg(feature = "mysql")]
93            SqlxPool::MySql(pool) => {
94                let mut query = sqlx::query(sql);
95                for param in params {
96                    query = bind_mysql_param(query, param);
97                }
98                let rows = query.fetch_all(pool).await?;
99                Ok(rows.into_iter().map(SqlxRow::MySql).collect())
100            }
101            #[cfg(feature = "sqlite")]
102            SqlxPool::Sqlite(pool) => {
103                let mut query = sqlx::query(sql);
104                for param in params {
105                    query = bind_sqlite_param(query, param);
106                }
107                let rows = query.fetch_all(pool).await?;
108                Ok(rows.into_iter().map(SqlxRow::Sqlite).collect())
109            }
110        }
111    }
112
113    /// Execute a raw SQL query and return a single row.
114    pub async fn raw_query_one(&self, sql: &str, params: &[FilterValue]) -> SqlxResult<SqlxRow> {
115        debug!(sql = %sql, "Executing raw_query_one");
116
117        match &*self.pool {
118            #[cfg(feature = "postgres")]
119            SqlxPool::Postgres(pool) => {
120                let mut query = sqlx::query(sql);
121                for param in params {
122                    query = bind_pg_param(query, param);
123                }
124                let row = query.fetch_one(pool).await?;
125                Ok(SqlxRow::Postgres(row))
126            }
127            #[cfg(feature = "mysql")]
128            SqlxPool::MySql(pool) => {
129                let mut query = sqlx::query(sql);
130                for param in params {
131                    query = bind_mysql_param(query, param);
132                }
133                let row = query.fetch_one(pool).await?;
134                Ok(SqlxRow::MySql(row))
135            }
136            #[cfg(feature = "sqlite")]
137            SqlxPool::Sqlite(pool) => {
138                let mut query = sqlx::query(sql);
139                for param in params {
140                    query = bind_sqlite_param(query, param);
141                }
142                let row = query.fetch_one(pool).await?;
143                Ok(SqlxRow::Sqlite(row))
144            }
145        }
146    }
147
148    /// Execute a raw SQL query and return an optional row.
149    pub async fn raw_query_optional(
150        &self,
151        sql: &str,
152        params: &[FilterValue],
153    ) -> SqlxResult<Option<SqlxRow>> {
154        debug!(sql = %sql, "Executing raw_query_optional");
155
156        match &*self.pool {
157            #[cfg(feature = "postgres")]
158            SqlxPool::Postgres(pool) => {
159                let mut query = sqlx::query(sql);
160                for param in params {
161                    query = bind_pg_param(query, param);
162                }
163                let row = query.fetch_optional(pool).await?;
164                Ok(row.map(SqlxRow::Postgres))
165            }
166            #[cfg(feature = "mysql")]
167            SqlxPool::MySql(pool) => {
168                let mut query = sqlx::query(sql);
169                for param in params {
170                    query = bind_mysql_param(query, param);
171                }
172                let row = query.fetch_optional(pool).await?;
173                Ok(row.map(SqlxRow::MySql))
174            }
175            #[cfg(feature = "sqlite")]
176            SqlxPool::Sqlite(pool) => {
177                let mut query = sqlx::query(sql);
178                for param in params {
179                    query = bind_sqlite_param(query, param);
180                }
181                let row = query.fetch_optional(pool).await?;
182                Ok(row.map(SqlxRow::Sqlite))
183            }
184        }
185    }
186
187    /// Execute a SQL statement (INSERT, UPDATE, DELETE) and return affected rows.
188    pub async fn raw_execute(&self, sql: &str, params: &[FilterValue]) -> SqlxResult<u64> {
189        debug!(sql = %sql, "Executing raw_execute");
190
191        match &*self.pool {
192            #[cfg(feature = "postgres")]
193            SqlxPool::Postgres(pool) => {
194                let mut query = sqlx::query(sql);
195                for param in params {
196                    query = bind_pg_param(query, param);
197                }
198                let result = query.execute(pool).await?;
199                Ok(result.rows_affected())
200            }
201            #[cfg(feature = "mysql")]
202            SqlxPool::MySql(pool) => {
203                let mut query = sqlx::query(sql);
204                for param in params {
205                    query = bind_mysql_param(query, param);
206                }
207                let result = query.execute(pool).await?;
208                Ok(result.rows_affected())
209            }
210            #[cfg(feature = "sqlite")]
211            SqlxPool::Sqlite(pool) => {
212                let mut query = sqlx::query(sql);
213                for param in params {
214                    query = bind_sqlite_param(query, param);
215                }
216                let result = query.execute(pool).await?;
217                Ok(result.rows_affected())
218            }
219        }
220    }
221
222    /// Count rows in a table with optional filter.
223    pub async fn count_table(&self, table: &str, filter: Option<&str>) -> SqlxResult<u64> {
224        let table = quote_identifier(self.backend, table);
225        let sql = match filter {
226            Some(f) => format!("SELECT COUNT(*) as count FROM {} WHERE {}", table, f),
227            None => format!("SELECT COUNT(*) as count FROM {}", table),
228        };
229
230        let row = self.raw_query_one(&sql, &[]).await?;
231        match row {
232            #[cfg(feature = "postgres")]
233            SqlxRow::Postgres(r) => Ok(r.try_get::<i64, _>("count")? as u64),
234            #[cfg(feature = "mysql")]
235            SqlxRow::MySql(r) => Ok(r.try_get::<i64, _>("count")? as u64),
236            #[cfg(feature = "sqlite")]
237            SqlxRow::Sqlite(r) => Ok(r.try_get::<i64, _>("count")? as u64),
238        }
239    }
240}
241
242// ==================== Parameter Binding Helpers ====================
243
244#[cfg(feature = "postgres")]
245fn bind_pg_param<'q>(
246    query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
247    value: &'q FilterValue,
248) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
249    match value {
250        FilterValue::String(s) => query.bind(s.as_str()),
251        FilterValue::Int(i) => query.bind(*i),
252        FilterValue::Float(f) => query.bind(*f),
253        FilterValue::Bool(b) => query.bind(*b),
254        FilterValue::Null => query.bind(Option::<String>::None),
255        FilterValue::Json(j) => query.bind(j.clone()),
256        FilterValue::List(arr) => {
257            // Convert list to JSON for PostgreSQL
258            let json = serde_json::to_value(arr).unwrap_or(serde_json::Value::Null);
259            query.bind(json)
260        }
261    }
262}
263
264#[cfg(feature = "mysql")]
265fn bind_mysql_param<'q>(
266    query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
267    value: &'q FilterValue,
268) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
269    match value {
270        FilterValue::String(s) => query.bind(s.as_str()),
271        FilterValue::Int(i) => query.bind(*i),
272        FilterValue::Float(f) => query.bind(*f),
273        FilterValue::Bool(b) => query.bind(*b),
274        FilterValue::Null => query.bind(Option::<String>::None),
275        FilterValue::Json(j) => query.bind(j.to_string()),
276        FilterValue::List(arr) => {
277            let json = serde_json::to_string(arr).unwrap_or_default();
278            query.bind(json)
279        }
280    }
281}
282
283#[cfg(feature = "sqlite")]
284fn bind_sqlite_param<'q>(
285    query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
286    value: &'q FilterValue,
287) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
288    match value {
289        FilterValue::String(s) => query.bind(s.as_str()),
290        FilterValue::Int(i) => query.bind(*i),
291        FilterValue::Float(f) => query.bind(*f),
292        FilterValue::Bool(b) => query.bind(*b),
293        FilterValue::Null => query.bind(Option::<String>::None),
294        FilterValue::Json(j) => query.bind(j.to_string()),
295        FilterValue::List(arr) => {
296            let json = serde_json::to_string(arr).unwrap_or_default();
297            query.bind(json)
298        }
299    }
300}
301
302// ==================== QueryEngine Trait Implementation ====================
303
304impl QueryEngine for SqlxEngine {
305    fn dialect(&self) -> &dyn prax_query::dialect::SqlDialect {
306        match self.backend {
307            DatabaseBackend::Postgres => &prax_query::dialect::Postgres,
308            DatabaseBackend::MySql => &prax_query::dialect::Mysql,
309            DatabaseBackend::Sqlite => &prax_query::dialect::Sqlite,
310        }
311    }
312
313    fn query_many<T: Model + prax_query::row::FromRow + Send + 'static>(
314        &self,
315        sql: &str,
316        params: Vec<FilterValue>,
317    ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
318        let sql = sql.to_string();
319        Box::pin(async move {
320            debug!(sql = %sql, "Executing query_many via QueryEngine");
321
322            let rows = self
323                .raw_query_many(&sql, &params)
324                .await
325                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
326
327            rows.iter()
328                .map(|r| {
329                    let rr = crate::row_ref::SqlxRowRef::from_sqlx(r).map_err(|e| {
330                        let msg = e.to_string();
331                        prax_query::QueryError::deserialization(msg).with_source(e)
332                    })?;
333                    T::from_row(&rr).map_err(|e| {
334                        let msg = e.to_string();
335                        prax_query::QueryError::deserialization(msg).with_source(e)
336                    })
337                })
338                .collect()
339        })
340    }
341
342    fn query_one<T: Model + prax_query::row::FromRow + Send + 'static>(
343        &self,
344        sql: &str,
345        params: Vec<FilterValue>,
346    ) -> BoxFuture<'_, QueryResult<T>> {
347        let sql = sql.to_string();
348        Box::pin(async move {
349            debug!(sql = %sql, "Executing query_one via QueryEngine");
350
351            let row = self.raw_query_one(&sql, &params).await.map_err(|e| {
352                let msg = e.to_string();
353                if msg.contains("no rows") {
354                    prax_query::QueryError::not_found(T::MODEL_NAME)
355                } else {
356                    prax_query::QueryError::database(msg)
357                }
358            })?;
359
360            let rr = crate::row_ref::SqlxRowRef::from_sqlx(&row).map_err(|e| {
361                let msg = e.to_string();
362                prax_query::QueryError::deserialization(msg).with_source(e)
363            })?;
364            T::from_row(&rr).map_err(|e| {
365                let msg = e.to_string();
366                prax_query::QueryError::deserialization(msg).with_source(e)
367            })
368        })
369    }
370
371    fn query_optional<T: Model + prax_query::row::FromRow + Send + 'static>(
372        &self,
373        sql: &str,
374        params: Vec<FilterValue>,
375    ) -> BoxFuture<'_, QueryResult<Option<T>>> {
376        let sql = sql.to_string();
377        Box::pin(async move {
378            debug!(sql = %sql, "Executing query_optional via QueryEngine");
379
380            let row = self
381                .raw_query_optional(&sql, &params)
382                .await
383                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
384
385            match row {
386                Some(r) => {
387                    let rr = crate::row_ref::SqlxRowRef::from_sqlx(&r).map_err(|e| {
388                        let msg = e.to_string();
389                        prax_query::QueryError::deserialization(msg).with_source(e)
390                    })?;
391                    T::from_row(&rr).map(Some).map_err(|e| {
392                        let msg = e.to_string();
393                        prax_query::QueryError::deserialization(msg).with_source(e)
394                    })
395                }
396                None => Ok(None),
397            }
398        })
399    }
400
401    fn execute_insert<T: Model + prax_query::row::FromRow + Send + 'static>(
402        &self,
403        sql: &str,
404        params: Vec<FilterValue>,
405    ) -> BoxFuture<'_, QueryResult<T>> {
406        let sql = sql.to_string();
407        Box::pin(async move {
408            debug!(sql = %sql, "Executing execute_insert via QueryEngine");
409
410            let row = self
411                .raw_query_one(&sql, &params)
412                .await
413                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
414
415            let rr = crate::row_ref::SqlxRowRef::from_sqlx(&row).map_err(|e| {
416                let msg = e.to_string();
417                prax_query::QueryError::deserialization(msg).with_source(e)
418            })?;
419            T::from_row(&rr).map_err(|e| {
420                let msg = e.to_string();
421                prax_query::QueryError::deserialization(msg).with_source(e)
422            })
423        })
424    }
425
426    fn execute_update<T: Model + prax_query::row::FromRow + Send + 'static>(
427        &self,
428        sql: &str,
429        params: Vec<FilterValue>,
430    ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
431        let sql = sql.to_string();
432        Box::pin(async move {
433            debug!(sql = %sql, "Executing execute_update via QueryEngine");
434
435            let rows = self
436                .raw_query_many(&sql, &params)
437                .await
438                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
439
440            rows.iter()
441                .map(|r| {
442                    let rr = crate::row_ref::SqlxRowRef::from_sqlx(r).map_err(|e| {
443                        let msg = e.to_string();
444                        prax_query::QueryError::deserialization(msg).with_source(e)
445                    })?;
446                    T::from_row(&rr).map_err(|e| {
447                        let msg = e.to_string();
448                        prax_query::QueryError::deserialization(msg).with_source(e)
449                    })
450                })
451                .collect()
452        })
453    }
454
455    fn execute_delete(
456        &self,
457        sql: &str,
458        params: Vec<FilterValue>,
459    ) -> BoxFuture<'_, QueryResult<u64>> {
460        let sql = sql.to_string();
461        Box::pin(async move {
462            debug!(sql = %sql, "Executing execute_delete via QueryEngine");
463
464            let affected = self
465                .raw_execute(&sql, &params)
466                .await
467                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
468
469            Ok(affected)
470        })
471    }
472
473    fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
474        let sql = sql.to_string();
475        Box::pin(async move {
476            debug!(sql = %sql, "Executing execute_raw via QueryEngine");
477
478            let affected = self
479                .raw_execute(&sql, &params)
480                .await
481                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
482
483            Ok(affected)
484        })
485    }
486
487    fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
488        let sql = sql.to_string();
489        Box::pin(async move {
490            debug!(sql = %sql, "Executing count via QueryEngine");
491
492            let row = self
493                .raw_query_one(&sql, &params)
494                .await
495                .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
496
497            let count = match row {
498                #[cfg(feature = "postgres")]
499                SqlxRow::Postgres(r) => r
500                    .try_get::<i64, _>(0)
501                    .map_err(|e| prax_query::QueryError::database(e.to_string()))?
502                    as u64,
503                #[cfg(feature = "mysql")]
504                SqlxRow::MySql(r) => r
505                    .try_get::<i64, _>(0)
506                    .map_err(|e| prax_query::QueryError::database(e.to_string()))?
507                    as u64,
508                #[cfg(feature = "sqlite")]
509                SqlxRow::Sqlite(r) => r
510                    .try_get::<i64, _>(0)
511                    .map_err(|e| prax_query::QueryError::database(e.to_string()))?
512                    as u64,
513            };
514
515            Ok(count)
516        })
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use crate::types::placeholder;
524
525    #[test]
526    fn test_placeholder_generation() {
527        assert_eq!(placeholder(DatabaseBackend::Postgres, 1), "$1");
528        assert_eq!(placeholder(DatabaseBackend::MySql, 1), "?");
529        assert_eq!(placeholder(DatabaseBackend::Sqlite, 1), "?");
530    }
531
532    #[test]
533    fn test_quote_identifier() {
534        assert_eq!(
535            quote_identifier(DatabaseBackend::Postgres, "users"),
536            "\"users\""
537        );
538        assert_eq!(quote_identifier(DatabaseBackend::MySql, "users"), "`users`");
539    }
540}