1pub use async_trait;
2pub use sqlx;
3use sqlx::{Database, Executor as SqlxExecutor, IntoArguments};
4
5pub struct Premix;
6pub mod migrator;
7pub use migrator::{Migration, Migrator};
8
9pub trait SqlDialect: Database + Sized + Send + Sync
12where
13 Self::Connection: Send,
14{
15 fn placeholder(n: usize) -> String;
16 fn auto_increment_pk() -> &'static str;
17 fn rows_affected(res: &Self::QueryResult) -> u64;
18 fn last_insert_id(res: &Self::QueryResult) -> i64;
19
20 fn current_timestamp_fn() -> &'static str {
21 "CURRENT_TIMESTAMP"
22 }
23 fn int_type() -> &'static str {
24 "INTEGER"
25 }
26 fn text_type() -> &'static str {
27 "TEXT"
28 }
29 fn bool_type() -> &'static str {
30 "BOOLEAN"
31 }
32 fn float_type() -> &'static str {
33 "REAL"
34 }
35}
36
37#[cfg(feature = "sqlite")]
38impl SqlDialect for sqlx::Sqlite {
39 fn placeholder(_n: usize) -> String {
40 "?".to_string()
41 }
42 fn auto_increment_pk() -> &'static str {
43 "INTEGER PRIMARY KEY"
44 }
45 fn rows_affected(res: &sqlx::sqlite::SqliteQueryResult) -> u64 {
46 res.rows_affected()
47 }
48 fn last_insert_id(res: &sqlx::sqlite::SqliteQueryResult) -> i64 {
49 res.last_insert_rowid()
50 }
51}
52
53#[cfg(feature = "postgres")]
54impl SqlDialect for sqlx::Postgres {
55 fn placeholder(n: usize) -> String {
56 format!("${}", n)
57 }
58 fn auto_increment_pk() -> &'static str {
59 "SERIAL PRIMARY KEY"
60 }
61 fn rows_affected(res: &sqlx::postgres::PgQueryResult) -> u64 {
62 res.rows_affected()
63 }
64 fn last_insert_id(_res: &sqlx::postgres::PgQueryResult) -> i64 {
65 0
66 }
67}
68
69#[cfg(feature = "mysql")]
70impl SqlDialect for sqlx::MySql {
71 fn placeholder(_n: usize) -> String {
72 "?".to_string()
73 }
74 fn auto_increment_pk() -> &'static str {
75 "INTEGER AUTO_INCREMENT PRIMARY KEY"
76 }
77 fn rows_affected(res: &sqlx::mysql::MySqlQueryResult) -> u64 {
78 res.rows_affected()
79 }
80 fn last_insert_id(res: &sqlx::mysql::MySqlQueryResult) -> i64 {
81 res.last_insert_id() as i64
82 }
83}
84
85pub enum Executor<'a, DB: Database> {
87 Pool(&'a sqlx::Pool<DB>),
88 Conn(&'a mut DB::Connection),
89}
90
91unsafe impl<'a, DB: Database> Send for Executor<'a, DB> where DB::Connection: Send {}
92unsafe impl<'a, DB: Database> Sync for Executor<'a, DB> where DB::Connection: Sync {}
93
94impl<'a, DB: Database> From<&'a sqlx::Pool<DB>> for Executor<'a, DB> {
95 fn from(pool: &'a sqlx::Pool<DB>) -> Self {
96 Self::Pool(pool)
97 }
98}
99
100impl<'a, DB: Database> From<&'a mut DB::Connection> for Executor<'a, DB> {
101 fn from(conn: &'a mut DB::Connection) -> Self {
102 Self::Conn(conn)
103 }
104}
105
106pub trait IntoExecutor<'a>: Send + 'a {
107 type DB: SqlDialect;
108 fn into_executor(self) -> Executor<'a, Self::DB>;
109}
110
111impl<'a, DB: SqlDialect> IntoExecutor<'a> for &'a sqlx::Pool<DB> {
112 type DB = DB;
113 fn into_executor(self) -> Executor<'a, DB> {
114 Executor::Pool(self)
115 }
116}
117
118#[cfg(feature = "sqlite")]
119impl<'a> IntoExecutor<'a> for &'a mut sqlx::SqliteConnection {
120 type DB = sqlx::Sqlite;
121 fn into_executor(self) -> Executor<'a, Self::DB> {
122 Executor::Conn(self)
123 }
124}
125
126#[cfg(feature = "postgres")]
127impl<'a> IntoExecutor<'a> for &'a mut sqlx::postgres::PgConnection {
128 type DB = sqlx::Postgres;
129 fn into_executor(self) -> Executor<'a, Self::DB> {
130 Executor::Conn(self)
131 }
132}
133
134impl<'a, DB: SqlDialect> IntoExecutor<'a> for Executor<'a, DB> {
135 type DB = DB;
136 fn into_executor(self) -> Executor<'a, DB> {
137 self
138 }
139}
140
141impl<'a, DB: Database> Executor<'a, DB> {
142 pub async fn execute<'q, A>(
143 &mut self,
144 query: sqlx::query::Query<'q, DB, A>,
145 ) -> Result<DB::QueryResult, sqlx::Error>
146 where
147 A: sqlx::IntoArguments<'q, DB> + 'q,
148 DB: SqlDialect,
149 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
150 {
151 match self {
152 Self::Pool(pool) => query.execute(*pool).await,
153 Self::Conn(conn) => query.execute(&mut **conn).await,
154 }
155 }
156
157 pub async fn fetch_all<'q, T, A>(
158 &mut self,
159 query: sqlx::query::QueryAs<'q, DB, T, A>,
160 ) -> Result<Vec<T>, sqlx::Error>
161 where
162 T: for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
163 A: sqlx::IntoArguments<'q, DB> + 'q,
164 DB: SqlDialect,
165 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
166 {
167 match self {
168 Self::Pool(pool) => query.fetch_all(*pool).await,
169 Self::Conn(conn) => query.fetch_all(&mut **conn).await,
170 }
171 }
172
173 pub async fn fetch_optional<'q, T, A>(
174 &mut self,
175 query: sqlx::query::QueryAs<'q, DB, T, A>,
176 ) -> Result<Option<T>, sqlx::Error>
177 where
178 T: for<'r> sqlx::FromRow<'r, DB::Row> + Send + Unpin,
179 A: sqlx::IntoArguments<'q, DB> + 'q,
180 DB: SqlDialect,
181 for<'c> &'c mut DB::Connection: sqlx::Executor<'c, Database = DB>,
182 {
183 match self {
184 Self::Pool(pool) => query.fetch_optional(*pool).await,
185 Self::Conn(conn) => query.fetch_optional(&mut **conn).await,
186 }
187 }
188}
189
190#[async_trait::async_trait]
192pub trait ModelHooks {
193 async fn before_save(&mut self) -> Result<(), sqlx::Error> {
194 Ok(())
195 }
196 async fn after_save(&mut self) -> Result<(), sqlx::Error> {
197 Ok(())
198 }
199}
200
201#[async_trait::async_trait]
202impl<T: Send + Sync> ModelHooks for T {}
203
204#[derive(Debug, PartialEq)]
206pub enum UpdateResult {
207 Success,
208 VersionConflict,
209 NotFound,
210 NotImplemented,
211}
212
213#[derive(Debug, Clone)]
215pub struct ValidationError {
216 pub field: String,
217 pub message: String,
218}
219
220pub trait ModelValidation {
221 fn validate(&self) -> Result<(), Vec<ValidationError>> {
222 Ok(())
223 }
224}
225
226impl<T> ModelValidation for T {}
227
228#[async_trait::async_trait]
229pub trait Model<DB: Database>: Sized + Send + Sync + Unpin
230where
231 DB: SqlDialect,
232 for<'r> Self: sqlx::FromRow<'r, DB::Row>,
233{
234 fn table_name() -> &'static str;
235 fn create_table_sql() -> String;
236 fn list_columns() -> Vec<String>;
237
238 async fn save<'a, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
240 where
241 E: IntoExecutor<'a, DB = DB>;
242
243 async fn update<'a, E>(&mut self, executor: E) -> Result<UpdateResult, sqlx::Error>
244 where
245 E: IntoExecutor<'a, DB = DB>;
246
247 async fn delete<'a, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
249 where
250 E: IntoExecutor<'a, DB = DB>;
251 fn has_soft_delete() -> bool;
252
253 async fn find_by_id<'a, E>(executor: E, id: i32) -> Result<Option<Self>, sqlx::Error>
255 where
256 E: IntoExecutor<'a, DB = DB>;
257
258 async fn eager_load<'a, E>(
259 _models: &mut [Self],
260 _relation: &str,
261 _executor: E,
262 ) -> Result<(), sqlx::Error>
263 where
264 E: IntoExecutor<'a, DB = DB>,
265 {
266 Ok(())
267 }
268 fn find<'a, E>(executor: E) -> QueryBuilder<'a, Self, DB>
269 where
270 E: IntoExecutor<'a, DB = DB>,
271 {
272 QueryBuilder::new(executor.into_executor())
273 }
274
275 fn find_in_pool(pool: &sqlx::Pool<DB>) -> QueryBuilder<'_, Self, DB> {
277 QueryBuilder::new(Executor::Pool(pool))
278 }
279
280 fn find_in_tx(conn: &mut DB::Connection) -> QueryBuilder<'_, Self, DB> {
281 QueryBuilder::new(Executor::Conn(conn))
282 }
283}
284
285pub struct QueryBuilder<'a, T, DB: Database> {
286 executor: Executor<'a, DB>,
287 filters: Vec<String>,
288 limit: Option<i32>,
289 offset: Option<i32>,
290 includes: Vec<String>,
291 include_deleted: bool, _marker: std::marker::PhantomData<T>,
293}
294
295impl<'a, T, DB> QueryBuilder<'a, T, DB>
296where
297 DB: SqlDialect,
298 T: Model<DB>,
299{
300 pub fn new(executor: Executor<'a, DB>) -> Self {
301 Self {
302 executor,
303 filters: Vec::new(),
304 limit: None,
305 offset: None,
306 includes: Vec::new(),
307 include_deleted: false,
308 _marker: std::marker::PhantomData,
309 }
310 }
311
312 pub fn filter(mut self, condition: impl Into<String>) -> Self {
313 self.filters.push(condition.into());
314 self
315 }
316
317 pub fn limit(mut self, limit: i32) -> Self {
318 self.limit = Some(limit);
319 self
320 }
321
322 pub fn offset(mut self, offset: i32) -> Self {
323 self.offset = Some(offset);
324 self
325 }
326
327 pub fn include(mut self, relation: impl Into<String>) -> Self {
328 self.includes.push(relation.into());
329 self
330 }
331
332 pub fn with_deleted(mut self) -> Self {
334 self.include_deleted = true;
335 self
336 }
337
338 fn build_where_clause(&self) -> String {
339 let mut filters = self.filters.clone();
340
341 if T::has_soft_delete() && !self.include_deleted {
343 filters.push("deleted_at IS NULL".to_string());
344 }
345
346 if filters.is_empty() {
347 "".to_string()
348 } else {
349 format!(" WHERE {}", filters.join(" AND "))
350 }
351 }
352}
353
354impl<'a, T, DB> QueryBuilder<'a, T, DB>
355where
356 DB: SqlDialect,
357 T: Model<DB>,
358 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
359 for<'c> &'c mut <DB as Database>::Connection: SqlxExecutor<'c, Database = DB>,
360 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
361 DB::Connection: Send,
362 T: Send,
363{
364 pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
365 let mut sql = format!(
366 "SELECT * FROM {}{}",
367 T::table_name(),
368 self.build_where_clause()
369 );
370
371 if let Some(limit) = self.limit {
372 sql.push_str(&format!(" LIMIT {}", limit));
373 }
374
375 if let Some(offset) = self.offset {
376 sql.push_str(&format!(" OFFSET {}", offset));
377 }
378
379 let mut results: Vec<T> = match &mut self.executor {
380 Executor::Pool(pool) => sqlx::query_as::<DB, T>(&sql).fetch_all(*pool).await?,
381 Executor::Conn(conn) => sqlx::query_as::<DB, T>(&sql).fetch_all(&mut **conn).await?,
382 };
383
384 for relation in self.includes {
385 match &mut self.executor {
386 Executor::Pool(pool) => {
387 T::eager_load(&mut results, &relation, Executor::Pool(*pool)).await?;
388 }
389 Executor::Conn(conn) => {
390 T::eager_load(&mut results, &relation, Executor::Conn(&mut **conn)).await?;
391 }
392 }
393 }
394
395 Ok(results)
396 }
397
398 pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
400 where
401 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
402 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
403 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
404 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
405 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
406 {
407 let obj = values.as_object().ok_or_else(|| {
408 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
409 })?;
410
411 let mut i = 1;
412 let set_clause = obj
413 .keys()
414 .map(|k| {
415 let p = DB::placeholder(i);
416 i += 1;
417 format!("{} = {}", k, p)
418 })
419 .collect::<Vec<_>>()
420 .join(", ");
421
422 let sql = format!(
423 "UPDATE {} SET {}{}",
424 T::table_name(),
425 set_clause,
426 self.build_where_clause()
427 );
428
429 let mut query = sqlx::query::<DB>(&sql);
430 for val in obj.values() {
431 match val {
432 serde_json::Value::String(s) => query = query.bind(s.clone()),
433 serde_json::Value::Number(n) => {
434 if let Some(v) = n.as_i64() {
435 query = query.bind(v);
436 } else if let Some(v) = n.as_f64() {
437 query = query.bind(v);
438 }
439 }
440 serde_json::Value::Bool(b) => query = query.bind(*b),
441 serde_json::Value::Null => query = query.bind(Option::<String>::None),
442 _ => {
443 return Err(sqlx::Error::Protocol(
444 "Unsupported type in bulk update".to_string(),
445 ));
446 }
447 }
448 }
449
450 match &mut self.executor {
451 Executor::Pool(pool) => {
452 let res = query.execute(*pool).await?;
453 Ok(DB::rows_affected(&res))
454 }
455 Executor::Conn(conn) => {
456 let res = query.execute(&mut **conn).await?;
457 Ok(DB::rows_affected(&res))
458 }
459 }
460 }
461
462 pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
463 let sql = if T::has_soft_delete() {
464 format!(
465 "UPDATE {} SET deleted_at = {}{}",
466 T::table_name(),
467 DB::current_timestamp_fn(),
468 self.build_where_clause()
469 )
470 } else {
471 format!(
472 "DELETE FROM {}{}",
473 T::table_name(),
474 self.build_where_clause()
475 )
476 };
477
478 match &mut self.executor {
479 Executor::Pool(pool) => {
480 let res = sqlx::query::<DB>(&sql).execute(*pool).await?;
481 Ok(DB::rows_affected(&res))
482 }
483 Executor::Conn(conn) => {
484 let res = sqlx::query::<DB>(&sql).execute(&mut **conn).await?;
485 Ok(DB::rows_affected(&res))
486 }
487 }
488 }
489}
490
491impl Premix {
492 pub async fn sync<DB, M>(pool: &sqlx::Pool<DB>) -> Result<(), sqlx::Error>
493 where
494 DB: SqlDialect,
495 M: Model<DB>,
496 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
497 for<'c> &'c mut <DB as Database>::Connection: SqlxExecutor<'c, Database = DB>,
498 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
499 {
500 let sql = M::create_table_sql();
501 sqlx::query::<DB>(&sql).execute(pool).await?;
502 Ok(())
503 }
504}