Skip to main content

chopin_orm_macro/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{Data, DeriveInput, Fields, LitStr, parse_macro_input};
6
7#[proc_macro_derive(Model, attributes(model))]
8pub fn derive_model(input: TokenStream) -> TokenStream {
9    let input = parse_macro_input!(input as DeriveInput);
10    let name = &input.ident;
11
12    let mut table_name = name.to_string().to_lowercase() + "s"; // Default plural table name
13    let mut pk_fields = Vec::new();
14    let mut generated_fields = Vec::new();
15    let mut columns = Vec::new();
16    let mut has_many_rels = Vec::new(); // stores (related_model_ident, fk_column_name_str)
17
18    // Parse struct attributes for table_name
19    for attr in &input.attrs {
20        if attr.path().is_ident("model") {
21            let _ = attr.parse_nested_meta(|meta| {
22                if meta.path.is_ident("table_name") {
23                    let value = meta.value()?;
24                    let s: LitStr = value.parse()?;
25                    table_name = s.value();
26                }
27                if meta.path.is_ident("has_many") {
28                    let mut target_ident: Option<syn::Ident> = None;
29                    let mut fk_name = String::new();
30
31                    let _ = meta.parse_nested_meta(|inner| {
32                        if inner.path.is_ident("fk") {
33                            let value = inner.value()?;
34                            let s: LitStr = value.parse()?;
35                            fk_name = s.value();
36                        } else if target_ident.is_none() {
37                            target_ident = inner.path.get_ident().cloned();
38                        }
39                        Ok(())
40                    });
41
42                    if let Some(ident) = target_ident {
43                        has_many_rels.push((ident, fk_name));
44                    }
45                }
46                Ok(())
47            });
48        }
49    }
50
51    let mut field_types = Vec::new();
52    let mut non_pk_fields = Vec::new();
53    let mut non_pk_types = Vec::new();
54    let mut belongs_to_fks = Vec::new(); // stores (field_ident, related_model_ident)
55
56    let fields_list = if let Data::Struct(data_struct) = &input.data {
57        if let Fields::Named(syn_fields) = &data_struct.fields {
58            let mut extracted = Vec::new();
59            for f in &syn_fields.named {
60                let field_name = match &f.ident {
61                    Some(ident) => ident,
62                    None => {
63                        return syn::Error::new_spanned(f, "All fields must have names")
64                            .to_compile_error()
65                            .into();
66                    }
67                };
68                let field_name_str = field_name.to_string();
69                columns.push(field_name_str.clone());
70                field_types.push(f.ty.clone());
71
72                let mut is_pk = false;
73                let mut is_gen = false;
74                // Check for primary_key attribute
75                for attr in &f.attrs {
76                    if attr.path().is_ident("model") {
77                        let _ = attr.parse_nested_meta(|meta| {
78                            if meta.path.is_ident("primary_key") {
79                                is_pk = true;
80                            }
81                            if meta.path.is_ident("generated") {
82                                is_gen = true;
83                            }
84                            if meta.path.is_ident("belongs_to") {
85                                let _ = meta.parse_nested_meta(|inner| {
86                                    if let Some(ident) = inner.path.get_ident() {
87                                        belongs_to_fks.push((field_name.clone(), ident.clone()));
88                                    }
89                                    Ok(())
90                                });
91                            }
92                            Ok(())
93                        });
94                    }
95                }
96
97                if is_pk {
98                    pk_fields.push(field_name.clone());
99                    let ty = &f.ty;
100                    let ty_str = quote::quote!(#ty).to_string().replace(" ", "");
101                    if ty_str == "i32" || ty_str == "i64" {
102                        is_gen = true;
103                    }
104                } else {
105                    non_pk_fields.push(field_name.clone());
106                    non_pk_types.push(f.ty.clone());
107                }
108
109                if is_gen {
110                    generated_fields.push(field_name.clone());
111                }
112
113                extracted.push(field_name.clone());
114            }
115            extracted
116        } else {
117            return syn::Error::new_spanned(
118                input,
119                "Model can only be derived for structs with named fields",
120            )
121            .to_compile_error()
122            .into();
123        }
124    } else {
125        return syn::Error::new_spanned(
126            input,
127            "Model can only be derived for structs with named fields",
128        )
129        .to_compile_error()
130        .into();
131    };
132
133    if pk_fields.is_empty() {
134        if columns.contains(&"id".to_string()) {
135            pk_fields.push(syn::Ident::new("id", proc_macro2::Span::call_site()));
136            generated_fields.push(syn::Ident::new("id", proc_macro2::Span::call_site()));
137        } else {
138            return syn::Error::new_spanned(name, "Model requires at least one primary key field (e.g., #[model(primary_key)] id) or a field named 'id'").to_compile_error().into();
139        }
140    }
141
142    let field_names_str: Vec<String> = columns.clone();
143    let pk_names_str: Vec<String> = pk_fields.iter().map(|i| i.to_string()).collect();
144    let gen_names_str: Vec<String> = generated_fields.iter().map(|i| i.to_string()).collect();
145
146    let column_enum_name =
147        syn::Ident::new(&format!("{}Column", name), proc_macro2::Span::call_site());
148    let _active_model_name = syn::Ident::new(
149        &format!("{}ActiveModel", name),
150        proc_macro2::Span::call_site(),
151    );
152
153    let gen_field_names = generated_fields.clone();
154    let gen_fields_len = generated_fields.len();
155
156    let mut column_defs = Vec::new();
157    let mut col_names = Vec::new();
158    let mut col_types = Vec::new();
159    for (i, field_name) in columns.iter().enumerate() {
160        let ty = &field_types[i];
161        let is_pk = pk_names_str.contains(field_name);
162        let is_gen = gen_names_str.contains(field_name);
163
164        let mut not_null = true;
165        let mut inner_ty = ty;
166        if let syn::Type::Path(type_path) = ty
167            && let Some(segment) = type_path.path.segments.last()
168            && segment.ident == "Option"
169        {
170            not_null = false;
171            if let syn::PathArguments::AngleBracketed(args) = &segment.arguments
172                && let Some(syn::GenericArgument::Type(t)) = args.args.first()
173            {
174                inner_ty = t;
175            }
176        }
177
178        let type_str = quote::quote!(#inner_ty).to_string().replace(" ", "");
179
180        let mut sql_type = match type_str.as_str() {
181            "i32" if is_gen && is_pk && pk_fields.len() == 1 => "SERIAL PRIMARY KEY".to_string(),
182            "i32" if is_gen => "SERIAL".to_string(),
183            "i32" => "INT".to_string(),
184            "i64" if is_gen && is_pk && pk_fields.len() == 1 => "BIGSERIAL PRIMARY KEY".to_string(),
185            "i64" if is_gen => "BIGSERIAL".to_string(),
186            "i64" => "BIGINT".to_string(),
187            "String" => "TEXT".to_string(),
188            "bool" => "BOOLEAN".to_string(),
189            "f64" => "DOUBLE PRECISION".to_string(),
190            _ => "TEXT".to_string(),
191        };
192
193        if is_pk && pk_fields.len() == 1 && !sql_type.contains("PRIMARY KEY") {
194            sql_type.push_str(" PRIMARY KEY");
195        }
196
197        if not_null && !sql_type.contains("PRIMARY KEY") && !sql_type.contains("SERIAL") {
198            sql_type.push_str(" NOT NULL");
199        }
200
201        col_names.push(field_name.clone());
202        col_types.push(sql_type.clone());
203        column_defs.push(format!("{} {}", field_name, sql_type));
204    }
205
206    if pk_fields.len() > 1 {
207        let pk_csv = pk_names_str.join(", ");
208        column_defs.push(format!("PRIMARY KEY ({})", pk_csv));
209    }
210
211    let base_sql = format!(
212        "CREATE TABLE IF NOT EXISTS {} (\n    {}\n)",
213        table_name,
214        column_defs.join(",\n    ")
215    );
216
217    let fk_fields: Vec<_> = belongs_to_fks.iter().map(|(f, _)| f.clone()).collect();
218    let fk_models: Vec<_> = belongs_to_fks.iter().map(|(_, m)| m.clone()).collect();
219
220    let hm_targets: Vec<_> = has_many_rels.iter().map(|(m, _)| m.clone()).collect();
221    let hm_fks: Vec<_> = has_many_rels.iter().map(|(_, fk)| fk.clone()).collect();
222    let fetch_hm_names: Vec<_> = hm_targets
223        .iter()
224        .map(|m| {
225            syn::Ident::new(
226                &format!("fetch_{}s", m.to_string().to_lowercase()),
227                proc_macro2::Span::call_site(),
228            )
229        })
230        .collect();
231    let fetch_bt_names: Vec<_> = fk_fields
232        .iter()
233        .map(|f| {
234            let fname = f.to_string();
235            let base = fname.strip_suffix("_id").unwrap_or(&fname);
236            syn::Ident::new(&format!("fetch_{}", base), proc_macro2::Span::call_site())
237        })
238        .collect();
239    let first_pk = pk_fields[0].clone();
240    let field_names_join = field_names_str.join(", ");
241    let fields_indices: Vec<usize> = (0..columns.len()).collect();
242
243    let expanded = quote! {
244        impl chopin_orm::Model for #name {
245            fn table_name() -> &'static str {
246                #table_name
247            }
248
249            fn create_table_stmt() -> String {
250                let mut sql = String::from(#base_sql);
251                #(
252                    sql.pop(); // Remove closing parenthesis
253                    sql.pop(); // Remove newline
254                    let fk_constraint = format!(",\n    FOREIGN KEY ({}) REFERENCES {} (id)\n)", stringify!(#fk_fields), <#fk_models as chopin_orm::Model>::table_name());
255                    sql.push_str(&fk_constraint);
256                )*
257                sql
258            }
259
260            fn column_definitions() -> Vec<(&'static str, &'static str)> {
261                vec![
262                    #( (#col_names, #col_types) ),*
263                ]
264            }
265
266            fn primary_key_columns() -> &'static [&'static str] {
267                &[#(#pk_names_str),*]
268            }
269
270            fn generated_columns() -> &'static [&'static str] {
271                &[#(#gen_names_str),*]
272            }
273
274            fn columns() -> &'static [&'static str] {
275                &[#(#field_names_str),*]
276            }
277
278            fn select_clause() -> &'static str {
279                const COLS: &[&str] = &[#(#field_names_str),*];
280                const JOINED: &str = #field_names_join;
281                JOINED
282            }
283
284            fn primary_key_values(&self) -> Vec<chopin_pg::PgValue> {
285                use chopin_pg::types::ToSql;
286                vec![
287                    #(self.#pk_fields.to_sql()),*
288                ]
289            }
290
291            fn get_values(&self) -> Vec<chopin_pg::PgValue> {
292                use chopin_pg::types::ToSql;
293                vec![
294                    #(self.#fields_list.to_sql()),*
295                ]
296            }
297
298            fn set_generated_values(&mut self, mut values: Vec<chopin_pg::PgValue>) -> chopin_orm::OrmResult<()> {
299                if values.len() != #gen_fields_len {
300                    return Err(chopin_orm::OrmError::ModelError("Generated values length mismatch".to_string()));
301                }
302                let mut iter = values.into_iter();
303                #(
304                    if let Some(val) = iter.next() {
305                        self.#gen_field_names = chopin_orm::ExtractValue::from_pg_value(val)?;
306                    }
307                )*
308                Ok(())
309            }
310        }
311
312        impl chopin_orm::FromRow for #name {
313            fn from_row(row: &chopin_pg::Row) -> chopin_orm::OrmResult<Self> {
314                Ok(Self {
315                    #(
316                        #fields_list: chopin_orm::ExtractValue::extract_at(row, #fields_indices)?,
317                    )*
318                })
319            }
320        }
321
322        impl #name {
323            #(
324                pub fn #fetch_bt_names(&self, executor: &mut impl chopin_orm::Executor) -> chopin_orm::OrmResult<Option<#fk_models>> {
325                    use chopin_pg::types::ToSql;
326                    use chopin_pg::types::ToParam;
327                    let qb = #fk_models::find().filter((
328                        format!("{} = $1", <#fk_models as chopin_orm::Model>::primary_key_columns()[0]),
329                        vec![self.#fk_fields.to_param()]
330                    ));
331                    qb.one(executor)
332                }
333            )*
334
335            #(
336                pub fn #fetch_hm_names(&self, executor: &mut impl chopin_orm::Executor) -> chopin_orm::OrmResult<Vec<#hm_targets>> {
337                    use chopin_pg::types::ToSql;
338                    use chopin_pg::types::ToParam;
339                    let target_pk: chopin_pg::PgValue = self.#first_pk.clone().to_param();
340                    let qb = #hm_targets::find().filter((
341                        format!("{} = $1", #hm_fks),
342                        vec![target_pk]
343                    ));
344                    qb.all(executor)
345                }
346            )*
347        }
348
349        #[allow(non_camel_case_types)]
350        #[derive(Clone, Copy, Debug, PartialEq, Eq)]
351        pub enum #column_enum_name {
352            #(#fields_list),*
353        }
354
355        impl chopin_orm::builder::ColumnTrait<#name> for #column_enum_name {
356            fn column_name(&self) -> &'static str {
357                match self {
358                    #(Self::#fields_list => #field_names_str),*
359                }
360            }
361        }
362    };
363
364    let active_expanded = quote! {};
365
366    let mut belongs_to_field_names = Vec::new();
367    let mut belongs_to_related_models = Vec::new();
368    for (f, r) in &belongs_to_fks {
369        belongs_to_field_names.push(f.clone());
370        belongs_to_related_models.push(r.clone());
371    }
372
373    let final_expanded = quote! {
374        #expanded
375        #active_expanded
376
377        #(
378            impl chopin_orm::HasForeignKey<#belongs_to_related_models> for #name {
379                fn foreign_key_info() -> (&'static str, Vec<(&'static str, &'static str)>) {
380                    (<Self as chopin_orm::Model>::table_name(), vec![(stringify!(#belongs_to_field_names), "id")])
381                }
382            }
383        )*
384    };
385
386    TokenStream::from(final_expanded)
387}