byteorm_lib/
rustgen.rs

1use quote::{quote, format_ident};
2use proc_macro2::TokenStream;
3use crate::{Schema, Model, Field, Modifier};
4
5pub fn generate_rust_code(schema: &Schema) -> String {
6    let structs_and_impls = schema.models.iter().map(|model| {
7        generate_model_with_query_builder(model)
8    });
9
10    let code = quote! {
11        use serde::{Deserialize, Serialize};
12        use chrono::{DateTime, Utc};
13        use tokio_postgres::Client;
14
15        fn calculate_json_diff(before: &serde_json::Value, after: &serde_json::Value) -> serde_json::Value {
16            let mut diff = serde_json::Map::new();
17
18            if let (Some(before_obj), Some(after_obj)) = (before.as_object(), after.as_object()) {
19                for (key, after_val) in after_obj {
20                    if let Some(before_val) = before_obj.get(key) {
21                        if before_val != after_val {
22                            diff.insert(
23                                key.clone(),
24                                serde_json::json!({
25                                    "from": before_val,
26                                    "to": after_val
27                                })
28                            );
29                        }
30                    } else {
31                        diff.insert(key.clone(), serde_json::json!({ "added": after_val }));
32                    }
33                }
34
35                for (key, before_val) in before_obj {
36                    if !after_obj.contains_key(key) {
37                        diff.insert(key.clone(), serde_json::json!({ "removed": before_val }));
38                    }
39                }
40            }
41
42            serde_json::Value::Object(diff)
43        }
44
45        #(#structs_and_impls)*
46    };
47
48    let file: syn::File = syn::parse2(code).unwrap();
49    prettyplease::unparse(&file)
50}
51
52fn generate_model_with_query_builder(model: &Model) -> TokenStream {
53    let model_struct = generate_model_struct(model);
54    let query_builder_struct = generate_query_builder_struct(model);
55    let query_builder_impl = generate_query_builder_impl(model);
56    let model_impl = generate_model_impl(model);
57
58    quote! {
59        #model_struct
60
61        #query_builder_struct
62
63        #model_impl
64
65        #query_builder_impl
66    }
67}
68
69fn generate_model_struct(model: &Model) -> TokenStream {
70    let name = format_ident!("{}", model.name);
71    let fields = model.fields.iter().map(|field| {
72        let field_name = format_ident!("{}", field.name);
73        let field_type = rust_type_from_schema(&field.type_name);
74
75        quote! {
76            pub #field_name: #field_type
77        }
78    });
79
80    quote! {
81        #[derive(Debug, Clone, Serialize, Deserialize)]
82        pub struct #name {
83            #(#fields),*
84        }
85    }
86}
87
88fn generate_model_impl(model: &Model) -> TokenStream {
89    let model_name = format_ident!("{}", model.name);
90    let builder_name = format_ident!("{}Query", model.name);
91
92    let pk_field = model.fields.iter()
93        .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)));
94
95    let find_by_id_impl = if let Some(pk) = pk_field {
96        let pk_name = format_ident!("{}", to_snake_case(&pk.name));
97        let pk_type = rust_type_from_schema(&pk.type_name);
98
99        quote! {
100            pub fn find_by_id(id: #pk_type) -> String {
101                format!("SELECT * FROM {} WHERE {} = {}",
102                    stringify!(#pk_name).replace("_", ""),
103                    stringify!(#pk_name),
104                    id
105                )
106            }
107        }
108    } else {
109        quote! {}
110    };
111
112    quote! {
113        impl #model_name {
114            pub fn query() -> #builder_name {
115                #builder_name::new()
116            }
117
118            #find_by_id_impl
119        }
120    }
121}
122
123fn generate_query_builder_struct(model: &Model) -> TokenStream {
124    let builder_name = format_ident!("{}Query", model.name);
125
126    quote! {
127        pub struct #builder_name {
128            table: String,
129            conditions: Vec<String>,
130            limit: Option<usize>,
131            offset: Option<usize>,
132            order_by: Vec<(String, String)>,
133        }
134
135        impl Clone for #builder_name {
136            fn clone(&self) -> Self {
137                Self {
138                    table: self.table.clone(),
139                    conditions: self.conditions.clone(),
140                    limit: self.limit,
141                    offset: self.offset,
142                    order_by: self.order_by.clone(),
143                }
144            }
145        }
146    }
147}
148
149fn generate_query_builder_impl(model: &Model) -> TokenStream {
150    let builder_name = format_ident!("{}Query", model.name);
151    let model_name = format_ident!("{}", model.name);
152    let table_name = model.name.to_lowercase();
153
154    let where_methods = model.fields.iter().map(|field| {
155        let method_name = format_ident!("where_{}", to_snake_case(&field.name));
156        let field_name = to_snake_case(&field.name);
157        let field_type = rust_type_from_schema(&field.type_name);
158
159        quote! {
160            pub fn #method_name(mut self, value: #field_type) -> Self {
161                self.conditions.push(format!("{} = {:?}", #field_name, value));
162                self
163            }
164        }
165    });
166
167    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
168        let field_name = format_ident!("{}", field.name);
169
170        quote! {
171            #field_name: row.get(#idx)
172        }
173    });
174
175    let has_audit = model.fields.iter().any(|f| f.get_audit_model().is_some());
176
177    let update_method = if has_audit {
178        generate_update_with_audit(model)
179    } else {
180        quote! {}
181    };
182
183    quote! {
184        impl #builder_name {
185            pub fn new() -> Self {
186                Self {
187                    table: #table_name.to_string(),
188                    conditions: vec![],
189                    limit: None,
190                    offset: None,
191                    order_by: vec![],
192                }
193            }
194
195            #(#where_methods)*
196
197            pub fn limit(mut self, limit: usize) -> Self {
198                self.limit = Some(limit);
199                self
200            }
201
202            pub fn offset(mut self, offset: usize) -> Self {
203                self.offset = Some(offset);
204                self
205            }
206
207            pub fn order_by(mut self, column: &str, direction: &str) -> Self {
208                self.order_by.push((column.to_string(), direction.to_string()));
209                self
210            }
211
212            pub async fn select(&self, client: &Client)
213                -> Result<Vec<#model_name>, Box<dyn std::error::Error>>
214            {
215                let sql = self.build_select();
216
217                let rows = client.query(&sql, &[]).await?;
218                let mut results = Vec::new();
219
220                for row in rows {
221                    results.push(#model_name {
222                        #(#field_gets),*
223                    });
224                }
225
226                Ok(results)
227            }
228
229            pub async fn first(&self, client: &Client)
230                -> Result<Option<#model_name>, Box<dyn std::error::Error>>
231            {
232                let mut query = self.clone();
233                query.limit = Some(1);
234                let results = query.select(client).await?;
235                Ok(results.into_iter().next())
236            }
237
238            pub async fn count(&self, client: &Client)
239                -> Result<i64, Box<dyn std::error::Error>>
240            {
241                let sql = format!("SELECT COUNT(*) FROM {}{}",
242                    self.table,
243                    if self.conditions.is_empty() {
244                        String::new()
245                    } else {
246                        format!(" WHERE {}", self.conditions.join(" AND "))
247                    }
248                );
249
250                let row = client.query_one(&sql, &[]).await?;
251                Ok(row.get(0))
252            }
253
254            #update_method
255
256            fn build_select(&self) -> String {
257                let mut sql = format!("SELECT * FROM {}", self.table);
258
259                if !self.conditions.is_empty() {
260                    sql.push_str(" WHERE ");
261                    sql.push_str(&self.conditions.join(" AND "));
262                }
263
264                if !self.order_by.is_empty() {
265                    sql.push_str(" ORDER BY ");
266                    let order_clauses: Vec<String> = self.order_by.iter()
267                        .map(|(col, dir)| format!("{} {}", col, dir))
268                        .collect();
269                    sql.push_str(&order_clauses.join(", "));
270                }
271
272                if let Some(limit) = self.limit {
273                    sql.push_str(&format!(" LIMIT {}", limit));
274                }
275
276                if let Some(offset) = self.offset {
277                    sql.push_str(&format!(" OFFSET {}", offset));
278                }
279
280                sql
281            }
282        }
283    }
284}
285
286fn generate_update_with_audit(model: &Model) -> TokenStream {
287    let table_name = model.name.to_lowercase();
288
289    let audit_field = model.fields.iter()
290        .find(|f| f.get_audit_model().is_some())
291        .expect("Called generate_update_with_audit without audit field");
292
293    let audit_model_name = audit_field.get_audit_model().unwrap();
294    let audit_table = audit_model_name.to_lowercase();
295    let audited_field_name = &audit_field.name;
296
297    let pk_field = model.fields.iter()
298        .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
299        .expect("Model must have primary key for audit");
300
301    let pk_column = to_snake_case(&pk_field.name);
302
303    quote! {
304        pub async fn update(
305            &self,
306            client: &Client,
307            new_value: serde_json::Value,
308            who: i64,
309        ) -> Result<(), Box<dyn std::error::Error>> {
310            let transaction = client.transaction().await?;
311
312            if self.conditions.is_empty() {
313                return Err("No WHERE condition specified for update".into());
314            }
315
316            let pk_value: i64 = self.conditions[0]
317                .split('=')
318                .nth(1)
319                .and_then(|s| s.trim().parse().ok())
320                .ok_or("Failed to parse primary key value")?;
321
322            let before_sql = format!(
323                "SELECT {} FROM {} WHERE {} = $1",
324                #audited_field_name,
325                #table_name,
326                #pk_column
327            );
328
329            let before_row = transaction.query_one(&before_sql, &[&pk_value]).await?;
330            let before: serde_json::Value = before_row.get(0);
331
332            let update_sql = format!(
333                "UPDATE {} SET {} = $1, updated_at = now() WHERE {} = $2",
334                #table_name,
335                #audited_field_name,
336                #pk_column
337            );
338
339            transaction.execute(&update_sql, &[&new_value, &pk_value]).await?;
340
341            let diff = calculate_json_diff(&before, &new_value);
342
343            let audit_sql = format!(
344                "INSERT INTO {} ({}, who, changed_at, before, after, diff) VALUES ($1, $2, now(), $3, $4, $5)",
345                #audit_table,
346                #pk_column
347            );
348
349            transaction.execute(
350                &audit_sql,
351                &[&pk_value, &who, &before, &new_value, &diff]
352            ).await?;
353
354            transaction.commit().await?;
355
356            Ok(())
357        }
358    }
359}
360
361
362fn rust_type_from_schema(type_name: &str) -> TokenStream {
363    match type_name {
364        "BigInt" => quote! { i64 },
365        "Int" => quote! { i32 },
366        "String" => quote! { String },
367        "JsonB" => quote! { serde_json::Value },
368        "TimestamptZ" => quote! { DateTime<Utc> },
369        "Boolean" => quote! { bool },
370        "Float" => quote! { f64 },
371        _ => quote! { String },
372    }
373}
374
375fn to_snake_case(s: &str) -> String {
376    let mut result = String::new();
377    for (i, ch) in s.chars().enumerate() {
378        if ch.is_uppercase() && i > 0 {
379            result.push('_');
380        }
381        result.push(ch.to_lowercase().next().unwrap_or(ch));
382    }
383    result
384}