byteorm_lib/rustgen/
model.rs1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use crate::{Model, Modifier};
4use crate::rustgen::{generate_query_builder_struct, generate_update_builder, generate_upsert_builder, rust_type_from_schema, to_snake_case};
5
6pub fn generate_model_with_query_builder(model: &Model) -> TokenStream {
7 let model_struct = generate_model_struct(model);
8 let query_builder_struct = generate_query_builder_struct(model);
9 let update_builder = generate_update_builder(model);
10 let upsert_builder = generate_upsert_builder(model);
11 let model_impl = generate_model_impl(model);
12
13 quote! {
14 #model_struct
15 #query_builder_struct
16 #update_builder
17 #upsert_builder
18 #model_impl
19 }
20}
21
22pub fn generate_model_struct(model: &Model) -> TokenStream {
23 let name = format_ident!("{}", model.name);
24 let fields = model.fields.iter().map(|field| {
25 let field_name = format_ident!("{}", field.name);
26 let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
27 let field_type = rust_type_from_schema(&field.type_name, is_nullable);
28
29 quote! {
30 pub #field_name: #field_type
31 }
32 });
33
34 quote! {
35 #[derive(Debug, Clone, Serialize, Deserialize)]
36 pub struct #name {
37 #(#fields),*
38 }
39 }
40}
41
42fn generate_model_impl(model: &Model) -> TokenStream {
43 let model_name = format_ident!("{}", model.name);
44 let builder_name = format_ident!("{}Query", model_name);
45
46 let pk_fields: Vec<_> = model.fields.iter()
47 .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
48 .collect();
49
50 let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
51 let field_name = format_ident!("{}", field.name);
52 quote! { #field_name: row.get(#idx) }
53 });
54
55 let find_by_id_impl = if !pk_fields.is_empty() {
56 if pk_fields.len() == 1 {
57 let pk = &pk_fields[0];
58 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
59 let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
60 let pk_name = to_snake_case(&pk.name);
61
62 quote! {
63 pub async fn find_by_id(client: &PgClient, id: #pk_type)
64 -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
65 {
66 let sql = format!("SELECT * FROM {} WHERE {} = $1",
67 stringify!(#model_name).to_lowercase(), #pk_name);
68 let row_opt = client.query_opt(&sql, &[&id]).await?;
69 Ok(row_opt.map(|row| #model_name {
70 #(#field_gets),*
71 }))
72 }
73 }
74 } else {
75 let pk_params = pk_fields.iter().map(|pk| {
76 let param_name = format_ident!("{}", to_snake_case(&pk.name));
77 let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
78 let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
79 quote! { #param_name: #pk_type }
80 });
81
82 let pk_conditions = pk_fields.iter().enumerate().map(|(i, pk)| {
83 let pk_col = to_snake_case(&pk.name);
84 let param_num = i + 1;
85 format!("{} = ${}", pk_col, param_num)
86 });
87 let where_clause = pk_conditions.collect::<Vec<_>>().join(" AND ");
88
89 let pk_args = pk_fields.iter().map(|pk| {
90 let param_name = format_ident!("{}", to_snake_case(&pk.name));
91 quote! { &#param_name }
92 });
93
94 quote! {
95 pub async fn find_by_composite_pk(client: &PgClient, #(#pk_params),*)
96 -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
97 {
98 let sql = format!("SELECT * FROM {} WHERE {}",
99 stringify!(#model_name).to_lowercase(), #where_clause);
100 let row_opt = client.query_opt(&sql, &[#(#pk_args),*]).await?;
101 Ok(row_opt.map(|row| #model_name {
102 #(#field_gets),*
103 }))
104 }
105 }
106 }
107 } else {
108 quote! {}
109 };
110
111 quote! {
112 impl #model_name {
113 pub fn query() -> #builder_name {
114 #builder_name::new()
115 }
116 #find_by_id_impl
117 }
118 }
119}