byteorm_lib/
rustgen.rs

1use quote::{quote, format_ident};
2use proc_macro2::TokenStream;
3use crate::{Schema, Model, Field, Modifier};
4
5pub fn generate_rust_code(schema: &Schema) -> String {
6    let structs_and_impls = schema.models.iter().map(|model| {
7        generate_model_with_query_builder(model)
8    });
9
10    let code = quote! {
11        use serde::{Deserialize, Serialize};
12        use chrono::{DateTime, Utc};
13        use tokio_postgres::Client;
14
15        fn calculate_json_diff(before: &serde_json::Value, after: &serde_json::Value) -> serde_json::Value {
16            let mut diff = serde_json::Map::new();
17            if let (Some(before_obj), Some(after_obj)) = (before.as_object(), after.as_object()) {
18                for (key, after_val) in after_obj {
19                    if let Some(before_val) = before_obj.get(key) {
20                        if before_val != after_val {
21                            diff.insert(
22                                key.clone(),
23                                serde_json::json!({ "from": before_val, "to": after_val })
24                            );
25                        }
26                    } else {
27                        diff.insert(key.clone(), serde_json::json!({ "added": after_val }));
28                    }
29                }
30                for (key, before_val) in before_obj {
31                    if !after_obj.contains_key(key) {
32                        diff.insert(key.clone(), serde_json::json!({ "removed": before_val }));
33                    }
34                }
35            }
36            serde_json::Value::Object(diff)
37        }
38
39        #(#structs_and_impls)*
40    };
41
42    let file: syn::File = syn::parse2(code).unwrap();
43    prettyplease::unparse(&file)
44}
45
46fn generate_model_with_query_builder(model: &Model) -> TokenStream {
47    let model_struct = generate_model_struct(model);
48    let query_builder_struct = generate_query_builder_struct(model);
49    let query_builder_impl = generate_query_builder_impl(model);
50    let model_impl = generate_model_impl(model);
51
52    quote! {
53        #model_struct
54        #query_builder_struct
55        #model_impl
56        #query_builder_impl
57    }
58}
59
60fn generate_model_struct(model: &Model) -> TokenStream {
61    let name = format_ident!("{}", model.name);
62    let fields = model.fields.iter().map(|field| {
63        let field_name = format_ident!("{}", field.name);
64        let field_type = rust_type_from_schema(&field.type_name);
65
66        quote! {
67            pub #field_name: #field_type
68        }
69    });
70
71    quote! {
72        #[derive(Debug, Clone, Serialize, Deserialize)]
73        pub struct #name {
74            #(#fields),*
75        }
76    }
77}
78
79fn generate_model_impl(model: &Model) -> TokenStream {
80    let model_name = format_ident!("{}", model.name);
81    let builder_name = format_ident!("{}Query", model.name);
82
83    let pk_field = model.fields.iter()
84        .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)));
85
86    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
87        let field_name = format_ident!("{}", field.name);
88        quote! { #field_name: row.get(#idx) }
89    });
90
91    let find_by_id_impl = if let Some(pk) = pk_field {
92        let pk_type = rust_type_from_schema(&pk.type_name);
93        let pk_name = to_snake_case(&pk.name);
94
95        quote! {
96            pub async fn find_by_id(client: &Client, id: #pk_type)
97                -> Result<Option<#model_name>, Box<dyn std::error::Error>>
98            {
99                let sql = format!("SELECT * FROM {} WHERE {} = $1", stringify!(#model_name).to_lowercase(), #pk_name);
100                let row_opt = client.query_opt(&sql, &[&id]).await?;
101                Ok(row_opt.map(|row| #model_name {
102                    #(#field_gets),*
103                }))
104            }
105        }
106    } else {
107        quote! {}
108    };
109
110    quote! {
111        impl #model_name {
112            pub fn query() -> #builder_name {
113                #builder_name::new()
114            }
115            #find_by_id_impl
116        }
117    }
118}
119
120fn generate_query_builder_struct(model: &Model) -> TokenStream {
121    let builder_name = format_ident!("{}Query", model.name);
122
123    quote! {
124        pub struct #builder_name {
125            table: String,
126            where_fragments: Vec<(&'static str, usize)>,
127            args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
128            limit: Option<usize>,
129            offset: Option<usize>,
130            order_by: Vec<(String, String)>,
131        }
132        impl Clone for #builder_name {
133            fn clone(&self) -> Self {
134                Self {
135                    table: self.table.clone(),
136                    where_fragments: self.where_fragments.clone(),
137                    args: Vec::new(),
138                    limit: self.limit,
139                    offset: self.offset,
140                    order_by: self.order_by.clone(),
141                }
142            }
143        }
144    }
145}
146
147fn generate_query_builder_impl(model: &Model) -> TokenStream {
148    let builder_name = format_ident!("{}Query", model.name);
149    let model_name = format_ident!("{}", model.name);
150    let table_name = model.name.to_lowercase();
151
152    let field_methods = model.fields.iter().enumerate().map(|(i, field)| {
153        let method_name = format_ident!("where_{}", to_snake_case(&field.name));
154        let field_type = rust_type_from_schema(&field.type_name);
155        let field_col = to_snake_case(&field.name);
156
157        quote! {
158            pub fn #method_name(mut self, value: #field_type) -> Self {
159                self.args.push(Box::new(value));
160                self.where_fragments.push((#field_col, self.args.len()));
161                self
162            }
163        }
164    });
165
166    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
167        let field_name = format_ident!("{}", field.name);
168        quote! { #field_name: row.get(#idx) }
169    });
170
171    quote! {
172        impl #builder_name {
173            pub fn new() -> Self {
174                Self {
175                    table: #table_name.to_string(),
176                    where_fragments: vec![],
177                    args: vec![],
178                    limit: None,
179                    offset: None,
180                    order_by: vec![],
181                }
182            }
183            #(#field_methods)*
184            pub fn limit(mut self, limit: usize) -> Self {
185                self.limit = Some(limit);
186                self
187            }
188            pub fn offset(mut self, offset: usize) -> Self {
189                self.offset = Some(offset);
190                self
191            }
192            pub fn order_by(mut self, column: &str, direction: &str) -> Self {
193                self.order_by.push((column.to_string(), direction.to_string()));
194                self
195            }
196
197            pub async fn select(&self, client: &Client)
198                -> Result<Vec<#model_name>, Box<dyn std::error::Error>>
199            {
200                let (sql, params) = self.build_select();
201                let rows = client.query(&sql, &params[..]).await?;
202                let mut results = Vec::new();
203                for row in rows {
204                    results.push(#model_name { #(#field_gets),* });
205                }
206                Ok(results)
207            }
208
209            pub async fn first(&self, client: &Client)
210                -> Result<Option<#model_name>, Box<dyn std::error::Error>>
211            {
212                let mut query = #builder_name::new();
213                query.table = self.table.clone();
214                query.where_fragments = self.where_fragments.clone();
215                query.args = Vec::new();
216                query.limit = Some(1);
217                query.offset = self.offset;
218                query.order_by = self.order_by.clone();
219
220                let results = query.select(client).await?;
221                Ok(results.into_iter().next())
222            }
223
224            pub async fn count(&self, client: &Client)
225                -> Result<i64, Box<dyn std::error::Error>>
226            {
227                let (sql, params) = self.build_count();
228                let row = client.query_one(&sql, &params[..]).await?;
229                Ok(row.get(0))
230            }
231
232            fn build_select(&self) -> (String, Vec<&(dyn tokio_postgres::types::ToSql + Sync)>) {
233                let mut sql = format!("SELECT * FROM {}", self.table);
234                let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
235
236                if !self.where_fragments.is_empty() {
237                    let conds: Vec<String> = self.where_fragments.iter()
238                        .enumerate()
239                        .map(|(i, &(col, idx))| format!("{} = ${}", col, i + 1))
240                        .collect();
241                    sql.push_str(" WHERE ");
242                    sql.push_str(&conds.join(" AND "));
243                    for arg in &self.args {
244                        params.push(arg.as_ref());
245                    }
246                }
247                if !self.order_by.is_empty() {
248                    sql.push_str(" ORDER BY ");
249                    let order_clauses: Vec<String> = self.order_by.iter()
250                        .map(|(col, dir)| format!("{} {}", col, dir))
251                        .collect();
252                    sql.push_str(&order_clauses.join(", "));
253                }
254                if let Some(limit) = self.limit {
255                    sql.push_str(&format!(" LIMIT {}", limit));
256                }
257                if let Some(offset) = self.offset {
258                    sql.push_str(&format!(" OFFSET {}", offset));
259                }
260                (sql, params)
261            }
262            fn build_count(&self) -> (String, Vec<&(dyn tokio_postgres::types::ToSql + Sync)>) {
263                let mut sql = format!("SELECT COUNT(*) FROM {}", self.table);
264                let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
265                if !self.where_fragments.is_empty() {
266                    let conds: Vec<String> = self.where_fragments.iter()
267                        .enumerate()
268                        .map(|(i, &(col, idx))| format!("{} = ${}", col, i + 1))
269                        .collect();
270                    sql.push_str(" WHERE ");
271                    sql.push_str(&conds.join(" AND "));
272                    for arg in &self.args {
273                        params.push(arg.as_ref());
274                    }
275                }
276                (sql, params)
277            }
278        }
279    }
280}
281
282fn rust_type_from_schema(type_name: &str) -> TokenStream {
283    match type_name {
284        "BigInt" => quote! { i64 },
285        "Int" => quote! { i32 },
286        "String" => quote! { String },
287        "JsonB" => quote! { serde_json::Value },
288        "TimestamptZ" => quote! { DateTime<Utc> },
289        "Boolean" => quote! { bool },
290        "Float" => quote! { f64 },
291        _ => quote! { String },
292    }
293}
294
295fn to_snake_case(s: &str) -> String {
296    let mut result = String::new();
297    for (i, ch) in s.chars().enumerate() {
298        if ch.is_uppercase() && i > 0 {
299            result.push('_');
300        }
301        result.push(ch.to_lowercase().next().unwrap_or(ch));
302    }
303    result
304}