1pub use async_trait;
2pub use sqlx;
3
4pub mod prelude {
5 pub use crate::{Executor, IntoExecutor, Model, Premix, UpdateResult};
6}
7use sqlx::{Database, Executor as SqlxExecutor, IntoArguments};
8
9pub struct Premix;
10pub mod migrator;
11pub use migrator::{Migration, Migrator};
12
13pub trait SqlDialect: Database + Sized + Send + Sync
16where
17 Self::Connection: Send,
18{
19 fn placeholder(n: usize) -> String;
20 fn auto_increment_pk() -> &'static str;
21 fn rows_affected(res: &Self::QueryResult) -> u64;
22 fn last_insert_id(res: &Self::QueryResult) -> i64;
23
24 fn current_timestamp_fn() -> &'static str {
25 "CURRENT_TIMESTAMP"
26 }
27 fn int_type() -> &'static str {
28 "INTEGER"
29 }
30 fn text_type() -> &'static str {
31 "TEXT"
32 }
33 fn bool_type() -> &'static str {
34 "BOOLEAN"
35 }
36 fn float_type() -> &'static str {
37 "REAL"
38 }
39}
40
41#[cfg(feature = "sqlite")]
42impl SqlDialect for sqlx::Sqlite {
43 fn placeholder(_n: usize) -> String {
44 "?".to_string()
45 }
46 fn auto_increment_pk() -> &'static str {
47 "INTEGER PRIMARY KEY"
48 }
49 fn rows_affected(res: &sqlx::sqlite::SqliteQueryResult) -> u64 {
50 res.rows_affected()
51 }
52 fn last_insert_id(res: &sqlx::sqlite::SqliteQueryResult) -> i64 {
53 res.last_insert_rowid()
54 }
55}
56
57#[cfg(feature = "postgres")]
58impl SqlDialect for sqlx::Postgres {
59 fn placeholder(n: usize) -> String {
60 format!("${}", n)
61 }
62 fn auto_increment_pk() -> &'static str {
63 "SERIAL PRIMARY KEY"
64 }
65 fn rows_affected(res: &sqlx::postgres::PgQueryResult) -> u64 {
66 res.rows_affected()
67 }
68 fn last_insert_id(_res: &sqlx::postgres::PgQueryResult) -> i64 {
69 0
70 }
71}
72
73#[cfg(feature = "mysql")]
74impl SqlDialect for sqlx::MySql {
75 fn placeholder(_n: usize) -> String {
76 "?".to_string()
77 }
78 fn auto_increment_pk() -> &'static str {
79 "INTEGER AUTO_INCREMENT PRIMARY KEY"
80 }
81 fn rows_affected(res: &sqlx::mysql::MySqlQueryResult) -> u64 {
82 res.rows_affected()
83 }
84 fn last_insert_id(res: &sqlx::mysql::MySqlQueryResult) -> i64 {
85 res.last_insert_id() as i64
86 }
87}
88
89pub enum Executor<'a, DB: Database> {
91 Pool(&'a sqlx::Pool<DB>),
92 Conn(&'a mut DB::Connection),
93}
94
95unsafe impl<'a, DB: Database> Send for Executor<'a, DB> where DB::Connection: Send {}
96unsafe impl<'a, DB: Database> Sync for Executor<'a, DB> where DB::Connection: Sync {}
97
98impl<'a, DB: Database> From<&'a sqlx::Pool<DB>> for Executor<'a, DB> {
99 fn from(pool: &'a sqlx::Pool<DB>) -> Self {
100 Self::Pool(pool)
101 }
102}
103
104impl<'a, DB: Database> From<&'a mut DB::Connection> for Executor<'a, DB> {
105 fn from(conn: &'a mut DB::Connection) -> Self {
106 Self::Conn(conn)
107 }
108}
109
110pub trait IntoExecutor<'a>: Send + 'a {
111 type DB: SqlDialect;
112 fn into_executor(self) -> Executor<'a, Self::DB>;
113}
114
115impl<'a, DB: SqlDialect> IntoExecutor<'a> for &'a sqlx::Pool<DB> {
116 type DB = DB;
117 fn into_executor(self) -> Executor<'a, DB> {
118 Executor::Pool(self)
119 }
120}
121
122#[cfg(feature = "sqlite")]
123impl<'a> IntoExecutor<'a> for &'a mut sqlx::SqliteConnection {
124 type DB = sqlx::Sqlite;
125 fn into_executor(self) -> Executor<'a, Self::DB> {
126 Executor::Conn(self)
127 }
128}
129
130#[cfg(feature = "postgres")]
131impl<'a> IntoExecutor<'a> for &'a mut sqlx::postgres::PgConnection {
132 type DB = sqlx::Postgres;
133 fn into_executor(self) -> Executor<'a, Self::DB> {
134 Executor::Conn(self)
135 }
136}
137
138impl<'a, DB: SqlDialect> IntoExecutor<'a> for Executor<'a, DB> {
139 type DB = DB;
140 fn into_executor(self) -> Executor<'a, DB> {
141 self
142 }
143}
144
145impl<'a, DB: Database> Executor<'a, DB> {
146 pub async fn execute<'q, A>(
147 &mut self,
148 query: sqlx::query::Query<'q, DB, A>,
149 ) -> Result<DB::QueryResult, sqlx::Error>
150 where
151 A: sqlx::IntoArguments<'q, DB> + 'q,
152 DB: SqlDialect,
153 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
154 {
155 match self {
156 Self::Pool(pool) => query.execute(*pool).await,
157 Self::Conn(conn) => query.execute(&mut **conn).await,
158 }
159 }
160
161 pub async fn fetch_all<'q, T, A>(
162 &mut self,
163 query: sqlx::query::QueryAs<'q, DB, T, A>,
164 ) -> Result<Vec<T>, sqlx::Error>
165 where
166 T: for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
167 A: sqlx::IntoArguments<'q, DB> + 'q,
168 DB: SqlDialect,
169 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
170 {
171 match self {
172 Self::Pool(pool) => query.fetch_all(*pool).await,
173 Self::Conn(conn) => query.fetch_all(&mut **conn).await,
174 }
175 }
176
177 pub async fn fetch_optional<'q, T, A>(
178 &mut self,
179 query: sqlx::query::QueryAs<'q, DB, T, A>,
180 ) -> Result<Option<T>, sqlx::Error>
181 where
182 T: for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
183 A: sqlx::IntoArguments<'q, DB> + 'q,
184 DB: SqlDialect,
185 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
186 {
187 match self {
188 Self::Pool(pool) => query.fetch_optional(*pool).await,
189 Self::Conn(conn) => query.fetch_optional(&mut **conn).await,
190 }
191 }
192}
193
194#[async_trait::async_trait]
196pub trait ModelHooks {
197 async fn before_save(&mut self) -> Result<(), sqlx::Error> {
198 Ok(())
199 }
200 async fn after_save(&mut self) -> Result<(), sqlx::Error> {
201 Ok(())
202 }
203}
204
205#[async_trait::async_trait]
206impl<T: Send + Sync> ModelHooks for T {}
207
208#[derive(Debug, PartialEq)]
210pub enum UpdateResult {
211 Success,
212 VersionConflict,
213 NotFound,
214 NotImplemented,
215}
216
217#[derive(Debug, Clone)]
219pub struct ValidationError {
220 pub field: String,
221 pub message: String,
222}
223
224pub trait ModelValidation {
225 fn validate(&self) -> Result<(), Vec<ValidationError>> {
226 Ok(())
227 }
228}
229
230impl<T> ModelValidation for T {}
231
232#[async_trait::async_trait]
233pub trait Model<DB: Database>: Sized + Send + Sync + Unpin
234where
235 DB: SqlDialect,
236 for<'r> Self: sqlx::FromRow<'r, DB::Row>,
237{
238 fn table_name() -> &'static str;
239 fn create_table_sql() -> String;
240 fn list_columns() -> Vec<String>;
241
242 async fn save<'a, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
244 where
245 E: IntoExecutor<'a, DB = DB>;
246
247 async fn update<'a, E>(&mut self, executor: E) -> Result<UpdateResult, sqlx::Error>
248 where
249 E: IntoExecutor<'a, DB = DB>;
250
251 async fn delete<'a, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
253 where
254 E: IntoExecutor<'a, DB = DB>;
255 fn has_soft_delete() -> bool;
256
257 async fn find_by_id<'a, E>(executor: E, id: i32) -> Result<Option<Self>, sqlx::Error>
259 where
260 E: IntoExecutor<'a, DB = DB>;
261
262 async fn eager_load<'a, E>(
263 _models: &mut [Self],
264 _relation: &str,
265 _executor: E,
266 ) -> Result<(), sqlx::Error>
267 where
268 E: IntoExecutor<'a, DB = DB>,
269 {
270 Ok(())
271 }
272 fn find<'a, E>(executor: E) -> QueryBuilder<'a, Self, DB>
273 where
274 E: IntoExecutor<'a, DB = DB>,
275 {
276 QueryBuilder::new(executor.into_executor())
277 }
278
279 fn find_in_pool(pool: &sqlx::Pool<DB>) -> QueryBuilder<'_, Self, DB> {
281 QueryBuilder::new(Executor::Pool(pool))
282 }
283
284 fn find_in_tx(conn: &mut DB::Connection) -> QueryBuilder<'_, Self, DB> {
285 QueryBuilder::new(Executor::Conn(conn))
286 }
287}
288
289pub struct QueryBuilder<'a, T, DB: Database> {
290 executor: Executor<'a, DB>,
291 filters: Vec<String>,
292 limit: Option<i32>,
293 offset: Option<i32>,
294 includes: Vec<String>,
295 include_deleted: bool, _marker: std::marker::PhantomData<T>,
297}
298
299impl<'a, T, DB> QueryBuilder<'a, T, DB>
300where
301 DB: SqlDialect,
302 T: Model<DB>,
303{
304 pub fn new(executor: Executor<'a, DB>) -> Self {
305 Self {
306 executor,
307 filters: Vec::new(),
308 limit: None,
309 offset: None,
310 includes: Vec::new(),
311 include_deleted: false,
312 _marker: std::marker::PhantomData,
313 }
314 }
315
316 pub fn filter(mut self, condition: impl Into<String>) -> Self {
317 self.filters.push(condition.into());
318 self
319 }
320
321 pub fn limit(mut self, limit: i32) -> Self {
322 self.limit = Some(limit);
323 self
324 }
325
326 pub fn offset(mut self, offset: i32) -> Self {
327 self.offset = Some(offset);
328 self
329 }
330
331 pub fn include(mut self, relation: impl Into<String>) -> Self {
332 self.includes.push(relation.into());
333 self
334 }
335
336 pub fn with_deleted(mut self) -> Self {
338 self.include_deleted = true;
339 self
340 }
341
342 fn build_where_clause(&self) -> String {
343 let mut filters = self.filters.clone();
344
345 if T::has_soft_delete() && !self.include_deleted {
347 filters.push("deleted_at IS NULL".to_string());
348 }
349
350 if filters.is_empty() {
351 "".to_string()
352 } else {
353 format!(" WHERE {}", filters.join(" AND "))
354 }
355 }
356}
357
358impl<'a, T, DB> QueryBuilder<'a, T, DB>
359where
360 DB: SqlDialect,
361 T: Model<DB>,
362 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
363 for<'c> &'c mut <DB as Database>::Connection: SqlxExecutor<'c, Database = DB>,
364 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
365 DB::Connection: Send,
366 T: Send,
367{
368 pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
369 let mut sql = format!(
370 "SELECT * FROM {}{}",
371 T::table_name(),
372 self.build_where_clause()
373 );
374
375 if let Some(limit) = self.limit {
376 sql.push_str(&format!(" LIMIT {}", limit));
377 }
378
379 if let Some(offset) = self.offset {
380 sql.push_str(&format!(" OFFSET {}", offset));
381 }
382
383 let mut results: Vec<T> = match &mut self.executor {
384 Executor::Pool(pool) => sqlx::query_as::<DB, T>(&sql).fetch_all(*pool).await?,
385 Executor::Conn(conn) => sqlx::query_as::<DB, T>(&sql).fetch_all(&mut **conn).await?,
386 };
387
388 for relation in self.includes {
389 match &mut self.executor {
390 Executor::Pool(pool) => {
391 T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
392 }
393 Executor::Conn(conn) => {
394 T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
395 }
396 }
397 }
398
399 Ok(results)
400 }
401
402 pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
404 where
405 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
406 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
407 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
408 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
409 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
410 {
411 let obj = values.as_object().ok_or_else(|| {
412 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
413 })?;
414
415 let mut i = 1;
416 let set_clause = obj
417 .keys()
418 .map(|k| {
419 let p = DB::placeholder(i);
420 i += 1;
421 format!("{} = {}", k, p)
422 })
423 .collect::<Vec<_>>()
424 .join(", ");
425
426 let sql = format!(
427 "UPDATE {} SET {}{}",
428 T::table_name(),
429 set_clause,
430 self.build_where_clause()
431 );
432
433 let mut query = sqlx::query::<DB>(&sql);
434 for val in obj.values() {
435 match val {
436 serde_json::Value::String(s) => query = query.bind(s.clone()),
437 serde_json::Value::Number(n) => {
438 if let Some(v) = n.as_i64() {
439 query = query.bind(v);
440 } else if let Some(v) = n.as_f64() {
441 query = query.bind(v);
442 }
443 }
444 serde_json::Value::Bool(b) => query = query.bind(*b),
445 serde_json::Value::Null => query = query.bind(Option::<String>::None),
446 _ => {
447 return Err(sqlx::Error::Protocol(
448 "Unsupported type in bulk update".to_string(),
449 ));
450 }
451 }
452 }
453
454 match &mut self.executor {
455 Executor::Pool(pool) => {
456 let res = query.execute(*pool).await?;
457 Ok(DB::rows_affected(&res))
458 }
459 Executor::Conn(conn) => {
460 let res = query.execute(&mut **conn).await?;
461 Ok(DB::rows_affected(&res))
462 }
463 }
464 }
465
466 pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
467 let sql = if T::has_soft_delete() {
468 format!(
469 "UPDATE {} SET deleted_at = {}{}",
470 T::table_name(),
471 DB::current_timestamp_fn(),
472 self.build_where_clause()
473 )
474 } else {
475 format!(
476 "DELETE FROM {}{}",
477 T::table_name(),
478 self.build_where_clause()
479 )
480 };
481
482 match &mut self.executor {
483 Executor::Pool(pool) => {
484 let res = sqlx::query::<DB>(&sql).execute(*pool).await?;
485 Ok(DB::rows_affected(&res))
486 }
487 Executor::Conn(conn) => {
488 let res = sqlx::query::<DB>(&sql).execute(&mut **conn).await?;
489 Ok(DB::rows_affected(&res))
490 }
491 }
492 }
493}
494
495impl Premix {
496 pub async fn sync<DB, M>(pool: &sqlx::Pool<DB>) -> Result<(), sqlx::Error>
497 where
498 DB: SqlDialect,
499 M: Model<DB>,
500 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
501 for<'c> &'c mut <DB as Database>::Connection: SqlxExecutor<'c, Database = DB>,
502 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
503 {
504 let sql = M::create_table_sql();
505 sqlx::query::<DB>(&sql).execute(pool).await?;
506 Ok(())
507 }
508}