1pub use sqlx;
2use sqlx::{Database, Executor as SqlxExecutor, IntoArguments};
3
4pub struct Premix;
5pub mod migrator;
6pub use migrator::{Migration, Migrator};
7
8pub trait SqlDialect: Database + Sized {
11 fn placeholder(n: usize) -> String;
12 fn auto_increment_pk() -> &'static str;
13 fn rows_affected(res: &Self::QueryResult) -> u64;
14
15 fn current_timestamp_fn() -> &'static str {
16 "CURRENT_TIMESTAMP"
17 }
18 fn int_type() -> &'static str {
19 "INTEGER"
20 }
21 fn text_type() -> &'static str {
22 "TEXT"
23 }
24 fn bool_type() -> &'static str {
25 "BOOLEAN"
26 }
27 fn float_type() -> &'static str {
28 "REAL"
29 }
30}
31
32#[cfg(feature = "sqlite")]
33impl SqlDialect for sqlx::Sqlite {
34 fn placeholder(_n: usize) -> String {
35 "?".to_string()
36 }
37 fn auto_increment_pk() -> &'static str {
38 "INTEGER PRIMARY KEY"
39 }
40 fn rows_affected(res: &sqlx::sqlite::SqliteQueryResult) -> u64 {
41 res.rows_affected()
42 }
43}
44
45#[cfg(feature = "postgres")]
46impl SqlDialect for sqlx::Postgres {
47 fn placeholder(n: usize) -> String {
48 format!("${}", n)
49 }
50 fn auto_increment_pk() -> &'static str {
51 "SERIAL PRIMARY KEY"
52 }
53 fn rows_affected(res: &sqlx::postgres::PgQueryResult) -> u64 {
54 res.rows_affected()
55 }
56}
57
58#[cfg(feature = "mysql")]
59impl SqlDialect for sqlx::MySql {
60 fn placeholder(_n: usize) -> String {
61 "?".to_string()
62 }
63 fn auto_increment_pk() -> &'static str {
64 "INTEGER AUTO_INCREMENT PRIMARY KEY"
65 }
66 fn rows_affected(res: &sqlx::mysql::MySqlQueryResult) -> u64 {
67 res.rows_affected()
68 }
69}
70
71pub enum Executor<'a, DB: Database> {
73 Pool(&'a sqlx::Pool<DB>),
74 Conn(&'a mut DB::Connection),
75}
76
77pub trait ModelHooks {
79 #[allow(async_fn_in_trait)]
80 async fn before_save(&mut self) -> Result<(), sqlx::Error> {
81 Ok(())
82 }
83 #[allow(async_fn_in_trait)]
84 async fn after_save(&mut self) -> Result<(), sqlx::Error> {
85 Ok(())
86 }
87}
88
89impl<T> ModelHooks for T {}
90
91#[derive(Debug, PartialEq)]
93pub enum UpdateResult {
94 Success,
95 VersionConflict,
96 NotFound,
97 NotImplemented,
98}
99
100#[derive(Debug, Clone)]
102pub struct ValidationError {
103 pub field: String,
104 pub message: String,
105}
106
107pub trait ModelValidation {
108 fn validate(&self) -> Result<(), Vec<ValidationError>> {
109 Ok(())
110 }
111}
112
113impl<T> ModelValidation for T {}
114
115#[allow(async_fn_in_trait)]
116pub trait Model<DB: Database>: Sized + Send + Sync + Unpin
117where
118 DB: SqlDialect,
119 for<'r> Self: sqlx::FromRow<'r, DB::Row>,
120{
121 fn table_name() -> &'static str;
122 fn create_table_sql() -> String;
123 fn list_columns() -> Vec<String>;
124
125 async fn save<'e, E>(&mut self, executor: E) -> Result<(), sqlx::Error>
135 where
136 E: SqlxExecutor<'e, Database = DB>;
137
138 async fn update(&mut self, executor: Executor<'_, DB>) -> Result<UpdateResult, sqlx::Error>;
139
140 async fn delete(&mut self, executor: Executor<'_, DB>) -> Result<(), sqlx::Error>;
142 fn has_soft_delete() -> bool;
143
144 async fn find_by_id<'e, E>(executor: E, id: i32) -> Result<Option<Self>, sqlx::Error>
151 where
152 E: SqlxExecutor<'e, Database = DB>;
153
154 async fn eager_load<'e, E>(
155 _models: &mut [Self],
156 _relation: &str,
157 _executor: E,
158 ) -> Result<(), sqlx::Error>
159 where
160 E: SqlxExecutor<'e, Database = DB>,
161 {
162 Ok(())
163 }
164
165 fn find(executor: Executor<'_, DB>) -> QueryBuilder<'_, Self, DB> {
166 QueryBuilder::new(executor)
167 }
168
169 fn find_in_pool(pool: &sqlx::Pool<DB>) -> QueryBuilder<'_, Self, DB> {
171 QueryBuilder::new(Executor::Pool(pool))
172 }
173
174 fn find_in_tx(conn: &mut DB::Connection) -> QueryBuilder<'_, Self, DB> {
175 QueryBuilder::new(Executor::Conn(conn))
176 }
177}
178
179pub struct QueryBuilder<'a, T, DB: Database> {
180 executor: Executor<'a, DB>,
181 filters: Vec<String>,
182 limit: Option<i32>,
183 offset: Option<i32>,
184 includes: Vec<String>,
185 include_deleted: bool, _marker: std::marker::PhantomData<T>,
187}
188
189impl<'a, T, DB> QueryBuilder<'a, T, DB>
190where
191 DB: SqlDialect,
192 T: Model<DB>,
193{
194 pub fn new(executor: Executor<'a, DB>) -> Self {
195 Self {
196 executor,
197 filters: Vec::new(),
198 limit: None,
199 offset: None,
200 includes: Vec::new(),
201 include_deleted: false,
202 _marker: std::marker::PhantomData,
203 }
204 }
205
206 pub fn filter(mut self, condition: impl Into<String>) -> Self {
207 self.filters.push(condition.into());
208 self
209 }
210
211 pub fn limit(mut self, limit: i32) -> Self {
212 self.limit = Some(limit);
213 self
214 }
215
216 pub fn offset(mut self, offset: i32) -> Self {
217 self.offset = Some(offset);
218 self
219 }
220
221 pub fn include(mut self, relation: impl Into<String>) -> Self {
222 self.includes.push(relation.into());
223 self
224 }
225
226 pub fn with_deleted(mut self) -> Self {
228 self.include_deleted = true;
229 self
230 }
231
232 fn build_where_clause(&self) -> String {
233 let mut filters = self.filters.clone();
234
235 if T::has_soft_delete() && !self.include_deleted {
237 filters.push("deleted_at IS NULL".to_string());
238 }
239
240 if filters.is_empty() {
241 "".to_string()
242 } else {
243 format!(" WHERE {}", filters.join(" AND "))
244 }
245 }
246}
247
248impl<'a, T, DB> QueryBuilder<'a, T, DB>
249where
250 DB: SqlDialect,
251 T: Model<DB>,
252 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
253 for<'c> &'c mut <DB as Database>::Connection: SqlxExecutor<'c, Database = DB>,
254 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
255{
256 pub async fn all(mut self) -> Result<Vec<T>, sqlx::Error> {
257 let mut sql = format!(
258 "SELECT * FROM {}{}",
259 T::table_name(),
260 self.build_where_clause()
261 );
262
263 if let Some(limit) = self.limit {
264 sql.push_str(&format!(" LIMIT {}", limit));
265 }
266
267 if let Some(offset) = self.offset {
268 sql.push_str(&format!(" OFFSET {}", offset));
269 }
270
271 let mut results: Vec<T> = match &mut self.executor {
272 Executor::Pool(pool) => sqlx::query_as::<DB, T>(&sql).fetch_all(*pool).await?,
273 Executor::Conn(conn) => sqlx::query_as::<DB, T>(&sql).fetch_all(&mut **conn).await?,
274 };
275
276 for relation in self.includes {
277 match &mut self.executor {
278 Executor::Pool(pool) => {
279 T::eager_load(&mut results, &relation, *pool).await?;
280 }
281 Executor::Conn(conn) => {
282 T::eager_load(&mut results, &relation, &mut **conn).await?;
283 }
284 }
285 }
286
287 Ok(results)
288 }
289
290 pub async fn update(mut self, values: serde_json::Value) -> Result<u64, sqlx::Error>
292 where
293 String: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
294 i64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
295 f64: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
296 bool: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
297 Option<String>: for<'q> sqlx::Encode<'q, DB> + sqlx::Type<DB>,
298 {
299 let obj = values.as_object().ok_or_else(|| {
300 sqlx::Error::Protocol("Bulk update requires a JSON object".to_string())
301 })?;
302
303 let mut i = 1;
304 let set_clause = obj
305 .keys()
306 .map(|k| {
307 let p = DB::placeholder(i);
308 i += 1;
309 format!("{} = {}", k, p)
310 })
311 .collect::<Vec<_>>()
312 .join(", ");
313
314 let sql = format!(
315 "UPDATE {} SET {}{}",
316 T::table_name(),
317 set_clause,
318 self.build_where_clause()
319 );
320
321 let mut query = sqlx::query::<DB>(&sql);
322 for val in obj.values() {
323 match val {
324 serde_json::Value::String(s) => query = query.bind(s.clone()),
325 serde_json::Value::Number(n) => {
326 if let Some(v) = n.as_i64() {
327 query = query.bind(v);
328 } else if let Some(v) = n.as_f64() {
329 query = query.bind(v);
330 }
331 }
332 serde_json::Value::Bool(b) => query = query.bind(*b),
333 serde_json::Value::Null => query = query.bind(Option::<String>::None),
334 _ => {
335 return Err(sqlx::Error::Protocol(
336 "Unsupported type in bulk update".to_string(),
337 ));
338 }
339 }
340 }
341
342 match &mut self.executor {
343 Executor::Pool(pool) => {
344 let res = query.execute(*pool).await?;
345 Ok(DB::rows_affected(&res))
346 }
347 Executor::Conn(conn) => {
348 let res = query.execute(&mut **conn).await?;
349 Ok(DB::rows_affected(&res))
350 }
351 }
352 }
353
354 pub async fn delete(mut self) -> Result<u64, sqlx::Error> {
355 let sql = if T::has_soft_delete() {
356 format!(
357 "UPDATE {} SET deleted_at = {}{}",
358 T::table_name(),
359 DB::current_timestamp_fn(),
360 self.build_where_clause()
361 )
362 } else {
363 format!(
364 "DELETE FROM {}{}",
365 T::table_name(),
366 self.build_where_clause()
367 )
368 };
369
370 match &mut self.executor {
371 Executor::Pool(pool) => {
372 let res = sqlx::query::<DB>(&sql).execute(*pool).await?;
373 Ok(DB::rows_affected(&res))
374 }
375 Executor::Conn(conn) => {
376 let res = sqlx::query::<DB>(&sql).execute(&mut **conn).await?;
377 Ok(DB::rows_affected(&res))
378 }
379 }
380 }
381}
382
383impl Premix {
384 pub async fn sync<DB, M>(pool: &sqlx::Pool<DB>) -> Result<(), sqlx::Error>
385 where
386 DB: SqlDialect,
387 M: Model<DB>,
388 for<'q> <DB as Database>::Arguments<'q>: IntoArguments<'q, DB>,
389 for<'c> &'c mut <DB as Database>::Connection: SqlxExecutor<'c, Database = DB>,
390 for<'c> &'c str: sqlx::ColumnIndex<DB::Row>,
391 {
392 let sql = M::create_table_sql();
393 sqlx::query::<DB>(&sql).execute(pool).await?;
394 Ok(())
395 }
396}