1use crate::config::{DatabaseBackend, SqlxConfig};
4use crate::error::SqlxResult;
5use crate::pool::SqlxPool;
6use crate::row::SqlxRow;
7use crate::types::quote_identifier;
8use prax_query::QueryResult;
9use prax_query::filter::FilterValue;
10use prax_query::traits::{BoxFuture, Model, QueryEngine};
11use sqlx::Row;
12use std::sync::Arc;
13use tracing::debug;
14
15#[derive(Clone)]
32pub struct SqlxEngine {
33 pool: Arc<SqlxPool>,
34 backend: DatabaseBackend,
35}
36
37impl SqlxEngine {
38 pub async fn new(config: SqlxConfig) -> SqlxResult<Self> {
40 let backend = config.backend;
41 let pool = SqlxPool::connect(&config).await?;
42 Ok(Self {
43 pool: Arc::new(pool),
44 backend,
45 })
46 }
47
48 pub fn from_pool(pool: SqlxPool) -> Self {
50 let backend = pool.backend();
51 Self {
52 pool: Arc::new(pool),
53 backend,
54 }
55 }
56
57 pub fn backend(&self) -> DatabaseBackend {
59 self.backend
60 }
61
62 pub fn pool(&self) -> &SqlxPool {
64 &self.pool
65 }
66
67 pub async fn close(&self) {
69 self.pool.close().await;
70 }
71
72 pub async fn raw_query_many(
76 &self,
77 sql: &str,
78 params: &[FilterValue],
79 ) -> SqlxResult<Vec<SqlxRow>> {
80 debug!(sql = %sql, "Executing raw_query_many");
81
82 match &*self.pool {
83 #[cfg(feature = "postgres")]
84 SqlxPool::Postgres(pool) => {
85 let mut query = sqlx::query(sql);
86 for param in params {
87 query = bind_pg_param(query, param);
88 }
89 let rows = query.fetch_all(pool).await?;
90 Ok(rows.into_iter().map(SqlxRow::Postgres).collect())
91 }
92 #[cfg(feature = "mysql")]
93 SqlxPool::MySql(pool) => {
94 let mut query = sqlx::query(sql);
95 for param in params {
96 query = bind_mysql_param(query, param);
97 }
98 let rows = query.fetch_all(pool).await?;
99 Ok(rows.into_iter().map(SqlxRow::MySql).collect())
100 }
101 #[cfg(feature = "sqlite")]
102 SqlxPool::Sqlite(pool) => {
103 let mut query = sqlx::query(sql);
104 for param in params {
105 query = bind_sqlite_param(query, param);
106 }
107 let rows = query.fetch_all(pool).await?;
108 Ok(rows.into_iter().map(SqlxRow::Sqlite).collect())
109 }
110 }
111 }
112
113 pub async fn raw_query_one(&self, sql: &str, params: &[FilterValue]) -> SqlxResult<SqlxRow> {
115 debug!(sql = %sql, "Executing raw_query_one");
116
117 match &*self.pool {
118 #[cfg(feature = "postgres")]
119 SqlxPool::Postgres(pool) => {
120 let mut query = sqlx::query(sql);
121 for param in params {
122 query = bind_pg_param(query, param);
123 }
124 let row = query.fetch_one(pool).await?;
125 Ok(SqlxRow::Postgres(row))
126 }
127 #[cfg(feature = "mysql")]
128 SqlxPool::MySql(pool) => {
129 let mut query = sqlx::query(sql);
130 for param in params {
131 query = bind_mysql_param(query, param);
132 }
133 let row = query.fetch_one(pool).await?;
134 Ok(SqlxRow::MySql(row))
135 }
136 #[cfg(feature = "sqlite")]
137 SqlxPool::Sqlite(pool) => {
138 let mut query = sqlx::query(sql);
139 for param in params {
140 query = bind_sqlite_param(query, param);
141 }
142 let row = query.fetch_one(pool).await?;
143 Ok(SqlxRow::Sqlite(row))
144 }
145 }
146 }
147
148 pub async fn raw_query_optional(
150 &self,
151 sql: &str,
152 params: &[FilterValue],
153 ) -> SqlxResult<Option<SqlxRow>> {
154 debug!(sql = %sql, "Executing raw_query_optional");
155
156 match &*self.pool {
157 #[cfg(feature = "postgres")]
158 SqlxPool::Postgres(pool) => {
159 let mut query = sqlx::query(sql);
160 for param in params {
161 query = bind_pg_param(query, param);
162 }
163 let row = query.fetch_optional(pool).await?;
164 Ok(row.map(SqlxRow::Postgres))
165 }
166 #[cfg(feature = "mysql")]
167 SqlxPool::MySql(pool) => {
168 let mut query = sqlx::query(sql);
169 for param in params {
170 query = bind_mysql_param(query, param);
171 }
172 let row = query.fetch_optional(pool).await?;
173 Ok(row.map(SqlxRow::MySql))
174 }
175 #[cfg(feature = "sqlite")]
176 SqlxPool::Sqlite(pool) => {
177 let mut query = sqlx::query(sql);
178 for param in params {
179 query = bind_sqlite_param(query, param);
180 }
181 let row = query.fetch_optional(pool).await?;
182 Ok(row.map(SqlxRow::Sqlite))
183 }
184 }
185 }
186
187 pub async fn raw_execute(&self, sql: &str, params: &[FilterValue]) -> SqlxResult<u64> {
189 debug!(sql = %sql, "Executing raw_execute");
190
191 match &*self.pool {
192 #[cfg(feature = "postgres")]
193 SqlxPool::Postgres(pool) => {
194 let mut query = sqlx::query(sql);
195 for param in params {
196 query = bind_pg_param(query, param);
197 }
198 let result = query.execute(pool).await?;
199 Ok(result.rows_affected())
200 }
201 #[cfg(feature = "mysql")]
202 SqlxPool::MySql(pool) => {
203 let mut query = sqlx::query(sql);
204 for param in params {
205 query = bind_mysql_param(query, param);
206 }
207 let result = query.execute(pool).await?;
208 Ok(result.rows_affected())
209 }
210 #[cfg(feature = "sqlite")]
211 SqlxPool::Sqlite(pool) => {
212 let mut query = sqlx::query(sql);
213 for param in params {
214 query = bind_sqlite_param(query, param);
215 }
216 let result = query.execute(pool).await?;
217 Ok(result.rows_affected())
218 }
219 }
220 }
221
222 pub async fn count_table(&self, table: &str, filter: Option<&str>) -> SqlxResult<u64> {
224 let table = quote_identifier(self.backend, table);
225 let sql = match filter {
226 Some(f) => format!("SELECT COUNT(*) as count FROM {} WHERE {}", table, f),
227 None => format!("SELECT COUNT(*) as count FROM {}", table),
228 };
229
230 let row = self.raw_query_one(&sql, &[]).await?;
231 match row {
232 #[cfg(feature = "postgres")]
233 SqlxRow::Postgres(r) => Ok(r.try_get::<i64, _>("count")? as u64),
234 #[cfg(feature = "mysql")]
235 SqlxRow::MySql(r) => Ok(r.try_get::<i64, _>("count")? as u64),
236 #[cfg(feature = "sqlite")]
237 SqlxRow::Sqlite(r) => Ok(r.try_get::<i64, _>("count")? as u64),
238 }
239 }
240}
241
242#[cfg(feature = "postgres")]
245fn bind_pg_param<'q>(
246 query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
247 value: &'q FilterValue,
248) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
249 match value {
250 FilterValue::String(s) => query.bind(s.as_str()),
251 FilterValue::Int(i) => query.bind(*i),
252 FilterValue::Float(f) => query.bind(*f),
253 FilterValue::Bool(b) => query.bind(*b),
254 FilterValue::Null => query.bind(Option::<String>::None),
255 FilterValue::Json(j) => query.bind(j.clone()),
256 FilterValue::List(arr) => {
257 let json = serde_json::to_value(arr).unwrap_or(serde_json::Value::Null);
259 query.bind(json)
260 }
261 }
262}
263
264#[cfg(feature = "mysql")]
265fn bind_mysql_param<'q>(
266 query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
267 value: &'q FilterValue,
268) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
269 match value {
270 FilterValue::String(s) => query.bind(s.as_str()),
271 FilterValue::Int(i) => query.bind(*i),
272 FilterValue::Float(f) => query.bind(*f),
273 FilterValue::Bool(b) => query.bind(*b),
274 FilterValue::Null => query.bind(Option::<String>::None),
275 FilterValue::Json(j) => query.bind(j.to_string()),
276 FilterValue::List(arr) => {
277 let json = serde_json::to_string(arr).unwrap_or_default();
278 query.bind(json)
279 }
280 }
281}
282
283#[cfg(feature = "sqlite")]
284fn bind_sqlite_param<'q>(
285 query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
286 value: &'q FilterValue,
287) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
288 match value {
289 FilterValue::String(s) => query.bind(s.as_str()),
290 FilterValue::Int(i) => query.bind(*i),
291 FilterValue::Float(f) => query.bind(*f),
292 FilterValue::Bool(b) => query.bind(*b),
293 FilterValue::Null => query.bind(Option::<String>::None),
294 FilterValue::Json(j) => query.bind(j.to_string()),
295 FilterValue::List(arr) => {
296 let json = serde_json::to_string(arr).unwrap_or_default();
297 query.bind(json)
298 }
299 }
300}
301
302impl QueryEngine for SqlxEngine {
305 fn dialect(&self) -> &dyn prax_query::dialect::SqlDialect {
306 match self.backend {
307 DatabaseBackend::Postgres => &prax_query::dialect::Postgres,
308 DatabaseBackend::MySql => &prax_query::dialect::Mysql,
309 DatabaseBackend::Sqlite => &prax_query::dialect::Sqlite,
310 }
311 }
312
313 fn query_many<T: Model + prax_query::row::FromRow + Send + 'static>(
314 &self,
315 sql: &str,
316 params: Vec<FilterValue>,
317 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
318 let sql = sql.to_string();
319 Box::pin(async move {
320 debug!(sql = %sql, "Executing query_many via QueryEngine");
321
322 let rows = self
323 .raw_query_many(&sql, ¶ms)
324 .await
325 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
326
327 rows.iter()
328 .map(|r| {
329 let rr = crate::row_ref::SqlxRowRef::from_sqlx(r).map_err(|e| {
330 let msg = e.to_string();
331 prax_query::QueryError::deserialization(msg).with_source(e)
332 })?;
333 T::from_row(&rr).map_err(|e| {
334 let msg = e.to_string();
335 prax_query::QueryError::deserialization(msg).with_source(e)
336 })
337 })
338 .collect()
339 })
340 }
341
342 fn query_one<T: Model + prax_query::row::FromRow + Send + 'static>(
343 &self,
344 sql: &str,
345 params: Vec<FilterValue>,
346 ) -> BoxFuture<'_, QueryResult<T>> {
347 let sql = sql.to_string();
348 Box::pin(async move {
349 debug!(sql = %sql, "Executing query_one via QueryEngine");
350
351 let row = self.raw_query_one(&sql, ¶ms).await.map_err(|e| {
352 let msg = e.to_string();
353 if msg.contains("no rows") {
354 prax_query::QueryError::not_found(T::MODEL_NAME)
355 } else {
356 prax_query::QueryError::database(msg)
357 }
358 })?;
359
360 let rr = crate::row_ref::SqlxRowRef::from_sqlx(&row).map_err(|e| {
361 let msg = e.to_string();
362 prax_query::QueryError::deserialization(msg).with_source(e)
363 })?;
364 T::from_row(&rr).map_err(|e| {
365 let msg = e.to_string();
366 prax_query::QueryError::deserialization(msg).with_source(e)
367 })
368 })
369 }
370
371 fn query_optional<T: Model + prax_query::row::FromRow + Send + 'static>(
372 &self,
373 sql: &str,
374 params: Vec<FilterValue>,
375 ) -> BoxFuture<'_, QueryResult<Option<T>>> {
376 let sql = sql.to_string();
377 Box::pin(async move {
378 debug!(sql = %sql, "Executing query_optional via QueryEngine");
379
380 let row = self
381 .raw_query_optional(&sql, ¶ms)
382 .await
383 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
384
385 match row {
386 Some(r) => {
387 let rr = crate::row_ref::SqlxRowRef::from_sqlx(&r).map_err(|e| {
388 let msg = e.to_string();
389 prax_query::QueryError::deserialization(msg).with_source(e)
390 })?;
391 T::from_row(&rr).map(Some).map_err(|e| {
392 let msg = e.to_string();
393 prax_query::QueryError::deserialization(msg).with_source(e)
394 })
395 }
396 None => Ok(None),
397 }
398 })
399 }
400
401 fn execute_insert<T: Model + prax_query::row::FromRow + Send + 'static>(
402 &self,
403 sql: &str,
404 params: Vec<FilterValue>,
405 ) -> BoxFuture<'_, QueryResult<T>> {
406 let sql = sql.to_string();
407 Box::pin(async move {
408 debug!(sql = %sql, "Executing execute_insert via QueryEngine");
409
410 let row = self
411 .raw_query_one(&sql, ¶ms)
412 .await
413 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
414
415 let rr = crate::row_ref::SqlxRowRef::from_sqlx(&row).map_err(|e| {
416 let msg = e.to_string();
417 prax_query::QueryError::deserialization(msg).with_source(e)
418 })?;
419 T::from_row(&rr).map_err(|e| {
420 let msg = e.to_string();
421 prax_query::QueryError::deserialization(msg).with_source(e)
422 })
423 })
424 }
425
426 fn execute_update<T: Model + prax_query::row::FromRow + Send + 'static>(
427 &self,
428 sql: &str,
429 params: Vec<FilterValue>,
430 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
431 let sql = sql.to_string();
432 Box::pin(async move {
433 debug!(sql = %sql, "Executing execute_update via QueryEngine");
434
435 let rows = self
436 .raw_query_many(&sql, ¶ms)
437 .await
438 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
439
440 rows.iter()
441 .map(|r| {
442 let rr = crate::row_ref::SqlxRowRef::from_sqlx(r).map_err(|e| {
443 let msg = e.to_string();
444 prax_query::QueryError::deserialization(msg).with_source(e)
445 })?;
446 T::from_row(&rr).map_err(|e| {
447 let msg = e.to_string();
448 prax_query::QueryError::deserialization(msg).with_source(e)
449 })
450 })
451 .collect()
452 })
453 }
454
455 fn execute_delete(
456 &self,
457 sql: &str,
458 params: Vec<FilterValue>,
459 ) -> BoxFuture<'_, QueryResult<u64>> {
460 let sql = sql.to_string();
461 Box::pin(async move {
462 debug!(sql = %sql, "Executing execute_delete via QueryEngine");
463
464 let affected = self
465 .raw_execute(&sql, ¶ms)
466 .await
467 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
468
469 Ok(affected)
470 })
471 }
472
473 fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
474 let sql = sql.to_string();
475 Box::pin(async move {
476 debug!(sql = %sql, "Executing execute_raw via QueryEngine");
477
478 let affected = self
479 .raw_execute(&sql, ¶ms)
480 .await
481 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
482
483 Ok(affected)
484 })
485 }
486
487 fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
488 let sql = sql.to_string();
489 Box::pin(async move {
490 debug!(sql = %sql, "Executing count via QueryEngine");
491
492 let row = self
493 .raw_query_one(&sql, ¶ms)
494 .await
495 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
496
497 let count = match row {
498 #[cfg(feature = "postgres")]
499 SqlxRow::Postgres(r) => r
500 .try_get::<i64, _>(0)
501 .map_err(|e| prax_query::QueryError::database(e.to_string()))?
502 as u64,
503 #[cfg(feature = "mysql")]
504 SqlxRow::MySql(r) => r
505 .try_get::<i64, _>(0)
506 .map_err(|e| prax_query::QueryError::database(e.to_string()))?
507 as u64,
508 #[cfg(feature = "sqlite")]
509 SqlxRow::Sqlite(r) => r
510 .try_get::<i64, _>(0)
511 .map_err(|e| prax_query::QueryError::database(e.to_string()))?
512 as u64,
513 };
514
515 Ok(count)
516 })
517 }
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use crate::types::placeholder;
524
525 #[test]
526 fn test_placeholder_generation() {
527 assert_eq!(placeholder(DatabaseBackend::Postgres, 1), "$1");
528 assert_eq!(placeholder(DatabaseBackend::MySql, 1), "?");
529 assert_eq!(placeholder(DatabaseBackend::Sqlite, 1), "?");
530 }
531
532 #[test]
533 fn test_quote_identifier() {
534 assert_eq!(
535 quote_identifier(DatabaseBackend::Postgres, "users"),
536 "\"users\""
537 );
538 assert_eq!(quote_identifier(DatabaseBackend::MySql, "users"), "`users`");
539 }
540}