byteorm_lib/rustgen/
client.rs

1use std::collections::HashMap;
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use crate::rustgen::{capitalize_first, rust_type_from_schema, to_snake_case};
5use crate::{Modifier, Schema};
6
7fn pk_args(model: &crate::Model) -> (Vec<proc_macro2::Ident>, Vec<proc_macro2::TokenStream>, Vec<String>, Vec<String>, Vec<proc_macro2::TokenStream>) {
8    let pk_fields: Vec<_> = model.fields.iter()
9        .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
10        .collect();
11    let pk_names = pk_fields.iter().map(|pk| format_ident!("{}", to_snake_case(&pk.name))).collect();
12    let pk_types = pk_fields.iter().map(|pk| {
13        let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
14        rust_type_from_schema(&pk.type_name, is_nullable)
15    }).collect();
16    let pk_cols: Vec<_> = pk_fields.iter().map(|pk| to_snake_case(&pk.name)).collect();
17    let pk_placeholders: Vec<_> = (1..=pk_fields.len()).map(|i| format!("${}", i)).collect();
18    let pk_arg_refs = pk_fields.iter().map(|pk| {
19        let name = format_ident!("{}", to_snake_case(&pk.name));
20        quote! { &#name }
21    }).collect();
22    (pk_names, pk_types, pk_cols, pk_placeholders, pk_arg_refs)
23}
24
25fn generate_find_unique(model_name: &proc_macro2::Ident, model: &crate::Model) -> TokenStream {
26    let pk_fields: Vec<_> = model.fields.iter()
27        .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
28        .collect();
29
30    if pk_fields.is_empty() {
31        quote! {}
32    } else if pk_fields.len() == 1 {
33        let pk = &pk_fields[0];
34        let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
35        let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
36        quote! {
37            pub async fn find_unique(&self, id: #pk_type)
38                -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
39            {
40                #model_name::find_by_id(&self.client, id).await
41            }
42        }
43    } else {
44        let pk_params = pk_fields.iter().map(|pk| {
45            let name = format_ident!("{}", to_snake_case(&pk.name));
46            let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
47            let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
48            quote! { #name: #pk_type }
49        });
50
51        let pk_args = pk_fields.iter().map(|pk| {
52            let name = format_ident!("{}", to_snake_case(&pk.name));
53            quote! { #name }
54        });
55
56        quote! {
57            pub async fn find_unique(&self, #(#pk_params),*)
58                -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
59            {
60                #model_name::find_by_composite_pk(&self.client, #(#pk_args),*).await
61            }
62        }
63    }
64}
65
66fn generate_find_or_create(model_name: &proc_macro2::Ident, model: &crate::Model, table_name: &str) -> TokenStream {
67    let pk_fields: Vec<_> = model.fields.iter()
68        .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
69        .collect();
70
71    if pk_fields.is_empty() {
72        quote! {}
73    } else if pk_fields.len() == 1 {
74        let pk = &pk_fields[0];
75        let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
76        let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
77        let pk_col = to_snake_case(&pk.name);
78        quote! {
79            pub async fn find_or_create(&self, id: #pk_type)
80                -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>>
81            {
82                self.client.execute(
83                    &format!("INSERT INTO {} ({}) VALUES ($1) ON CONFLICT DO NOTHING", #table_name, #pk_col),
84                    &[&id]
85                ).await?;
86                self.find_unique(id).await?.ok_or("Record should exist after find_or_create".into())
87            }
88        }
89    } else {
90        let pk_params = pk_fields.iter().map(|pk| {
91            let name = format_ident!("{}", to_snake_case(&pk.name));
92            let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
93            let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
94            quote! { #name: #pk_type }
95        });
96        let pk_cols: Vec<_> = pk_fields.iter().map(|pk| to_snake_case(&pk.name)).collect();
97        let pk_cols_str = pk_cols.join(", ");
98        let pk_placeholders: Vec<_> = (1..=pk_fields.len()).map(|i| format!("${}", i)).collect();
99        let pk_placeholders_str = pk_placeholders.join(", ");
100        let pk_conflict = pk_cols.join(", ");
101        let pk_args = pk_fields.iter().map(|pk| {
102            let name = format_ident!("{}", to_snake_case(&pk.name));
103            quote! { &#name }
104        });
105        let pk_args_call = pk_fields.iter().map(|pk| {
106            let name = format_ident!("{}", to_snake_case(&pk.name));
107            quote! { #name }
108        });
109        quote! {
110            pub async fn find_or_create(&self, #(#pk_params),*)
111                -> Result<#model_name, Box<dyn std::error::Error + Send + Sync>>
112            {
113                let sql = format!(
114                    "INSERT INTO {} ({}) VALUES ({}) ON CONFLICT ({}) DO NOTHING",
115                    #table_name, #pk_cols_str, #pk_placeholders_str, #pk_conflict
116                );
117                self.client.execute(&sql, &[#(#pk_args),*]).await?;
118                self.find_unique(#(#pk_args_call),*).await?.ok_or("Record should exist after find_or_create".into())
119            }
120        }
121    }
122}
123
124fn generate_jsonb_sub_accessors(model: &crate::Model, jsonb_defaults: &HashMap<(String, String), String>) -> Vec<TokenStream> {
125    let model_name = &model.name;
126    let query_builder = format_ident!("{}Query", model.name);
127    let table_name = model.name.to_lowercase();
128
129    let pk_fields: Vec<_> = model.fields.iter()
130        .filter(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)))
131        .collect();
132
133    let jsonb_fields: Vec<_> = model.fields.iter()
134        .filter(|f| f.type_name == "JsonB")
135        .collect();
136
137    jsonb_fields.into_iter().map(|jsonb| {
138        let jsonb_name = &jsonb.name;
139        let jsonb_snake = to_snake_case(jsonb_name);
140        let jsonb_field_ident = format_ident!("{}", jsonb_name);
141        let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(jsonb_name));
142        let defaults_const = format_ident!("{}_DEFAULTS", jsonb_snake.to_uppercase());
143
144        let default_json_init = if let Some(json_content) = jsonb_defaults.get(&(model.name.clone(), jsonb.name.clone())) {
145            quote! {
146                static #defaults_const: Lazy<serde_json::Value> = Lazy::new(|| {
147                    serde_json::from_str(#json_content)
148                        .expect(&format!("Failed to parse default JSON for {}.{}", stringify!(#model_name), #jsonb_name))
149                });
150            }
151        } else {
152            quote! {
153                static #defaults_const: Lazy<serde_json::Value> = Lazy::new(|| {
154                    serde_json::json!({})
155                });
156            }
157        };
158
159        let (pk_params, pk_where_methods, pk_args_for_set) = if pk_fields.len() == 1 {
160            let pk = &pk_fields[0];
161            let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
162            let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
163            let pk_field_name = format_ident!("where_{}", to_snake_case(&pk.name));
164            (
165                quote! { id: #pk_type },
166                quote! { .#pk_field_name(id) },
167                vec![quote! { &id }],
168            )
169        } else {
170            let params = pk_fields.iter().map(|pk| {
171                let param_name = format_ident!("{}", to_snake_case(&pk.name));
172                let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
173                let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
174                quote! { #param_name: #pk_type }
175            });
176
177            let where_methods = pk_fields.iter().map(|pk| {
178                let method_name = format_ident!("where_{}", to_snake_case(&pk.name));
179                let param_name = format_ident!("{}", to_snake_case(&pk.name));
180                quote! { .#method_name(#param_name) }
181            });
182
183            let set_args = pk_fields.iter().map(|pk| {
184                let param_name = format_ident!("{}", to_snake_case(&pk.name));
185                quote! { &#param_name }
186            });
187
188            (
189                quote! { #(#params),* },
190                quote! { #(#where_methods)* },
191                set_args.collect::<Vec<_>>(),
192            )
193        };
194
195        let pk_args_clone = if pk_fields.len() == 1 {
196            quote! { id }
197        } else {
198            let args = pk_fields.iter().map(|pk| {
199                let param_name = format_ident!("{}", to_snake_case(&pk.name));
200                quote! { #param_name }
201            });
202            quote! { #(#args),* }
203        };
204
205        let pk_columns: Vec<_> = pk_fields.iter().map(|pk| to_snake_case(&pk.name)).collect();
206        let pk_placeholders: Vec<_> = (1..=pk_fields.len()).map(|i| format!("${}", i)).collect();
207        let insert_pk_part = pk_columns.join(", ");
208        let insert_values_part = pk_placeholders.join(", ");
209        let conflict_clause = pk_columns.join(", ");
210
211        quote! {
212            #default_json_init
213
214            #[derive(Clone)]
215            pub struct #sub_accessor_struct {
216                client: Arc<PgClient>,
217            }
218
219            impl std::fmt::Debug for #sub_accessor_struct {
220                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221                    f.debug_struct(stringify!(#sub_accessor_struct))
222                        .field("client", &"<PgClient>")
223                        .finish()
224                }
225            }
226
227            impl #sub_accessor_struct {
228                pub fn new(client: Arc<PgClient>) -> Self {
229                    Self { client }
230                }
231
232                pub async fn get(&self, #pk_params, key: &str)
233                    -> Result<String, Box<dyn std::error::Error + Send + Sync>>
234                {
235                    match #query_builder::new()
236                        #pk_where_methods
237                        .first(&self.client)
238                        .await
239                    {
240                        Ok(Some(record)) => {
241                            record.#jsonb_field_ident.get_string(key)
242                                .or_else(|_| #defaults_const.get_string(key))
243                        },
244                        Ok(None) => #defaults_const.get_string(key),
245                        Err(e) => Err(e),
246                    }
247                }
248
249                pub async fn get_as<T>(&self, #pk_params, key: &str)
250                    -> Result<T, Box<dyn std::error::Error + Send + Sync>>
251                where
252                    T: serde::de::DeserializeOwned,
253                {
254                    match #query_builder::new()
255                        #pk_where_methods
256                        .first(&self.client)
257                        .await
258                    {
259                        Ok(Some(record)) => {
260                            record.#jsonb_field_ident.get_value(key)
261                                .or_else(|_| #defaults_const.get_value(key))
262                        },
263                        Ok(None) => #defaults_const.get_value(key),
264                        Err(e) => Err(e),
265                    }
266                }
267
268                pub async fn get_or<T>(&self, #pk_params, key: &str, default: T)
269                    -> Result<T, Box<dyn std::error::Error + Send + Sync>>
270                where
271                    T: serde::de::DeserializeOwned,
272                {
273                    match self.get_as(#pk_args_clone, key).await {
274                        Ok(value) => Ok(value),
275                        Err(_) => Ok(default),
276                    }
277                }
278
279                pub async fn has(&self, #pk_params, key: &str)
280                    -> Result<bool, Box<dyn std::error::Error + Send + Sync>>
281                {
282                    match #query_builder::new()
283                        #pk_where_methods
284                        .first(&self.client)
285                        .await
286                    {
287                        Ok(Some(record)) => Ok(record.#jsonb_field_ident.has_key(key) || #defaults_const.has_key(key)),
288                        Ok(None) => Ok(#defaults_const.has_key(key)),
289                        Err(e) => Err(e),
290                    }
291                }
292
293                pub async fn set<T>(&self, #pk_params, key: &str, value: T)
294                    -> Result<(), Box<dyn std::error::Error + Send + Sync>>
295                where
296                    T: serde::Serialize + Send + Sync,
297                {
298                    let value_json = serde_json::to_value(&value)?;
299                    let value_str = value_json.to_string();
300
301                    let sql = format!(
302                        "INSERT INTO {} ({}, {}, updated_at) VALUES ({}, jsonb_build_object($1, $2), NOW()) \
303                         ON CONFLICT ({}) DO UPDATE SET {} = jsonb_set(COALESCE({}.{}, '{{}}'::jsonb), $3, $4, true), updated_at = NOW()",
304                        #table_name,
305                        #insert_pk_part,
306                        #jsonb_snake,
307                        #insert_values_part,
308                        #conflict_clause,
309                        #jsonb_snake,
310                        #table_name,
311                        #jsonb_snake,
312                    );
313
314                    let key_path = format!("{{{}}}", key);
315                    self.client.execute(
316                        &sql,
317                        &[#(#pk_args_for_set),*, &key, &value_str, &key_path, &value_str]
318                    ).await?;
319
320                    Ok(())
321                }
322
323                pub async fn get_many(
324                    &self, #pk_params, keys: &[&str]
325                ) -> Result<HashMap<String, serde_json::Value>, Box<dyn std::error::Error + Send + Sync>>
326                {
327                    let opt = #query_builder::new()
328                        #pk_where_methods
329                        .first(&self.client).await?;
330
331                    let mut out = HashMap::new();
332
333                    if let Some(record) = opt {
334                        for &key in keys {
335                            if let Some(v) = record.#jsonb_field_ident.get(key) {
336                                out.insert(key.to_string(), v.clone());
337                            }
338                        }
339                    } else {
340                        for &key in keys {
341                            if let Some(v) = #defaults_const.get(key) {
342                                out.insert(key.to_string(), v.clone());
343                            }
344                        }
345                    }
346                    Ok(out)
347                }
348
349                pub async fn get_many_as<T>(
350                    &self, #pk_params, keys: &[&str]
351                ) -> Result<HashMap<String, T>, Box<dyn std::error::Error + Send + Sync>>
352                where T: serde::de::DeserializeOwned
353                {
354                    let values = self.get_many(#pk_args_clone, keys).await?;
355                    let mut map = HashMap::new();
356                    for (k, v) in values {
357                        if let Ok(x) = serde_json::from_value::<T>(v) {
358                            map.insert(k, x);
359                        }
360                    }
361                    Ok(map)
362                }
363            }
364        }
365    }).collect()
366}
367
368pub fn generate_client_struct(schema: &Schema, jsonb_defaults: &HashMap<(String, String), String>) -> TokenStream {
369    let model_accessors = schema.models.iter().map(|model| {
370        let accessor_name = format_ident!("{}", to_snake_case(&model.name));
371        let accessor_struct = format_ident!("{}Accessor", model.name);
372        quote! { pub #accessor_name: #accessor_struct }
373    });
374
375    let accessor_structs = schema.models.iter().map(|model| {
376        let model_name = format_ident!("{}", model.name);
377        let accessor_struct = format_ident!("{}Accessor", model.name);
378        let query_builder = format_ident!("{}Query", model.name);
379        let update_builder = format_ident!("{}Update", model.name);
380        let upsert_builder = format_ident!("{}Upsert", model.name);
381        let table_name = model.name.to_lowercase();
382
383        let find_unique = generate_find_unique(&model_name, model);
384        let find_or_create = generate_find_or_create(&model_name, model, &table_name);
385
386        let jsonb_fields: Vec<_> = model.fields.iter()
387            .filter(|f| f.type_name == "JsonB")
388            .collect();
389
390        let jsonb_accessor_fields = jsonb_fields.iter().map(|jsonb| {
391            let jsonb_snake = to_snake_case(&jsonb.name);
392            let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(&jsonb.name));
393            let sub_accessor_field = format_ident!("{}", jsonb_snake);
394            quote! { pub #sub_accessor_field: #sub_accessor_struct }
395        });
396
397        let jsonb_accessor_inits = jsonb_fields.iter().map(|jsonb| {
398            let jsonb_snake = to_snake_case(&jsonb.name);
399            let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(&jsonb.name));
400            let sub_accessor_field = format_ident!("{}", jsonb_snake);
401            quote! { #sub_accessor_field: #sub_accessor_struct::new(client.clone()) }
402        });
403
404        let jsonb_debug_fields = jsonb_fields.iter().map(|jsonb| {
405            let jsonb_snake = to_snake_case(&jsonb.name);
406            let sub_accessor_field = format_ident!("{}", jsonb_snake);
407            quote! { .field(stringify!(#sub_accessor_field), &self.#sub_accessor_field) }
408        });
409
410        let jsonb_sub_accessors = generate_jsonb_sub_accessors(model, jsonb_defaults);
411
412        quote! {
413            #(#jsonb_sub_accessors)*
414
415            #[derive(Clone)]
416            pub struct #accessor_struct {
417                client: Arc<PgClient>,
418                #(#jsonb_accessor_fields),*
419            }
420            impl std::fmt::Debug for #accessor_struct {
421                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
422                    f.debug_struct(stringify!(#accessor_struct))
423                        .field("client", &"<PgClient>")
424                        #(#jsonb_debug_fields)*
425                        .finish()
426                }
427            }
428            impl #accessor_struct {
429                pub fn new(client: Arc<PgClient>) -> Self {
430                    Self {
431                        client: client.clone(),
432                        #(#jsonb_accessor_inits),*
433                    }
434                }
435                pub fn find_many(&self) -> #query_builder { #query_builder::new() }
436                pub fn update(&self) -> #update_builder { #update_builder::new(self.client.clone()) }
437                pub fn upsert(&self) -> #upsert_builder { #upsert_builder::new(self.client.clone()) }
438                #find_unique
439                #find_or_create
440                pub async fn find_first(&self) -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>> {
441                    #query_builder::new().first(&self.client).await
442                }
443                pub async fn count(&self) -> Result<i64, Box<dyn std::error::Error + Send + Sync>> {
444                    #query_builder::new().count(&self.client).await
445                }
446                pub fn client(&self) -> &PgClient { &self.client }
447            }
448        }
449    });
450
451    let accessor_inits = schema.models.iter().map(|model| {
452        let accessor_name = format_ident!("{}", to_snake_case(&model.name));
453        let accessor_struct = format_ident!("{}Accessor", model.name);
454        quote! { #accessor_name: #accessor_struct::new(client.clone()) }
455    });
456
457    let debug_accessor_fields = schema.models.iter().map(|model| {
458        let accessor_name = to_snake_case(&model.name);
459        let accessor_name_ident = format_ident!("{}", accessor_name);
460        quote! { .field(#accessor_name, &self.#accessor_name_ident) }
461    });
462
463    quote! {
464        #(#accessor_structs)*
465
466        pub struct Client {
467            client: Arc<PgClient>,
468            #(#model_accessors),*
469        }
470        impl std::fmt::Debug for Client {
471            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
472                f.debug_struct("Client")
473                    .field("client", &"<PgClient>")
474                    #(#debug_accessor_fields)*
475                    .finish()
476            }
477        }
478        impl Client {
479            pub async fn new(connection_string: &str) -> Result<Self, Error> {
480                let (client, connection) = tokio_postgres::connect(connection_string, NoTls).await?;
481                tokio::spawn(async move {
482                    if let Err(e) = connection.await {
483                        eprintln!("connection error: {}", e);
484                    }
485                });
486                let client = Arc::new(client);
487                Ok(Self {
488                    client: client.clone(),
489                    #(#accessor_inits),*
490                })
491            }
492            pub fn client(&self) -> &PgClient { &self.client }
493        }
494    }
495}