premix_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4    Attribute, Data, DeriveInput, Field, Fields, Ident, Token, parse_macro_input,
5    punctuated::Punctuated,
6};
7
8mod relations;
9
10#[proc_macro_derive(Model, attributes(has_many, belongs_to, premix))]
11pub fn derive_model(input: TokenStream) -> TokenStream {
12    let input = parse_macro_input!(input as DeriveInput);
13    match derive_model_impl(&input) {
14        Ok(tokens) => TokenStream::from(tokens),
15        Err(err) => TokenStream::from(err.to_compile_error()),
16    }
17}
18
19fn derive_model_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
20    let impl_block = generate_generic_impl(input)?;
21    let rel_block = relations::impl_relations(input)?;
22    Ok(quote! {
23        #impl_block
24        #rel_block
25    })
26}
27
28#[cfg(test)]
29mod tests {
30    use syn::parse_quote;
31
32    use super::*;
33
34    #[test]
35    fn generate_generic_impl_includes_table_and_columns() {
36        let input: DeriveInput = parse_quote! {
37            struct User {
38                id: i32,
39                name: String,
40                version: i32,
41                deleted_at: Option<String>,
42            }
43        };
44        let tokens = generate_generic_impl(&input).unwrap().to_string();
45        assert!(tokens.contains("CREATE TABLE IF NOT EXISTS"));
46        assert!(tokens.contains("users"));
47        assert!(tokens.contains("deleted_at"));
48        assert!(tokens.contains("version"));
49    }
50
51    #[test]
52    fn generate_generic_impl_rejects_tuple_struct() {
53        let input: DeriveInput = parse_quote! {
54            struct User(i32, String);
55        };
56        let err = generate_generic_impl(&input).unwrap_err();
57        assert!(err.to_string().contains("named fields"));
58    }
59
60    #[test]
61    fn generate_generic_impl_rejects_non_struct() {
62        let input: DeriveInput = parse_quote! {
63            enum User {
64                A,
65                B,
66            }
67        };
68        let err = generate_generic_impl(&input).unwrap_err();
69        assert!(err.to_string().contains("only supports structs"));
70    }
71
72    #[test]
73    fn generate_generic_impl_version_update_branch() {
74        let input: DeriveInput = parse_quote! {
75            struct User {
76                id: i32,
77                version: i32,
78                name: String,
79            }
80        };
81        let tokens = generate_generic_impl(&input).unwrap().to_string();
82        assert!(tokens.contains("version = version + 1"));
83    }
84
85    #[test]
86    fn generate_generic_impl_no_version_branch() {
87        let input: DeriveInput = parse_quote! {
88            struct User {
89                id: i32,
90                name: String,
91            }
92        };
93        let tokens = generate_generic_impl(&input).unwrap().to_string();
94        assert!(!tokens.contains("version = version + 1"));
95    }
96
97    #[test]
98    fn generate_generic_impl_includes_default_hooks_and_validation() {
99        let input: DeriveInput = parse_quote! {
100            struct User {
101                id: i32,
102                name: String,
103            }
104        };
105        let tokens = generate_generic_impl(&input).unwrap().to_string();
106        assert!(tokens.contains("ModelHooks"));
107        assert!(tokens.contains("ModelValidation"));
108    }
109
110    #[test]
111    fn generate_generic_impl_skips_custom_hooks_and_validation() {
112        let input: DeriveInput = parse_quote! {
113            #[premix(custom_hooks, custom_validation)]
114            struct User {
115                id: i32,
116                name: String,
117            }
118        };
119        let tokens = generate_generic_impl(&input).unwrap().to_string();
120        assert!(!tokens.contains("impl premix_orm :: ModelHooks"));
121        assert!(!tokens.contains("impl premix_orm :: ModelValidation"));
122    }
123
124    #[test]
125    fn is_ignored_detects_attribute() {
126        let field: Field = parse_quote! {
127            #[premix(ignore)]
128            ignored: Option<String>
129        };
130        assert!(is_ignored(&field));
131    }
132
133    #[test]
134    fn is_ignored_false_for_other_attrs() {
135        let field: Field = parse_quote! {
136            #[serde(skip)]
137            name: String
138        };
139        assert!(!is_ignored(&field));
140    }
141
142    #[test]
143    fn is_ignored_false_for_premix_other_arg() {
144        let field: Field = parse_quote! {
145            #[premix(skip)]
146            name: String
147        };
148        assert!(!is_ignored(&field));
149    }
150
151    #[test]
152    fn is_ignored_false_when_premix_has_no_args() {
153        let field: Field = parse_quote! {
154            #[premix]
155            name: String
156        };
157        assert!(!is_ignored(&field));
158    }
159
160    #[test]
161    fn derive_model_impl_emits_tokens() {
162        let input: DeriveInput = parse_quote! {
163            struct User {
164                id: i32,
165                name: String,
166            }
167        };
168        let tokens = derive_model_impl(&input).unwrap().to_string();
169        assert!(tokens.contains("impl"));
170    }
171
172    #[test]
173    fn derive_model_impl_propagates_error() {
174        let input: DeriveInput = parse_quote! {
175            enum User {
176                A,
177            }
178        };
179        let err = derive_model_impl(&input).unwrap_err();
180        assert!(err.to_string().contains("only supports structs"));
181    }
182
183    #[test]
184    fn generate_generic_impl_includes_soft_delete_delete_impl() {
185        let input: DeriveInput = parse_quote! {
186            struct AuditLog {
187                id: i32,
188                deleted_at: Option<String>,
189            }
190        };
191        let tokens = generate_generic_impl(&input).unwrap().to_string();
192        assert!(tokens.contains("deleted_at ="));
193        assert!(tokens.contains("has_soft_delete"));
194    }
195
196    #[test]
197    fn generate_generic_impl_ignores_marked_fields() {
198        let input: DeriveInput = parse_quote! {
199            struct User {
200                id: i32,
201                name: String,
202                #[premix(ignore)]
203                temp: Option<String>,
204            }
205        };
206        let tokens = generate_generic_impl(&input).unwrap().to_string();
207        assert!(tokens.contains("temp : None"));
208        assert!(!tokens.contains("\"temp\""));
209    }
210
211    #[test]
212    fn generate_generic_impl_adds_relation_bounds() {
213        let input: DeriveInput = parse_quote! {
214            struct User {
215                id: i32,
216                #[has_many(Post)]
217                posts: Vec<Post>,
218            }
219        };
220        let tokens = generate_generic_impl(&input).unwrap().to_string();
221        assert!(tokens.contains("Post : premix_orm :: Model < DB >"));
222    }
223
224    #[test]
225    fn generate_generic_impl_records_field_names() {
226        let input: DeriveInput = parse_quote! {
227            struct Account {
228                id: i32,
229                user_id: i32,
230                is_active: bool,
231            }
232        };
233        let tokens = generate_generic_impl(&input).unwrap().to_string();
234        assert!(tokens.contains("\"user_id\""));
235        assert!(tokens.contains("\"is_active\""));
236    }
237}
238
239fn generate_generic_impl(input: &DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
240    let struct_name = &input.ident;
241    let table_name = struct_name.to_string().to_lowercase() + "s";
242    let custom_hooks = has_premix_flag(&input.attrs, "custom_hooks");
243    let custom_validation = has_premix_flag(&input.attrs, "custom_validation");
244
245    let all_fields = if let Data::Struct(data) = &input.data {
246        if let Fields::Named(fields) = &data.fields {
247            &fields.named
248        } else {
249            return Err(syn::Error::new_spanned(
250                &data.fields,
251                "Premix Model only supports structs with named fields",
252            ));
253        }
254    } else {
255        return Err(syn::Error::new_spanned(
256            input,
257            "Premix Model only supports structs",
258        ));
259    };
260
261    let mut db_fields = Vec::new();
262    let mut ignored_field_idents = Vec::new();
263
264    for field in all_fields {
265        if is_ignored(field) {
266            ignored_field_idents.push(field.ident.as_ref().unwrap());
267        } else {
268            db_fields.push(field);
269        }
270    }
271
272    let field_idents: Vec<_> = db_fields
273        .iter()
274        .map(|f| f.ident.as_ref().unwrap())
275        .collect();
276    let field_types: Vec<_> = db_fields.iter().map(|f| &f.ty).collect();
277    let field_indices: Vec<_> = (0..db_fields.len()).collect();
278    let field_names: Vec<_> = field_idents.iter().map(|id| id.to_string()).collect();
279    let field_idents_len = field_idents.len();
280
281    let eager_load_body = relations::generate_eager_load_body(input)?;
282    let has_version = field_names.contains(&"version".to_string());
283    let has_soft_delete = field_names.contains(&"deleted_at".to_string());
284
285    let update_impl = if has_version {
286        quote! {
287            async fn update<'a, E>(&mut self, executor: E) -> Result<premix_orm::UpdateResult, premix_orm::sqlx::Error>
288            where
289                E: premix_orm::IntoExecutor<'a, DB = DB>
290            {
291                let mut executor = executor.into_executor();
292                let table_name = Self::table_name();
293                let set_clause = vec![ #( format!("{} = {}", #field_names, <DB as premix_orm::SqlDialect>::placeholder(1 + #field_indices)) ),* ].join(", ");
294                let id_p = <DB as premix_orm::SqlDialect>::placeholder(1 + #field_idents_len);
295                let ver_p = <DB as premix_orm::SqlDialect>::placeholder(2 + #field_idents_len);
296                let sql = format!(
297                    "UPDATE {} SET {}, version = version + 1 WHERE id = {} AND version = {}",
298                    table_name, set_clause, id_p, ver_p
299                );
300
301                let mut query = premix_orm::sqlx::query::<DB>(&sql)
302                    #( .bind(&self.#field_idents) )*
303                    .bind(&self.id)
304                    .bind(&self.version);
305
306                let result = executor.execute(query).await?;
307
308                if <DB as premix_orm::SqlDialect>::rows_affected(&result) == 0 {
309                    let exists_p = <DB as premix_orm::SqlDialect>::placeholder(1);
310                    let exists_sql = format!("SELECT id FROM {} WHERE id = {}", table_name, exists_p);
311                    let exists_query = premix_orm::sqlx::query_as::<DB, (i32,)>(&exists_sql).bind(&self.id);
312                    let exists = executor.fetch_optional(exists_query).await?;
313
314                    if exists.is_none() {
315                        Ok(premix_orm::UpdateResult::NotFound)
316                    } else {
317                        Ok(premix_orm::UpdateResult::VersionConflict)
318                    }
319                } else {
320                    self.version += 1;
321                    Ok(premix_orm::UpdateResult::Success)
322                }
323            }
324        }
325    } else {
326        quote! {
327            async fn update<'a, E>(&mut self, executor: E) -> Result<premix_orm::UpdateResult, premix_orm::sqlx::Error>
328            where
329                E: premix_orm::IntoExecutor<'a, DB = DB>
330            {
331                let mut executor = executor.into_executor();
332                let table_name = Self::table_name();
333                let set_clause = vec![ #( format!("{} = {}", #field_names, <DB as premix_orm::SqlDialect>::placeholder(1 + #field_indices)) ),* ].join(", ");
334                let id_p = <DB as premix_orm::SqlDialect>::placeholder(1 + #field_idents_len);
335                let sql = format!("UPDATE {} SET {} WHERE id = {}", table_name, set_clause, id_p);
336
337                let mut query = premix_orm::sqlx::query::<DB>(&sql)
338                    #( .bind(&self.#field_idents) )*
339                    .bind(&self.id);
340
341                let result = executor.execute(query).await?;
342
343                if <DB as premix_orm::SqlDialect>::rows_affected(&result) == 0 {
344                    Ok(premix_orm::UpdateResult::NotFound)
345                } else {
346                    Ok(premix_orm::UpdateResult::Success)
347                }
348            }
349        }
350    };
351
352    let delete_impl = if has_soft_delete {
353        quote! {
354            async fn delete<'a, E>(&mut self, executor: E) -> Result<(), premix_orm::sqlx::Error>
355            where
356                E: premix_orm::IntoExecutor<'a, DB = DB>
357            {
358                let mut executor = executor.into_executor();
359                let table_name = Self::table_name();
360                let id_p = <DB as premix_orm::SqlDialect>::placeholder(1);
361                let sql = format!("UPDATE {} SET deleted_at = {} WHERE id = {}", table_name, <DB as premix_orm::SqlDialect>::current_timestamp_fn(), id_p);
362
363                let query = premix_orm::sqlx::query::<DB>(&sql).bind(&self.id);
364                executor.execute(query).await?;
365
366                self.deleted_at = Some("DELETED".to_string());
367                Ok(())
368            }
369            fn has_soft_delete() -> bool { true }
370        }
371    } else {
372        quote! {
373            async fn delete<'a, E>(&mut self, executor: E) -> Result<(), premix_orm::sqlx::Error>
374            where
375                E: premix_orm::IntoExecutor<'a, DB = DB>
376            {
377                let mut executor = executor.into_executor();
378                let table_name = Self::table_name();
379                let id_p = <DB as premix_orm::SqlDialect>::placeholder(1);
380                let sql = format!("DELETE FROM {} WHERE id = {}", table_name, id_p);
381
382                let query = premix_orm::sqlx::query::<DB>(&sql).bind(&self.id);
383                executor.execute(query).await?;
384
385                Ok(())
386            }
387            fn has_soft_delete() -> bool { false }
388        }
389    };
390
391    let mut related_model_bounds = Vec::new();
392    for field in all_fields {
393        for attr in &field.attrs {
394            if (attr.path().is_ident("has_many") || attr.path().is_ident("belongs_to"))
395                && let Ok(related_ident) = attr.parse_args::<syn::Ident>()
396            {
397                related_model_bounds.push(quote! { #related_ident: premix_orm::Model<DB> });
398            }
399        }
400    }
401
402    let hooks_impl = if custom_hooks {
403        quote! {}
404    } else {
405        quote! {
406            #[premix_orm::async_trait::async_trait]
407            impl premix_orm::ModelHooks for #struct_name {}
408        }
409    };
410
411    let validation_impl = if custom_validation {
412        quote! {}
413    } else {
414        quote! {
415            impl premix_orm::ModelValidation for #struct_name {}
416        }
417    };
418
419    // Generic Implementation
420    Ok(quote! {
421        impl<'r, R> premix_orm::sqlx::FromRow<'r, R> for #struct_name
422        where
423            R: premix_orm::sqlx::Row,
424            R::Database: premix_orm::sqlx::Database,
425            #(
426                #field_types: premix_orm::sqlx::Type<R::Database> + premix_orm::sqlx::Decode<'r, R::Database>,
427            )*
428            for<'c> &'c str: premix_orm::sqlx::ColumnIndex<R>,
429        {
430            fn from_row(row: &'r R) -> Result<Self, premix_orm::sqlx::Error> {
431                use premix_orm::sqlx::Row;
432                Ok(Self {
433                    #(
434                        #field_idents: row.try_get(#field_names)?,
435                    )*
436                    #(
437                        #ignored_field_idents: None,
438                    )*
439                })
440            }
441        }
442
443        #[premix_orm::async_trait::async_trait]
444        impl<DB> premix_orm::Model<DB> for #struct_name
445        where
446            DB: premix_orm::SqlDialect,
447            for<'c> &'c str: premix_orm::sqlx::ColumnIndex<DB::Row>,
448            usize: premix_orm::sqlx::ColumnIndex<DB::Row>,
449            for<'q> <DB as premix_orm::sqlx::Database>::Arguments<'q>: premix_orm::sqlx::IntoArguments<'q, DB>,
450            for<'c> &'c mut <DB as premix_orm::sqlx::Database>::Connection: premix_orm::sqlx::Executor<'c, Database = DB>,
451            i32: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
452            i64: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
453            String: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
454            bool: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
455            Option<String>: premix_orm::sqlx::Type<DB> + for<'q> premix_orm::sqlx::Encode<'q, DB> + for<'r> premix_orm::sqlx::Decode<'r, DB>,
456            #( #related_model_bounds, )*
457        {
458            fn table_name() -> &'static str {
459                #table_name
460            }
461
462            fn create_table_sql() -> String {
463                let mut cols = vec!["id ".to_string() + <DB as premix_orm::SqlDialect>::auto_increment_pk()];
464                #(
465                    if #field_names != "id" {
466                        let field_name: &str = #field_names;
467                        let sql_type = if field_name.ends_with("_id") {
468                            <DB as premix_orm::SqlDialect>::int_type()
469                        } else {
470                            match field_name {
471                                "name" | "title" | "status" | "email" | "role" => <DB as premix_orm::SqlDialect>::text_type(),
472                                "age" | "version" | "price" | "balance" => <DB as premix_orm::SqlDialect>::int_type(),
473                                "is_active" => <DB as premix_orm::SqlDialect>::bool_type(),
474                                "deleted_at" => <DB as premix_orm::SqlDialect>::text_type(),
475                                _ => <DB as premix_orm::SqlDialect>::text_type(),
476                            }
477                        };
478                        cols.push(format!("{} {}", #field_names, sql_type));
479                    }
480                )*
481                format!("CREATE TABLE IF NOT EXISTS {} ({})", #table_name, cols.join(", "))
482            }
483
484            fn list_columns() -> Vec<String> {
485                vec![ #( #field_names.to_string() ),* ]
486            }
487
488            async fn save<'a, E>(&mut self, executor: E) -> Result<(), premix_orm::sqlx::Error>
489            where
490                E: premix_orm::IntoExecutor<'a, DB = DB>
491            {
492                let mut executor = executor.into_executor();
493                use premix_orm::ModelHooks;
494                self.before_save().await?;
495
496                // Filter out 'id' and 'version' for INSERT
497                let columns: Vec<&str> = vec![ #( #field_names ),* ]
498                    .into_iter()
499                    .filter(|&c| {
500                        if c == "id" { return self.id != 0; }
501                        true
502                    })
503                    .collect();
504
505                let placeholders = (1..=columns.len())
506                    .map(|i| <DB as premix_orm::SqlDialect>::placeholder(i))
507                    .collect::<Vec<_>>()
508                    .join(", ");
509
510                let sql = format!("INSERT INTO {} ({}) VALUES ({})", #table_name, columns.join(", "), placeholders);
511
512                let mut query = premix_orm::sqlx::query::<DB>(&sql);
513
514                // Bind only non-id/version fields
515                #(
516                    if #field_names != "id" {
517                        query = query.bind(&self.#field_idents);
518                    } else {
519                        if self.id != 0 {
520                            query = query.bind(&self.id);
521                        }
522                    }
523                )*
524
525                let result = executor.execute(query).await?;
526
527                // Sync the ID from Database
528                let last_id = <DB as premix_orm::SqlDialect>::last_insert_id(&result);
529                if last_id > 0 {
530                     self.id = last_id as i32;
531                }
532
533                self.after_save().await?;
534                Ok(())
535            }
536
537            #update_impl
538            #delete_impl
539
540            async fn find_by_id<'a, E>(executor: E, id: i32) -> Result<Option<Self>, premix_orm::sqlx::Error>
541            where
542                E: premix_orm::IntoExecutor<'a, DB = DB>
543            {
544                let mut executor = executor.into_executor();
545                let p = <DB as premix_orm::SqlDialect>::placeholder(1);
546                let mut where_clause = format!("WHERE id = {}", p);
547                if Self::has_soft_delete() {
548                    where_clause.push_str(" AND deleted_at IS NULL");
549                }
550                let sql = format!("SELECT * FROM {} {} LIMIT 1", #table_name, where_clause);
551                let query = premix_orm::sqlx::query_as::<DB, Self>(&sql).bind(id);
552
553                executor.fetch_optional(query).await
554            }
555
556            async fn eager_load<'a, E>(models: &mut [Self], relation: &str, executor: E) -> Result<(), premix_orm::sqlx::Error>
557            where
558                E: premix_orm::IntoExecutor<'a, DB = DB>
559            {
560                let mut executor = executor.into_executor();
561                #eager_load_body
562            }
563        }
564
565        #hooks_impl
566        #validation_impl
567    })
568}
569
570fn is_ignored(field: &Field) -> bool {
571    for attr in &field.attrs {
572        if attr.path().is_ident("premix")
573            && let Ok(meta) = attr.parse_args::<syn::Ident>()
574            && meta == "ignore"
575        {
576            return true;
577        }
578    }
579    false
580}
581
582fn has_premix_flag(attrs: &[Attribute], flag: &str) -> bool {
583    for attr in attrs {
584        if attr.path().is_ident("premix") {
585            let args = attr.parse_args_with(Punctuated::<Ident, Token![,]>::parse_terminated);
586            if let Ok(args) = args {
587                if args.iter().any(|ident| ident == flag) {
588                    return true;
589                }
590            }
591        }
592    }
593    false
594}