oxidite_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, DeriveInput, Data, Fields};
4
5#[proc_macro_derive(Model, attributes(validate))]
6pub fn derive_model(input: TokenStream) -> TokenStream {
7    let input = parse_macro_input!(input as DeriveInput);
8    let name = input.ident;
9    
10    // Generate table name: snake_case + 's' (very naive pluralization for now)
11    let table_name = name.to_string().to_lowercase() + "s";
12    
13    // Get fields
14    let fields = match input.data {
15        Data::Struct(ref data) => {
16            match data.fields {
17                Fields::Named(ref fields) => {
18                    fields.named.iter().collect::<Vec<_>>()
19                },
20                _ => Vec::new(),
21            }
22        },
23        _ => Vec::new(),
24    };
25    
26    let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
27    let field_names_str: Vec<_> = field_names.iter().map(|f| f.to_string()).collect();
28    
29    // Check for timestamp fields
30    let has_created_at = field_names_str.contains(&"created_at".to_string());
31    let has_updated_at = field_names_str.contains(&"updated_at".to_string());
32    
33    // Check for soft delete field
34    let has_deleted_at = field_names_str.contains(&"deleted_at".to_string());
35
36    // Filter out 'id' for create/update columns
37    // Also filter out timestamps from bind list because we will handle them manually
38    let non_id_fields: Vec<_> = fields.iter()
39        .filter(|f| {
40            let name = f.ident.as_ref().unwrap().to_string();
41            name != "id" && name != "created_at" && name != "updated_at" && name != "deleted_at"
42        })
43        .collect();
44        
45    let non_id_names: Vec<_> = non_id_fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
46    let non_id_names_str: Vec<_> = non_id_names.iter().map(|f| f.to_string()).collect();
47    
48    // Create query generation
49    let mut create_cols_list = non_id_names_str.clone();
50    if has_created_at { create_cols_list.push("created_at".to_string()); }
51    if has_updated_at { create_cols_list.push("updated_at".to_string()); }
52    // deleted_at is usually null on creation, so we skip it unless we want to support creating deleted records
53    
54    let create_cols = create_cols_list.join(", ");
55    let create_placeholders: Vec<_> = (1..=create_cols_list.len()).map(|i| format!("${}", i)).collect();
56    let create_placeholders_str = create_placeholders.join(", ");
57    let create_query = format!("INSERT INTO {} ({}) VALUES ({})", table_name, create_cols, create_placeholders_str);
58    
59    // Update query generation
60    let mut update_sets_list = Vec::new();
61    for (i, name) in non_id_names_str.iter().enumerate() {
62        update_sets_list.push(format!("{} = ${}", name, i + 1));
63    }
64    
65    let mut param_count = non_id_names_str.len();
66    if has_updated_at {
67        param_count += 1;
68        update_sets_list.push(format!("updated_at = ${}", param_count));
69    }
70    
71    let update_sets_str = update_sets_list.join(", ");
72    let update_where = format!("WHERE id = ${}", param_count + 1);
73    let update_query = format!("UPDATE {} SET {} {}", table_name, update_sets_str, update_where);
74    
75    // Delete query generation
76    let hard_delete_query = format!("DELETE FROM {} WHERE id = $1", table_name);
77    
78    let delete_impl = if has_deleted_at {
79        let soft_delete_query = format!("UPDATE {} SET deleted_at = $1 WHERE id = $2", table_name);
80        quote! {
81            async fn delete(&self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
82                let now = oxidite_db::chrono::Utc::now().timestamp();
83                let query = oxidite_db::sqlx::query(#soft_delete_query)
84                    .bind(now)
85                    .bind(&self.id);
86                db.execute_query(query).await?;
87                Ok(())
88            }
89        }
90    } else {
91        quote! {
92            async fn delete(&self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
93                let query = oxidite_db::sqlx::query(#hard_delete_query)
94                    .bind(&self.id);
95                db.execute_query(query).await?;
96                Ok(())
97            }
98        }
99    };
100
101    // Code generation parts for timestamps
102    let created_at_logic = if has_created_at {
103        quote! {
104            let now = oxidite_db::chrono::Utc::now().timestamp();
105            self.created_at = now;
106            let query = query.bind(now);
107        }
108    } else {
109        quote! {}
110    };
111
112    let updated_at_create_logic = if has_updated_at {
113        quote! {
114            let now = oxidite_db::chrono::Utc::now().timestamp();
115            self.updated_at = now;
116            let query = query.bind(now);
117        }
118    } else {
119        quote! {}
120    };
121
122    let updated_at_update_logic = if has_updated_at {
123        quote! {
124            let now = oxidite_db::chrono::Utc::now().timestamp();
125            self.updated_at = now;
126            let query = query.bind(now);
127        }
128    } else {
129        quote! {}
130    };
131
132    // Generate validation checks
133    let mut validation_checks = Vec::new();
134    for field in &fields {
135        let field_name = field.ident.as_ref().unwrap();
136        for attr in &field.attrs {
137            if attr.path().is_ident("validate") {
138                let attr_str = attr.to_token_stream().to_string();
139                if attr_str.contains("email") {
140                    validation_checks.push(quote! {
141                        {
142                            static EMAIL_REGEX: oxidite_db::once_cell::sync::Lazy<oxidite_db::regex::Regex> = 
143                                oxidite_db::once_cell::sync::Lazy::new(|| oxidite_db::regex::Regex::new(r"^[^@\s]+@[^@\s]+\.[^@\s]+$").unwrap());
144                            if !EMAIL_REGEX.is_match(&self.#field_name) {
145                                return Err(format!("Invalid email format for field {}", stringify!(#field_name)));
146                            }
147                        }
148                    });
149                }
150            }
151        }
152    }
153
154    let expanded = quote! {
155        #[oxidite_db::async_trait]
156        impl oxidite_db::Model for #name {
157            fn table_name() -> &'static str {
158                #table_name
159            }
160            
161            fn fields() -> &'static [&'static str] {
162                &[#(#field_names_str),*]
163            }
164            
165            fn has_soft_delete() -> bool {
166                #has_deleted_at
167            }
168            
169            async fn create(&mut self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
170                let query = oxidite_db::sqlx::query(#create_query);
171                #(
172                    let query = query.bind(&self.#non_id_names);
173                )*
174                #created_at_logic
175                #updated_at_create_logic
176                
177                db.execute_query(query).await?;
178                Ok(())
179            }
180            
181            async fn update(&mut self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
182                let query = oxidite_db::sqlx::query(#update_query);
183                #(
184                    let query = query.bind(&self.#non_id_names);
185                )*
186                #updated_at_update_logic
187                
188                let query = query.bind(&self.id);
189                db.execute_query(query).await?;
190                Ok(())
191            }
192            
193            #delete_impl
194            
195            async fn force_delete(&self, db: &impl oxidite_db::Database) -> oxidite_db::Result<()> {
196                let query = oxidite_db::sqlx::query(#hard_delete_query)
197                    .bind(&self.id);
198                db.execute_query(query).await?;
199                Ok(())
200            }
201            
202            fn validate(&self) -> std::result::Result<(), String> {
203                #(#validation_checks)*
204                Ok(())
205            }
206        }
207    };
208    
209    TokenStream::from(expanded)
210}