premix_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{Data, DeriveInput, Field, Fields, parse_macro_input};
4
5mod relations;
6
7#[proc_macro_derive(Model, attributes(has_many, belongs_to, premix))]
8pub fn derive_model(input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10    match derive_model_impl(&input) {
11        Ok(tokens) => TokenStream::from(tokens),
12        Err(err) => TokenStream::from(err.to_compile_error()),
13    }
14}
15
16fn derive_model_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
17    let impl_block = generate_generic_impl(input)?;
18    let rel_block = relations::impl_relations(input)?;
19    Ok(quote! {
20        #impl_block
21        #rel_block
22    })
23}
24
25#[cfg(test)]
26mod tests {
27    use syn::parse_quote;
28
29    use super::*;
30
31    #[test]
32    fn generate_generic_impl_includes_table_and_columns() {
33        let input: DeriveInput = parse_quote! {
34            struct User {
35                id: i32,
36                name: String,
37                version: i32,
38                deleted_at: Option<String>,
39            }
40        };
41        let tokens = generate_generic_impl(&input).unwrap().to_string();
42        assert!(tokens.contains("CREATE TABLE IF NOT EXISTS"));
43        assert!(tokens.contains("users"));
44        assert!(tokens.contains("deleted_at"));
45        assert!(tokens.contains("version"));
46    }
47
48    #[test]
49    fn generate_generic_impl_rejects_tuple_struct() {
50        let input: DeriveInput = parse_quote! {
51            struct User(i32, String);
52        };
53        let err = generate_generic_impl(&input).unwrap_err();
54        assert!(err.to_string().contains("named fields"));
55    }
56
57    #[test]
58    fn generate_generic_impl_rejects_non_struct() {
59        let input: DeriveInput = parse_quote! {
60            enum User {
61                A,
62                B,
63            }
64        };
65        let err = generate_generic_impl(&input).unwrap_err();
66        assert!(err.to_string().contains("only supports structs"));
67    }
68
69    #[test]
70    fn generate_generic_impl_version_update_branch() {
71        let input: DeriveInput = parse_quote! {
72            struct User {
73                id: i32,
74                version: i32,
75                name: String,
76            }
77        };
78        let tokens = generate_generic_impl(&input).unwrap().to_string();
79        assert!(tokens.contains("version = version + 1"));
80    }
81
82    #[test]
83    fn generate_generic_impl_no_version_branch() {
84        let input: DeriveInput = parse_quote! {
85            struct User {
86                id: i32,
87                name: String,
88            }
89        };
90        let tokens = generate_generic_impl(&input).unwrap().to_string();
91        assert!(!tokens.contains("version = version + 1"));
92    }
93
94    #[test]
95    fn is_ignored_detects_attribute() {
96        let field: Field = parse_quote! {
97            #[premix(ignore)]
98            ignored: Option<String>
99        };
100        assert!(is_ignored(&field));
101    }
102
103    #[test]
104    fn is_ignored_false_for_other_attrs() {
105        let field: Field = parse_quote! {
106            #[serde(skip)]
107            name: String
108        };
109        assert!(!is_ignored(&field));
110    }
111
112    #[test]
113    fn is_ignored_false_for_premix_other_arg() {
114        let field: Field = parse_quote! {
115            #[premix(skip)]
116            name: String
117        };
118        assert!(!is_ignored(&field));
119    }
120
121    #[test]
122    fn is_ignored_false_when_premix_has_no_args() {
123        let field: Field = parse_quote! {
124            #[premix]
125            name: String
126        };
127        assert!(!is_ignored(&field));
128    }
129
130    #[test]
131    fn derive_model_impl_emits_tokens() {
132        let input: DeriveInput = parse_quote! {
133            struct User {
134                id: i32,
135                name: String,
136            }
137        };
138        let tokens = derive_model_impl(&input).unwrap().to_string();
139        assert!(tokens.contains("impl"));
140    }
141
142    #[test]
143    fn derive_model_impl_propagates_error() {
144        let input: DeriveInput = parse_quote! {
145            enum User {
146                A,
147            }
148        };
149        let err = derive_model_impl(&input).unwrap_err();
150        assert!(err.to_string().contains("only supports structs"));
151    }
152
153    #[test]
154    fn generate_generic_impl_includes_soft_delete_delete_impl() {
155        let input: DeriveInput = parse_quote! {
156            struct AuditLog {
157                id: i32,
158                deleted_at: Option<String>,
159            }
160        };
161        let tokens = generate_generic_impl(&input).unwrap().to_string();
162        assert!(tokens.contains("deleted_at ="));
163        assert!(tokens.contains("has_soft_delete"));
164    }
165
166    #[test]
167    fn generate_generic_impl_ignores_marked_fields() {
168        let input: DeriveInput = parse_quote! {
169            struct User {
170                id: i32,
171                name: String,
172                #[premix(ignore)]
173                temp: Option<String>,
174            }
175        };
176        let tokens = generate_generic_impl(&input).unwrap().to_string();
177        assert!(tokens.contains("temp : None"));
178        assert!(!tokens.contains("\"temp\""));
179    }
180
181    #[test]
182    fn generate_generic_impl_adds_relation_bounds() {
183        let input: DeriveInput = parse_quote! {
184            struct User {
185                id: i32,
186                #[has_many(Post)]
187                posts: Vec<Post>,
188            }
189        };
190        let tokens = generate_generic_impl(&input).unwrap().to_string();
191        assert!(tokens.contains("Post : premix_core :: Model < DB >"));
192    }
193
194    #[test]
195    fn generate_generic_impl_records_field_names() {
196        let input: DeriveInput = parse_quote! {
197            struct Account {
198                id: i32,
199                user_id: i32,
200                is_active: bool,
201            }
202        };
203        let tokens = generate_generic_impl(&input).unwrap().to_string();
204        assert!(tokens.contains("\"user_id\""));
205        assert!(tokens.contains("\"is_active\""));
206    }
207}
208
209fn generate_generic_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
210    let struct_name = &input.ident;
211    let table_name = struct_name.to_string().to_lowercase() + "s";
212
213    let all_fields = if let Data::Struct(data) = &input.data {
214        if let Fields::Named(fields) = &data.fields {
215            &fields.named
216        } else {
217            return Err(syn::Error::new_spanned(
218                &data.fields,
219                "Premix Model only supports structs with named fields",
220            ));
221        }
222    } else {
223        return Err(syn::Error::new_spanned(
224            input,
225            "Premix Model only supports structs",
226        ));
227    };
228
229    let mut db_fields = Vec::new();
230    let mut ignored_field_idents = Vec::new();
231
232    for field in all_fields {
233        if is_ignored(field) {
234            ignored_field_idents.push(field.ident.as_ref().unwrap());
235        } else {
236            db_fields.push(field);
237        }
238    }
239
240    let field_idents: Vec<_> = db_fields
241        .iter()
242        .map(|f| f.ident.as_ref().unwrap())
243        .collect();
244    let field_types: Vec<_> = db_fields.iter().map(|f| &f.ty).collect();
245    let field_indices: Vec<_> = (0..db_fields.len()).collect();
246    let field_names: Vec<_> = field_idents.iter().map(|id| id.to_string()).collect();
247    let field_idents_len = field_idents.len();
248
249    let eager_load_body = relations::generate_eager_load_body(input)?;
250    let has_version = field_names.contains(&"version".to_string());
251    let has_soft_delete = field_names.contains(&"deleted_at".to_string());
252
253    let update_impl = if has_version {
254        quote! {
255            async fn update<'a, E>(&mut self, executor: E) -> Result<premix_core::UpdateResult, premix_core::sqlx::Error>
256            where
257                E: premix_core::IntoExecutor<'a, DB = DB>
258            {
259                let mut executor = executor.into_executor();
260                let table_name = Self::table_name();
261                let set_clause = vec![ #( format!("{} = {}", #field_names, <DB as premix_core::SqlDialect>::placeholder(1 + #field_indices)) ),* ].join(", ");
262                let id_p = <DB as premix_core::SqlDialect>::placeholder(1 + #field_idents_len);
263                let ver_p = <DB as premix_core::SqlDialect>::placeholder(2 + #field_idents_len);
264                let sql = format!(
265                    "UPDATE {} SET {}, version = version + 1 WHERE id = {} AND version = {}",
266                    table_name, set_clause, id_p, ver_p
267                );
268
269                let mut query = premix_core::sqlx::query::<DB>(&sql)
270                    #( .bind(&self.#field_idents) )*
271                    .bind(&self.id)
272                    .bind(&self.version);
273
274                let result = executor.execute(query).await?;
275
276                if <DB as premix_core::SqlDialect>::rows_affected(&result) == 0 {
277                    let exists_p = <DB as premix_core::SqlDialect>::placeholder(1);
278                    let exists_sql = format!("SELECT id FROM {} WHERE id = {}", table_name, exists_p);
279                    let exists_query = premix_core::sqlx::query_as::<DB, (i32,)>(&exists_sql).bind(&self.id);
280                    let exists = executor.fetch_optional(exists_query).await?;
281
282                    if exists.is_none() {
283                        Ok(premix_core::UpdateResult::NotFound)
284                    } else {
285                        Ok(premix_core::UpdateResult::VersionConflict)
286                    }
287                } else {
288                    self.version += 1;
289                    Ok(premix_core::UpdateResult::Success)
290                }
291            }
292        }
293    } else {
294        quote! {
295            async fn update<'a, E>(&mut self, executor: E) -> Result<premix_core::UpdateResult, premix_core::sqlx::Error>
296            where
297                E: premix_core::IntoExecutor<'a, DB = DB>
298            {
299                let mut executor = executor.into_executor();
300                let table_name = Self::table_name();
301                let set_clause = vec![ #( format!("{} = {}", #field_names, <DB as premix_core::SqlDialect>::placeholder(1 + #field_indices)) ),* ].join(", ");
302                let id_p = <DB as premix_core::SqlDialect>::placeholder(1 + #field_idents_len);
303                let sql = format!("UPDATE {} SET {} WHERE id = {}", table_name, set_clause, id_p);
304
305                let mut query = premix_core::sqlx::query::<DB>(&sql)
306                    #( .bind(&self.#field_idents) )*
307                    .bind(&self.id);
308
309                let result = executor.execute(query).await?;
310
311                if <DB as premix_core::SqlDialect>::rows_affected(&result) == 0 {
312                    Ok(premix_core::UpdateResult::NotFound)
313                } else {
314                    Ok(premix_core::UpdateResult::Success)
315                }
316            }
317        }
318    };
319
320    let delete_impl = if has_soft_delete {
321        quote! {
322            async fn delete<'a, E>(&mut self, executor: E) -> Result<(), premix_core::sqlx::Error>
323            where
324                E: premix_core::IntoExecutor<'a, DB = DB>
325            {
326                let mut executor = executor.into_executor();
327                let table_name = Self::table_name();
328                let id_p = <DB as premix_core::SqlDialect>::placeholder(1);
329                let sql = format!("UPDATE {} SET deleted_at = {} WHERE id = {}", table_name, <DB as premix_core::SqlDialect>::current_timestamp_fn(), id_p);
330
331                let query = premix_core::sqlx::query::<DB>(&sql).bind(&self.id);
332                executor.execute(query).await?;
333
334                self.deleted_at = Some("DELETED".to_string());
335                Ok(())
336            }
337            fn has_soft_delete() -> bool { true }
338        }
339    } else {
340        quote! {
341            async fn delete<'a, E>(&mut self, executor: E) -> Result<(), premix_core::sqlx::Error>
342            where
343                E: premix_core::IntoExecutor<'a, DB = DB>
344            {
345                let mut executor = executor.into_executor();
346                let table_name = Self::table_name();
347                let id_p = <DB as premix_core::SqlDialect>::placeholder(1);
348                let sql = format!("DELETE FROM {} WHERE id = {}", table_name, id_p);
349
350                let query = premix_core::sqlx::query::<DB>(&sql).bind(&self.id);
351                executor.execute(query).await?;
352
353                Ok(())
354            }
355            fn has_soft_delete() -> bool { false }
356        }
357    };
358
359    let mut related_model_bounds = Vec::new();
360    for field in all_fields {
361        for attr in &field.attrs {
362            if (attr.path().is_ident("has_many") || attr.path().is_ident("belongs_to"))
363                && let Ok(related_ident) = attr.parse_args::<syn::Ident>()
364            {
365                related_model_bounds.push(quote! { #related_ident: premix_core::Model<DB> });
366            }
367        }
368    }
369
370    // Generic Implementation
371    Ok(quote! {
372        impl<'r, R> premix_core::sqlx::FromRow<'r, R> for #struct_name
373        where
374            R: premix_core::sqlx::Row,
375            R::Database: premix_core::sqlx::Database,
376            #(
377                #field_types: premix_core::sqlx::Type<R::Database> + premix_core::sqlx::Decode<'r, R::Database>,
378            )*
379            for<'c> &'c str: premix_core::sqlx::ColumnIndex<R>,
380        {
381            fn from_row(row: &'r R) -> Result<Self, premix_core::sqlx::Error> {
382                use premix_core::sqlx::Row;
383                Ok(Self {
384                    #(
385                        #field_idents: row.try_get(#field_names)?,
386                    )*
387                    #(
388                        #ignored_field_idents: None,
389                    )*
390                })
391            }
392        }
393
394        #[premix_core::async_trait::async_trait]
395        impl<DB> premix_core::Model<DB> for #struct_name
396        where
397            DB: premix_core::SqlDialect,
398            for<'c> &'c str: premix_core::sqlx::ColumnIndex<DB::Row>,
399            usize: premix_core::sqlx::ColumnIndex<DB::Row>,
400            for<'q> <DB as premix_core::sqlx::Database>::Arguments<'q>: premix_core::sqlx::IntoArguments<'q, DB>,
401            for<'c> &'c mut <DB as premix_core::sqlx::Database>::Connection: premix_core::sqlx::Executor<'c, Database = DB>,
402            i32: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
403            i64: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
404            String: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
405            bool: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
406            Option<String>: premix_core::sqlx::Type<DB> + for<'q> premix_core::sqlx::Encode<'q, DB> + for<'r> premix_core::sqlx::Decode<'r, DB>,
407            #( #related_model_bounds, )*
408        {
409            fn table_name() -> &'static str {
410                #table_name
411            }
412
413            fn create_table_sql() -> String {
414                let mut cols = vec!["id ".to_string() + <DB as premix_core::SqlDialect>::auto_increment_pk()];
415                #(
416                    if #field_names != "id" {
417                        let field_name: &str = #field_names;
418                        let sql_type = if field_name.ends_with("_id") {
419                            <DB as premix_core::SqlDialect>::int_type()
420                        } else {
421                            match field_name {
422                                "name" | "title" | "status" | "email" | "role" => <DB as premix_core::SqlDialect>::text_type(),
423                                "age" | "version" | "price" | "balance" => <DB as premix_core::SqlDialect>::int_type(),
424                                "is_active" => <DB as premix_core::SqlDialect>::bool_type(),
425                                "deleted_at" => <DB as premix_core::SqlDialect>::text_type(),
426                                _ => <DB as premix_core::SqlDialect>::text_type(),
427                            }
428                        };
429                        cols.push(format!("{} {}", #field_names, sql_type));
430                    }
431                )*
432                format!("CREATE TABLE IF NOT EXISTS {} ({})", #table_name, cols.join(", "))
433            }
434
435            fn list_columns() -> Vec<String> {
436                vec![ #( #field_names.to_string() ),* ]
437            }
438
439            async fn save<'a, E>(&mut self, executor: E) -> Result<(), premix_core::sqlx::Error>
440            where
441                E: premix_core::IntoExecutor<'a, DB = DB>
442            {
443                let mut executor = executor.into_executor();
444                use premix_core::ModelHooks;
445                self.before_save().await?;
446
447                // Filter out 'id' and 'version' for INSERT
448                let columns: Vec<&str> = vec![ #( #field_names ),* ]
449                    .into_iter()
450                    .filter(|&c| {
451                        if c == "id" { return self.id != 0; }
452                        true
453                    })
454                    .collect();
455
456                let placeholders = (1..=columns.len())
457                    .map(|i| <DB as premix_core::SqlDialect>::placeholder(i))
458                    .collect::<Vec<_>>()
459                    .join(", ");
460
461                let sql = format!("INSERT INTO {} ({}) VALUES ({})", #table_name, columns.join(", "), placeholders);
462
463                let mut query = premix_core::sqlx::query::<DB>(&sql);
464
465                // Bind only non-id/version fields
466                #(
467                    if #field_names != "id" {
468                        query = query.bind(&self.#field_idents);
469                    } else {
470                        if self.id != 0 {
471                            query = query.bind(&self.id);
472                        }
473                    }
474                )*
475
476                let result = executor.execute(query).await?;
477
478                // Sync the ID from Database
479                let last_id = <DB as premix_core::SqlDialect>::last_insert_id(&result);
480                if last_id > 0 {
481                     self.id = last_id as i32;
482                }
483
484                self.after_save().await?;
485                Ok(())
486            }
487
488            #update_impl
489            #delete_impl
490
491            async fn find_by_id<'a, E>(executor: E, id: i32) -> Result<Option<Self>, premix_core::sqlx::Error>
492            where
493                E: premix_core::IntoExecutor<'a, DB = DB>
494            {
495                let mut executor = executor.into_executor();
496                let p = <DB as premix_core::SqlDialect>::placeholder(1);
497                let mut where_clause = format!("WHERE id = {}", p);
498                if Self::has_soft_delete() {
499                    where_clause.push_str(" AND deleted_at IS NULL");
500                }
501                let sql = format!("SELECT * FROM {} {} LIMIT 1", #table_name, where_clause);
502                let query = premix_core::sqlx::query_as::<DB, Self>(&sql).bind(id);
503
504                executor.fetch_optional(query).await
505            }
506
507            async fn eager_load<'a, E>(models: &mut [Self], relation: &str, executor: E) -> Result<(), premix_core::sqlx::Error>
508            where
509                E: premix_core::IntoExecutor<'a, DB = DB>
510            {
511                let mut executor = executor.into_executor();
512                #eager_load_body
513            }
514        }
515    })
516}
517
518fn is_ignored(field: &Field) -> bool {
519    for attr in &field.attrs {
520        if attr.path().is_ident("premix")
521            && let Ok(meta) = attr.parse_args::<syn::Ident>()
522            && meta == "ignore"
523        {
524            return true;
525        }
526    }
527    false
528}