1pub use async_trait;
2pub use sqlx;
3
4pub mod prelude {
5 pub use crate::{Executor, IntoExecutor, Model, Premix, UpdateResult};
6}
7use sqlx::{Database, Executor as SqlxExecutor, IntoArguments};
8
9pub struct Premix;
10pub mod migrator;
11pub use migrator::{Migration, Migrator};
12
13pub trait SqlDialect: Database + Sized + Send + Sync
16where
17 Self::Connection: Send,
18{
19 fn placeholder(n: usize) -> String;
20 fn auto_increment_pk() -> &'static str;
21 fn rows_affected(res: &Self::QueryResult) -> u64;
22 fn last_insert_id(res: &Self::QueryResult) -> i64;
23
24 fn current_timestamp_fn() -> &'static str {
25 "CURRENT_TIMESTAMP"
26 }
27 fn int_type() -> &'static str {
28 "INTEGER"
29 }
30 fn text_type() -> &'static str {
31 "TEXT"
32 }
33 fn bool_type() -> &'static str {
34 "BOOLEAN"
35 }
36 fn float_type() -> &'static str {
37 "REAL"
38 }
39}
40
41#[cfg(feature = "sqlite")]
42impl SqlDialect for sqlx::Sqlite {
43 fn placeholder(_n: usize) -> String {
44 "?".to_string()
45 }
46 fn auto_increment_pk() -> &'static str {
47 "INTEGER PRIMARY KEY"
48 }
49 fn rows_affected(res: &sqlx::sqlite::SqliteQueryResult) -> u64 {
50 res.rows_affected()
51 }
52 fn last_insert_id(res: &sqlx::sqlite::SqliteQueryResult) -> i64 {
53 res.last_insert_rowid()
54 }
55}
56
57#[cfg(feature = "postgres")]
58impl SqlDialect for sqlx::Postgres {
59 fn placeholder(n: usize) -> String {
60 format!("${}", n)
61 }
62 fn auto_increment_pk() -> &'static str {
63 "SERIAL PRIMARY KEY"
64 }
65 fn rows_affected(res: &sqlx::postgres::PgQueryResult) -> u64 {
66 res.rows_affected()
67 }
68 fn last_insert_id(_res: &sqlx::postgres::PgQueryResult) -> i64 {
69 0
70 }
71}
72
73#[cfg(feature = "mysql")]
74impl SqlDialect for sqlx::MySql {
75 fn placeholder(_n: usize) -> String {
76 "?".to_string()
77 }
78 fn auto_increment_pk() -> &'static str {
79 "INTEGER AUTO_INCREMENT PRIMARY KEY"
80 }
81 fn rows_affected(res: &sqlx::mysql::MySqlQueryResult) -> u64 {
82 res.rows_affected()
83 }
84 fn last_insert_id(res: &sqlx::mysql::MySqlQueryResult) -> i64 {
85 res.last_insert_id() as i64
86 }
87}
88
89pub enum Executor<'a, DB: Database> {
91 Pool(&'a sqlx::Pool<DB>),
92 Conn(&'a mut DB::Connection),
93}
94
95unsafe impl<'a, DB: Database> Send for Executor<'a, DB> where DB::Connection: Send {}
96unsafe impl<'a, DB: Database> Sync for Executor<'a, DB> where DB::Connection: Sync {}
97
98impl<'a, DB: Database> From<&'a sqlx::Pool<DB>> for Executor<'a, DB> {
99 fn from(pool: &'a sqlx::Pool<DB>) -> Self {
100 Self::Pool(pool)
101 }
102}
103
104impl<'a, DB: Database> From<&'a mut DB::Connection> for Executor<'a, DB> {
105 fn from(conn: &'a mut DB::Connection) -> Self {
106 Self::Conn(conn)
107 }
108}
109
110pub trait IntoExecutor<'a>: Send + 'a {
111 type DB: SqlDialect;
112 fn into_executor(self) -> Executor<'a, Self::DB>;
113}
114
115impl<'a, DB: SqlDialect> IntoExecutor<'a> for &'a sqlx::Pool<DB> {
116 type DB = DB;
117 fn into_executor(self) -> Executor<'a, DB> {
118 Executor::Pool(self)
119 }
120}
121
122#[cfg(feature = "sqlite")]
123impl<'a> IntoExecutor<'a> for &'a mut sqlx::SqliteConnection {
124 type DB = sqlx::Sqlite;
125 fn into_executor(self) -> Executor<'a, Self::DB> {
126 Executor::Conn(self)
127 }
128}
129
130#[cfg(feature = "postgres")]
131impl<'a> IntoExecutor<'a> for &'a mut sqlx::postgres::PgConnection {
132 type DB = sqlx::Postgres;
133 fn into_executor(self) -> Executor<'a, Self::DB> {
134 Executor::Conn(self)
135 }
136}
137
138impl<'a, DB: SqlDialect> IntoExecutor<'a> for Executor<'a, DB> {
139 type DB = DB;
140 fn into_executor(self) -> Executor<'a, DB> {
141 self
142 }
143}
144
145impl<'a, DB: Database> Executor<'a, DB> {
146 pub async fn execute<'q, A>(
147 &mut self,
148 query: sqlx::query::Query<'q, DB, A>,
149 ) -> Result<DB::QueryResult, sqlx::Error>
150 where
151 A: sqlx::IntoArguments<'q, DB> + 'q,
152 DB: SqlDialect,
153 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
154 {
155 match self {
156 Self::Pool(pool) => query.execute(*pool).await,
157 Self::Conn(conn) => query.execute(&mut **conn).await,
158 }
159 }
160
161 pub async fn fetch_all<'q, T, A>(
162 &mut self,
163 query: sqlx::query::QueryAs<'q, DB, T, A>,
164 ) -> Result<Vec<T>, sqlx::Error>
165 where
166 T: for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
167 A: sqlx::IntoArguments<'q, DB> + 'q,
168 DB: SqlDialect,
169 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
170 {
171 match self {
172 Self::Pool(pool) => query.fetch_all(*pool).await,
173 Self::Conn(conn) => query.fetch_all(&mut **conn).await,
174 }
175 }
176
177 pub async fn fetch_optional<'q, T, A>(
178 &mut self,
179 query: sqlx::query::QueryAs<'q, DB, T, A>,
180 ) -> Result<Option<T>, sqlx::Error>
181 where
182 T: for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
183 A: sqlx::IntoArguments<'q, DB> + 'q,
184 DB: SqlDialect,
185 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
186 {
187 match self {
188 Self::Pool(pool) => query.fetch_optional(*pool).await,
189 Self::Conn(conn) => query.fetch_optional(&mut **conn).await,
190 }
191 }
192}
193
194#[async_trait::async_trait]
196pub trait ModelHooks {
197 async fn before_save(&mut self) -> Result<(), sqlx::Error> {
198 Ok(())
199 }
200 async fn after_save(&mut self) -> Result<(), sqlx::Error> {
201 Ok(())
202 }
203}
204
205#[async_trait::async_trait]
206impl<T: Send + Sync> ModelHooks for T {}
207
208#[derive(Debug, PartialEq)]
210pub enum UpdateResult {
211 Success,
212 VersionConflict,
213 NotFound,
214 NotImplemented,
215}
216
217#[derive(Debug, Clone)]
219pub struct ValidationError {
220 pub field: String,
221 pub message: String,
222}
223
224pub trait ModelValidation {
225 fn validate(&self) -> Result<(), Vec<ValidationError>> {
226 Ok(())
227 }
228}
229
230impl<T> ModelValidation for T {}
231
232#[async_trait::async_trait]
233pub trait Model<DB: Database>: Sized + Send + Sync + Unpin
234where
235 DB: SqlDialect,
236 for<'r> Self: sqlx::FromRow<'r, DB::Row>,
237{
238 fn table_name() -> &'static str;
239 fn create_table_sql() -> String;
240 fn list_columns() -> Vec<String>;
241
242 async fn save<'a, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
244 where
245 E: IntoExecutor<'a, DB = DB>;
246
247 async fn update<'a, E>(&mut self, executor: E) -> Result<UpdateResult, sqlx::Error>
248 where
249 E: IntoExecutor<'a, DB = DB>;
250
251 async fn delete<'a, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
253 where
254 E: IntoExecutor<'a, DB = DB>;
255 fn has_soft_delete() -> bool;
256
257 async fn find_by_id<'a, E>(executor: E, id: i32) -> Result<Option<Self>, sqlx::Error>
259 where
260 E: IntoExecutor<'a, DB = DB>;
261
262 fn raw_sql<'q>(
264 sql: &'q str,
265 ) -> sqlx::query::QueryAs<'q, DB, Self, <DB as Database>::Arguments<'q>> {
266 sqlx::query_as::<DB, Self>(sql)
267 }
268
269 async fn eager_load<'a, E>(
270 _models: &mut [Self],
271 _relation: &str,
272 _executor: E,
273 ) -> Result<(), sqlx::Error>
274 where
275 E: IntoExecutor<'a, DB = DB>,
276 {
277 Ok(())
278 }
279 fn find<'a, E>(executor: E) -> QueryBuilder<'a, Self, DB>
280 where
281 E: IntoExecutor<'a, DB = DB>,
282 {
283 QueryBuilder::new(executor.into_executor())
284 }
285
286 fn find_in_pool(pool: &sqlx::Pool<DB>) -> QueryBuilder<'_, Self, DB> {
288 QueryBuilder::new(Executor::Pool(pool))
289 }
290
291 fn find_in_tx(conn: &mut DB::Connection) -> QueryBuilder<'_, Self, DB> {
292 QueryBuilder::new(Executor::Conn(conn))
293 }
294}
295
296pub struct QueryBuilder<'a, T, DB: Database> {
297 executor: Executor<'a, DB>,
298 filters: Vec<String>,
299 limit: Option<i32>,
300 offset: Option<i32>,
301 includes: Vec<String>,
302 include_deleted: bool, _marker: std::marker::PhantomData<T>,
304}
305
306impl<'a, T, DB> QueryBuilder<'a, T, DB>
307where
308 DB: SqlDialect,
309 T: Model<DB>,
310{
311 pub fn new(executor: Executor<'a, DB>) -> Self {
312 Self {
313 executor,
314 filters: Vec::new(),
315 limit: None,
316 offset: None,
317 includes: Vec::new(),
318 include_deleted: false,
319 _marker: std::marker::PhantomData,
320 }
321 }
322
323 pub fn filter(mut self, condition: impl Into<String>) -> Self {
324 self.filters.push(condition.into());
325 self
326 }
327
328 pub fn limit(mut self, limit: i32) -> Self {
329 self.limit = Some(limit);
330 self
331 }
332
333 pub fn offset(mut self, offset: i32) -> Self {
334 self.offset = Some(offset);
335 self
336 }
337
338 pub fn include(mut self, relation: impl Into<String>) -> Self {
339 self.includes.push(relation.into());
340 self
341 }
342
343 pub fn with_deleted(mut self) -> Self {
345 self.include_deleted = true;
346 self
347 }
348
349 pub fn to_sql(&self) -> String {
351 let mut sql = format!(
352 "SELECT * FROM {}{}",
353 T::table_name(),
354 self.build_where_clause()
355 );
356
357 if let Some(limit) = self.limit {
358 sql.push_str(&format!(" LIMIT {}", limit));
359 }
360
361 if let Some(offset) = self.offset {
362 sql.push_str(&format!(" OFFSET {}", offset));
363 }
364
365 sql
366 }
367
368 pub fn to_update_sql(&self, values: &serde_json::Value) -> Result<String, sqlx::Error> {
370 let obj = values.as_object().ok_or_else(|| {
371 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
372 })?;
373
374 let mut i = 1;
375 let set_clause = obj
376 .keys()
377 .map(|k| {
378 let p = DB::placeholder(i);
379 i += 1;
380 format!("{} = {}", k, p)
381 })
382 .collect::<Vec<_>>()
383 .join(", ");
384
385 Ok(format!(
386 "UPDATE {} SET {}{}",
387 T::table_name(),
388 set_clause,
389 self.build_where_clause()
390 ))
391 }
392
393 pub fn to_delete_sql(&self) -> String {
395 if T::has_soft_delete() {
396 format!(
397 "UPDATE {} SET deleted_at = {}{}",
398 T::table_name(),
399 DB::current_timestamp_fn(),
400 self.build_where_clause()
401 )
402 } else {
403 format!(
404 "DELETE FROM {}{}",
405 T::table_name(),
406 self.build_where_clause()
407 )
408 }
409 }
410
411 fn build_where_clause(&self) -> String {
412 let mut filters = self.filters.clone();
413
414 if T::has_soft_delete() && !self.include_deleted {
416 filters.push("deleted_at IS NULL".to_string());
417 }
418
419 if filters.is_empty() {
420 "".to_string()
421 } else {
422 format!(" WHERE {}", filters.join(" AND "))
423 }
424 }
425}
426
427impl<'a, T, DB> QueryBuilder<'a, T, DB>
428where
429 DB: SqlDialect,
430 T: Model<DB>,
431 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
432 for<'c> &'c mut <DB as Database>::Connection: SqlxExecutor<'c, Database = DB>,
433 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
434 DB::Connection: Send,
435 T: Send,
436{
437 pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
438 let mut sql = format!(
439 "SELECT * FROM {}{}",
440 T::table_name(),
441 self.build_where_clause()
442 );
443
444 if let Some(limit) = self.limit {
445 sql.push_str(&format!(" LIMIT {}", limit));
446 }
447
448 if let Some(offset) = self.offset {
449 sql.push_str(&format!(" OFFSET {}", offset));
450 }
451
452 let mut results: Vec<T> = match &mut self.executor {
453 Executor::Pool(pool) => sqlx::query_as::<DB, T>(&sql).fetch_all(*pool).await?,
454 Executor::Conn(conn) => sqlx::query_as::<DB, T>(&sql).fetch_all(&mut **conn).await?,
455 };
456
457 for relation in self.includes {
458 match &mut self.executor {
459 Executor::Pool(pool) => {
460 T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
461 }
462 Executor::Conn(conn) => {
463 T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
464 }
465 }
466 }
467
468 Ok(results)
469 }
470
471 pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
473 where
474 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
475 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
476 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
477 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
478 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
479 {
480 let obj = values.as_object().ok_or_else(|| {
481 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
482 })?;
483
484 let mut i = 1;
485 let set_clause = obj
486 .keys()
487 .map(|k| {
488 let p = DB::placeholder(i);
489 i += 1;
490 format!("{} = {}", k, p)
491 })
492 .collect::<Vec<_>>()
493 .join(", ");
494
495 let sql = format!(
496 "UPDATE {} SET {}{}",
497 T::table_name(),
498 set_clause,
499 self.build_where_clause()
500 );
501
502 let mut query = sqlx::query::<DB>(&sql);
503 for val in obj.values() {
504 match val {
505 serde_json::Value::String(s) => query = query.bind(s.clone()),
506 serde_json::Value::Number(n) => {
507 if let Some(v) = n.as_i64() {
508 query = query.bind(v);
509 } else if let Some(v) = n.as_f64() {
510 query = query.bind(v);
511 }
512 }
513 serde_json::Value::Bool(b) => query = query.bind(*b),
514 serde_json::Value::Null => query = query.bind(Option::<String>::None),
515 _ => {
516 return Err(sqlx::Error::Protocol(
517 "Unsupported type in bulk update".to_string(),
518 ));
519 }
520 }
521 }
522
523 match &mut self.executor {
524 Executor::Pool(pool) => {
525 let res = query.execute(*pool).await?;
526 Ok(DB::rows_affected(&res))
527 }
528 Executor::Conn(conn) => {
529 let res = query.execute(&mut **conn).await?;
530 Ok(DB::rows_affected(&res))
531 }
532 }
533 }
534
535 pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
536 let sql = if T::has_soft_delete() {
537 format!(
538 "UPDATE {} SET deleted_at = {}{}",
539 T::table_name(),
540 DB::current_timestamp_fn(),
541 self.build_where_clause()
542 )
543 } else {
544 format!(
545 "DELETE FROM {}{}",
546 T::table_name(),
547 self.build_where_clause()
548 )
549 };
550
551 match &mut self.executor {
552 Executor::Pool(pool) => {
553 let res = sqlx::query::<DB>(&sql).execute(*pool).await?;
554 Ok(DB::rows_affected(&res))
555 }
556 Executor::Conn(conn) => {
557 let res = sqlx::query::<DB>(&sql).execute(&mut **conn).await?;
558 Ok(DB::rows_affected(&res))
559 }
560 }
561 }
562}
563
564impl Premix {
565 pub async fn sync<DB, M>(pool: &sqlx::Pool<DB>) -> Result<(), sqlx::Error>
566 where
567 DB: SqlDialect,
568 M: Model<DB>,
569 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
570 for<'c> &'c mut <DB as Database>::Connection: SqlxExecutor<'c, Database = DB>,
571 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
572 {
573 let sql = M::create_table_sql();
574 sqlx::query::<DB>(&sql).execute(pool).await?;
575 Ok(())
576 }
577}