1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::parse::Parser;
4use syn::{parse_macro_input, Attribute, Field, Fields, Ident, ItemStruct, Meta, Type};
5
6#[proc_macro_derive(Model, attributes(model, key, autoincrement, unique, index, has_many, belongs_to, many_to_many))]
7pub fn derive_model(_input: TokenStream) -> TokenStream {
8 TokenStream::from(quote! {
9 compile_error!("dbkit: use #[model] instead of #[derive(Model)]");
10 })
11}
12
13#[proc_macro_derive(DbEnum, attributes(dbkit))]
14pub fn derive_db_enum(input: TokenStream) -> TokenStream {
15 let input = parse_macro_input!(input as syn::ItemEnum);
16 match expand_db_enum(input) {
17 Ok(tokens) => tokens,
18 Err(err) => err.to_compile_error().into(),
19 }
20}
21
22#[proc_macro_attribute]
23pub fn model(attr: TokenStream, item: TokenStream) -> TokenStream {
24 let input = parse_macro_input!(item as ItemStruct);
25 let args = parse_macro_input!(attr with syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated);
26 let args = parse_model_args(args);
27 match expand_model(args, input) {
28 Ok(tokens) => tokens,
29 Err(err) => err.to_compile_error().into(),
30 }
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34enum RelationKind {
35 HasMany,
36 BelongsTo,
37 ManyToMany,
38}
39
40struct RelationInfo {
41 field: Field,
42 param_ident: Ident,
43 state_mod_ident: Ident,
44 child_type: Type,
45 kind: RelationKind,
46 belongs_to_key: Option<Ident>,
47 belongs_to_ref: Option<Ident>,
48 many_to_many_through: Option<Ident>,
49 many_to_many_left_key: Option<Ident>,
50 many_to_many_right_key: Option<Ident>,
51}
52
53struct ScalarFieldInfo {
54 field: Field,
55 ident: Ident,
56 ty: Type,
57 is_key: bool,
58 is_autoincrement: bool,
59}
60
61#[derive(Default)]
62struct ModelArgs {
63 table: Option<String>,
64 schema: Option<String>,
65}
66
67fn expand_model(args: ModelArgs, input: ItemStruct) -> syn::Result<TokenStream> {
68 if !input.generics.params.is_empty() {
69 return Err(syn::Error::new_spanned(
70 input.generics,
71 "dbkit: #[model] does not support generics yet",
72 ));
73 }
74
75 let struct_ident = input.ident;
76 let model_ident = format_ident!("{}Model", struct_ident);
77 let insert_ident = format_ident!("{}Insert", struct_ident);
78 let vis = input.vis;
79
80 let table_name = args.table.unwrap_or_else(|| to_snake_case(&struct_ident.to_string()));
81 let schema_name = args.schema;
82
83 let mut primary_keys: Vec<(Ident, Type)> = Vec::new();
84 let mut relation_fields = Vec::new();
85 let mut output_fields = Vec::new();
86 let mut insert_fields = Vec::new();
87 let mut scalar_fields = Vec::new();
88
89 let struct_attrs = filter_struct_attrs(&input.attrs);
90
91 let fields = match input.fields {
92 Fields::Named(named) => named.named,
93 _ => {
94 return Err(syn::Error::new_spanned(
95 struct_ident,
96 "dbkit: #[model] requires a struct with named fields",
97 ))
98 }
99 };
100
101 for field in fields {
102 let field_ident = field
103 .ident
104 .clone()
105 .ok_or_else(|| syn::Error::new_spanned(&field, "dbkit: unnamed field"))?;
106
107 let is_relation =
108 has_attr(&field.attrs, "has_many") || has_attr(&field.attrs, "belongs_to") || has_attr(&field.attrs, "many_to_many");
109
110 let is_key = has_attr(&field.attrs, "key");
111 let is_autoincrement = has_attr(&field.attrs, "autoincrement");
112
113 if is_key {
114 primary_keys.push((field_ident.clone(), field.ty.clone()));
115 }
116
117 if is_relation {
118 let (kind, child_type) = relation_type(&field)?;
119 let state_mod_ident = format_ident!("{}_{}_state", to_snake_case(&struct_ident.to_string()), field_ident);
120 let param_ident = format_ident!("{}Rel", to_camel_case(&field_ident.to_string()));
121 let (belongs_to_key, belongs_to_ref) = if kind == RelationKind::BelongsTo {
122 let (key, references) = parse_belongs_to_args(&field.attrs)?;
123 (Some(key), Some(references))
124 } else {
125 (None, None)
126 };
127 let (many_to_many_through, many_to_many_left_key, many_to_many_right_key) = if kind == RelationKind::ManyToMany {
128 let (through, left_key, right_key) = parse_many_to_many_args(&field.attrs)?;
129 (Some(through), Some(left_key), Some(right_key))
130 } else {
131 (None, None, None)
132 };
133
134 relation_fields.push(RelationInfo {
135 field: field.clone(),
136 param_ident: param_ident.clone(),
137 state_mod_ident,
138 child_type,
139 kind,
140 belongs_to_key,
141 belongs_to_ref,
142 many_to_many_through,
143 many_to_many_left_key,
144 many_to_many_right_key,
145 });
146
147 let cleaned_field = Field {
148 attrs: filter_field_attrs(&field.attrs),
149 ty: syn::parse_quote!(#param_ident),
150 ..field
151 };
152 output_fields.push(cleaned_field);
153 continue;
154 }
155
156 let cleaned_field = Field {
157 attrs: filter_field_attrs(&field.attrs),
158 ..field.clone()
159 };
160 output_fields.push(cleaned_field.clone());
161
162 if !(is_key && is_autoincrement) {
163 insert_fields.push(cleaned_field.clone());
164 }
165
166 scalar_fields.push(ScalarFieldInfo {
167 field: cleaned_field,
168 ident: field_ident,
169 ty: field.ty.clone(),
170 is_key,
171 is_autoincrement,
172 });
173 }
174
175 let table_expr = if let Some(schema) = schema_name {
176 quote!(::dbkit::Table::new(#table_name).with_schema(#schema))
177 } else {
178 quote!(::dbkit::Table::new(#table_name))
179 };
180
181 if relation_fields.iter().any(|rel| rel.kind == RelationKind::ManyToMany) && primary_keys.len() != 1 {
182 return Err(syn::Error::new_spanned(
183 struct_ident,
184 "dbkit: many-to-many requires exactly one #[key] on the parent model",
185 ));
186 }
187
188 let generics_with_defaults = relation_fields
189 .iter()
190 .map(|rel| {
191 let ident = &rel.param_ident;
192 let state_mod = &rel.state_mod_ident;
193 quote!(#ident: #state_mod::State = ::dbkit::NotLoaded)
194 })
195 .collect::<Vec<_>>();
196
197 let impl_generics_params = relation_fields
198 .iter()
199 .map(|rel| {
200 let ident = &rel.param_ident;
201 let state_mod = &rel.state_mod_ident;
202 quote!(#ident: #state_mod::State)
203 })
204 .collect::<Vec<_>>();
205
206 let generic_idents = relation_fields.iter().map(|rel| &rel.param_ident).collect::<Vec<_>>();
207
208 let struct_generics = if generics_with_defaults.is_empty() {
209 quote!()
210 } else {
211 quote!(<#(#generics_with_defaults),*>)
212 };
213
214 let impl_generics = if impl_generics_params.is_empty() {
215 quote!()
216 } else {
217 quote!(<#(#impl_generics_params),*>)
218 };
219
220 let struct_type_args = if generic_idents.is_empty() {
221 quote!()
222 } else {
223 quote!(<#(#generic_idents),*>)
224 };
225
226 let columns = output_fields
227 .iter()
228 .filter(|field| !is_relation_field(field, &relation_fields))
229 .map(|field| {
230 let ident = field.ident.as_ref().expect("field ident");
231 let name = ident.to_string();
232 let ty = option_inner_type(&field.ty).unwrap_or_else(|| field.ty.clone());
233 quote!(pub const #ident: ::dbkit::Column<#struct_ident, #ty> = ::dbkit::Column::new(Self::TABLE, #name);)
234 })
235 .collect::<Vec<_>>();
236
237 let column_refs = output_fields
238 .iter()
239 .filter(|field| !is_relation_field(field, &relation_fields))
240 .map(|field| {
241 let ident = field.ident.as_ref().expect("field ident");
242 quote!(Self::#ident.as_ref())
243 })
244 .collect::<Vec<_>>();
245
246 let columns_const = quote!(
247 pub const COLUMNS: &'static [::dbkit::ColumnRef] = &[#(#column_refs),*];
248 );
249
250 let primary_key_refs = primary_keys
251 .iter()
252 .map(|(ident, _)| quote!(Self::#ident.as_ref()))
253 .collect::<Vec<_>>();
254
255 let primary_keys_const = if primary_keys.is_empty() {
256 quote!(
257 pub const PRIMARY_KEYS: &'static [::dbkit::ColumnRef] = &[];
258 )
259 } else {
260 quote!(pub const PRIMARY_KEYS: &'static [::dbkit::ColumnRef] = &[#(#primary_key_refs),*];)
261 };
262
263 let insert_values = insert_fields.iter().map(|field| {
264 let ident = field.ident.as_ref().expect("field ident");
265 quote!(insert = insert.value(Self::#ident, values.#ident);)
266 });
267 let insert_field_idents = insert_fields
268 .iter()
269 .map(|field| field.ident.as_ref().expect("field ident"))
270 .collect::<Vec<_>>();
271
272 let active_ident = format_ident!("{}Active", struct_ident);
273
274 let active_fields = scalar_fields.iter().map(|field| {
275 let ident = &field.ident;
276 let vis = &field.field.vis;
277 let ty = option_inner_type(&field.ty).unwrap_or_else(|| field.ty.clone());
278 quote!(#vis #ident: ::dbkit::ActiveValue<#ty>)
279 });
280
281 let active_from_model = scalar_fields.iter().map(|field| {
282 let ident = &field.ident;
283 if option_inner_type(&field.ty).is_some() {
284 quote!(#ident: ::dbkit::ActiveValue::unchanged_option(#ident))
285 } else {
286 quote!(#ident: ::dbkit::ActiveValue::unchanged(#ident))
287 }
288 });
289
290 let active_destructure = scalar_fields.iter().map(|field| field.ident.clone()).collect::<Vec<_>>();
291
292 let active_insert_steps = scalar_fields.iter().map(|field| {
293 let ident = &field.ident;
294 let name = ident.to_string();
295 let ty = option_inner_type(&field.ty).unwrap_or_else(|| field.ty.clone());
296 let is_option = option_inner_type(&field.ty).is_some();
297 let required = !field.is_autoincrement && !is_option;
298 let required_check = if required {
299 quote!(return Err(::dbkit::Error::Decode(format!("missing required field: {}", #name)));)
300 } else {
301 quote!()
302 };
303 quote!(
304 match #ident {
305 ::dbkit::ActiveValue::Unset => {
306 #required_check
307 }
308 ::dbkit::ActiveValue::Set(value) => {
309 insert = insert.value(#struct_ident::#ident, value);
310 }
311 ::dbkit::ActiveValue::Unchanged(value) => {
312 insert = insert.value(#struct_ident::#ident, value);
313 }
314 ::dbkit::ActiveValue::UnchangedNull => {
315 insert = insert.value(#struct_ident::#ident, None::<#ty>);
316 }
317 ::dbkit::ActiveValue::Null => {
318 insert = insert.value(#struct_ident::#ident, None::<#ty>);
319 }
320 }
321 )
322 });
323
324 let active_insert_fn = quote!(
325 pub async fn insert(
326 self,
327 ex: &(impl ::dbkit::Executor + Send + Sync),
328 ) -> Result<#struct_ident, ::dbkit::Error> {
329 let Self { #(#active_destructure,)* } = self;
330 let mut insert = ::dbkit::Insert::new(#struct_ident::TABLE);
331 #(#active_insert_steps)*
332 let insert = insert.returning_all();
333 let row = ::dbkit::InsertExt::one(insert, ex).await?;
334 row.ok_or(::dbkit::Error::NotFound)
335 }
336 );
337
338 let pk_idents = primary_keys.iter().map(|(ident, _)| ident.clone()).collect::<Vec<_>>();
339
340 let active_update_fn = if !primary_keys.is_empty() {
341 let pk_vars = primary_keys
342 .iter()
343 .enumerate()
344 .map(|(idx, _)| format_ident!("pk_value_{}", idx))
345 .collect::<Vec<_>>();
346 let pk_extracts = primary_keys.iter().zip(pk_vars.iter()).map(|((ident, _), var)| {
347 let pk_name = ident.to_string();
348 quote!(
349 let #var = match #ident {
350 ::dbkit::ActiveValue::Set(value) | ::dbkit::ActiveValue::Unchanged(value) => value,
351 ::dbkit::ActiveValue::Null | ::dbkit::ActiveValue::Unset | ::dbkit::ActiveValue::UnchangedNull => {
352 return Err(::dbkit::Error::Decode(format!(
353 "missing required field: {}",
354 #pk_name
355 )));
356 }
357 };
358 )
359 });
360 let pk_filters = primary_keys
361 .iter()
362 .zip(pk_vars.iter())
363 .map(|((ident, _), var)| quote!(update = update.filter(#struct_ident::#ident.eq(#var));));
364 let update_steps = scalar_fields.iter().filter(|field| !field.is_key).map(|field| {
365 let ident = &field.ident;
366 let ty = option_inner_type(&field.ty).unwrap_or_else(|| field.ty.clone());
367 quote!(
368 match #ident {
369 ::dbkit::ActiveValue::Unset => {}
370 ::dbkit::ActiveValue::Set(value) => {
371 update = update.set(#struct_ident::#ident, value);
372 any_set = true;
373 }
374 ::dbkit::ActiveValue::Unchanged(_) | ::dbkit::ActiveValue::UnchangedNull => {}
375 ::dbkit::ActiveValue::Null => {
376 update = update.set(#struct_ident::#ident, None::<#ty>);
377 any_set = true;
378 }
379 }
380 )
381 });
382 quote!(
383 pub async fn update(
384 self,
385 ex: &(impl ::dbkit::Executor + Send + Sync),
386 ) -> Result<#struct_ident, ::dbkit::Error> {
387 let Self { #(#active_destructure,)* } = self;
388 #(#pk_extracts)*
389 let mut update = ::dbkit::Update::new(#struct_ident::TABLE);
390 let mut any_set = false;
391 #(#update_steps)*
392 if !any_set {
393 return Err(::dbkit::Error::Decode("no fields set for update".to_string()));
394 }
395 #(#pk_filters)*
396 let update = update.returning_all();
397 let mut rows = ::dbkit::UpdateExt::all(update, ex).await?;
398 rows.pop().ok_or(::dbkit::Error::NotFound)
399 }
400 )
401 } else {
402 quote!()
403 };
404
405 let active_delete_fn = if !primary_keys.is_empty() {
406 let pk_vars = primary_keys
407 .iter()
408 .enumerate()
409 .map(|(idx, _)| format_ident!("pk_value_{}", idx))
410 .collect::<Vec<_>>();
411 let pk_extracts = primary_keys.iter().zip(pk_vars.iter()).map(|((ident, _), var)| {
412 let pk_name = ident.to_string();
413 quote!(
414 let #var = match #ident {
415 ::dbkit::ActiveValue::Set(value) | ::dbkit::ActiveValue::Unchanged(value) => value,
416 ::dbkit::ActiveValue::Null | ::dbkit::ActiveValue::Unset | ::dbkit::ActiveValue::UnchangedNull => {
417 return Err(::dbkit::Error::Decode(format!(
418 "missing required field: {}",
419 #pk_name
420 )));
421 }
422 };
423 )
424 });
425 let pk_filters = primary_keys
426 .iter()
427 .zip(pk_vars.iter())
428 .map(|((ident, _), var)| quote!(delete = delete.filter(#struct_ident::#ident.eq(#var));));
429 quote!(
430 pub async fn delete(
431 self,
432 ex: &(impl ::dbkit::Executor + Send + Sync),
433 ) -> Result<u64, ::dbkit::Error> {
434 let Self { #(#pk_idents,)* .. } = self;
435 #(#pk_extracts)*
436 let mut delete = ::dbkit::Delete::new(#struct_ident::TABLE);
437 #(#pk_filters)*
438 ::dbkit::DeleteExt::execute(delete, ex).await
439 }
440 )
441 } else {
442 quote!()
443 };
444
445 let active_save_flag_checks = scalar_fields.iter().map(|field| {
446 let ident = &field.ident;
447 quote!(
448 match &#ident {
449 ::dbkit::ActiveValue::Unchanged(_) | ::dbkit::ActiveValue::UnchangedNull => {
450 any_loaded = true;
451 }
452 ::dbkit::ActiveValue::Set(_) | ::dbkit::ActiveValue::Null => {
453 any_changed = true;
454 }
455 ::dbkit::ActiveValue::Unset => {}
456 }
457 )
458 });
459
460 let active_save_model_fields = scalar_fields.iter().map(|field| {
461 let ident = &field.ident;
462 let name = ident.to_string();
463 if option_inner_type(&field.ty).is_some() {
464 quote!(
465 #ident: match #ident {
466 ::dbkit::ActiveValue::Set(value) | ::dbkit::ActiveValue::Unchanged(value) => Some(value),
467 ::dbkit::ActiveValue::Null | ::dbkit::ActiveValue::UnchangedNull => None,
468 ::dbkit::ActiveValue::Unset => {
469 return Err(::dbkit::Error::Decode(format!(
470 "missing required field: {}",
471 #name
472 )));
473 }
474 },
475 )
476 } else {
477 quote!(
478 #ident: match #ident {
479 ::dbkit::ActiveValue::Set(value) | ::dbkit::ActiveValue::Unchanged(value) => value,
480 ::dbkit::ActiveValue::Null
481 | ::dbkit::ActiveValue::Unset
482 | ::dbkit::ActiveValue::UnchangedNull => {
483 return Err(::dbkit::Error::Decode(format!(
484 "missing required field: {}",
485 #name
486 )));
487 }
488 },
489 )
490 }
491 });
492
493 let active_save_relation_defaults = relation_fields.iter().map(|rel| {
494 let ident = rel.field.ident.as_ref().expect("field ident");
495 quote!(#ident: Default::default(),)
496 });
497
498 let active_save_update_branch = if !primary_keys.is_empty() {
499 quote!(return Self { #(#active_destructure,)* }.update(ex).await;)
500 } else {
501 quote!(
502 return Err(::dbkit::Error::Decode(
503 "update requires primary key".to_string(),
504 ));
505 )
506 };
507
508 let active_save_fn = quote!(
509 pub async fn save(
510 self,
511 ex: &(impl ::dbkit::Executor + Send + Sync),
512 ) -> Result<#struct_ident, ::dbkit::Error> {
513 let Self { #(#active_destructure,)* } = self;
514 let mut any_loaded = false;
515 let mut any_changed = false;
516 #(#active_save_flag_checks)*
517
518 if any_loaded {
519 if any_changed {
520 #active_save_update_branch
521 }
522 let model = #struct_ident {
523 #(#active_save_model_fields)*
524 #(#active_save_relation_defaults)*
525 };
526 return Ok(model);
527 }
528
529 Self { #(#active_destructure,)* }.insert(ex).await
530 }
531 );
532
533 let model_delete_impl = if !primary_keys.is_empty() {
534 let pk_filters = primary_keys
535 .iter()
536 .map(|(ident, _)| quote!(delete = delete.filter(Self::#ident.eq(#ident));));
537 quote!(
538 impl #impl_generics ::dbkit::ModelDelete for #model_ident #struct_type_args {
539 fn delete<'e, E>(self, ex: &'e E) -> ::dbkit::executor::BoxFuture<'e, Result<u64, ::dbkit::Error>>
540 where
541 E: ::dbkit::Executor + Send + Sync + 'e,
542 {
543 let Self { #(#pk_idents,)* .. } = self;
544 let mut delete = ::dbkit::Delete::new(Self::TABLE);
545 #(#pk_filters)*
546 ::dbkit::DeleteExt::execute(delete, ex)
547 }
548 }
549 )
550 } else {
551 quote!()
552 };
553
554 let into_active_fn = quote!(
555 pub fn into_active(self) -> #active_ident {
556 let Self { #(#active_destructure,)* .. } = self;
557 #active_ident {
558 #(#active_from_model,)*
559 }
560 }
561 );
562
563 let primary_key_const = if primary_keys.len() == 1 {
564 let (ident, ty) = primary_keys.first().expect("primary key length checked");
565 let name = ident.to_string();
566 Some(quote!(pub const PRIMARY_KEY: ::dbkit::Column<#struct_ident, #ty> = ::dbkit::Column::new(Self::TABLE, #name);))
567 } else {
568 None
569 };
570
571 let by_id_fn = if primary_keys.len() == 1 {
572 let (ident, ty) = primary_keys.first().expect("primary key length checked");
573 Some(quote!(
574 pub fn by_id(id: #ty) -> ::dbkit::Select<#struct_ident> {
575 Self::query().filter(Self::#ident.eq(id)).limit(1)
576 }
577 ))
578 } else {
579 None
580 };
581
582 let any_state_ident = format_ident!("{}AnyState", struct_ident);
583
584 let relation_state_modules = relation_fields.iter().map(|rel| {
585 let state_mod = &rel.state_mod_ident;
586 let (sealed_impl, state_impl) = match rel.kind {
587 RelationKind::HasMany | RelationKind::ManyToMany => (
588 quote!(
589 impl<T> Sealed for Vec<T> {}
590 ),
591 quote!(
592 impl<T> State for Vec<T> {}
593 ),
594 ),
595 RelationKind::BelongsTo => (
596 quote!(
597 impl<T> Sealed for Option<T> {}
598 ),
599 quote!(
600 impl<T> State for Option<T> {}
601 ),
602 ),
603 };
604 quote!(
605 pub mod #state_mod {
606 mod sealed {
607 pub trait Sealed {}
608 impl Sealed for ::dbkit::NotLoaded {}
609 #sealed_impl
610 }
611 pub trait State: sealed::Sealed {}
612 impl State for ::dbkit::NotLoaded {}
613 #state_impl
614 }
615 )
616 });
617
618 let relation_methods = relation_fields.iter().map(|rel| {
619 let field_ident = rel.field.ident.as_ref().expect("field ident");
620 let method_ident = format_ident!("{}_loaded", field_ident);
621 let item_ident = format_ident!("{}Item", to_camel_case(&field_ident.to_string()));
622 let loaded_type: Type = match rel.kind {
623 RelationKind::HasMany | RelationKind::ManyToMany => syn::parse_quote!(Vec<#item_ident>),
624 RelationKind::BelongsTo => syn::parse_quote!(Option<#item_ident>),
625 };
626
627 let mut other_params = Vec::new();
628 let mut type_params = Vec::new();
629 for other in &relation_fields {
630 if other.field.ident == rel.field.ident {
631 type_params.push(quote!(#loaded_type));
632 } else {
633 let ident = &other.param_ident;
634 let state_mod = &other.state_mod_ident;
635 other_params.push(quote!(#ident: #state_mod::State));
636 type_params.push(quote!(#ident));
637 }
638 }
639
640 let mut impl_params = Vec::new();
641 impl_params.push(quote!(#item_ident));
642 impl_params.extend(other_params);
643
644 let impl_generics = if impl_params.is_empty() {
645 quote!()
646 } else {
647 quote!(<#(#impl_params),*>)
648 };
649 let type_args = if type_params.is_empty() {
650 quote!()
651 } else {
652 quote!(<#(#type_params),*>)
653 };
654
655 let (return_ty, body) = match rel.kind {
656 RelationKind::HasMany | RelationKind::ManyToMany => (quote!(&[#item_ident]), quote!(&self.#field_ident)),
657 RelationKind::BelongsTo => (quote!(Option<&#item_ident>), quote!(self.#field_ident.as_ref())),
658 };
659
660 quote!(
661 impl #impl_generics #model_ident #type_args {
662 pub fn #method_ident(&self) -> #return_ty {
663 #body
664 }
665 }
666 )
667 });
668
669 let model_value_arms = output_fields
670 .iter()
671 .filter(|field| !is_relation_field(field, &relation_fields))
672 .map(|field| {
673 let ident = field.ident.as_ref().expect("field ident");
674 let name = ident.to_string();
675 quote!(#name => Some(self.#ident.clone().into()),)
676 });
677
678 let model_value_impl = quote!(
679 impl #impl_generics ::dbkit::ModelValue for #model_ident #struct_type_args {
680 fn column_value(&self, column: ::dbkit::ColumnRef) -> Option<::dbkit::Value> {
681 if column.table.name != Self::TABLE.name {
682 return None;
683 }
684 match column.name {
685 #(#model_value_arms)*
686 _ => None,
687 }
688 }
689 }
690 );
691
692 let from_row_generics = relation_fields.iter().map(|rel| {
693 let ident = &rel.param_ident;
694 let state_mod = &rel.state_mod_ident;
695 quote!(#ident: #state_mod::State + Default)
696 });
697
698 let from_row_impl_generics = if relation_fields.is_empty() {
699 quote!(<'r>)
700 } else {
701 quote!(<'r, #(#from_row_generics),*>)
702 };
703
704 let from_row_fields = output_fields.iter().map(|field| {
705 let ident = field.ident.as_ref().expect("field ident");
706 if is_relation_field(field, &relation_fields) {
707 quote!(#ident: Default::default())
708 } else {
709 let name = ident.to_string();
710 quote!(#ident: ::dbkit::sqlx::Row::try_get(row, #name)?)
711 }
712 });
713
714 let from_row_impl = quote!(
715 impl #from_row_impl_generics ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow>
716 for #model_ident #struct_type_args
717 {
718 fn from_row(row: &'r ::dbkit::sqlx::postgres::PgRow) -> Result<Self, ::dbkit::sqlx::Error> {
719 Ok(Self {
720 #(#from_row_fields,)*
721 })
722 }
723 }
724 );
725
726 let joined_from_row_fields = output_fields.iter().map(|field| {
727 let ident = field.ident.as_ref().expect("field ident");
728 if is_relation_field(field, &relation_fields) {
729 quote!(#ident: Default::default())
730 } else {
731 let name = ident.to_string();
732 quote!(
733 #ident: {
734 let column = format!("{}{}", prefix, #name);
735 ::dbkit::sqlx::Row::try_get(row, column.as_str())?
736 }
737 )
738 }
739 });
740
741 let joined_pk_checks = if primary_keys.is_empty() {
742 if let Some(first_field) = scalar_fields.first() {
743 let name = first_field.ident.to_string();
744 let ty = option_inner_type(&first_field.ty).unwrap_or_else(|| first_field.ty.clone());
745 quote!(
746 let value: Option<#ty> = {
747 let column = format!("{}{}", prefix, #name);
748 ::dbkit::sqlx::Row::try_get(row, column.as_str())?
749 };
750 Ok(value.is_some())
751 )
752 } else {
753 quote!(Ok(false))
754 }
755 } else {
756 let checks = primary_keys.iter().map(|(ident, ty)| {
757 let name = ident.to_string();
758 let ty = option_inner_type(ty).unwrap_or_else(|| ty.clone());
759 quote!(
760 let value: Option<#ty> = {
761 let column = format!("{}{}", prefix, #name);
762 ::dbkit::sqlx::Row::try_get(row, column.as_str())?
763 };
764 if value.is_some() {
765 return Ok(true);
766 }
767 )
768 });
769 quote!(
770 #(#checks)*
771 Ok(false)
772 )
773 };
774
775 let joined_model_impl = quote!(
776 impl #from_row_impl_generics ::dbkit::JoinedModel for #model_ident #struct_type_args {
777 fn joined_columns() -> &'static [::dbkit::ColumnRef] {
778 Self::COLUMNS
779 }
780
781 fn joined_primary_keys() -> &'static [::dbkit::ColumnRef] {
782 Self::PRIMARY_KEYS
783 }
784
785 fn joined_from_row_prefixed(
786 row: &::dbkit::sqlx::postgres::PgRow,
787 prefix: &str,
788 ) -> Result<Self, ::dbkit::sqlx::Error> {
789 Ok(Self {
790 #(#joined_from_row_fields,)*
791 })
792 }
793
794 fn joined_row_has_pk(
795 row: &::dbkit::sqlx::postgres::PgRow,
796 prefix: &str,
797 ) -> Result<bool, ::dbkit::sqlx::Error> {
798 #joined_pk_checks
799 }
800 }
801 );
802
803 let set_relation_impls = relation_fields.iter().map(|rel| {
804 let field_ident = rel.field.ident.as_ref().expect("field ident");
805 let child_type = &rel.child_type;
806 let item_ident = format_ident!("{}Item", to_camel_case(&field_ident.to_string()));
807 let (value_ty, rel_ty) = match rel.kind {
808 RelationKind::HasMany => (quote!(Vec<#item_ident>), quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>)),
809 RelationKind::ManyToMany => {
810 let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
811 (
812 quote!(Vec<#item_ident>),
813 quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>),
814 )
815 }
816 RelationKind::BelongsTo => (
817 quote!(Option<#item_ident>),
818 quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
819 ),
820 };
821
822 let mut other_params = Vec::new();
823 let mut type_params = Vec::new();
824 for other in &relation_fields {
825 if other.field.ident == rel.field.ident {
826 type_params.push(value_ty.clone());
827 } else {
828 let ident = &other.param_ident;
829 let state_mod = &other.state_mod_ident;
830 other_params.push(quote!(#ident: #state_mod::State));
831 type_params.push(quote!(#ident));
832 }
833 }
834
835 let mut impl_params = Vec::new();
836 impl_params.push(quote!(#item_ident));
837 impl_params.extend(other_params);
838
839 let impl_generics = if impl_params.is_empty() {
840 quote!()
841 } else {
842 quote!(<#(#impl_params),*>)
843 };
844 let type_args = if type_params.is_empty() {
845 quote!()
846 } else {
847 quote!(<#(#type_params),*>)
848 };
849
850 quote!(
851 impl #impl_generics ::dbkit::SetRelation<#rel_ty, #value_ty> for #model_ident #type_args {
852 fn set_relation(&mut self, _rel: #rel_ty, value: #value_ty) -> Result<(), ::dbkit::Error> {
853 self.#field_ident = value;
854 Ok(())
855 }
856 }
857 )
858 });
859
860 let get_relation_impls = relation_fields.iter().map(|rel| {
861 let field_ident = rel.field.ident.as_ref().expect("field ident");
862 let child_type = &rel.child_type;
863 let item_ident = format_ident!("{}Item", to_camel_case(&field_ident.to_string()));
864 let (value_ty, rel_ty) = match rel.kind {
865 RelationKind::HasMany => (quote!(Vec<#item_ident>), quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>)),
866 RelationKind::ManyToMany => {
867 let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
868 (
869 quote!(Vec<#item_ident>),
870 quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>),
871 )
872 }
873 RelationKind::BelongsTo => (
874 quote!(Option<#item_ident>),
875 quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
876 ),
877 };
878
879 let mut other_params = Vec::new();
880 let mut type_params = Vec::new();
881 for other in &relation_fields {
882 if other.field.ident == rel.field.ident {
883 type_params.push(value_ty.clone());
884 } else {
885 let ident = &other.param_ident;
886 let state_mod = &other.state_mod_ident;
887 other_params.push(quote!(#ident: #state_mod::State));
888 type_params.push(quote!(#ident));
889 }
890 }
891
892 let mut impl_params = Vec::new();
893 impl_params.push(quote!(#item_ident));
894 impl_params.extend(other_params);
895
896 let impl_generics = if impl_params.is_empty() {
897 quote!()
898 } else {
899 quote!(<#(#impl_params),*>)
900 };
901 let type_args = if type_params.is_empty() {
902 quote!()
903 } else {
904 quote!(<#(#type_params),*>)
905 };
906
907 quote!(
908 impl #impl_generics ::dbkit::GetRelation<#rel_ty, #value_ty> for #model_ident #type_args {
909 fn get_relation(&self, _rel: #rel_ty) -> Option<&#value_ty> {
910 Some(&self.#field_ident)
911 }
912
913 fn get_relation_mut(&mut self, _rel: #rel_ty) -> Option<&mut #value_ty> {
914 Some(&mut self.#field_ident)
915 }
916 }
917 )
918 });
919
920 let load_method = quote!(
921 pub async fn load<Rel>(
922 self,
923 rel: Rel,
924 ex: &(impl ::dbkit::Executor + Send + Sync),
925 ) -> Result<<Self as ::dbkit::LoadRelation<Rel>>::Out, ::dbkit::Error>
926 where
927 Self: ::dbkit::LoadRelation<Rel>,
928 {
929 ::dbkit::LoadRelation::load_relation(self, rel, ex).await
930 }
931 );
932
933 let load_relation_impls = relation_fields.iter().map(|rel| {
934 let field_ident = rel.field.ident.as_ref().expect("field ident");
935 let child_type = &rel.child_type;
936 let rel_type = match rel.kind {
937 RelationKind::HasMany => quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>),
938 RelationKind::BelongsTo => quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
939 RelationKind::ManyToMany => {
940 let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
941 quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>)
942 }
943 };
944 let loaded_type = match rel.kind {
945 RelationKind::HasMany | RelationKind::ManyToMany => quote!(Vec<#child_type>),
946 RelationKind::BelongsTo => quote!(Option<#child_type>),
947 };
948 let loader_fn = match rel.kind {
949 RelationKind::HasMany => quote!(::dbkit::runtime::load_selectin_has_many),
950 RelationKind::ManyToMany => quote!(::dbkit::runtime::load_selectin_many_to_many),
951 RelationKind::BelongsTo => quote!(::dbkit::runtime::load_selectin_belongs_to),
952 };
953
954 let mut other_params = Vec::new();
955 let mut type_params = Vec::new();
956 let mut out_params = Vec::new();
957 for other in &relation_fields {
958 if other.field.ident == rel.field.ident {
959 type_params.push(quote!(::dbkit::NotLoaded));
960 out_params.push(loaded_type.clone());
961 } else {
962 let ident = &other.param_ident;
963 let state_mod = &other.state_mod_ident;
964 other_params.push(quote!(#ident: #state_mod::State + Send + 'static));
965 type_params.push(quote!(#ident));
966 out_params.push(quote!(#ident));
967 }
968 }
969
970 let impl_generics = if other_params.is_empty() {
971 quote!()
972 } else {
973 quote!(<#(#other_params),*>)
974 };
975 let type_args = if type_params.is_empty() {
976 quote!()
977 } else {
978 quote!(<#(#type_params),*>)
979 };
980 let out_type = if out_params.is_empty() {
981 quote!(#model_ident)
982 } else {
983 quote!(#model_ident<#(#out_params),*>)
984 };
985 let out_construct = if out_params.is_empty() {
986 quote!(#model_ident)
987 } else {
988 quote!(#model_ident::<#(#out_params),*>)
989 };
990
991 let destructure_fields = output_fields.iter().map(|field| {
992 let ident = field.ident.as_ref().expect("field ident");
993 if ident == field_ident {
994 quote!(#ident: _)
995 } else {
996 quote!(#ident)
997 }
998 });
999
1000 let build_fields = output_fields.iter().map(|field| {
1001 let ident = field.ident.as_ref().expect("field ident");
1002 if ident == field_ident {
1003 quote!(#ident: Default::default())
1004 } else {
1005 quote!(#ident)
1006 }
1007 });
1008
1009 quote!(
1010 impl #impl_generics ::dbkit::LoadRelation<#rel_type> for #model_ident #type_args {
1011 type Out = #out_type;
1012
1013 fn load_relation<'e, E>(
1014 self,
1015 rel: #rel_type,
1016 ex: &'e E,
1017 ) -> ::dbkit::executor::BoxFuture<'e, Result<Self::Out, ::dbkit::Error>>
1018 where
1019 E: ::dbkit::Executor + Send + Sync + 'e,
1020 {
1021 Box::pin(async move {
1022 let Self { #(#destructure_fields,)* } = self;
1023 let mut out = #out_construct {
1024 #(#build_fields,)*
1025 };
1026 let mut rows = vec![out];
1027 #loader_fn(ex, &mut rows, rel, &::dbkit::load::NoLoad).await?;
1028 Ok(rows.pop().expect("loaded row"))
1029 })
1030 }
1031 }
1032 )
1033 });
1034
1035 let relation_consts = relation_fields.iter().filter_map(|rel| {
1036 let field_ident = rel.field.ident.as_ref().expect("field ident");
1037 let child_type = &rel.child_type;
1038 match rel.kind {
1039 RelationKind::HasMany => Some(quote!(
1040 pub const #field_ident: ::dbkit::rel::HasMany<#struct_ident, #child_type> =
1041 ::dbkit::rel::HasMany::new(
1042 <#child_type as ::dbkit::rel::BelongsToSpec<#struct_ident>>::PARENT_TABLE,
1043 <#child_type as ::dbkit::rel::BelongsToSpec<#struct_ident>>::CHILD_TABLE,
1044 <#child_type as ::dbkit::rel::BelongsToSpec<#struct_ident>>::PARENT_KEY,
1045 <#child_type as ::dbkit::rel::BelongsToSpec<#struct_ident>>::CHILD_KEY,
1046 );
1047 )),
1048 RelationKind::BelongsTo => {
1049 let key = rel.belongs_to_key.as_ref().expect("belongs_to key");
1050 let references = rel.belongs_to_ref.as_ref().expect("belongs_to references");
1051 Some(quote!(
1052 pub const #field_ident: ::dbkit::rel::BelongsTo<#struct_ident, #child_type> =
1053 ::dbkit::rel::BelongsTo::new(
1054 Self::TABLE,
1055 #child_type::TABLE,
1056 Self::#key.as_ref(),
1057 #child_type::#references.as_ref(),
1058 );
1059 ))
1060 }
1061 RelationKind::ManyToMany => {
1062 let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
1063 let left_key = rel.many_to_many_left_key.as_ref().expect("many-to-many left_key");
1064 let right_key = rel.many_to_many_right_key.as_ref().expect("many-to-many right_key");
1065 let parent_pk = primary_keys.first().map(|(ident, _)| ident).expect("many-to-many parent pk");
1066 Some(quote!(
1067 pub const #field_ident: ::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through> =
1068 ::dbkit::rel::ManyToMany::new(
1069 Self::TABLE,
1070 #child_type::TABLE,
1071 #through::TABLE,
1072 Self::#parent_pk.as_ref(),
1073 #child_type::PRIMARY_KEY.as_ref(),
1074 #through::#left_key.as_ref(),
1075 #through::#right_key.as_ref(),
1076 );
1077 ))
1078 }
1079 }
1080 });
1081
1082 let belongs_to_specs = relation_fields.iter().filter_map(|rel| {
1083 if rel.kind != RelationKind::BelongsTo {
1084 return None;
1085 }
1086 let parent_type = &rel.child_type;
1087 let key = rel.belongs_to_key.as_ref().expect("belongs_to key");
1088 let references = rel.belongs_to_ref.as_ref().expect("belongs_to references");
1089 Some(quote!(
1090 impl #impl_generics ::dbkit::rel::BelongsToSpec<#parent_type> for #model_ident #struct_type_args {
1091 const CHILD_TABLE: ::dbkit::Table = Self::TABLE;
1092 const PARENT_TABLE: ::dbkit::Table = #parent_type::TABLE;
1093 const CHILD_KEY: ::dbkit::ColumnRef = Self::#key.as_ref();
1094 const PARENT_KEY: ::dbkit::ColumnRef = #parent_type::#references.as_ref();
1095 }
1096 ))
1097 });
1098
1099 let apply_load_impls = relation_fields.iter().flat_map(|rel| {
1100 let child_type = &rel.child_type;
1101 let rel_type = match rel.kind {
1102 RelationKind::HasMany => quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>),
1103 RelationKind::BelongsTo => quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
1104 RelationKind::ManyToMany => {
1105 let through = rel.many_to_many_through.as_ref().expect("many-to-many through");
1106 quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>)
1107 }
1108 };
1109
1110 let loaded_child = quote!(<Nested as ::dbkit::load::ApplyLoad<#child_type>>::Out2);
1111 let loaded_param = match rel.kind {
1112 RelationKind::HasMany | RelationKind::ManyToMany => quote!(Vec<#loaded_child>),
1113 RelationKind::BelongsTo => quote!(Option<#loaded_child>),
1114 };
1115
1116 let mut out_params = Vec::new();
1117 for other in &relation_fields {
1118 if other.field.ident == rel.field.ident {
1119 out_params.push(loaded_param.clone());
1120 } else {
1121 let ident = &other.param_ident;
1122 out_params.push(quote!(#ident));
1123 }
1124 }
1125
1126 let model_type = if generic_idents.is_empty() {
1127 quote!(#model_ident)
1128 } else {
1129 quote!(#model_ident<#(#generic_idents),*>)
1130 };
1131 let out_type = if out_params.is_empty() {
1132 quote!(#model_ident)
1133 } else {
1134 quote!(#model_ident<#(#out_params),*>)
1135 };
1136
1137 let mut apply_generics = Vec::new();
1138 apply_generics.push(quote!(Nested));
1139 apply_generics.extend(impl_generics_params.iter().cloned());
1140 let apply_generics = if apply_generics.is_empty() {
1141 quote!()
1142 } else {
1143 quote!(<#(#apply_generics),*>)
1144 };
1145
1146 let mut items = Vec::new();
1147 for strategy in ["SelectIn", "Joined"] {
1148 let load_ty = if strategy == "SelectIn" {
1149 quote!(::dbkit::load::SelectIn<#rel_type, Nested>)
1150 } else {
1151 quote!(::dbkit::load::Joined<#rel_type, Nested>)
1152 };
1153 items.push(quote!(
1154 impl #apply_generics ::dbkit::load::ApplyLoad<#model_type> for #load_ty
1155 where
1156 Nested: ::dbkit::load::ApplyLoad<#child_type>,
1157 {
1158 type Out2 = #out_type;
1159 }
1160 ));
1161 }
1162 items.into_iter()
1163 });
1164
1165 let run_load_impls = relation_fields.iter().flat_map(|rel| {
1166 let child_type = &rel.child_type;
1167 let through = rel.many_to_many_through.as_ref();
1168 let rel_type = match rel.kind {
1169 RelationKind::HasMany => quote!(::dbkit::rel::HasMany<#struct_ident, #child_type>),
1170 RelationKind::BelongsTo => quote!(::dbkit::rel::BelongsTo<#struct_ident, #child_type>),
1171 RelationKind::ManyToMany => {
1172 let through = through.expect("many-to-many through");
1173 quote!(::dbkit::rel::ManyToMany<#struct_ident, #child_type, #through>)
1174 }
1175 };
1176
1177 let loaded_child = quote!(<Nested as ::dbkit::load::ApplyLoad<#child_type>>::Out2);
1178 let loaded_param = match rel.kind {
1179 RelationKind::HasMany | RelationKind::ManyToMany => quote!(Vec<#loaded_child>),
1180 RelationKind::BelongsTo => quote!(Option<#loaded_child>),
1181 };
1182
1183 let mut out_params = Vec::new();
1184 for other in &relation_fields {
1185 if other.field.ident == rel.field.ident {
1186 out_params.push(loaded_param.clone());
1187 } else {
1188 let ident = &other.param_ident;
1189 out_params.push(quote!(#ident));
1190 }
1191 }
1192
1193 let out_type = if out_params.is_empty() {
1194 quote!(#model_ident)
1195 } else {
1196 quote!(#model_ident<#(#out_params),*>)
1197 };
1198
1199 let mut apply_generics = Vec::new();
1200 apply_generics.push(quote!(Nested));
1201 for other in &relation_fields {
1202 if other.field.ident == rel.field.ident {
1203 continue;
1204 }
1205 let ident = &other.param_ident;
1206 let state_mod = &other.state_mod_ident;
1207 apply_generics.push(quote!(#ident: #state_mod::State + Send + 'static));
1208 }
1209 let apply_generics = if apply_generics.is_empty() {
1210 quote!()
1211 } else {
1212 quote!(<#(#apply_generics),*>)
1213 };
1214
1215 let (child_bounds, loader_fn) = match rel.kind {
1216 RelationKind::HasMany => (
1217 quote!(#loaded_child: ::dbkit::ModelValue + for<'r> ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow> + Send + Unpin,),
1218 quote!(::dbkit::runtime::load_selectin_has_many),
1219 ),
1220 RelationKind::ManyToMany => {
1221 let through = through.expect("many-to-many through");
1222 (
1223 quote!(
1224 #loaded_child: ::dbkit::ModelValue + Clone + for<'r> ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow> + Send + Unpin,
1225 #through: ::dbkit::ModelValue + for<'r> ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow> + Send + Unpin,
1226 ),
1227 quote!(::dbkit::runtime::load_selectin_many_to_many),
1228 )
1229 }
1230 RelationKind::BelongsTo => (
1231 quote!(#loaded_child: ::dbkit::ModelValue + Clone + for<'r> ::dbkit::sqlx::FromRow<'r, ::dbkit::sqlx::postgres::PgRow> + Send + Unpin,),
1232 quote!(::dbkit::runtime::load_selectin_belongs_to),
1233 ),
1234 };
1235
1236 let joined_loader_fn = match rel.kind {
1237 RelationKind::HasMany => quote!(::dbkit::runtime::load_joined_has_many),
1238 RelationKind::ManyToMany => quote!(::dbkit::runtime::load_joined_many_to_many),
1239 RelationKind::BelongsTo => quote!(::dbkit::runtime::load_joined_belongs_to),
1240 };
1241
1242 let mut items = Vec::new();
1243 for (strategy, loader) in [
1244 ("SelectIn", loader_fn),
1245 ("Joined", joined_loader_fn),
1246 ] {
1247 let load_ty = if strategy == "SelectIn" {
1248 quote!(::dbkit::load::SelectIn<#rel_type, Nested>)
1249 } else {
1250 quote!(::dbkit::load::Joined<#rel_type, Nested>)
1251 };
1252 let out_bound = if strategy == "SelectIn" {
1253 quote!(::dbkit::ModelValue + ::dbkit::SetRelation<#rel_type, #loaded_param>)
1254 } else {
1255 quote!(::dbkit::GetRelation<#rel_type, #loaded_param>)
1256 };
1257
1258 items.push(quote!(
1259 impl #apply_generics ::dbkit::runtime::RunLoad<#out_type> for #load_ty
1260 where
1261 Nested: ::dbkit::load::ApplyLoad<#child_type> + ::dbkit::runtime::RunLoads<#loaded_child> + Sync,
1262 #out_type: #out_bound,
1263 #child_bounds
1264 {
1265 fn run<'e, E>(
1266 &'e self,
1267 ex: &'e E,
1268 rows: &'e mut [#out_type],
1269 ) -> ::dbkit::executor::BoxFuture<'e, Result<(), ::dbkit::Error>>
1270 where
1271 E: ::dbkit::Executor + Send + Sync + 'e,
1272 {
1273 #loader(ex, rows, self.rel.clone(), &self.nested)
1274 }
1275 }
1276 ));
1277 }
1278 items.into_iter()
1279 });
1280
1281 let output = quote! {
1282 #(#struct_attrs)*
1283 #[derive(Debug, Clone)]
1284 #vis struct #model_ident #struct_generics {
1285 #(#output_fields,)*
1286 }
1287
1288 #vis type #struct_ident = #model_ident;
1289
1290 #(#relation_state_modules)*
1291
1292 #vis trait #any_state_ident {}
1293 impl #impl_generics #any_state_ident for #model_ident #struct_type_args {}
1294
1295 impl #impl_generics #model_ident #struct_type_args {
1296 pub const TABLE: ::dbkit::Table = #table_expr;
1297 #(#columns)*
1298 #columns_const
1299 #primary_key_const
1300 #primary_keys_const
1301 #(#relation_consts)*
1302
1303 pub fn query() -> ::dbkit::Select<#struct_ident> {
1304 ::dbkit::Select::new(Self::TABLE)
1305 }
1306
1307 #by_id_fn
1308
1309 pub fn insert(values: #insert_ident) -> ::dbkit::Insert<#struct_ident> {
1310 let mut insert = ::dbkit::Insert::new(Self::TABLE);
1311 #(#insert_values)*
1312 insert
1313 }
1314
1315 pub fn insert_many(values: Vec<#insert_ident>) -> ::dbkit::Insert<#struct_ident> {
1316 let mut insert = ::dbkit::Insert::new(Self::TABLE);
1317 for value in values {
1318 insert = insert.row(|row| {
1319 let mut row = row;
1320 #(
1321 row = row.value(Self::#insert_field_idents, value.#insert_field_idents);
1322 )*
1323 row
1324 });
1325 }
1326 insert
1327 }
1328
1329 pub fn update() -> ::dbkit::Update<#struct_ident> {
1330 ::dbkit::Update::new(Self::TABLE)
1331 }
1332
1333 pub fn delete() -> ::dbkit::Delete {
1334 ::dbkit::Delete::new(Self::TABLE)
1335 }
1336
1337 pub fn new_active() -> #active_ident {
1338 #active_ident::new()
1339 }
1340
1341 #into_active_fn
1342 #load_method
1343 }
1344
1345 #[derive(Debug, Clone)]
1346 #vis struct #insert_ident {
1347 #(#insert_fields,)*
1348 }
1349
1350 #[derive(Debug, Clone, Default)]
1351 #vis struct #active_ident {
1352 #(#active_fields,)*
1353 }
1354
1355 impl #active_ident {
1356 pub fn new() -> Self {
1357 Self::default()
1358 }
1359
1360 #active_insert_fn
1361 #active_update_fn
1362 #active_delete_fn
1363 #active_save_fn
1364 }
1365
1366 #(#relation_methods)*
1367 #model_value_impl
1368 #from_row_impl
1369 #joined_model_impl
1370 #(#set_relation_impls)*
1371 #(#get_relation_impls)*
1372 #(#load_relation_impls)*
1373 #(#belongs_to_specs)*
1374 #(#apply_load_impls)*
1375 #(#run_load_impls)*
1376 #model_delete_impl
1377 };
1378
1379 Ok(output.into())
1380}
1381
1382fn parse_model_args(args: syn::punctuated::Punctuated<Meta, syn::Token![,]>) -> ModelArgs {
1383 let mut out = ModelArgs::default();
1384 for meta in args {
1385 if let Meta::NameValue(nv) = meta {
1386 if nv.path.is_ident("table") {
1387 if let Some(value) = extract_lit_str(&nv.value) {
1388 out.table = Some(value);
1389 }
1390 } else if nv.path.is_ident("schema") {
1391 if let Some(value) = extract_lit_str(&nv.value) {
1392 out.schema = Some(value);
1393 }
1394 }
1395 }
1396 }
1397 out
1398}
1399
1400fn parse_belongs_to_args(attrs: &[Attribute]) -> syn::Result<(Ident, Ident)> {
1401 for attr in attrs {
1402 if !attr.path().is_ident("belongs_to") {
1403 continue;
1404 }
1405 let args = attr.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)?;
1406 let mut key = None;
1407 let mut references = None;
1408 for meta in args {
1409 if let Meta::NameValue(nv) = meta {
1410 if nv.path.is_ident("key") {
1411 key = extract_ident(&nv.value);
1412 } else if nv.path.is_ident("references") {
1413 references = extract_ident(&nv.value);
1414 }
1415 }
1416 }
1417 if let (Some(key), Some(references)) = (key, references) {
1418 return Ok((key, references));
1419 }
1420 }
1421 Err(syn::Error::new(
1422 proc_macro2::Span::call_site(),
1423 "dbkit: #[belongs_to] requires key = <field> and references = <field>",
1424 ))
1425}
1426
1427fn parse_many_to_many_args(attrs: &[Attribute]) -> syn::Result<(Ident, Ident, Ident)> {
1428 for attr in attrs {
1429 if !attr.path().is_ident("many_to_many") {
1430 continue;
1431 }
1432 let args = attr.parse_args_with(syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated)?;
1433 let mut through = None;
1434 let mut left_key = None;
1435 let mut right_key = None;
1436 for meta in args {
1437 if let Meta::NameValue(nv) = meta {
1438 if nv.path.is_ident("through") {
1439 through = extract_ident(&nv.value);
1440 } else if nv.path.is_ident("left_key") {
1441 left_key = extract_ident(&nv.value);
1442 } else if nv.path.is_ident("right_key") {
1443 right_key = extract_ident(&nv.value);
1444 }
1445 }
1446 }
1447 if let (Some(through), Some(left_key), Some(right_key)) = (through, left_key, right_key) {
1448 return Ok((through, left_key, right_key));
1449 }
1450 }
1451 Err(syn::Error::new(
1452 proc_macro2::Span::call_site(),
1453 "dbkit: #[many_to_many] requires through = <Model>, left_key = <field>, right_key = <field>",
1454 ))
1455}
1456
1457fn extract_lit_str(expr: &syn::Expr) -> Option<String> {
1458 if let syn::Expr::Lit(syn::ExprLit {
1459 lit: syn::Lit::Str(lit), ..
1460 }) = expr
1461 {
1462 Some(lit.value())
1463 } else {
1464 None
1465 }
1466}
1467
1468fn extract_ident(expr: &syn::Expr) -> Option<Ident> {
1469 if let syn::Expr::Path(path) = expr {
1470 path.path.get_ident().cloned()
1471 } else {
1472 None
1473 }
1474}
1475
1476fn option_inner_type(ty: &Type) -> Option<Type> {
1477 let path = match ty {
1478 Type::Path(path) => path,
1479 _ => return None,
1480 };
1481 let segment = path.path.segments.last()?;
1482 if segment.ident != "Option" {
1483 return None;
1484 }
1485 let args = match &segment.arguments {
1486 syn::PathArguments::AngleBracketed(args) => args,
1487 _ => return None,
1488 };
1489 let inner = args.args.first()?;
1490 match inner {
1491 syn::GenericArgument::Type(inner_ty) => Some(inner_ty.clone()),
1492 _ => None,
1493 }
1494}
1495
1496fn has_attr(attrs: &[Attribute], name: &str) -> bool {
1497 attrs.iter().any(|attr| attr.path().is_ident(name))
1498}
1499
1500fn filter_struct_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
1501 let mut kept = Vec::new();
1502 for attr in attrs {
1503 if is_model_attr(attr) {
1504 continue;
1505 }
1506 if attr.path().is_ident("derive") {
1507 if let Ok(mut paths) = attr.parse_args_with(syn::punctuated::Punctuated::<syn::Path, syn::Token![,]>::parse_terminated) {
1508 paths = paths
1509 .into_iter()
1510 .filter(|path| !path.segments.last().map(|seg| seg.ident == "Model").unwrap_or(false))
1511 .collect();
1512 if paths.is_empty() {
1513 continue;
1514 }
1515 let new_attr = quote!(#[derive(#paths)]);
1516 let parsed = syn::Attribute::parse_outer.parse2(new_attr).expect("derive attr");
1517 kept.extend(parsed);
1518 continue;
1519 }
1520 }
1521 kept.push(attr.clone());
1522 }
1523 kept
1524}
1525
1526fn filter_field_attrs(attrs: &[Attribute]) -> Vec<Attribute> {
1527 attrs.iter().filter(|attr| !is_field_orm_attr(attr)).cloned().collect()
1528}
1529
1530fn is_field_orm_attr(attr: &Attribute) -> bool {
1531 let name = attr.path().get_ident().map(|ident| ident.to_string());
1532 matches!(
1533 name.as_deref(),
1534 Some("key") | Some("autoincrement") | Some("unique") | Some("index") | Some("has_many") | Some("belongs_to") | Some("many_to_many")
1535 )
1536}
1537
1538fn is_model_attr(attr: &Attribute) -> bool {
1539 attr.path().is_ident("model")
1540}
1541
1542fn relation_type(field: &Field) -> syn::Result<(RelationKind, Type)> {
1543 let kind = if has_attr(&field.attrs, "has_many") {
1544 RelationKind::HasMany
1545 } else if has_attr(&field.attrs, "belongs_to") {
1546 RelationKind::BelongsTo
1547 } else if has_attr(&field.attrs, "many_to_many") {
1548 RelationKind::ManyToMany
1549 } else {
1550 return Err(syn::Error::new_spanned(field, "dbkit: missing relation attribute"));
1551 };
1552
1553 let child_type = match &field.ty {
1554 Type::Path(path) => {
1555 let segment = path
1556 .path
1557 .segments
1558 .last()
1559 .ok_or_else(|| syn::Error::new_spanned(&field.ty, "dbkit: invalid type"))?;
1560 let expected = match kind {
1561 RelationKind::HasMany => "HasMany",
1562 RelationKind::BelongsTo => "BelongsTo",
1563 RelationKind::ManyToMany => "ManyToMany",
1564 };
1565 if segment.ident != expected {
1566 return Err(syn::Error::new_spanned(
1567 &segment.ident,
1568 format!("dbkit: expected {} marker type", expected),
1569 ));
1570 }
1571 match &segment.arguments {
1572 syn::PathArguments::AngleBracketed(args) => {
1573 let ty = args.args.iter().find_map(|arg| match arg {
1574 syn::GenericArgument::Type(ty) => Some(ty.clone()),
1575 _ => None,
1576 });
1577 ty.ok_or_else(|| syn::Error::new_spanned(&segment, "dbkit: missing type"))?
1578 }
1579 _ => return Err(syn::Error::new_spanned(&segment.arguments, "dbkit: expected generic argument")),
1580 }
1581 }
1582 _ => return Err(syn::Error::new_spanned(&field.ty, "dbkit: relation marker must be a type path")),
1583 };
1584
1585 Ok((kind, child_type))
1586}
1587
1588fn is_relation_field(field: &Field, rels: &[RelationInfo]) -> bool {
1589 rels.iter().any(|rel| rel.field.ident == field.ident)
1590}
1591
1592fn to_snake_case(name: &str) -> String {
1593 let chars: Vec<char> = name.chars().collect();
1594 let mut out = String::with_capacity(name.len() + (name.len() / 4));
1595
1596 for (idx, &ch) in chars.iter().enumerate() {
1597 let prev = idx.checked_sub(1).and_then(|i| chars.get(i)).copied();
1598 let next = chars.get(idx + 1).copied();
1599
1600 if ch.is_uppercase() {
1601 let prev_is_lower_or_digit = prev.map(|p| p.is_lowercase() || p.is_ascii_digit()).unwrap_or(false);
1602 let prev_is_upper = prev.map(|p| p.is_uppercase()).unwrap_or(false);
1603 let next_is_lower = next.map(|n| n.is_lowercase()).unwrap_or(false);
1604 let leading_upper_pair = idx == 1 && prev_is_upper && next_is_lower;
1605 let needs_separator = idx > 0 && (prev_is_lower_or_digit || (prev_is_upper && next_is_lower && !leading_upper_pair));
1606
1607 if needs_separator && !out.ends_with('_') {
1608 out.push('_');
1609 }
1610 for lower in ch.to_lowercase() {
1611 out.push(lower);
1612 }
1613 continue;
1614 }
1615
1616 out.push(ch);
1617 }
1618
1619 out
1620}
1621
1622fn to_camel_case(name: &str) -> String {
1623 let mut out = String::new();
1624 let mut uppercase_next = true;
1625 for ch in name.chars() {
1626 if ch == '_' {
1627 uppercase_next = true;
1628 continue;
1629 }
1630 if uppercase_next {
1631 for up in ch.to_uppercase() {
1632 out.push(up);
1633 }
1634 uppercase_next = false;
1635 } else {
1636 out.push(ch);
1637 }
1638 }
1639 out
1640}
1641
1642#[derive(Default)]
1647struct DbEnumArgs {
1648 type_name: Option<String>,
1649 rename_all: Option<String>,
1650}
1651
1652#[derive(Clone, Copy)]
1653enum DbEnumRenameAll {
1654 AsIs,
1655 SnakeCase,
1656 LowerCase,
1657 UpperCase,
1658 ScreamingSnakeCase,
1659}
1660
1661fn expand_db_enum(input: syn::ItemEnum) -> syn::Result<TokenStream> {
1662 if !input.generics.params.is_empty() {
1663 return Err(syn::Error::new_spanned(
1664 input.generics,
1665 "dbkit: #[derive(DbEnum)] does not support generics",
1666 ));
1667 }
1668
1669 let args = parse_db_enum_args(&input.attrs)?;
1670 let type_name = args
1671 .type_name
1672 .ok_or_else(|| syn::Error::new_spanned(&input.ident, "dbkit: DbEnum requires #[dbkit(type_name = \"...\")]"))?;
1673 let rename_rule = parse_db_enum_rename_all(args.rename_all.as_deref())?;
1674
1675 let enum_ident = input.ident.clone();
1676
1677 let mut as_db_arms = Vec::new();
1678 let mut from_db_arms = Vec::new();
1679 let mut expected_values = Vec::new();
1680 let mut seen_db_names: std::collections::BTreeMap<String, syn::Ident> = std::collections::BTreeMap::new();
1681
1682 for variant in input.variants.iter() {
1683 if !matches!(variant.fields, syn::Fields::Unit) {
1684 return Err(syn::Error::new_spanned(
1685 &variant.fields,
1686 "dbkit: DbEnum only supports unit variants",
1687 ));
1688 }
1689
1690 let variant_ident = &variant.ident;
1691 let explicit = parse_db_enum_variant_rename(&variant.attrs)?;
1692 let db_name = match explicit {
1693 Some(value) => value,
1694 None => apply_db_enum_rename_rule(&variant.ident.to_string(), rename_rule),
1695 };
1696 if let Some(first_variant) = seen_db_names.get(&db_name) {
1697 return Err(syn::Error::new_spanned(
1698 variant_ident,
1699 format!(
1700 "dbkit: duplicate DbEnum wire name `{}` for variants `{}` and `{}`",
1701 db_name, first_variant, variant_ident
1702 ),
1703 ));
1704 }
1705 seen_db_names.insert(db_name.clone(), variant_ident.clone());
1706 let db_name_lit = syn::LitStr::new(&db_name, variant.ident.span());
1707 expected_values.push(db_name);
1708
1709 as_db_arms.push(quote!(Self::#variant_ident => #db_name_lit,));
1710 from_db_arms.push(quote!(#db_name_lit => Ok(Self::#variant_ident),));
1711 }
1712
1713 if as_db_arms.is_empty() {
1714 return Err(syn::Error::new_spanned(enum_ident, "dbkit: DbEnum requires at least one variant"));
1715 }
1716
1717 let type_name_lit = syn::LitStr::new(&type_name, proc_macro2::Span::call_site());
1718 let expected_lit = syn::LitStr::new(&expected_values.join(", "), proc_macro2::Span::call_site());
1719
1720 let tokens = quote! {
1721 impl #enum_ident {
1722 pub const DB_TYPE_NAME: &'static str = #type_name_lit;
1723
1724 pub fn as_db_str(&self) -> &'static str {
1725 match self {
1726 #(#as_db_arms)*
1727 }
1728 }
1729 }
1730
1731 impl ::std::str::FromStr for #enum_ident {
1732 type Err = String;
1733
1734 fn from_str(value: &str) -> Result<Self, Self::Err> {
1735 match value {
1736 #(#from_db_arms)*
1737 _ => Err(format!(
1738 "dbkit: invalid value `{}` for enum {} (expected one of: {})",
1739 value,
1740 stringify!(#enum_ident),
1741 #expected_lit
1742 )),
1743 }
1744 }
1745 }
1746
1747 impl From<#enum_ident> for ::dbkit::Value {
1748 fn from(value: #enum_ident) -> Self {
1749 ::dbkit::Value::Enum {
1750 type_name: #type_name_lit,
1751 value: value.as_db_str().to_string(),
1752 }
1753 }
1754 }
1755
1756 impl ::dbkit::sqlx::Type<::dbkit::sqlx::Postgres> for #enum_ident {
1757 fn type_info() -> ::dbkit::sqlx::postgres::PgTypeInfo {
1758 ::dbkit::sqlx::postgres::PgTypeInfo::with_name(#type_name_lit)
1759 }
1760
1761 fn compatible(ty: &::dbkit::sqlx::postgres::PgTypeInfo) -> bool {
1762 *ty == ::dbkit::sqlx::postgres::PgTypeInfo::with_name(#type_name_lit)
1763 || <&str as ::dbkit::sqlx::Type<::dbkit::sqlx::Postgres>>::compatible(ty)
1764 }
1765 }
1766
1767 impl<'q> ::dbkit::sqlx::Encode<'q, ::dbkit::sqlx::Postgres> for #enum_ident {
1768 fn encode_by_ref(
1769 &self,
1770 buf: &mut ::dbkit::sqlx::postgres::PgArgumentBuffer,
1771 ) -> Result<::dbkit::sqlx::encode::IsNull, ::dbkit::sqlx::error::BoxDynError> {
1772 <&str as ::dbkit::sqlx::Encode<'q, ::dbkit::sqlx::Postgres>>::encode(self.as_db_str(), buf)
1773 }
1774
1775 fn produces(&self) -> Option<::dbkit::sqlx::postgres::PgTypeInfo> {
1776 Some(::dbkit::sqlx::postgres::PgTypeInfo::with_name(#type_name_lit))
1777 }
1778
1779 fn size_hint(&self) -> usize {
1780 self.as_db_str().len()
1781 }
1782 }
1783
1784 impl<'r> ::dbkit::sqlx::Decode<'r, ::dbkit::sqlx::Postgres> for #enum_ident {
1785 fn decode(value: ::dbkit::sqlx::postgres::PgValueRef<'r>) -> Result<Self, ::dbkit::sqlx::error::BoxDynError> {
1786 let value = <&str as ::dbkit::sqlx::Decode<'r, ::dbkit::sqlx::Postgres>>::decode(value)?;
1787 <Self as ::std::str::FromStr>::from_str(value).map_err(|err| err.into())
1788 }
1789 }
1790 };
1791
1792 Ok(TokenStream::from(tokens))
1793}
1794
1795fn parse_db_enum_args(attrs: &[Attribute]) -> syn::Result<DbEnumArgs> {
1796 let mut args = DbEnumArgs::default();
1797
1798 for attr in attrs {
1799 if !attr.path().is_ident("dbkit") {
1800 continue;
1801 }
1802 attr.parse_nested_meta(|meta| {
1803 if meta.path.is_ident("type_name") {
1804 let lit: syn::LitStr = meta.value()?.parse()?;
1805 args.type_name = Some(lit.value());
1806 return Ok(());
1807 }
1808 if meta.path.is_ident("rename_all") {
1809 let lit: syn::LitStr = meta.value()?.parse()?;
1810 args.rename_all = Some(lit.value());
1811 return Ok(());
1812 }
1813 Err(meta.error("dbkit: unsupported DbEnum option; expected `type_name` or `rename_all`"))
1814 })?;
1815 }
1816
1817 Ok(args)
1818}
1819
1820fn parse_db_enum_variant_rename(attrs: &[Attribute]) -> syn::Result<Option<String>> {
1821 let mut rename = None;
1822
1823 for attr in attrs {
1824 if !attr.path().is_ident("dbkit") {
1825 continue;
1826 }
1827 attr.parse_nested_meta(|meta| {
1828 if meta.path.is_ident("rename") {
1829 let lit: syn::LitStr = meta.value()?.parse()?;
1830 rename = Some(lit.value());
1831 return Ok(());
1832 }
1833 Err(meta.error("dbkit: unsupported DbEnum variant option; expected `rename`"))
1834 })?;
1835 }
1836
1837 Ok(rename)
1838}
1839
1840fn parse_db_enum_rename_all(value: Option<&str>) -> syn::Result<DbEnumRenameAll> {
1841 match value {
1842 None => Ok(DbEnumRenameAll::AsIs),
1843 Some("snake_case") => Ok(DbEnumRenameAll::SnakeCase),
1844 Some("lowercase") => Ok(DbEnumRenameAll::LowerCase),
1845 Some("UPPERCASE") => Ok(DbEnumRenameAll::UpperCase),
1846 Some("SCREAMING_SNAKE_CASE") => Ok(DbEnumRenameAll::ScreamingSnakeCase),
1847 Some(other) => Err(syn::Error::new(
1848 proc_macro2::Span::call_site(),
1849 format!(
1850 "dbkit: unsupported rename_all strategy `{}` for DbEnum; supported values: snake_case, lowercase, UPPERCASE, SCREAMING_SNAKE_CASE",
1851 other
1852 ),
1853 )),
1854 }
1855}
1856
1857fn apply_db_enum_rename_rule(value: &str, rule: DbEnumRenameAll) -> String {
1858 match rule {
1859 DbEnumRenameAll::AsIs => value.to_string(),
1860 DbEnumRenameAll::SnakeCase => to_snake_case(value),
1861 DbEnumRenameAll::LowerCase => value.to_lowercase(),
1862 DbEnumRenameAll::UpperCase => value.to_uppercase(),
1863 DbEnumRenameAll::ScreamingSnakeCase => to_snake_case(value).to_uppercase(),
1864 }
1865}
1866
1867#[cfg(test)]
1868mod tests {
1869 use super::{apply_db_enum_rename_rule, to_snake_case, DbEnumRenameAll};
1870
1871 #[test]
1872 fn snake_case_respects_acronym_word_boundaries() {
1873 assert_eq!(to_snake_case("HTTPWebhook"), "http_webhook");
1874 assert_eq!(to_snake_case("OAuthToken"), "oauth_token");
1875 assert_eq!(to_snake_case("XMLHttpRequest"), "xml_http_request");
1876 assert_eq!(to_snake_case("WebhookHTTP"), "webhook_http");
1877 }
1878
1879 #[test]
1880 fn screaming_snake_case_respects_acronym_word_boundaries() {
1881 assert_eq!(
1882 apply_db_enum_rename_rule("HTTPWebhook", DbEnumRenameAll::ScreamingSnakeCase),
1883 "HTTP_WEBHOOK"
1884 );
1885 assert_eq!(
1886 apply_db_enum_rename_rule("XMLHttpRequest", DbEnumRenameAll::ScreamingSnakeCase),
1887 "XML_HTTP_REQUEST"
1888 );
1889 }
1890}