1use proc_macro2::TokenStream;
2use quote::quote;
3use syn::Ident;
4
5#[proc_macro_attribute]
6pub fn entity(
7 args: proc_macro::TokenStream,
8 input: proc_macro::TokenStream,
9) -> proc_macro::TokenStream {
10 match generate(args.into(), input.into()) {
11 Ok(output) => output.into(),
12 Err(e) => e.to_compile_error().into(),
13 }
14}
15
16#[derive(Clone)]
17enum Relation {
18 ForeignKey {
19 entity: Ident,
20 foreign_key_field: Ident,
21 references_field: Option<Ident>, },
23 ReferencedBy {
24 entity: Ident,
25 relation_field: Ident,
26 is_collection: bool,
27 },
28}
29
30#[derive(Clone)]
31struct ParsedField {
32 field_name: Ident,
33 iden_name: Ident,
34 is_pk: bool,
35 relation: Option<Relation>,
36 raw: syn::Field,
37}
38
39fn parse_entity_from_type(entity: &syn::Type) -> Result<(bool, Ident), syn::Error> {
41 if let syn::Type::Path(type_path) = entity {
42 if let Some(segment) = type_path.path.segments.last() {
43 let is_collection = segment.ident == "Collection";
44 if is_collection || segment.ident == "Reference" {
45 if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
46 if let Some(syn::GenericArgument::Type(ty)) = args.args.first() {
47 if let syn::Type::Path(type_path) = ty {
48 return Ok((
49 is_collection,
50 type_path.path.segments.last().unwrap().ident.clone(),
51 ));
52 }
53 }
54 }
55 }
56 }
57 }
58 Err(syn::Error::new_spanned(
59 entity,
60 "expected Collection<T> or Reference<T>",
61 ))
62}
63
64fn parse_fields(entity: syn::ItemStruct) -> Result<Vec<ParsedField>, syn::Error> {
65 entity
66 .fields
67 .into_iter()
68 .map(|mut f| {
69 let ident = f
70 .ident
71 .as_ref()
72 .ok_or_else(|| syn::Error::new_spanned(&f, "expected named field"))?;
73
74 let field_name = ident.clone();
75 let iden_name = Ident::new(
76 &to_upper_camel_case(&field_name.to_string()),
77 field_name.span(),
78 );
79
80 let mut is_pk = false;
81 let mut relation_attr = None;
82
83 f.attrs.retain(|attr| {
84 if attr.path().is_ident("primary_key") {
85 is_pk = true;
86 false
87 } else if attr.path().is_ident("relation") {
88 relation_attr = Some(attr.clone());
89 false
90 } else {
91 true
92 }
93 });
94
95 let relation = if let Some(relation_attr) = relation_attr {
96 let (is_collection, entity) = parse_entity_from_type(&f.ty)?;
97 let is_owning = relation_attr.path().is_ident("foreign_key");
98 if is_owning {
99 return Err(syn::Error::new_spanned(
100 &relation_attr,
101 "expected `#[relation(referenced_by = ...)]` for Collection<T>",
102 ));
103 }
104
105 let mut referenced_by = None;
106 let mut foreign_key = None;
107 let mut references = None;
108 relation_attr.parse_nested_meta(|meta| {
109 if meta.path.is_ident("referenced_by") {
110 let value = meta.value()?;
111 referenced_by = Some(value.parse()?);
112 Ok(())
113 } else if meta.path.is_ident("foreign_key") {
114 let value = meta.value()?;
115 foreign_key = Some(value.parse()?);
116 Ok(())
117 } else if meta.path.is_ident("references") {
118 let value = meta.value()?;
119 references = Some(value.parse()?);
120 Ok(())
121 } else {
122 return Err(syn::Error::new_spanned(
123 &meta.path,
124 "expected `referenced_by` or `foreign_key` attribute",
125 ));
126 }
127 })?;
128
129 match (referenced_by, foreign_key, references) {
130 (None, Some(fk), None) => {
131 Some(Relation::ForeignKey {
132 entity,
133 foreign_key_field: fk,
134 references_field: None,
135 })
136 }
137 (None, Some(fk), Some(refs)) => {
138 Some(Relation::ForeignKey {
139 entity,
140 foreign_key_field: fk,
141 references_field: Some(refs),
142 })
143 },
144 (Some(refs), None, None) => {
145 Some(Relation::ReferencedBy {
146 entity,
147 relation_field: refs,
148 is_collection,
149 })
150 }
151 _ => {
152 return Err(syn::Error::new_spanned(
153 &relation_attr,
154 "expected either `#[relation(referenced_by = ...)]` or `#[relation(foreign_key = ...)]`",
155 ));
156 }
157 }
158 } else {
159 None
160 };
161
162 Ok(ParsedField {
163 field_name,
164 iden_name,
165 is_pk,
166 relation,
167 raw: f,
168 })
169 })
170 .collect()
171}
172
173fn generate_entity_column_enum(
174 vis: &syn::Visibility,
175 entity_name: &syn::Ident,
176 parsed_fields: &[ParsedField],
177) -> syn::Result<(Ident, TokenStream)> {
178 let col_enum_ident = Ident::new(&format!("{}Column", entity_name), entity_name.span());
179
180 let col_enum_variants = parsed_fields
181 .iter()
182 .map(|f| f.iden_name.clone())
183 .collect::<Vec<_>>();
184
185 let field_name_mappings = parsed_fields
187 .iter()
188 .map(|f| {
189 let variant = &f.iden_name;
190 let field_name = &f.field_name;
191 quote! { #col_enum_ident::#variant => stringify!(#field_name) }
192 })
193 .collect::<Vec<_>>();
194
195 let col_enum = quote! {
196 #[derive(Debug, Clone, Copy, PartialEq, Eq)]
197 #vis enum #col_enum_ident {
198 #(#col_enum_variants),*
199 }
200
201 impl kali::column::Column for #col_enum_ident {
202 fn to_col_name(&self) -> &str {
203 match self {
204 #(#field_name_mappings),*
205 }
206 }
207 }
208 };
209
210 Ok((col_enum_ident, col_enum))
211}
212
213fn generate_entity_constants(
214 vis: &syn::Visibility,
215 entity_enum_ident: &Ident,
216 parsed_fields: &[ParsedField],
217 primary_key: &ParsedField,
218) -> syn::Result<TokenStream> {
219 let col_enum_variants = parsed_fields
220 .iter()
221 .map(|f| f.iden_name.clone())
222 .collect::<Vec<_>>();
223
224 let col_constants = parsed_fields
225 .iter()
226 .map(|f| {
227 let iden_name = &f.iden_name;
228 quote! {
229 #[allow(non_upper_case_globals)]
230 pub const #iden_name: #entity_enum_ident = #entity_enum_ident::#iden_name;
231 }
232 })
233 .collect::<Vec<_>>();
234
235 let primary_key_iden_name = &primary_key.iden_name;
236
237 Ok(quote! {
238 #vis const COLUMNS: &'static [#entity_enum_ident] = &[#(#entity_enum_ident::#col_enum_variants),*];
239 #vis const PRIMARY_KEY: #entity_enum_ident = #entity_enum_ident::#primary_key_iden_name;
240 #(#col_constants)*
241 })
242}
243
244fn generate_relation_functions(
245 entity_name: &Ident,
246 col_enum_name: &Ident,
247 vis: &syn::Visibility,
248 parsed_relations: &[ParsedField],
249) -> syn::Result<TokenStream> {
250 let relation_functions = parsed_relations
251 .iter()
252 .map(|f| {
253 let relation = f.relation.as_ref().unwrap();
254 match relation {
255 Relation::ForeignKey {
256 entity: inversed_entity,
257 foreign_key_field,
258 references_field,
259 } => {
260 let field_name = &f.field_name;
261 let references_field = references_field
262 .as_ref();
263
264 let inversed_primary_key_getter = match references_field {
265 Some(refs) => {
266 quote! { entity.#refs }
267 },
268 None => {
269 quote! { entity.__primary_key_value() }
270 },
271 };
272
273 let references_field_iden_ident = if let Some(refs) = references_field {
274 Ident::new(
275 &to_upper_camel_case(&refs.to_string()),
276 refs.span(),
277 )
278 } else {
279 Ident::new(
280 "PRIMARY_KEY",
281 foreign_key_field.span(),
282 )
283 };
284
285 let foreign_key_iden_ident = Ident::new(
286 &to_upper_camel_case(&foreign_key_field.to_string()),
287 foreign_key_field.span(),
288 );
289
290 let inversed_filter_name = Ident::new(
291 &format!("__{}_inversed_filter", field_name),
292 field_name.span(),
293 );
294
295
296 quote! {
297 #vis fn #field_name(&self) -> kali::reference::Reference<#inversed_entity> {
298 kali::reference::Reference::new(#inversed_entity::#references_field_iden_ident.eq(self.#foreign_key_field))
299 }
300
301 #[doc(hidden)]
305 #vis fn #inversed_filter_name<'a>(entity: &#inversed_entity) -> kali::builder::expr::Expr<'a, #col_enum_name> {
306 #entity_name::#foreign_key_iden_ident.eq(#inversed_primary_key_getter)
307 }
308 }
309 }
310 Relation::ReferencedBy {
311 entity: owning_entity,
312 relation_field,
313 is_collection,
314 } => {
315 let field_name = &f.field_name;
317 let inversed_filter_name = Ident::new(
318 &format!("__{}_inversed_filter", relation_field),
319 relation_field.span(),
320 );
321
322 let return_kind = if *is_collection {
323 quote! { kali::collection::Collection<#owning_entity> }
324 } else {
325 quote! { kali::reference::Reference<#owning_entity> }
326 };
327
328 let struct_kind = if *is_collection {
329 quote! { kali::collection::Collection }
330 } else {
331 quote! { kali::reference::Reference }
332 };
333
334 quote! {
335 #vis fn #field_name(&self) -> #return_kind {
336 #struct_kind::new(#owning_entity::#inversed_filter_name(self))
337 }
338 }
339
340 }
341 }
342 })
343 .collect::<Vec<_>>();
344
345 Ok(quote! {
346 #(#relation_functions)*
347 })
348}
349
350fn generate(args: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
351 let mut entity: syn::ItemStruct = syn::parse2(input)?; let table_name = if args.is_empty() {
355 let name = entity.ident.to_string();
356 let snake_case_name = to_snake_case(&name);
357 quote! { #snake_case_name }
358 } else {
359 let table_name: syn::LitStr = syn::parse2(args)?;
360 quote! { #table_name }
361 };
362
363 let entity_vis = entity.vis.clone();
364 let entity_name = entity.ident.clone();
365
366 let parsed_fields = parse_fields(entity.clone())?;
367 let (parsed_fields, relation_fields): (Vec<_>, Vec<_>) = parsed_fields
368 .into_iter()
369 .partition(|f| f.relation.is_none());
370
371 match entity.fields {
372 syn::Fields::Named(ref mut fields) => {
373 fields.named = parsed_fields.clone().into_iter().map(|f| f.raw).collect();
374 }
375 syn::Fields::Unnamed(_) => {
376 return Err(syn::Error::new_spanned(&entity, "expected named fields"));
377 }
378 syn::Fields::Unit => {
379 return Err(syn::Error::new_spanned(&entity, "expected named fields"));
380 }
381 }
382
383 let primary_key = parsed_fields
385 .iter()
386 .find(|f| f.is_pk)
387 .or_else(|| parsed_fields.iter().find(|f| f.field_name == "id"));
388
389 let Some(primary_key) = primary_key else {
390 return Err(syn::Error::new_spanned(
391 &entity,
392 "missing primary key field with #[primary_key] attribute or named 'id'",
393 ));
394 };
395
396 let primary_key_name = &primary_key.field_name;
397 let primary_key_type = &primary_key.raw.ty;
398 let (col_enum_name, col_enum) =
399 generate_entity_column_enum(&entity_vis, &entity_name, &parsed_fields)?;
400 let entity_constants =
401 generate_entity_constants(&entity_vis, &col_enum_name, &parsed_fields, primary_key)?;
402
403 let relation_functions = generate_relation_functions(
404 &entity_name,
405 &col_enum_name,
406 &entity_vis,
407 &relation_fields,
408 )?;
409
410 Ok(quote! {
411 #entity
412
413 #[allow(non_upper_case_globals)]
414 impl #entity_name {
415 #entity_vis const TABLE_NAME: &'static str = #table_name;
416 #entity_constants
417
418 #relation_functions
419
420 #entity_vis async fn fetch_one<'e, E>(
421 executor: E,
422 id: #primary_key_type,
423 ) -> Result<Self, sqlx::Error>
424 where
425 E: 'e + sqlx::Executor<'e, Database = sqlx::Sqlite>,
426 {
427 kali::builder::QueryBuilder::select_from(Self::TABLE_NAME)
428 .columns(Self::COLUMNS)
429 .filter(Self::PRIMARY_KEY.eq(id))
430 .limit(1)
431 .fetch_one(executor)
432 .await
433 }
434
435 #entity_vis async fn fetch_optional<'e, E>(
436 executor: E,
437 id: #primary_key_type,
438 ) -> Result<Option<Self>, sqlx::Error>
439 where
440 E: 'e + sqlx::Executor<'e, Database = sqlx::Sqlite>,
441 {
442 kali::builder::QueryBuilder::select_from(Self::TABLE_NAME)
443 .columns(Self::COLUMNS)
444 .filter(Self::PRIMARY_KEY.eq(id))
445 .limit(1)
446 .fetch_optional(executor)
447 .await
448 }
449
450 #entity_vis async fn fetch_all<'e, E>(
451 executor: E,
452 ) -> Result<Vec<Self>, sqlx::Error>
453 where
454 E: 'e + sqlx::Executor<'e, Database = sqlx::Sqlite>,
455 {
456 kali::builder::QueryBuilder::select_from(Self::TABLE_NAME)
457 .columns(Self::COLUMNS)
458 .fetch_all(executor)
459 .await
460 }
461
462 #entity_vis async fn delete_one<'e, E>(
463 executor: E,
464 id: #primary_key_type,
465 ) -> Result<sqlx::sqlite::SqliteQueryResult, sqlx::Error>
466 where
467 E: 'e + sqlx::Executor<'e, Database = sqlx::Sqlite>,
468 {
469 kali::builder::QueryBuilder::delete_from(Self::TABLE_NAME)
470 .filter(Self::PRIMARY_KEY.eq(id))
471 .execute(executor)
472 .await
473 }
474
475 #[doc(hidden)]
476 #entity_vis fn __primary_key_value(&self) -> #primary_key_type {
477 self.#primary_key_name
478 }
479 }
480
481 impl kali::entity::Entity for #entity_name {
482 type C = #col_enum_name;
483
484 fn table_name() -> &'static str {
485 Self::TABLE_NAME
486 }
487
488 fn columns() -> &'static [#col_enum_name] {
489 Self::COLUMNS
490 }
491
492 fn primary_key() -> &'static #col_enum_name {
493 &Self::PRIMARY_KEY
494 }
495 }
496
497 #col_enum
498 })
499}
500
501fn to_snake_case(name: &str) -> String {
502 let mut result = String::new();
503 let mut prev_was_upper = false;
504
505 for (i, c) in name.chars().enumerate() {
506 if c.is_uppercase() {
507 if i != 0 && !prev_was_upper {
508 result.push('_');
509 }
510 result.push(c.to_ascii_lowercase());
511 prev_was_upper = true;
512 } else {
513 result.push(c);
514 prev_was_upper = false;
515 }
516 }
517
518 result
519}
520
521fn to_upper_camel_case(name: &str) -> String {
522 let mut result = String::new();
523 let mut capitalize_next = true;
524
525 for c in name.chars() {
526 if c == '_' {
527 capitalize_next = true;
528 } else if capitalize_next {
529 result.push(c.to_ascii_uppercase());
530 capitalize_next = false;
531 } else {
532 result.push(c);
533 }
534 }
535
536 result
537}