kali_macros/
lib.rs

1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::Ident;
4
5#[proc_macro_attribute]
6pub fn entity(
7    args: proc_macro::TokenStream,
8    input: proc_macro::TokenStream,
9) -> proc_macro::TokenStream {
10    match generate(args.into(), input.into()) {
11        Ok(output) => output.into(),
12        Err(e) => e.to_compile_error().into(),
13    }
14}
15
16#[derive(Clone)]
17enum Relation {
18    ForeignKey {
19        entity: Ident,
20        foreign_key_field: Ident,
21        references_field: Option<Ident>, // defaults to primary key
22    },
23    ReferencedBy {
24        entity: Ident,
25        relation_field: Ident,
26        is_collection: bool,
27    },
28}
29
30#[derive(Clone)]
31struct ParsedField {
32    field_name: Ident,
33    iden_name: Ident,
34    is_pk: bool,
35    relation: Option<Relation>,
36    raw: syn::Field,
37}
38
39// parse Collection<T> or Reference<T> to T
40fn parse_entity_from_type(entity: &syn::Type) -> Result<(bool, Ident), syn::Error> {
41    if let syn::Type::Path(type_path) = entity {
42        if let Some(segment) = type_path.path.segments.last() {
43            let is_collection = segment.ident == "Collection";
44            if is_collection || segment.ident == "Reference" {
45                if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
46                    if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
47                        if let syn::Type::Path(type_path) = ty {
48                            return Ok((
49                                is_collection,
50                                type_path.path.segments.last().unwrap().ident.clone(),
51                            ));
52                        }
53                    }
54                }
55            }
56        }
57    }
58    Err(syn::Error::new_spanned(
59        entity,
60        "expected Collection<T> or Reference<T>",
61    ))
62}
63
64fn parse_fields(entity: syn::ItemStruct) -> Result<Vec<ParsedField>, syn::Error> {
65    entity
66        .fields
67        .into_iter()
68        .map(|mut f| {
69            let ident = f
70                .ident
71                .as_ref()
72                .ok_or_else(|| syn::Error::new_spanned(&f, "expected named field"))?;
73
74            let field_name = ident.clone();
75            let iden_name = Ident::new(
76                &to_upper_camel_case(&field_name.to_string()),
77                field_name.span(),
78            );
79
80            let mut is_pk = false;
81            let mut relation_attr = None;
82
83            f.attrs.retain(|attr| {
84                if attr.path().is_ident("primary_key") {
85                    is_pk = true;
86                    false
87                } else if attr.path().is_ident("relation") {
88                    relation_attr = Some(attr.clone());
89                    false
90                } else {
91                    true
92                }
93            });
94
95            let relation = if let Some(relation_attr) = relation_attr {
96                let (is_collection, entity) = parse_entity_from_type(&f.ty)?;
97                let is_owning = relation_attr.path().is_ident("foreign_key");
98                if is_owning {
99                    return Err(syn::Error::new_spanned(
100                        &relation_attr,
101                        "expected `#[relation(referenced_by = ...)]` for Collection<T>",
102                    ));
103                }
104
105                let mut referenced_by = None;
106                let mut foreign_key = None;
107                let mut references = None;
108                relation_attr.parse_nested_meta(|meta| {
109                    if meta.path.is_ident("referenced_by") {
110                        let value = meta.value()?;
111                        referenced_by = Some(value.parse()?);
112                        Ok(())
113                    } else if meta.path.is_ident("foreign_key") {
114                        let value = meta.value()?;
115                        foreign_key = Some(value.parse()?);
116                        Ok(())
117                    } else if meta.path.is_ident("references") {
118                        let value = meta.value()?;
119                        references = Some(value.parse()?);
120                        Ok(())
121                    } else {
122                        return Err(syn::Error::new_spanned(
123                            &meta.path,
124                            "expected `referenced_by` or `foreign_key` attribute",
125                        ));
126                    }
127                })?;
128
129                match (referenced_by, foreign_key, references) {
130                    (None, Some(fk), None) => {
131                        Some(Relation::ForeignKey {
132                            entity,
133                            foreign_key_field: fk,
134                            references_field: None,
135                        })
136                    }
137                    (None, Some(fk), Some(refs)) => {
138                        Some(Relation::ForeignKey {
139                            entity,
140                            foreign_key_field: fk,
141                            references_field: Some(refs),
142                        })
143                    },
144                    (Some(refs), None, None) => {
145                        Some(Relation::ReferencedBy {
146                            entity,
147                            relation_field: refs,
148                            is_collection,
149                        })
150                    }
151                    _ => {
152                        return Err(syn::Error::new_spanned(
153                            &relation_attr,
154                            "expected either `#[relation(referenced_by = ...)]` or `#[relation(foreign_key = ...)]`",
155                        ));
156                    }
157                }
158            } else {
159                None
160            };
161
162            Ok(ParsedField {
163                field_name,
164                iden_name,
165                is_pk,
166                relation,
167                raw: f,
168            })
169        })
170        .collect()
171}
172
173fn generate_entity_column_enum(
174    vis: &syn::Visibility,
175    entity_name: &syn::Ident,
176    parsed_fields: &[ParsedField],
177) -> syn::Result<(Ident, TokenStream)> {
178    let col_enum_ident = Ident::new(&format!("{}Column", entity_name), entity_name.span());
179
180    let col_enum_variants = parsed_fields
181        .iter()
182        .map(|f| f.iden_name.clone())
183        .collect::<Vec<_>>();
184
185    // Create a mapping of enum variant to field name for the match expression
186    let field_name_mappings = parsed_fields
187        .iter()
188        .map(|f| {
189            let variant = &f.iden_name;
190            let field_name = &f.field_name;
191            quote! { #col_enum_ident::#variant => stringify!(#field_name) }
192        })
193        .collect::<Vec<_>>();
194
195    let col_enum = quote! {
196        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
197        #vis enum #col_enum_ident {
198            #(#col_enum_variants),*
199        }
200
201        impl kali::column::Column for #col_enum_ident {
202            fn to_col_name(&self) -> &str {
203                match self {
204                    #(#field_name_mappings),*
205                }
206            }
207        }
208    };
209
210    Ok((col_enum_ident, col_enum))
211}
212
213fn generate_entity_constants(
214    vis: &syn::Visibility,
215    entity_enum_ident: &Ident,
216    parsed_fields: &[ParsedField],
217    primary_key: &ParsedField,
218) -> syn::Result<TokenStream> {
219    let col_enum_variants = parsed_fields
220        .iter()
221        .map(|f| f.iden_name.clone())
222        .collect::<Vec<_>>();
223
224    let col_constants = parsed_fields
225        .iter()
226        .map(|f| {
227            let iden_name = &f.iden_name;
228            quote! {
229                #[allow(non_upper_case_globals)]
230                pub const #iden_name: #entity_enum_ident = #entity_enum_ident::#iden_name;
231            }
232        })
233        .collect::<Vec<_>>();
234
235    let primary_key_iden_name = &primary_key.iden_name;
236
237    Ok(quote! {
238        #vis const COLUMNS: &'static [#entity_enum_ident] = &[#(#entity_enum_ident::#col_enum_variants),*];
239        #vis const PRIMARY_KEY: #entity_enum_ident = #entity_enum_ident::#primary_key_iden_name;
240        #(#col_constants)*
241    })
242}
243
244fn generate_relation_functions(
245    entity_name: &Ident,
246    col_enum_name: &Ident,
247    vis: &syn::Visibility,
248    parsed_relations: &[ParsedField],
249) -> syn::Result<TokenStream> {
250    let relation_functions = parsed_relations
251        .iter()
252        .map(|f| {
253            let relation = f.relation.as_ref().unwrap();
254            match relation {
255                Relation::ForeignKey {
256                    entity: inversed_entity,
257                    foreign_key_field,
258                    references_field,
259                } => {
260                    let field_name = &f.field_name;
261                    let references_field = references_field
262                        .as_ref();
263
264                    let inversed_primary_key_getter = match references_field {
265                        Some(refs) => {
266                            quote! { entity.#refs }
267                        },
268                        None => {
269                            quote! { entity.__primary_key_value() }
270                        },
271                    };
272
273                    let references_field_iden_ident = if let Some(refs) = references_field {
274                        Ident::new(
275                            &to_upper_camel_case(&refs.to_string()),
276                            refs.span(),
277                        )
278                    } else {
279                        Ident::new(
280                            "PRIMARY_KEY",
281                            foreign_key_field.span(),
282                        )
283                    };
284
285                    let foreign_key_iden_ident = Ident::new(
286                        &to_upper_camel_case(&foreign_key_field.to_string()),
287                        foreign_key_field.span(),
288                    );
289
290                    let inversed_filter_name = Ident::new(
291                        &format!("__{}_inversed_filter", field_name),
292                        field_name.span(),
293                    );
294
295
296                    quote! {
297                        #vis fn #field_name(&self) -> kali::reference::Reference<#inversed_entity> {
298                            kali::reference::Reference::new(#inversed_entity::#references_field_iden_ident.eq(self.#foreign_key_field))
299                        }
300
301                        // this is really awkward, but its necessary for the inversed side to know
302                        // how to filter the relation. when the macro runs, we can't inspect the owning side
303                        // to figure it out, and other workarounds aren't as clean.
304                        #[doc(hidden)]
305                        #vis fn #inversed_filter_name<'a>(entity: &#inversed_entity) -> kali::builder::expr::Expr<'a, #col_enum_name> {
306                            #entity_name::#foreign_key_iden_ident.eq(#inversed_primary_key_getter)
307                        }
308                    }
309                }
310                Relation::ReferencedBy {
311                    entity: owning_entity,
312                    relation_field,
313                    is_collection,
314                } => {
315                    // we use the inversed_filter to filter the relation appropriately
316                    let field_name = &f.field_name;
317                    let inversed_filter_name = Ident::new(
318                        &format!("__{}_inversed_filter", relation_field),
319                        relation_field.span(),
320                    );
321
322                    let return_kind = if *is_collection {
323                        quote! { kali::collection::Collection<#owning_entity> }
324                    } else {
325                        quote! { kali::reference::Reference<#owning_entity> }
326                    };
327
328                    let struct_kind = if *is_collection {
329                        quote! { kali::collection::Collection }
330                    } else {
331                        quote! { kali::reference::Reference }
332                    };
333
334                    quote! {
335                        #vis fn #field_name(&self) -> #return_kind {
336                            #struct_kind::new(#owning_entity::#inversed_filter_name(self))
337                        }
338                    }
339
340                } 
341            }
342        })
343        .collect::<Vec<_>>();
344
345    Ok(quote! {
346        #(#relation_functions)*
347    })
348}
349
350fn generate(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
351    let mut entity: syn::ItemStruct = syn::parse2(input)?; // Make entity mutable
352
353    // table name is either #[entity("table_name")] or snake_case of the struct name
354    let table_name = if args.is_empty() {
355        let name = entity.ident.to_string();
356        let snake_case_name = to_snake_case(&name);
357        quote! { #snake_case_name }
358    } else {
359        let table_name: syn::LitStr = syn::parse2(args)?;
360        quote! { #table_name }
361    };
362
363    let entity_vis = entity.vis.clone();
364    let entity_name = entity.ident.clone();
365
366    let parsed_fields = parse_fields(entity.clone())?;
367    let (parsed_fields, relation_fields): (Vec<_>, Vec<_>) = parsed_fields
368        .into_iter()
369        .partition(|f| f.relation.is_none());
370
371    match entity.fields {
372        syn::Fields::Named(ref mut fields) => {
373            fields.named = parsed_fields.clone().into_iter().map(|f| f.raw).collect();
374        }
375        syn::Fields::Unnamed(_) => {
376            return Err(syn::Error::new_spanned(&entity, "expected named fields"));
377        }
378        syn::Fields::Unit => {
379            return Err(syn::Error::new_spanned(&entity, "expected named fields"));
380        }
381    }
382
383    // pk is either with #[primary_key] attribute or named "id"
384    let primary_key = parsed_fields
385        .iter()
386        .find(|f| f.is_pk)
387        .or_else(|| parsed_fields.iter().find(|f| f.field_name == "id"));
388
389    let Some(primary_key) = primary_key else {
390        return Err(syn::Error::new_spanned(
391            &entity,
392            "missing primary key field with #[primary_key] attribute or named 'id'",
393        ));
394    };
395
396    let primary_key_name = &primary_key.field_name;
397    let primary_key_type = &primary_key.raw.ty;
398    let (col_enum_name, col_enum) =
399        generate_entity_column_enum(&entity_vis, &entity_name, &parsed_fields)?;
400    let entity_constants =
401        generate_entity_constants(&entity_vis, &col_enum_name, &parsed_fields, primary_key)?;
402
403    let relation_functions = generate_relation_functions(
404        &entity_name,
405        &col_enum_name,
406        &entity_vis,
407        &relation_fields,
408    )?;
409
410    Ok(quote! {
411        #entity
412
413        #[allow(non_upper_case_globals)]
414        impl #entity_name {
415            #entity_vis const TABLE_NAME: &'static str = #table_name;
416            #entity_constants
417
418            #relation_functions
419
420            #entity_vis async fn fetch_one<'e, E>(
421                executor: E,
422                id: #primary_key_type,
423            ) -> Result<Self, sqlx::Error>
424            where
425                E: 'e + sqlx::Executor<'e, Database = sqlx::Sqlite>,
426            {
427                kali::builder::QueryBuilder::select_from(Self::TABLE_NAME)
428                    .columns(Self::COLUMNS)
429                    .filter(Self::PRIMARY_KEY.eq(id))
430                    .limit(1)
431                    .fetch_one(executor)
432                    .await
433            }
434
435            #entity_vis async fn fetch_optional<'e, E>(
436                executor: E,
437                id: #primary_key_type,
438            ) -> Result<Option<Self>, sqlx::Error>
439            where
440                E: 'e + sqlx::Executor<'e, Database = sqlx::Sqlite>,
441            {
442                kali::builder::QueryBuilder::select_from(Self::TABLE_NAME)
443                    .columns(Self::COLUMNS)
444                    .filter(Self::PRIMARY_KEY.eq(id))
445                    .limit(1)
446                    .fetch_optional(executor)
447                    .await
448            }
449
450            #entity_vis async fn fetch_all<'e, E>(
451                executor: E,
452            ) -> Result<Vec<Self>, sqlx::Error>
453            where
454                E: 'e + sqlx::Executor<'e, Database = sqlx::Sqlite>,
455            {
456                kali::builder::QueryBuilder::select_from(Self::TABLE_NAME)
457                    .columns(Self::COLUMNS)
458                    .fetch_all(executor)
459                    .await
460            }
461
462            #entity_vis async fn delete_one<'e, E>(
463                executor: E,
464                id: #primary_key_type,
465            ) -> Result<sqlx::sqlite::SqliteQueryResult, sqlx::Error>
466            where
467                E: 'e + sqlx::Executor<'e, Database = sqlx::Sqlite>,
468            {
469                kali::builder::QueryBuilder::delete_from(Self::TABLE_NAME)
470                    .filter(Self::PRIMARY_KEY.eq(id))
471                    .execute(executor)
472                    .await
473            }
474
475            #[doc(hidden)]
476            #entity_vis fn __primary_key_value(&self) -> #primary_key_type {
477                self.#primary_key_name
478            }
479        }
480
481        impl kali::entity::Entity for #entity_name {
482            type C = #col_enum_name;
483
484            fn table_name() -> &'static str {
485                Self::TABLE_NAME
486            }
487
488            fn columns() -> &'static [#col_enum_name] {
489                Self::COLUMNS
490            }
491
492            fn primary_key() -> &'static #col_enum_name {
493                &Self::PRIMARY_KEY
494            }
495        }
496
497        #col_enum
498    })
499}
500
501fn to_snake_case(name: &str) -> String {
502    let mut result = String::new();
503    let mut prev_was_upper = false;
504
505    for (i, c) in name.chars().enumerate() {
506        if c.is_uppercase() {
507            if i != 0 && !prev_was_upper {
508                result.push('_');
509            }
510            result.push(c.to_ascii_lowercase());
511            prev_was_upper = true;
512        } else {
513            result.push(c);
514            prev_was_upper = false;
515        }
516    }
517
518    result
519}
520
521fn to_upper_camel_case(name: &str) -> String {
522    let mut result = String::new();
523    let mut capitalize_next = true;
524
525    for c in name.chars() {
526        if c == '_' {
527            capitalize_next = true;
528        } else if capitalize_next {
529            result.push(c.to_ascii_uppercase());
530            capitalize_next = false;
531        } else {
532            result.push(c);
533        }
534    }
535
536    result
537}