1use darling::{FromDeriveInput, FromField, FromMeta};
2use heck::ToSnakeCase;
3use syn::ext::IdentExt;
4use syn::spanned::Spanned;
5
6use crate::symbol_resolver::SymbolResolver;
7
8#[expect(clippy::module_name_repetitions)]
9#[derive(Debug, Default, FromMeta)]
10pub struct ModelArgs {
11 #[darling(default)]
12 pub model_type: ModelType,
13 pub table_name: Option<String>,
14}
15
16#[expect(clippy::module_name_repetitions)]
17#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default, FromMeta)]
18pub enum ModelType {
19 #[default]
20 Application,
21 Migration,
22 Internal,
23}
24
25#[expect(clippy::module_name_repetitions)]
26#[derive(Debug, Clone, FromDeriveInput)]
27#[darling(forward_attrs(allow, doc, cfg), supports(struct_named))]
28pub struct ModelOpts {
29 pub ident: syn::Ident,
30 pub vis: syn::Visibility,
31 pub generics: syn::Generics,
32 pub data: darling::ast::Data<darling::util::Ignored, FieldOpts>,
33}
34
35impl ModelOpts {
36 pub fn new_from_derive_input(input: &syn::DeriveInput) -> Result<Self, darling::error::Error> {
37 let opts = Self::from_derive_input(input)?;
38 if !opts.generics.params.is_empty() {
39 return Err(
40 darling::Error::custom("generics in models are not supported")
41 .with_span(&opts.generics),
42 );
43 }
44 Ok(opts)
45 }
46
47 #[must_use]
53 pub fn fields(&self) -> Vec<&FieldOpts> {
54 self.data
55 .as_ref()
56 .take_struct()
57 .expect("Only structs are supported")
58 .fields
59 }
60
61 pub fn as_model(
68 &self,
69 args: &ModelArgs,
70 symbol_resolver: &SymbolResolver,
71 ) -> Result<Model, syn::Error> {
72 let self_reference = self.ident.to_string();
73 let as_field = |field: &&FieldOpts| field.as_field(symbol_resolver, Some(&self_reference));
74
75 let fields = self
76 .fields()
77 .iter()
78 .map(as_field)
79 .collect::<Result<Vec<_>, _>>()?;
80
81 let mut original_name = self.ident.unraw().to_string();
82 if args.model_type == ModelType::Migration {
83 original_name = original_name
84 .strip_prefix("_")
85 .ok_or_else(|| {
86 syn::Error::new(
87 self.ident.span(),
88 "migration model names must start with an underscore",
89 )
90 })?
91 .to_string();
92 }
93 let table_name = if let Some(table_name) = &args.table_name {
94 table_name.clone()
95 } else {
96 original_name.to_snake_case()
97 };
98
99 let primary_key_field = self.get_primary_key_field(&fields)?;
100
101 let ty = {
102 let mut ty = syn::Type::Path(syn::TypePath {
103 qself: None,
104 path: syn::Path::from(self.ident.clone()),
105 });
106 symbol_resolver.resolve(&mut ty, Some(&original_name));
107 ty
108 };
109
110 Ok(Model {
111 name: self.ident.clone(),
112 vis: self.vis.clone(),
113 original_name,
114 resolved_ty: ty,
115 model_type: args.model_type,
116 table_name,
117 pk_field: primary_key_field.clone(),
118 fields,
119 })
120 }
121
122 fn get_primary_key_field<'a>(&self, fields: &'a [Field]) -> Result<&'a Field, syn::Error> {
123 let pks: Vec<_> = fields.iter().filter(|field| field.primary_key).collect();
124 if pks.is_empty() {
125 return Err(syn::Error::new(
126 self.ident.span(),
127 "models must have a primary key field annotated with \
128 the `#[model(primary_key)]` attribute",
129 ));
130 }
131 if pks.len() > 1 {
132 return Err(syn::Error::new(
133 pks[1].name.span(),
134 "composite primary keys are not supported; only one primary key field is allowed",
135 ));
136 }
137
138 Ok(pks[0])
139 }
140}
141
142#[derive(Debug, Clone, FromField)]
143#[darling(attributes(model))]
144pub struct FieldOpts {
145 pub ident: Option<syn::Ident>,
146 pub ty: syn::Type,
147 pub primary_key: darling::util::Flag,
148 pub unique: darling::util::Flag,
149}
150
151impl FieldOpts {
152 fn find_type(&self, type_to_find: &str, symbol_resolver: &SymbolResolver) -> Option<syn::Type> {
153 let mut ty = self.ty.clone();
154 symbol_resolver.resolve(&mut ty, None);
155 Self::find_type_resolved(&ty, type_to_find)
156 }
157
158 fn find_type_resolved(ty: &syn::Type, type_to_find: &str) -> Option<syn::Type> {
159 if let syn::Type::Path(type_path) = ty {
160 let name = type_path
161 .path
162 .segments
163 .iter()
164 .map(|s| s.ident.to_string())
165 .collect::<Vec<_>>()
166 .join("::");
167
168 if name == type_to_find {
169 return Some(ty.clone());
170 }
171
172 for arg in &type_path.path.segments {
173 if let syn::PathArguments::AngleBracketed(arg) = &arg.arguments
174 && let Some(ty) = Self::find_type_in_generics(arg, type_to_find)
175 {
176 return Some(ty);
177 }
178 }
179 }
180
181 None
182 }
183
184 fn find_type_in_generics(
185 arg: &syn::AngleBracketedGenericArguments,
186 type_to_find: &str,
187 ) -> Option<syn::Type> {
188 arg.args.iter().find_map(|arg| {
189 if let syn::GenericArgument::Type(ty) = arg {
190 Self::find_type_resolved(ty, type_to_find)
191 } else {
192 None
193 }
194 })
195 }
196
197 pub fn as_field(
204 &self,
205 symbol_resolver: &SymbolResolver,
206 self_reference: Option<&String>,
207 ) -> Result<Field, syn::Error> {
208 let name = self
209 .ident
210 .clone()
211 .expect("Only named struct fields are supported");
212 let column_name = name.unraw().to_string();
213
214 let (auto_value, foreign_key) = (
215 self.find_type("cot::db::Auto", symbol_resolver).is_some(),
216 self.find_type("cot::db::ForeignKey", symbol_resolver)
217 .map(ForeignKeySpec::try_from)
218 .transpose()?,
219 );
220 let is_primary_key = self.primary_key.is_present();
221 let mut resolved_ty = self.ty.clone();
222 symbol_resolver.resolve(&mut resolved_ty, self_reference);
223 Ok(Field {
224 name: name.clone(),
225 column_name,
226 ty: resolved_ty,
227 auto_value,
228 primary_key: is_primary_key,
229 foreign_key,
230 unique: self.unique.is_present(),
231 })
232 }
233}
234
235#[derive(Debug, Clone, PartialEq, Eq, Hash)]
236pub struct Model {
237 pub name: syn::Ident,
238 pub vis: syn::Visibility,
239 pub original_name: String,
240 pub resolved_ty: syn::Type,
242 #[expect(clippy::struct_field_names)] pub model_type: ModelType,
244 pub table_name: String,
245 pub pk_field: Field,
246 pub fields: Vec<Field>,
247}
248
249impl Model {
250 #[must_use]
251 pub fn field_count(&self) -> usize {
252 self.fields.len()
253 }
254}
255
256#[derive(Debug, Clone, PartialEq, Eq, Hash)]
257pub struct Field {
258 pub name: syn::Ident,
259 pub column_name: String,
260 pub ty: syn::Type,
261 pub auto_value: bool,
263 pub primary_key: bool,
264 pub foreign_key: Option<ForeignKeySpec>,
267 pub unique: bool,
268}
269
270#[derive(Debug, Clone, PartialEq, Eq, Hash)]
271pub struct ForeignKeySpec {
272 pub to_model: syn::Type,
273}
274
275impl TryFrom<syn::Type> for ForeignKeySpec {
276 type Error = syn::Error;
277
278 fn try_from(ty: syn::Type) -> Result<Self, Self::Error> {
279 let syn::Type::Path(type_path) = &ty else {
280 panic!("Expected a path type for a foreign key");
281 };
282
283 let syn::PathArguments::AngleBracketed(args) = &type_path
284 .path
285 .segments
286 .last()
287 .expect("type path must have at least one segment")
288 .arguments
289 else {
290 return Err(syn::Error::new(
291 ty.span(),
292 "expected ForeignKey to have angle-bracketed generic arguments",
293 ));
294 };
295
296 if args.args.len() != 1 {
297 return Err(syn::Error::new(
298 ty.span(),
299 "expected ForeignKey to have only one generic parameter",
300 ));
301 }
302
303 let inner = &args.args[0];
304 if let syn::GenericArgument::Type(ty) = inner {
305 Ok(Self {
306 to_model: ty.clone(),
307 })
308 } else {
309 Err(syn::Error::new(
310 ty.span(),
311 "expected ForeignKey to have a type generic argument",
312 ))
313 }
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use syn::parse_quote;
320
321 use super::*;
322 use crate::symbol_resolver::{SymbolResolver, VisibleSymbol, VisibleSymbolKind};
323
324 #[test]
325 fn model_args_default() {
326 let args: ModelArgs = ModelArgs::default();
327 assert_eq!(args.model_type, ModelType::Application);
328 assert!(args.table_name.is_none());
329 }
330
331 #[test]
332 fn model_type_default() {
333 let model_type: ModelType = ModelType::default();
334 assert_eq!(model_type, ModelType::Application);
335 }
336
337 #[test]
338 fn model_opts_fields() {
339 let input: syn::DeriveInput = parse_quote! {
340 struct TestModel {
341 id: i32,
342 name: String,
343 }
344 };
345 let opts = ModelOpts::new_from_derive_input(&input).unwrap();
346 let fields = opts.fields();
347 assert_eq!(fields.len(), 2);
348 assert_eq!(fields[0].ident.as_ref().unwrap().to_string(), "id");
349 assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "name");
350 }
351
352 #[test]
353 fn model_opts_as_model() {
354 let input: syn::DeriveInput = parse_quote! {
355 struct TestModel {
356 #[model(primary_key)]
357 id: i32,
358 name: String,
359 }
360 };
361 let opts = ModelOpts::new_from_derive_input(&input).unwrap();
362 let args = ModelArgs::default();
363 let model = opts.as_model(&args, &SymbolResolver::new(vec![])).unwrap();
364 assert_eq!(model.name.to_string(), "TestModel");
365 assert_eq!(model.table_name, "test_model");
366 assert_eq!(model.fields.len(), 2);
367 assert_eq!(model.field_count(), 2);
368 }
369
370 #[test]
371 fn model_opts_raw_name() {
372 let input: syn::DeriveInput = parse_quote! {
373 struct r#abstract {
374 #[model(primary_key)]
375 id: i32,
376 name: String,
377 }
378 };
379 let opts = ModelOpts::new_from_derive_input(&input).unwrap();
380 let model = opts
381 .as_model(&ModelArgs::default(), &SymbolResolver::new(vec![]))
382 .unwrap();
383 assert_eq!(model.name.to_string(), "r#abstract");
384 assert_eq!(model.table_name, "abstract");
385 }
386
387 #[test]
388 fn model_opts_as_model_migration() {
389 let input: syn::DeriveInput = parse_quote! {
390 #[model(model_type = "migration")]
391 struct TestModel {
392 id: i32,
393 name: String,
394 }
395 };
396 let opts = ModelOpts::new_from_derive_input(&input).unwrap();
397 let args = ModelArgs::from_meta(&input.attrs.first().unwrap().meta).unwrap();
398 let err = opts
399 .as_model(&args, &SymbolResolver::new(vec![]))
400 .unwrap_err();
401 assert_eq!(
402 err.to_string(),
403 "migration model names must start with an underscore"
404 );
405 }
406
407 #[test]
408 fn model_opts_as_model_pk_attr() {
409 let input: syn::DeriveInput = parse_quote! {
410 #[model]
411 struct TestModel {
412 #[model(primary_key)]
413 name: i32,
414 }
415 };
416 let opts = ModelOpts::new_from_derive_input(&input).unwrap();
417 let args = ModelArgs::default();
418 let model = opts.as_model(&args, &SymbolResolver::new(vec![])).unwrap();
419 assert_eq!(model.fields.len(), 1);
420 assert!(model.fields[0].primary_key);
421 }
422
423 #[test]
424 fn model_opts_as_model_no_pk() {
425 let input: syn::DeriveInput = parse_quote! {
426 #[model]
427 struct TestModel {
428 name: String,
429 }
430 };
431 let opts = ModelOpts::new_from_derive_input(&input).unwrap();
432 let args = ModelArgs::default();
433 let err = opts
434 .as_model(&args, &SymbolResolver::new(vec![]))
435 .unwrap_err();
436 assert_eq!(
437 err.to_string(),
438 "models must have a primary key field annotated with \
439 the `#[model(primary_key)]` attribute"
440 );
441 }
442
443 #[test]
444 fn model_opts_as_model_multiple_pks() {
445 let input: syn::DeriveInput = parse_quote! {
446 #[model]
447 struct TestModel {
448 #[model(primary_key)]
449 id: i64,
450 #[model(primary_key)]
451 id_2: i64,
452 name: String,
453 }
454 };
455 let opts = ModelOpts::new_from_derive_input(&input).unwrap();
456 let args = ModelArgs::default();
457 let err = opts
458 .as_model(&args, &SymbolResolver::new(vec![]))
459 .unwrap_err();
460 assert_eq!(
461 err.to_string(),
462 "composite primary keys are not supported; only one primary key field is allowed"
463 );
464 }
465
466 #[test]
467 fn field_opts_as_field() {
468 let input: syn::Field = parse_quote! {
469 #[model(unique)]
470 name: String
471 };
472 let field_opts = FieldOpts::from_field(&input).unwrap();
473 let field = field_opts
474 .as_field(&SymbolResolver::new(vec![]), Some(&"TestModel".to_string()))
475 .unwrap();
476 assert_eq!(field.name.to_string(), "name");
477 assert_eq!(field.column_name, "name");
478 assert_eq!(field.ty, parse_quote!(String));
479 assert!(field.unique);
480 }
481
482 #[test]
483 fn field_opts_raw_name() {
484 let input: syn::Field = parse_quote! {
485 r#abstract: String
486 };
487 let field_opts = FieldOpts::from_field(&input).unwrap();
488 let field = field_opts
489 .as_field(&SymbolResolver::new(vec![]), Some(&"TestModel".to_string()))
490 .unwrap();
491 assert_eq!(field.name.to_string(), "r#abstract");
492 assert_eq!(field.column_name, "abstract");
493 }
494
495 #[test]
496 fn find_type_resolved() {
497 let input: syn::Type =
498 parse_quote! { ::my_crate::MyContainer<'a, Vec<std::string::String>> };
499 assert!(FieldOpts::find_type_resolved(&input, "my_crate::MyContainer").is_some());
500 assert!(FieldOpts::find_type_resolved(&input, "Vec").is_some());
501 assert!(FieldOpts::find_type_resolved(&input, "std::string::String").is_some());
502 assert!(FieldOpts::find_type_resolved(&input, "OtherType").is_none());
503 }
504
505 #[test]
506 fn find_type() {
507 let symbols = vec![VisibleSymbol::new(
508 "MyContainer",
509 "my_crate::MyContainer",
510 VisibleSymbolKind::Use,
511 )];
512 let resolver = SymbolResolver::new(symbols);
513
514 let opts = FieldOpts {
515 ident: None,
516 ty: parse_quote! { MyContainer<std::string::String> },
517 primary_key: darling::util::Flag::default(),
518 unique: darling::util::Flag::default(),
519 };
520
521 assert!(opts.find_type("my_crate::MyContainer", &resolver).is_some());
522 assert!(opts.find_type("std::string::String", &resolver).is_some());
523 assert!(opts.find_type("MyContainer", &resolver).is_none());
524 assert!(opts.find_type("String", &resolver).is_none());
525 }
526}