Skip to main content

cot_codegen/
model.rs

1use darling::{FromDeriveInput, FromField, FromMeta};
2use heck::ToSnakeCase;
3use syn::ext::IdentExt;
4use syn::spanned::Spanned;
5
6use crate::symbol_resolver::SymbolResolver;
7
8#[expect(clippy::module_name_repetitions)]
9#[derive(Debug, Default, FromMeta)]
10pub struct ModelArgs {
11    #[darling(default)]
12    pub model_type: ModelType,
13    pub table_name: Option<String>,
14}
15
16#[expect(clippy::module_name_repetitions)]
17#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, FromMeta)]
18pub enum ModelType {
19    #[default]
20    Application,
21    Migration,
22    Internal,
23}
24
25#[expect(clippy::module_name_repetitions)]
26#[derive(Debug, Clone, FromDeriveInput)]
27#[darling(forward_attrs(allow, doc, cfg), supports(struct_named))]
28pub struct ModelOpts {
29    pub ident: syn::Ident,
30    pub vis: syn::Visibility,
31    pub generics: syn::Generics,
32    pub data: darling::ast::Data<darling::util::Ignored, FieldOpts>,
33}
34
35impl ModelOpts {
36    pub fn new_from_derive_input(input: &syn::DeriveInput) -> Result<Self, darling::error::Error> {
37        let opts = Self::from_derive_input(input)?;
38        if !opts.generics.params.is_empty() {
39            return Err(
40                darling::Error::custom("generics in models are not supported")
41                    .with_span(&opts.generics),
42            );
43        }
44        Ok(opts)
45    }
46
47    /// Get the fields of the struct.
48    ///
49    /// # Panics
50    ///
51    /// Panics if the [`ModelOpts`] was not parsed from a struct.
52    #[must_use]
53    pub fn fields(&self) -> Vec<&FieldOpts> {
54        self.data
55            .as_ref()
56            .take_struct()
57            .expect("Only structs are supported")
58            .fields
59    }
60
61    /// Convert the model options into a model.
62    ///
63    /// # Errors
64    ///
65    /// Returns an error if the model name does not start with an underscore
66    /// when the model type is [`ModelType::Migration`].
67    pub fn as_model(
68        &self,
69        args: &ModelArgs,
70        symbol_resolver: &SymbolResolver,
71    ) -> Result<Model, syn::Error> {
72        let self_reference = self.ident.to_string();
73        let as_field = |field: &&FieldOpts| field.as_field(symbol_resolver, Some(&self_reference));
74
75        let fields = self
76            .fields()
77            .iter()
78            .map(as_field)
79            .collect::<Result<Vec<_>, _>>()?;
80
81        let mut original_name = self.ident.unraw().to_string();
82        if args.model_type == ModelType::Migration {
83            original_name = original_name
84                .strip_prefix("_")
85                .ok_or_else(|| {
86                    syn::Error::new(
87                        self.ident.span(),
88                        "migration model names must start with an underscore",
89                    )
90                })?
91                .to_string();
92        }
93        let table_name = if let Some(table_name) = &args.table_name {
94            table_name.clone()
95        } else {
96            original_name.to_snake_case()
97        };
98
99        let primary_key_field = self.get_primary_key_field(&fields)?;
100
101        let ty = {
102            let mut ty = syn::Type::Path(syn::TypePath {
103                qself: None,
104                path: syn::Path::from(self.ident.clone()),
105            });
106            symbol_resolver.resolve(&mut ty, Some(&original_name));
107            ty
108        };
109
110        Ok(Model {
111            name: self.ident.clone(),
112            vis: self.vis.clone(),
113            original_name,
114            resolved_ty: ty,
115            model_type: args.model_type,
116            table_name,
117            pk_field: primary_key_field.clone(),
118            fields,
119        })
120    }
121
122    fn get_primary_key_field<'a>(&self, fields: &'a [Field]) -> Result<&'a Field, syn::Error> {
123        let pks: Vec<_> = fields.iter().filter(|field| field.primary_key).collect();
124        if pks.is_empty() {
125            return Err(syn::Error::new(
126                self.ident.span(),
127                "models must have a primary key field annotated with \
128                the `#[model(primary_key)]` attribute",
129            ));
130        }
131        if pks.len() > 1 {
132            return Err(syn::Error::new(
133                pks[1].name.span(),
134                "composite primary keys are not supported; only one primary key field is allowed",
135            ));
136        }
137
138        Ok(pks[0])
139    }
140}
141
142#[derive(Debug, Clone, FromField)]
143#[darling(attributes(model))]
144pub struct FieldOpts {
145    pub ident: Option<syn::Ident>,
146    pub ty: syn::Type,
147    pub primary_key: darling::util::Flag,
148    pub unique: darling::util::Flag,
149}
150
151impl FieldOpts {
152    fn find_type(&self, type_to_find: &str, symbol_resolver: &SymbolResolver) -> Option<syn::Type> {
153        let mut ty = self.ty.clone();
154        symbol_resolver.resolve(&mut ty, None);
155        Self::find_type_resolved(&ty, type_to_find)
156    }
157
158    fn find_type_resolved(ty: &syn::Type, type_to_find: &str) -> Option<syn::Type> {
159        if let syn::Type::Path(type_path) = ty {
160            let name = type_path
161                .path
162                .segments
163                .iter()
164                .map(|s| s.ident.to_string())
165                .collect::<Vec<_>>()
166                .join("::");
167
168            if name == type_to_find {
169                return Some(ty.clone());
170            }
171
172            for arg in &type_path.path.segments {
173                if let syn::PathArguments::AngleBracketed(arg) = &arg.arguments
174                    && let Some(ty) = Self::find_type_in_generics(arg, type_to_find)
175                {
176                    return Some(ty);
177                }
178            }
179        }
180
181        None
182    }
183
184    fn find_type_in_generics(
185        arg: &syn::AngleBracketedGenericArguments,
186        type_to_find: &str,
187    ) -> Option<syn::Type> {
188        arg.args.iter().find_map(|arg| {
189            if let syn::GenericArgument::Type(ty) = arg {
190                Self::find_type_resolved(ty, type_to_find)
191            } else {
192                None
193            }
194        })
195    }
196
197    /// Convert the field options into a field.
198    ///
199    /// # Panics
200    ///
201    /// Panics if the field does not have an identifier (i.e. it is a tuple
202    /// struct).
203    pub fn as_field(
204        &self,
205        symbol_resolver: &SymbolResolver,
206        self_reference: Option<&String>,
207    ) -> Result<Field, syn::Error> {
208        let name = self
209            .ident
210            .clone()
211            .expect("Only named struct fields are supported");
212        let column_name = name.unraw().to_string();
213
214        let (auto_value, foreign_key) = (
215            self.find_type("cot::db::Auto", symbol_resolver).is_some(),
216            self.find_type("cot::db::ForeignKey", symbol_resolver)
217                .map(ForeignKeySpec::try_from)
218                .transpose()?,
219        );
220        let is_primary_key = self.primary_key.is_present();
221        let mut resolved_ty = self.ty.clone();
222        symbol_resolver.resolve(&mut resolved_ty, self_reference);
223        Ok(Field {
224            name: name.clone(),
225            column_name,
226            ty: resolved_ty,
227            auto_value,
228            primary_key: is_primary_key,
229            foreign_key,
230            unique: self.unique.is_present(),
231        })
232    }
233}
234
235#[derive(Debug, Clone, PartialEq, Eq, Hash)]
236pub struct Model {
237    pub name: syn::Ident,
238    pub vis: syn::Visibility,
239    pub original_name: String,
240    /// The type of the model resolved by symbol resolver.
241    pub resolved_ty: syn::Type,
242    #[expect(clippy::struct_field_names)] // `type` is not an allowed identifier in Rust
243    pub model_type: ModelType,
244    pub table_name: String,
245    pub pk_field: Field,
246    pub fields: Vec<Field>,
247}
248
249impl Model {
250    #[must_use]
251    pub fn field_count(&self) -> usize {
252        self.fields.len()
253    }
254}
255
256#[derive(Debug, Clone, PartialEq, Eq, Hash)]
257pub struct Field {
258    pub name: syn::Ident,
259    pub column_name: String,
260    pub ty: syn::Type,
261    /// Whether the field is an auto field (e.g. `id`).
262    pub auto_value: bool,
263    pub primary_key: bool,
264    /// [`Some`] if this field is a foreign key; [`None`] if this field is
265    /// determined not to be a foreign key.
266    pub foreign_key: Option<ForeignKeySpec>,
267    pub unique: bool,
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Hash)]
271pub struct ForeignKeySpec {
272    pub to_model: syn::Type,
273}
274
275impl TryFrom<syn::Type> for ForeignKeySpec {
276    type Error = syn::Error;
277
278    fn try_from(ty: syn::Type) -> Result<Self, Self::Error> {
279        let syn::Type::Path(type_path) = &ty else {
280            panic!("Expected a path type for a foreign key");
281        };
282
283        let syn::PathArguments::AngleBracketed(args) = &type_path
284            .path
285            .segments
286            .last()
287            .expect("type path must have at least one segment")
288            .arguments
289        else {
290            return Err(syn::Error::new(
291                ty.span(),
292                "expected ForeignKey to have angle-bracketed generic arguments",
293            ));
294        };
295
296        if args.args.len() != 1 {
297            return Err(syn::Error::new(
298                ty.span(),
299                "expected ForeignKey to have only one generic parameter",
300            ));
301        }
302
303        let inner = &args.args[0];
304        if let syn::GenericArgument::Type(ty) = inner {
305            Ok(Self {
306                to_model: ty.clone(),
307            })
308        } else {
309            Err(syn::Error::new(
310                ty.span(),
311                "expected ForeignKey to have a type generic argument",
312            ))
313        }
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use syn::parse_quote;
320
321    use super::*;
322    use crate::symbol_resolver::{SymbolResolver, VisibleSymbol, VisibleSymbolKind};
323
324    #[test]
325    fn model_args_default() {
326        let args: ModelArgs = ModelArgs::default();
327        assert_eq!(args.model_type, ModelType::Application);
328        assert!(args.table_name.is_none());
329    }
330
331    #[test]
332    fn model_type_default() {
333        let model_type: ModelType = ModelType::default();
334        assert_eq!(model_type, ModelType::Application);
335    }
336
337    #[test]
338    fn model_opts_fields() {
339        let input: syn::DeriveInput = parse_quote! {
340            struct TestModel {
341                id: i32,
342                name: String,
343            }
344        };
345        let opts = ModelOpts::new_from_derive_input(&input).unwrap();
346        let fields = opts.fields();
347        assert_eq!(fields.len(), 2);
348        assert_eq!(fields[0].ident.as_ref().unwrap().to_string(), "id");
349        assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "name");
350    }
351
352    #[test]
353    fn model_opts_as_model() {
354        let input: syn::DeriveInput = parse_quote! {
355            struct TestModel {
356                #[model(primary_key)]
357                id: i32,
358                name: String,
359            }
360        };
361        let opts = ModelOpts::new_from_derive_input(&input).unwrap();
362        let args = ModelArgs::default();
363        let model = opts.as_model(&args, &SymbolResolver::new(vec![])).unwrap();
364        assert_eq!(model.name.to_string(), "TestModel");
365        assert_eq!(model.table_name, "test_model");
366        assert_eq!(model.fields.len(), 2);
367        assert_eq!(model.field_count(), 2);
368    }
369
370    #[test]
371    fn model_opts_raw_name() {
372        let input: syn::DeriveInput = parse_quote! {
373            struct r#abstract {
374                #[model(primary_key)]
375                id: i32,
376                name: String,
377            }
378        };
379        let opts = ModelOpts::new_from_derive_input(&input).unwrap();
380        let model = opts
381            .as_model(&ModelArgs::default(), &SymbolResolver::new(vec![]))
382            .unwrap();
383        assert_eq!(model.name.to_string(), "r#abstract");
384        assert_eq!(model.table_name, "abstract");
385    }
386
387    #[test]
388    fn model_opts_as_model_migration() {
389        let input: syn::DeriveInput = parse_quote! {
390            #[model(model_type = "migration")]
391            struct TestModel {
392                id: i32,
393                name: String,
394            }
395        };
396        let opts = ModelOpts::new_from_derive_input(&input).unwrap();
397        let args = ModelArgs::from_meta(&input.attrs.first().unwrap().meta).unwrap();
398        let err = opts
399            .as_model(&args, &SymbolResolver::new(vec![]))
400            .unwrap_err();
401        assert_eq!(
402            err.to_string(),
403            "migration model names must start with an underscore"
404        );
405    }
406
407    #[test]
408    fn model_opts_as_model_pk_attr() {
409        let input: syn::DeriveInput = parse_quote! {
410            #[model]
411            struct TestModel {
412                #[model(primary_key)]
413                name: i32,
414            }
415        };
416        let opts = ModelOpts::new_from_derive_input(&input).unwrap();
417        let args = ModelArgs::default();
418        let model = opts.as_model(&args, &SymbolResolver::new(vec![])).unwrap();
419        assert_eq!(model.fields.len(), 1);
420        assert!(model.fields[0].primary_key);
421    }
422
423    #[test]
424    fn model_opts_as_model_no_pk() {
425        let input: syn::DeriveInput = parse_quote! {
426            #[model]
427            struct TestModel {
428                name: String,
429            }
430        };
431        let opts = ModelOpts::new_from_derive_input(&input).unwrap();
432        let args = ModelArgs::default();
433        let err = opts
434            .as_model(&args, &SymbolResolver::new(vec![]))
435            .unwrap_err();
436        assert_eq!(
437            err.to_string(),
438            "models must have a primary key field annotated with \
439            the `#[model(primary_key)]` attribute"
440        );
441    }
442
443    #[test]
444    fn model_opts_as_model_multiple_pks() {
445        let input: syn::DeriveInput = parse_quote! {
446            #[model]
447            struct TestModel {
448                #[model(primary_key)]
449                id: i64,
450                #[model(primary_key)]
451                id_2: i64,
452                name: String,
453            }
454        };
455        let opts = ModelOpts::new_from_derive_input(&input).unwrap();
456        let args = ModelArgs::default();
457        let err = opts
458            .as_model(&args, &SymbolResolver::new(vec![]))
459            .unwrap_err();
460        assert_eq!(
461            err.to_string(),
462            "composite primary keys are not supported; only one primary key field is allowed"
463        );
464    }
465
466    #[test]
467    fn field_opts_as_field() {
468        let input: syn::Field = parse_quote! {
469            #[model(unique)]
470            name: String
471        };
472        let field_opts = FieldOpts::from_field(&input).unwrap();
473        let field = field_opts
474            .as_field(&SymbolResolver::new(vec![]), Some(&"TestModel".to_string()))
475            .unwrap();
476        assert_eq!(field.name.to_string(), "name");
477        assert_eq!(field.column_name, "name");
478        assert_eq!(field.ty, parse_quote!(String));
479        assert!(field.unique);
480    }
481
482    #[test]
483    fn field_opts_raw_name() {
484        let input: syn::Field = parse_quote! {
485            r#abstract: String
486        };
487        let field_opts = FieldOpts::from_field(&input).unwrap();
488        let field = field_opts
489            .as_field(&SymbolResolver::new(vec![]), Some(&"TestModel".to_string()))
490            .unwrap();
491        assert_eq!(field.name.to_string(), "r#abstract");
492        assert_eq!(field.column_name, "abstract");
493    }
494
495    #[test]
496    fn find_type_resolved() {
497        let input: syn::Type =
498            parse_quote! { ::my_crate::MyContainer<'a, Vec<std::string::String>> };
499        assert!(FieldOpts::find_type_resolved(&input, "my_crate::MyContainer").is_some());
500        assert!(FieldOpts::find_type_resolved(&input, "Vec").is_some());
501        assert!(FieldOpts::find_type_resolved(&input, "std::string::String").is_some());
502        assert!(FieldOpts::find_type_resolved(&input, "OtherType").is_none());
503    }
504
505    #[test]
506    fn find_type() {
507        let symbols = vec![VisibleSymbol::new(
508            "MyContainer",
509            "my_crate::MyContainer",
510            VisibleSymbolKind::Use,
511        )];
512        let resolver = SymbolResolver::new(symbols);
513
514        let opts = FieldOpts {
515            ident: None,
516            ty: parse_quote! { MyContainer<std::string::String> },
517            primary_key: darling::util::Flag::default(),
518            unique: darling::util::Flag::default(),
519        };
520
521        assert!(opts.find_type("my_crate::MyContainer", &resolver).is_some());
522        assert!(opts.find_type("std::string::String", &resolver).is_some());
523        assert!(opts.find_type("MyContainer", &resolver).is_none());
524        assert!(opts.find_type("String", &resolver).is_none());
525    }
526}