1use quote::{quote, format_ident};
2use proc_macro2::TokenStream;
3use crate::{Schema, Model, Field, Modifier};
4
5pub fn generate_rust_code(schema: &Schema) -> String {
6 let structs_and_impls = schema.models.iter().map(|model| {
7 generate_model_with_query_builder(model)
8 });
9
10 let code = quote! {
11 use serde::{Deserialize, Serialize};
12 use chrono::{DateTime, Utc};
13 use tokio_postgres::Client;
14
15 fn calculate_json_diff(before: &serde_json::Value, after: &serde_json::Value) -> serde_json::Value {
16 let mut diff = serde_json::Map::new();
17 if let (Some(before_obj), Some(after_obj)) = (before.as_object(), after.as_object()) {
18 for (key, after_val) in after_obj {
19 if let Some(before_val) = before_obj.get(key) {
20 if before_val != after_val {
21 diff.insert(
22 key.clone(),
23 serde_json::json!({ "from": before_val, "to": after_val })
24 );
25 }
26 } else {
27 diff.insert(key.clone(), serde_json::json!({ "added": after_val }));
28 }
29 }
30 for (key, before_val) in before_obj {
31 if !after_obj.contains_key(key) {
32 diff.insert(key.clone(), serde_json::json!({ "removed": before_val }));
33 }
34 }
35 }
36 serde_json::Value::Object(diff)
37 }
38
39 #(#structs_and_impls)*
40 };
41
42 let file: syn::File = syn::parse2(code).unwrap();
43 prettyplease::unparse(&file)
44}
45
46fn generate_model_with_query_builder(model: &Model) -> TokenStream {
47 let model_struct = generate_model_struct(model);
48 let query_builder_struct = generate_query_builder_struct(model);
49 let query_builder_impl = generate_query_builder_impl(model);
50 let model_impl = generate_model_impl(model);
51
52 quote! {
53 #model_struct
54 #query_builder_struct
55 #model_impl
56 #query_builder_impl
57 }
58}
59
60fn generate_model_struct(model: &Model) -> TokenStream {
61 let name = format_ident!("{}", model.name);
62 let fields = model.fields.iter().map(|field| {
63 let field_name = format_ident!("{}", field.name);
64 let field_type = rust_type_from_schema(&field.type_name);
65
66 quote! {
67 pub #field_name: #field_type
68 }
69 });
70
71 quote! {
72 #[derive(Debug, Clone, Serialize, Deserialize)]
73 pub struct #name {
74 #(#fields),*
75 }
76 }
77}
78
79fn generate_model_impl(model: &Model) -> TokenStream {
80 let model_name = format_ident!("{}", model.name);
81 let builder_name = format_ident!("{}Query", model.name);
82
83 let pk_field = model.fields.iter()
84 .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)));
85
86 let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
87 let field_name = format_ident!("{}", field.name);
88 quote! { #field_name: row.get(#idx) }
89 });
90
91 let find_by_id_impl = if let Some(pk) = pk_field {
92 let pk_type = rust_type_from_schema(&pk.type_name);
93 let pk_name = to_snake_case(&pk.name);
94
95 quote! {
96 pub async fn find_by_id(client: &Client, id: #pk_type)
97 -> Result<Option<#model_name>, Box<dyn std::error::Error>>
98 {
99 let sql = format!("SELECT * FROM {} WHERE {} = $1", stringify!(#model_name).to_lowercase(), #pk_name);
100 let row_opt = client.query_opt(&sql, &[&id]).await?;
101 Ok(row_opt.map(|row| #model_name {
102 #(#field_gets),*
103 }))
104 }
105 }
106 } else {
107 quote! {}
108 };
109
110 quote! {
111 impl #model_name {
112 pub fn query() -> #builder_name {
113 #builder_name::new()
114 }
115 #find_by_id_impl
116 }
117 }
118}
119
120fn generate_query_builder_struct(model: &Model) -> TokenStream {
121 let builder_name = format_ident!("{}Query", model.name);
122
123 quote! {
124 pub struct #builder_name {
125 table: String,
126 where_fragments: Vec<(&'static str, usize)>,
127 args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
128 limit: Option<usize>,
129 offset: Option<usize>,
130 order_by: Vec<(String, String)>,
131 }
132 impl Clone for #builder_name {
133 fn clone(&self) -> Self {
134 Self {
135 table: self.table.clone(),
136 where_fragments: self.where_fragments.clone(),
137 args: Vec::new(),
138 limit: self.limit,
139 offset: self.offset,
140 order_by: self.order_by.clone(),
141 }
142 }
143 }
144 }
145}
146
147fn generate_query_builder_impl(model: &Model) -> TokenStream {
148 let builder_name = format_ident!("{}Query", model.name);
149 let model_name = format_ident!("{}", model.name);
150 let table_name = model.name.to_lowercase();
151
152 let field_methods = model.fields.iter().enumerate().map(|(i, field)| {
153 let method_name = format_ident!("where_{}", to_snake_case(&field.name));
154 let field_type = rust_type_from_schema(&field.type_name);
155 let field_col = to_snake_case(&field.name);
156
157 quote! {
158 pub fn #method_name(mut self, value: #field_type) -> Self {
159 self.args.push(Box::new(value));
160 self.where_fragments.push((#field_col, self.args.len()));
161 self
162 }
163 }
164 });
165
166 let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
167 let field_name = format_ident!("{}", field.name);
168 quote! { #field_name: row.get(#idx) }
169 });
170
171 quote! {
172 impl #builder_name {
173 pub fn new() -> Self {
174 Self {
175 table: #table_name.to_string(),
176 where_fragments: vec![],
177 args: vec![],
178 limit: None,
179 offset: None,
180 order_by: vec![],
181 }
182 }
183 #(#field_methods)*
184 pub fn limit(mut self, limit: usize) -> Self {
185 self.limit = Some(limit);
186 self
187 }
188 pub fn offset(mut self, offset: usize) -> Self {
189 self.offset = Some(offset);
190 self
191 }
192 pub fn order_by(mut self, column: &str, direction: &str) -> Self {
193 self.order_by.push((column.to_string(), direction.to_string()));
194 self
195 }
196
197 pub async fn select(&self, client: &Client)
198 -> Result<Vec<#model_name>, Box<dyn std::error::Error>>
199 {
200 let (sql, params) = self.build_select();
201 let rows = client.query(&sql, ¶ms[..]).await?;
202 let mut results = Vec::new();
203 for row in rows {
204 results.push(#model_name { #(#field_gets),* });
205 }
206 Ok(results)
207 }
208
209 pub async fn first(&self, client: &Client)
210 -> Result<Option<#model_name>, Box<dyn std::error::Error>>
211 {
212 let mut query = #builder_name::new();
213 query.table = self.table.clone();
214 query.where_fragments = self.where_fragments.clone();
215 query.args = Vec::new();
216 query.limit = Some(1);
217 query.offset = self.offset;
218 query.order_by = self.order_by.clone();
219
220 let results = query.select(client).await?;
221 Ok(results.into_iter().next())
222 }
223
224 pub async fn count(&self, client: &Client)
225 -> Result<i64, Box<dyn std::error::Error>>
226 {
227 let (sql, params) = self.build_count();
228 let row = client.query_one(&sql, ¶ms[..]).await?;
229 Ok(row.get(0))
230 }
231
232 fn build_select(&self) -> (String, Vec<&(dyn tokio_postgres::types::ToSql + Sync)>) {
233 let mut sql = format!("SELECT * FROM {}", self.table);
234 let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
235
236 if !self.where_fragments.is_empty() {
237 let conds: Vec<String> = self.where_fragments.iter()
238 .enumerate()
239 .map(|(i, &(col, idx))| format!("{} = ${}", col, i + 1))
240 .collect();
241 sql.push_str(" WHERE ");
242 sql.push_str(&conds.join(" AND "));
243 for arg in &self.args {
244 params.push(arg.as_ref());
245 }
246 }
247 if !self.order_by.is_empty() {
248 sql.push_str(" ORDER BY ");
249 let order_clauses: Vec<String> = self.order_by.iter()
250 .map(|(col, dir)| format!("{} {}", col, dir))
251 .collect();
252 sql.push_str(&order_clauses.join(", "));
253 }
254 if let Some(limit) = self.limit {
255 sql.push_str(&format!(" LIMIT {}", limit));
256 }
257 if let Some(offset) = self.offset {
258 sql.push_str(&format!(" OFFSET {}", offset));
259 }
260 (sql, params)
261 }
262 fn build_count(&self) -> (String, Vec<&(dyn tokio_postgres::types::ToSql + Sync)>) {
263 let mut sql = format!("SELECT COUNT(*) FROM {}", self.table);
264 let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
265 if !self.where_fragments.is_empty() {
266 let conds: Vec<String> = self.where_fragments.iter()
267 .enumerate()
268 .map(|(i, &(col, idx))| format!("{} = ${}", col, i + 1))
269 .collect();
270 sql.push_str(" WHERE ");
271 sql.push_str(&conds.join(" AND "));
272 for arg in &self.args {
273 params.push(arg.as_ref());
274 }
275 }
276 (sql, params)
277 }
278 }
279 }
280}
281
282fn rust_type_from_schema(type_name: &str) -> TokenStream {
283 match type_name {
284 "BigInt" => quote! { i64 },
285 "Int" => quote! { i32 },
286 "String" => quote! { String },
287 "JsonB" => quote! { serde_json::Value },
288 "TimestamptZ" => quote! { DateTime<Utc> },
289 "Boolean" => quote! { bool },
290 "Float" => quote! { f64 },
291 _ => quote! { String },
292 }
293}
294
295fn to_snake_case(s: &str) -> String {
296 let mut result = String::new();
297 for (i, ch) in s.chars().enumerate() {
298 if ch.is_uppercase() && i > 0 {
299 result.push('_');
300 }
301 result.push(ch.to_lowercase().next().unwrap_or(ch));
302 }
303 result
304}