1use crate::config::{DatabaseBackend, SqlxConfig};
4use crate::error::SqlxResult;
5use crate::pool::SqlxPool;
6use crate::row::SqlxRow;
7use crate::types::quote_identifier;
8use prax_query::QueryResult;
9use prax_query::filter::FilterValue;
10use prax_query::traits::{BoxFuture, Model, QueryEngine};
11use sqlx::Row;
12use std::sync::Arc;
13use tracing::debug;
14
15#[derive(Clone)]
32pub struct SqlxEngine {
33 pool: Arc<SqlxPool>,
34 backend: DatabaseBackend,
35}
36
37impl SqlxEngine {
38 pub async fn new(config: SqlxConfig) -> SqlxResult<Self> {
40 let backend = config.backend;
41 let pool = SqlxPool::connect(&config).await?;
42 Ok(Self {
43 pool: Arc::new(pool),
44 backend,
45 })
46 }
47
48 pub fn from_pool(pool: SqlxPool) -> Self {
50 let backend = pool.backend();
51 Self {
52 pool: Arc::new(pool),
53 backend,
54 }
55 }
56
57 pub fn backend(&self) -> DatabaseBackend {
59 self.backend
60 }
61
62 pub fn pool(&self) -> &SqlxPool {
64 &self.pool
65 }
66
67 pub async fn close(&self) {
69 self.pool.close().await;
70 }
71
72 pub async fn raw_query_many(
76 &self,
77 sql: &str,
78 params: &[FilterValue],
79 ) -> SqlxResult<Vec<SqlxRow>> {
80 debug!(sql = %sql, "Executing raw_query_many");
81
82 match &*self.pool {
83 #[cfg(feature = "postgres")]
84 SqlxPool::Postgres(pool) => {
85 let mut query = sqlx::query(sql);
86 for param in params {
87 query = bind_pg_param(query, param);
88 }
89 let rows = query.fetch_all(pool).await?;
90 Ok(rows.into_iter().map(SqlxRow::Postgres).collect())
91 }
92 #[cfg(feature = "mysql")]
93 SqlxPool::MySql(pool) => {
94 let mut query = sqlx::query(sql);
95 for param in params {
96 query = bind_mysql_param(query, param);
97 }
98 let rows = query.fetch_all(pool).await?;
99 Ok(rows.into_iter().map(SqlxRow::MySql).collect())
100 }
101 #[cfg(feature = "sqlite")]
102 SqlxPool::Sqlite(pool) => {
103 let mut query = sqlx::query(sql);
104 for param in params {
105 query = bind_sqlite_param(query, param);
106 }
107 let rows = query.fetch_all(pool).await?;
108 Ok(rows.into_iter().map(SqlxRow::Sqlite).collect())
109 }
110 }
111 }
112
113 pub async fn raw_query_one(&self, sql: &str, params: &[FilterValue]) -> SqlxResult<SqlxRow> {
115 debug!(sql = %sql, "Executing raw_query_one");
116
117 match &*self.pool {
118 #[cfg(feature = "postgres")]
119 SqlxPool::Postgres(pool) => {
120 let mut query = sqlx::query(sql);
121 for param in params {
122 query = bind_pg_param(query, param);
123 }
124 let row = query.fetch_one(pool).await?;
125 Ok(SqlxRow::Postgres(row))
126 }
127 #[cfg(feature = "mysql")]
128 SqlxPool::MySql(pool) => {
129 let mut query = sqlx::query(sql);
130 for param in params {
131 query = bind_mysql_param(query, param);
132 }
133 let row = query.fetch_one(pool).await?;
134 Ok(SqlxRow::MySql(row))
135 }
136 #[cfg(feature = "sqlite")]
137 SqlxPool::Sqlite(pool) => {
138 let mut query = sqlx::query(sql);
139 for param in params {
140 query = bind_sqlite_param(query, param);
141 }
142 let row = query.fetch_one(pool).await?;
143 Ok(SqlxRow::Sqlite(row))
144 }
145 }
146 }
147
148 pub async fn raw_query_optional(
150 &self,
151 sql: &str,
152 params: &[FilterValue],
153 ) -> SqlxResult<Option<SqlxRow>> {
154 debug!(sql = %sql, "Executing raw_query_optional");
155
156 match &*self.pool {
157 #[cfg(feature = "postgres")]
158 SqlxPool::Postgres(pool) => {
159 let mut query = sqlx::query(sql);
160 for param in params {
161 query = bind_pg_param(query, param);
162 }
163 let row = query.fetch_optional(pool).await?;
164 Ok(row.map(SqlxRow::Postgres))
165 }
166 #[cfg(feature = "mysql")]
167 SqlxPool::MySql(pool) => {
168 let mut query = sqlx::query(sql);
169 for param in params {
170 query = bind_mysql_param(query, param);
171 }
172 let row = query.fetch_optional(pool).await?;
173 Ok(row.map(SqlxRow::MySql))
174 }
175 #[cfg(feature = "sqlite")]
176 SqlxPool::Sqlite(pool) => {
177 let mut query = sqlx::query(sql);
178 for param in params {
179 query = bind_sqlite_param(query, param);
180 }
181 let row = query.fetch_optional(pool).await?;
182 Ok(row.map(SqlxRow::Sqlite))
183 }
184 }
185 }
186
187 pub async fn raw_execute(&self, sql: &str, params: &[FilterValue]) -> SqlxResult<u64> {
189 debug!(sql = %sql, "Executing raw_execute");
190
191 match &*self.pool {
192 #[cfg(feature = "postgres")]
193 SqlxPool::Postgres(pool) => {
194 let mut query = sqlx::query(sql);
195 for param in params {
196 query = bind_pg_param(query, param);
197 }
198 let result = query.execute(pool).await?;
199 Ok(result.rows_affected())
200 }
201 #[cfg(feature = "mysql")]
202 SqlxPool::MySql(pool) => {
203 let mut query = sqlx::query(sql);
204 for param in params {
205 query = bind_mysql_param(query, param);
206 }
207 let result = query.execute(pool).await?;
208 Ok(result.rows_affected())
209 }
210 #[cfg(feature = "sqlite")]
211 SqlxPool::Sqlite(pool) => {
212 let mut query = sqlx::query(sql);
213 for param in params {
214 query = bind_sqlite_param(query, param);
215 }
216 let result = query.execute(pool).await?;
217 Ok(result.rows_affected())
218 }
219 }
220 }
221
222 pub async fn count_table(&self, table: &str, filter: Option<&str>) -> SqlxResult<u64> {
224 let table = quote_identifier(self.backend, table);
225 let sql = match filter {
226 Some(f) => format!("SELECT COUNT(*) as count FROM {} WHERE {}", table, f),
227 None => format!("SELECT COUNT(*) as count FROM {}", table),
228 };
229
230 let row = self.raw_query_one(&sql, &[]).await?;
231 match row {
232 #[cfg(feature = "postgres")]
233 SqlxRow::Postgres(r) => Ok(r.try_get::<i64, _>("count")? as u64),
234 #[cfg(feature = "mysql")]
235 SqlxRow::MySql(r) => Ok(r.try_get::<i64, _>("count")? as u64),
236 #[cfg(feature = "sqlite")]
237 SqlxRow::Sqlite(r) => Ok(r.try_get::<i64, _>("count")? as u64),
238 }
239 }
240}
241
242#[cfg(feature = "postgres")]
245fn bind_pg_param<'q>(
246 query: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
247 value: &'q FilterValue,
248) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
249 match value {
250 FilterValue::String(s) => query.bind(s.as_str()),
251 FilterValue::Int(i) => query.bind(*i),
252 FilterValue::Float(f) => query.bind(*f),
253 FilterValue::Bool(b) => query.bind(*b),
254 FilterValue::Null => query.bind(Option::<String>::None),
255 FilterValue::Json(j) => query.bind(j.clone()),
256 FilterValue::List(arr) => {
257 let json = serde_json::to_value(arr).unwrap_or(serde_json::Value::Null);
259 query.bind(json)
260 }
261 }
262}
263
264#[cfg(feature = "mysql")]
265fn bind_mysql_param<'q>(
266 query: sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments>,
267 value: &'q FilterValue,
268) -> sqlx::query::Query<'q, sqlx::MySql, sqlx::mysql::MySqlArguments> {
269 match value {
270 FilterValue::String(s) => query.bind(s.as_str()),
271 FilterValue::Int(i) => query.bind(*i),
272 FilterValue::Float(f) => query.bind(*f),
273 FilterValue::Bool(b) => query.bind(*b),
274 FilterValue::Null => query.bind(Option::<String>::None),
275 FilterValue::Json(j) => query.bind(j.to_string()),
276 FilterValue::List(arr) => {
277 let json = serde_json::to_string(arr).unwrap_or_default();
278 query.bind(json)
279 }
280 }
281}
282
283#[cfg(feature = "sqlite")]
284fn bind_sqlite_param<'q>(
285 query: sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>>,
286 value: &'q FilterValue,
287) -> sqlx::query::Query<'q, sqlx::Sqlite, sqlx::sqlite::SqliteArguments<'q>> {
288 match value {
289 FilterValue::String(s) => query.bind(s.as_str()),
290 FilterValue::Int(i) => query.bind(*i),
291 FilterValue::Float(f) => query.bind(*f),
292 FilterValue::Bool(b) => query.bind(*b),
293 FilterValue::Null => query.bind(Option::<String>::None),
294 FilterValue::Json(j) => query.bind(j.to_string()),
295 FilterValue::List(arr) => {
296 let json = serde_json::to_string(arr).unwrap_or_default();
297 query.bind(json)
298 }
299 }
300}
301
302impl QueryEngine for SqlxEngine {
305 fn query_many<T: Model + Send + 'static>(
306 &self,
307 sql: &str,
308 params: Vec<FilterValue>,
309 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
310 let sql = sql.to_string();
311 Box::pin(async move {
312 debug!(sql = %sql, "Executing query_many via QueryEngine");
313
314 let _rows = self
315 .raw_query_many(&sql, ¶ms)
316 .await
317 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
318
319 Ok(Vec::new())
322 })
323 }
324
325 fn query_one<T: Model + Send + 'static>(
326 &self,
327 sql: &str,
328 params: Vec<FilterValue>,
329 ) -> BoxFuture<'_, QueryResult<T>> {
330 let sql = sql.to_string();
331 Box::pin(async move {
332 debug!(sql = %sql, "Executing query_one via QueryEngine");
333
334 let _row = self.raw_query_one(&sql, ¶ms).await.map_err(|e| {
335 let msg = e.to_string();
336 if msg.contains("no rows") {
337 prax_query::QueryError::not_found(T::MODEL_NAME)
338 } else {
339 prax_query::QueryError::database(msg)
340 }
341 })?;
342
343 Err(prax_query::QueryError::internal(
345 "deserialization not yet implemented".to_string(),
346 ))
347 })
348 }
349
350 fn query_optional<T: Model + Send + 'static>(
351 &self,
352 sql: &str,
353 params: Vec<FilterValue>,
354 ) -> BoxFuture<'_, QueryResult<Option<T>>> {
355 let sql = sql.to_string();
356 Box::pin(async move {
357 debug!(sql = %sql, "Executing query_optional via QueryEngine");
358
359 let _row = self
360 .raw_query_optional(&sql, ¶ms)
361 .await
362 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
363
364 Ok(None)
366 })
367 }
368
369 fn execute_insert<T: Model + Send + 'static>(
370 &self,
371 sql: &str,
372 params: Vec<FilterValue>,
373 ) -> BoxFuture<'_, QueryResult<T>> {
374 let sql = sql.to_string();
375 Box::pin(async move {
376 debug!(sql = %sql, "Executing execute_insert via QueryEngine");
377
378 let _row = self
379 .raw_query_one(&sql, ¶ms)
380 .await
381 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
382
383 Err(prax_query::QueryError::internal(
385 "deserialization not yet implemented".to_string(),
386 ))
387 })
388 }
389
390 fn execute_update<T: Model + Send + 'static>(
391 &self,
392 sql: &str,
393 params: Vec<FilterValue>,
394 ) -> BoxFuture<'_, QueryResult<Vec<T>>> {
395 let sql = sql.to_string();
396 Box::pin(async move {
397 debug!(sql = %sql, "Executing execute_update via QueryEngine");
398
399 let _rows = self
400 .raw_query_many(&sql, ¶ms)
401 .await
402 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
403
404 Ok(Vec::new())
406 })
407 }
408
409 fn execute_delete(
410 &self,
411 sql: &str,
412 params: Vec<FilterValue>,
413 ) -> BoxFuture<'_, QueryResult<u64>> {
414 let sql = sql.to_string();
415 Box::pin(async move {
416 debug!(sql = %sql, "Executing execute_delete via QueryEngine");
417
418 let affected = self
419 .raw_execute(&sql, ¶ms)
420 .await
421 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
422
423 Ok(affected)
424 })
425 }
426
427 fn execute_raw(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
428 let sql = sql.to_string();
429 Box::pin(async move {
430 debug!(sql = %sql, "Executing execute_raw via QueryEngine");
431
432 let affected = self
433 .raw_execute(&sql, ¶ms)
434 .await
435 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
436
437 Ok(affected)
438 })
439 }
440
441 fn count(&self, sql: &str, params: Vec<FilterValue>) -> BoxFuture<'_, QueryResult<u64>> {
442 let sql = sql.to_string();
443 Box::pin(async move {
444 debug!(sql = %sql, "Executing count via QueryEngine");
445
446 let row = self
447 .raw_query_one(&sql, ¶ms)
448 .await
449 .map_err(|e| prax_query::QueryError::database(e.to_string()))?;
450
451 let count = match row {
452 #[cfg(feature = "postgres")]
453 SqlxRow::Postgres(r) => r
454 .try_get::<i64, _>(0)
455 .map_err(|e| prax_query::QueryError::database(e.to_string()))?
456 as u64,
457 #[cfg(feature = "mysql")]
458 SqlxRow::MySql(r) => r
459 .try_get::<i64, _>(0)
460 .map_err(|e| prax_query::QueryError::database(e.to_string()))?
461 as u64,
462 #[cfg(feature = "sqlite")]
463 SqlxRow::Sqlite(r) => r
464 .try_get::<i64, _>(0)
465 .map_err(|e| prax_query::QueryError::database(e.to_string()))?
466 as u64,
467 };
468
469 Ok(count)
470 })
471 }
472}
473
474#[cfg(test)]
475mod tests {
476 use super::*;
477 use crate::types::placeholder;
478
479 #[test]
480 fn test_placeholder_generation() {
481 assert_eq!(placeholder(DatabaseBackend::Postgres, 1), "$1");
482 assert_eq!(placeholder(DatabaseBackend::MySql, 1), "?");
483 assert_eq!(placeholder(DatabaseBackend::Sqlite, 1), "?");
484 }
485
486 #[test]
487 fn test_quote_identifier() {
488 assert_eq!(
489 quote_identifier(DatabaseBackend::Postgres, "users"),
490 "\"users\""
491 );
492 assert_eq!(quote_identifier(DatabaseBackend::MySql, "users"), "`users`");
493 }
494}