byteorm_lib/rustgen/
model.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use crate::{Model, Modifier};
4use crate::rustgen::{generate_query_builder_struct, generate_update_builder, generate_upsert_builder, rust_type_from_schema, to_snake_case};
5
6pub fn generate_model_with_query_builder(model: &Model) -> TokenStream {
7    let model_struct = generate_model_struct(model);
8    let query_builder_struct = generate_query_builder_struct(model);
9    let update_builder = generate_update_builder(model);
10    let upsert_builder = generate_upsert_builder(model);
11    let model_impl = generate_model_impl(model);
12
13    quote! {
14        #model_struct
15        #query_builder_struct
16        #update_builder
17        #upsert_builder
18        #model_impl
19    }
20}
21
22pub fn generate_model_struct(model: &Model) -> TokenStream {
23    let name = format_ident!("{}", model.name);
24    let fields = model.fields.iter().map(|field| {
25        let field_name = format_ident!("{}", field.name);
26        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
27        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
28
29        quote! {
30            pub #field_name: #field_type
31        }
32    });
33
34    quote! {
35        #[derive(Debug, Clone, Serialize, Deserialize)]
36        pub struct #name {
37            #(#fields),*
38        }
39    }
40}
41
42fn generate_model_impl(model: &Model) -> TokenStream {
43    let model_name = format_ident!("{}", model.name);
44    let builder_name = format_ident!("{}Query", model_name);
45
46    let pk_fields: Vec<_> = model.fields.iter()
47        .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
48        .collect();
49
50    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
51        let field_name = format_ident!("{}", field.name);
52        quote! { #field_name: row.get(#idx) }
53    });
54
55    let find_by_id_impl = if !pk_fields.is_empty() {
56        if pk_fields.len() == 1 {
57            let pk = &pk_fields[0];
58            let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
59            let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
60            let pk_name = to_snake_case(&pk.name);
61
62            quote! {
63                pub async fn find_by_id(client: &PgClient, id: #pk_type)
64                    -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
65                {
66                    let sql = format!("SELECT * FROM {} WHERE {} = $1",
67                        stringify!(#model_name).to_lowercase(), #pk_name);
68                    let row_opt = client.query_opt(&sql, &[&id]).await?;
69                    Ok(row_opt.map(|row| #model_name {
70                        #(#field_gets),*
71                    }))
72                }
73            }
74        } else {
75            let pk_params = pk_fields.iter().map(|pk| {
76                let param_name = format_ident!("{}", to_snake_case(&pk.name));
77                let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
78                let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
79                quote! { #param_name: #pk_type }
80            });
81
82            let pk_conditions = pk_fields.iter().enumerate().map(|(i, pk)| {
83                let pk_col = to_snake_case(&pk.name);
84                let param_num = i + 1;
85                format!("{} = ${}", pk_col, param_num)
86            });
87            let where_clause = pk_conditions.collect::<Vec<_>>().join(" AND ");
88
89            let pk_args = pk_fields.iter().map(|pk| {
90                let param_name = format_ident!("{}", to_snake_case(&pk.name));
91                quote! { &#param_name }
92            });
93
94            quote! {
95                pub async fn find_by_composite_pk(client: &PgClient, #(#pk_params),*)
96                    -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
97                {
98                    let sql = format!("SELECT * FROM {} WHERE {}",
99                        stringify!(#model_name).to_lowercase(), #where_clause);
100                    let row_opt = client.query_opt(&sql, &[#(#pk_args),*]).await?;
101                    Ok(row_opt.map(|row| #model_name {
102                        #(#field_gets),*
103                    }))
104                }
105            }
106        }
107    } else {
108        quote! {}
109    };
110
111    quote! {
112        impl #model_name {
113            pub fn query() -> #builder_name {
114                #builder_name::new()
115            }
116            #find_by_id_impl
117        }
118    }
119}