1use quote::{quote, format_ident};
2use proc_macro2::TokenStream;
3use crate::{Schema, Model, Field, Modifier};
4
5pub fn generate_rust_code(schema: &Schema) -> String {
6 let structs_and_impls = schema.models.iter().map(|model| {
7 generate_model_with_query_builder(model)
8 });
9
10 let code = quote! {
11 use serde::{Deserialize, Serialize};
12 use chrono::{DateTime, Utc};
13 use tokio_postgres::Client;
14
15 fn calculate_json_diff(before: &serde_json::Value, after: &serde_json::Value) -> serde_json::Value {
16 let mut diff = serde_json::Map::new();
17
18 if let (Some(before_obj), Some(after_obj)) = (before.as_object(), after.as_object()) {
19 for (key, after_val) in after_obj {
20 if let Some(before_val) = before_obj.get(key) {
21 if before_val != after_val {
22 diff.insert(
23 key.clone(),
24 serde_json::json!({
25 "from": before_val,
26 "to": after_val
27 })
28 );
29 }
30 } else {
31 diff.insert(key.clone(), serde_json::json!({ "added": after_val }));
32 }
33 }
34
35 for (key, before_val) in before_obj {
36 if !after_obj.contains_key(key) {
37 diff.insert(key.clone(), serde_json::json!({ "removed": before_val }));
38 }
39 }
40 }
41
42 serde_json::Value::Object(diff)
43 }
44
45 #(#structs_and_impls)*
46 };
47
48 let file: syn::File = syn::parse2(code).unwrap();
49 prettyplease::unparse(&file)
50}
51
52fn generate_model_with_query_builder(model: &Model) -> TokenStream {
53 let model_struct = generate_model_struct(model);
54 let query_builder_struct = generate_query_builder_struct(model);
55 let query_builder_impl = generate_query_builder_impl(model);
56 let model_impl = generate_model_impl(model);
57
58 quote! {
59 #model_struct
60
61 #query_builder_struct
62
63 #model_impl
64
65 #query_builder_impl
66 }
67}
68
69fn generate_model_struct(model: &Model) -> TokenStream {
70 let name = format_ident!("{}", model.name);
71 let fields = model.fields.iter().map(|field| {
72 let field_name = format_ident!("{}", field.name);
73 let field_type = rust_type_from_schema(&field.type_name);
74
75 quote! {
76 pub #field_name: #field_type
77 }
78 });
79
80 quote! {
81 #[derive(Debug, Clone, Serialize, Deserialize)]
82 pub struct #name {
83 #(#fields),*
84 }
85 }
86}
87
88fn generate_model_impl(model: &Model) -> TokenStream {
89 let model_name = format_ident!("{}", model.name);
90 let builder_name = format_ident!("{}Query", model.name);
91
92 let pk_field = model.fields.iter()
93 .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)));
94
95 let find_by_id_impl = if let Some(pk) = pk_field {
96 let pk_name = format_ident!("{}", to_snake_case(&pk.name));
97 let pk_type = rust_type_from_schema(&pk.type_name);
98
99 quote! {
100 pub fn find_by_id(id: #pk_type) -> String {
101 format!("SELECT * FROM {} WHERE {} = {}",
102 stringify!(#pk_name).replace("_", ""),
103 stringify!(#pk_name),
104 id
105 )
106 }
107 }
108 } else {
109 quote! {}
110 };
111
112 quote! {
113 impl #model_name {
114 pub fn query() -> #builder_name {
115 #builder_name::new()
116 }
117
118 #find_by_id_impl
119 }
120 }
121}
122
123fn generate_query_builder_struct(model: &Model) -> TokenStream {
124 let builder_name = format_ident!("{}Query", model.name);
125
126 quote! {
127 pub struct #builder_name {
128 table: String,
129 conditions: Vec<String>,
130 limit: Option<usize>,
131 offset: Option<usize>,
132 order_by: Vec<(String, String)>,
133 }
134
135 impl Clone for #builder_name {
136 fn clone(&self) -> Self {
137 Self {
138 table: self.table.clone(),
139 conditions: self.conditions.clone(),
140 limit: self.limit,
141 offset: self.offset,
142 order_by: self.order_by.clone(),
143 }
144 }
145 }
146 }
147}
148
149fn generate_query_builder_impl(model: &Model) -> TokenStream {
150 let builder_name = format_ident!("{}Query", model.name);
151 let model_name = format_ident!("{}", model.name);
152 let table_name = model.name.to_lowercase();
153
154 let where_methods = model.fields.iter().map(|field| {
155 let method_name = format_ident!("where_{}", to_snake_case(&field.name));
156 let field_name = to_snake_case(&field.name);
157 let field_type = rust_type_from_schema(&field.type_name);
158
159 quote! {
160 pub fn #method_name(mut self, value: #field_type) -> Self {
161 self.conditions.push(format!("{} = {:?}", #field_name, value));
162 self
163 }
164 }
165 });
166
167 let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
168 let field_name = format_ident!("{}", field.name);
169
170 quote! {
171 #field_name: row.get(#idx)
172 }
173 });
174
175 let has_audit = model.fields.iter().any(|f| f.get_audit_model().is_some());
176
177 let update_method = if has_audit {
178 generate_update_with_audit(model)
179 } else {
180 quote! {}
181 };
182
183 quote! {
184 impl #builder_name {
185 pub fn new() -> Self {
186 Self {
187 table: #table_name.to_string(),
188 conditions: vec![],
189 limit: None,
190 offset: None,
191 order_by: vec![],
192 }
193 }
194
195 #(#where_methods)*
196
197 pub fn limit(mut self, limit: usize) -> Self {
198 self.limit = Some(limit);
199 self
200 }
201
202 pub fn offset(mut self, offset: usize) -> Self {
203 self.offset = Some(offset);
204 self
205 }
206
207 pub fn order_by(mut self, column: &str, direction: &str) -> Self {
208 self.order_by.push((column.to_string(), direction.to_string()));
209 self
210 }
211
212 pub async fn select(&self, client: &Client)
213 -> Result<Vec<#model_name>, Box<dyn std::error::Error>>
214 {
215 let sql = self.build_select();
216
217 let rows = client.query(&sql, &[]).await?;
218 let mut results = Vec::new();
219
220 for row in rows {
221 results.push(#model_name {
222 #(#field_gets),*
223 });
224 }
225
226 Ok(results)
227 }
228
229 pub async fn first(&self, client: &Client)
230 -> Result<Option<#model_name>, Box<dyn std::error::Error>>
231 {
232 let mut query = self.clone();
233 query.limit = Some(1);
234 let results = query.select(client).await?;
235 Ok(results.into_iter().next())
236 }
237
238 pub async fn count(&self, client: &Client)
239 -> Result<i64, Box<dyn std::error::Error>>
240 {
241 let sql = format!("SELECT COUNT(*) FROM {}{}",
242 self.table,
243 if self.conditions.is_empty() {
244 String::new()
245 } else {
246 format!(" WHERE {}", self.conditions.join(" AND "))
247 }
248 );
249
250 let row = client.query_one(&sql, &[]).await?;
251 Ok(row.get(0))
252 }
253
254 #update_method
255
256 fn build_select(&self) -> String {
257 let mut sql = format!("SELECT * FROM {}", self.table);
258
259 if !self.conditions.is_empty() {
260 sql.push_str(" WHERE ");
261 sql.push_str(&self.conditions.join(" AND "));
262 }
263
264 if !self.order_by.is_empty() {
265 sql.push_str(" ORDER BY ");
266 let order_clauses: Vec<String> = self.order_by.iter()
267 .map(|(col, dir)| format!("{} {}", col, dir))
268 .collect();
269 sql.push_str(&order_clauses.join(", "));
270 }
271
272 if let Some(limit) = self.limit {
273 sql.push_str(&format!(" LIMIT {}", limit));
274 }
275
276 if let Some(offset) = self.offset {
277 sql.push_str(&format!(" OFFSET {}", offset));
278 }
279
280 sql
281 }
282 }
283 }
284}
285
286fn generate_update_with_audit(model: &Model) -> TokenStream {
287 let table_name = model.name.to_lowercase();
288
289 let audit_field = model.fields.iter()
290 .find(|f| f.get_audit_model().is_some())
291 .expect("Called generate_update_with_audit without audit field");
292
293 let audit_model_name = audit_field.get_audit_model().unwrap();
294 let audit_table = audit_model_name.to_lowercase();
295 let audited_field_name = &audit_field.name;
296
297 let pk_field = model.fields.iter()
298 .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
299 .expect("Model must have primary key for audit");
300
301 let pk_column = to_snake_case(&pk_field.name);
302
303 quote! {
304 pub async fn update(
305 &self,
306 client: &Client,
307 new_value: serde_json::Value,
308 who: i64,
309 ) -> Result<(), Box<dyn std::error::Error>> {
310 let transaction = client.transaction().await?;
311
312 if self.conditions.is_empty() {
313 return Err("No WHERE condition specified for update".into());
314 }
315
316 let pk_value: i64 = self.conditions[0]
317 .split('=')
318 .nth(1)
319 .and_then(|s| s.trim().parse().ok())
320 .ok_or("Failed to parse primary key value")?;
321
322 let before_sql = format!(
323 "SELECT {} FROM {} WHERE {} = $1",
324 #audited_field_name,
325 #table_name,
326 #pk_column
327 );
328
329 let before_row = transaction.query_one(&before_sql, &[&pk_value]).await?;
330 let before: serde_json::Value = before_row.get(0);
331
332 let update_sql = format!(
333 "UPDATE {} SET {} = $1, updated_at = now() WHERE {} = $2",
334 #table_name,
335 #audited_field_name,
336 #pk_column
337 );
338
339 transaction.execute(&update_sql, &[&new_value, &pk_value]).await?;
340
341 let diff = calculate_json_diff(&before, &new_value);
342
343 let audit_sql = format!(
344 "INSERT INTO {} ({}, who, changed_at, before, after, diff) VALUES ($1, $2, now(), $3, $4, $5)",
345 #audit_table,
346 #pk_column
347 );
348
349 transaction.execute(
350 &audit_sql,
351 &[&pk_value, &who, &before, &new_value, &diff]
352 ).await?;
353
354 transaction.commit().await?;
355
356 Ok(())
357 }
358 }
359}
360
361
362fn rust_type_from_schema(type_name: &str) -> TokenStream {
363 match type_name {
364 "BigInt" => quote! { i64 },
365 "Int" => quote! { i32 },
366 "String" => quote! { String },
367 "JsonB" => quote! { serde_json::Value },
368 "TimestamptZ" => quote! { DateTime<Utc> },
369 "Boolean" => quote! { bool },
370 "Float" => quote! { f64 },
371 _ => quote! { String },
372 }
373}
374
375fn to_snake_case(s: &str) -> String {
376 let mut result = String::new();
377 for (i, ch) in s.chars().enumerate() {
378 if ch.is_uppercase() && i > 0 {
379 result.push('_');
380 }
381 result.push(ch.to_lowercase().next().unwrap_or(ch));
382 }
383 result
384}