byteorm_lib/
rustgen.rs

1use quote::{quote, format_ident};
2use proc_macro2::TokenStream;
3use crate::{Schema, Model, Field, Modifier};
4use std::fs;
5use std::collections::HashMap;
6
7pub fn generate_rust_code(schema: &Schema) -> String {
8    let mut jsonb_defaults = HashMap::new();
9    for model in &schema.models {
10        for field in &model.fields {
11            if let Some(path) = field.get_jsonb_default_path() {
12                match fs::read_to_string(&path) {
13                    Ok(content) => {
14                        jsonb_defaults.insert((model.name.clone(), field.name.clone()), content);
15                    }
16                    Err(e) => {
17                        eprintln!("Warning: Could not read default file '{}': {}", path, e);
18                    }
19                }
20            }
21        }
22    }
23
24    let structs_and_impls = schema.models.iter().map(|model| {
25        generate_model_with_query_builder(model)
26    });
27
28    let client_struct = generate_client_struct(schema, &jsonb_defaults);
29    let jsonb_ext = generate_jsonb_ext();
30
31    let code = quote! {
32        use serde::{Deserialize, Serialize};
33        use chrono::{DateTime, Utc};
34        use tokio_postgres::{Client as PgClient, NoTls, Error};
35        use std::sync::Arc;
36        use once_cell::sync::Lazy;
37
38        fn calculate_json_diff(before: &serde_json::Value, after: &serde_json::Value) -> serde_json::Value {
39            let mut diff = serde_json::Map::new();
40            if let (Some(before_obj), Some(after_obj)) = (before.as_object(), after.as_object()) {
41                for (key, after_val) in after_obj {
42                    if let Some(before_val) = before_obj.get(key) {
43                        if before_val != after_val {
44                            diff.insert(
45                                key.clone(),
46                                serde_json::json!({ "from": before_val, "to": after_val })
47                            );
48                        }
49                    } else {
50                        diff.insert(key.clone(), serde_json::json!({ "added": after_val }));
51                    }
52                }
53                for (key, before_val) in before_obj {
54                    if !after_obj.contains_key(key) {
55                        diff.insert(key.clone(), serde_json::json!({ "removed": before_val }));
56                    }
57                }
58            }
59            serde_json::Value::Object(diff)
60        }
61
62        #jsonb_ext
63        #client_struct
64        #(#structs_and_impls)*
65    };
66
67    let file: syn::File = syn::parse2(code).unwrap();
68    prettyplease::unparse(&file)
69}
70
71fn generate_jsonb_ext() -> TokenStream {
72    quote! {
73        /// Extension trait for easier JSONB field access
74        pub trait JsonbExt {
75            fn get_value<T>(&self, key: &str) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
76            where
77                T: serde::de::DeserializeOwned;
78
79            fn get_string(&self, key: &str) -> Result<String, Box<dyn std::error::Error + Send + Sync>>;
80
81            fn get_i64(&self, key: &str) -> Result<i64, Box<dyn std::error::Error + Send + Sync>>;
82
83            fn get_bool(&self, key: &str) -> Result<bool, Box<dyn std::error::Error + Send + Sync>>;
84
85            fn get_or_default<T>(&self, key: &str, default: T) -> T
86            where
87                T: serde::de::DeserializeOwned;
88
89            fn has_key(&self, key: &str) -> bool;
90        }
91
92        impl JsonbExt for serde_json::Value {
93            fn get_value<T>(&self, key: &str) -> Result<T, Box<dyn std::error::Error + Send + Sync>>
94            where
95                T: serde::de::DeserializeOwned,
96            {
97                let value = self.get(key)
98                    .ok_or_else(|| format!("Key '{}' not found", key))?;
99
100                serde_json::from_value(value.clone())
101                    .map_err(|e| format!("Failed to parse key '{}': {}", key, e).into())
102            }
103
104            fn get_string(&self, key: &str) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
105                match self.get(key) {
106                    Some(serde_json::Value::String(s)) => Ok(s.clone()),
107                    Some(v) => serde_json::from_value(v.clone())
108                        .map_err(|e| format!("Failed to parse '{}' as string: {}", key, e).into()),
109                    None => Err(format!("Key '{}' not found", key).into()),
110                }
111            }
112
113            fn get_i64(&self, key: &str) -> Result<i64, Box<dyn std::error::Error + Send + Sync>> {
114                match self.get(key) {
115                    Some(serde_json::Value::Number(n)) => n.as_i64()
116                        .ok_or_else(|| format!("Key '{}' is not a valid i64", key).into()),
117                    Some(v) => serde_json::from_value(v.clone())
118                        .map_err(|e| format!("Failed to parse '{}' as i64: {}", key, e).into()),
119                    None => Err(format!("Key '{}' not found", key).into()),
120                }
121            }
122
123            fn get_bool(&self, key: &str) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
124                match self.get(key) {
125                    Some(serde_json::Value::Bool(b)) => Ok(*b),
126                    Some(v) => serde_json::from_value(v.clone())
127                        .map_err(|e| format!("Failed to parse '{}' as bool: {}", key, e).into()),
128                    None => Err(format!("Key '{}' not found", key).into()),
129                }
130            }
131
132            fn get_or_default<T>(&self, key: &str, default: T) -> T
133            where
134                T: serde::de::DeserializeOwned,
135            {
136                self.get_value(key).unwrap_or(default)
137            }
138
139            fn has_key(&self, key: &str) -> bool {
140                self.get(key).is_some()
141            }
142        }
143    }
144}
145
146fn generate_client_struct(schema: &Schema, jsonb_defaults: &HashMap<(String, String), String>) -> TokenStream {
147    let model_accessors = schema.models.iter().map(|model| {
148        let accessor_name = format_ident!("{}", to_snake_case(&model.name));
149        let accessor_struct = format_ident!("{}Accessor", model.name);
150
151        quote! {
152            pub #accessor_name: #accessor_struct
153        }
154    });
155
156    let accessor_structs = schema.models.iter().map(|model| {
157        let model_name = format_ident!("{}", model.name);
158        let accessor_struct = format_ident!("{}Accessor", model.name);
159        let query_builder = format_ident!("{}Query", model.name);
160        let table_name = model.name.to_lowercase();
161
162        let pk_field = model.fields.iter()
163            .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)));
164
165        let jsonb_fields: Vec<_> = model.fields.iter()
166            .filter(|f| f.type_name == "JsonB")
167            .collect();
168
169        let find_unique = if let Some(pk) = pk_field {
170            let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
171            let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
172
173            quote! {
174                pub async fn find_unique(&self, id: #pk_type)
175                    -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
176                {
177                    #model_name::find_by_id(&self.client, id).await
178                }
179            }
180        } else {
181            quote! {}
182        };
183
184        // Generate JSONB sub-accessors with default support
185        let jsonb_sub_accessors = if let Some(pk) = pk_field {
186            let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
187            let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
188            let pk_field_name = format_ident!("where_{}", to_snake_case(&pk.name));
189            let pk_col_name = to_snake_case(&pk.name);
190
191            jsonb_fields.iter().map(|jsonb| {
192                let jsonb_name = &jsonb.name;
193                let jsonb_snake = to_snake_case(jsonb_name);
194                let jsonb_field_ident = format_ident!("{}", jsonb_name);
195                let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(jsonb_name));
196                let defaults_const = format_ident!("{}_DEFAULTS", jsonb_snake.to_uppercase());
197
198                // Generate default JSON constant
199                let default_json_init = if let Some(json_content) = jsonb_defaults.get(&(model.name.clone(), jsonb.name.clone())) {
200                    quote! {
201                        static #defaults_const: Lazy<serde_json::Value> = Lazy::new(|| {
202                            serde_json::from_str(#json_content)
203                                .expect(&format!("Failed to parse default JSON for {}.{}", stringify!(#model_name), #jsonb_name))
204                        });
205                    }
206                } else {
207                    quote! {
208                        static #defaults_const: Lazy<serde_json::Value> = Lazy::new(|| {
209                            serde_json::json!({})
210                        });
211                    }
212                };
213
214                let doc_struct = format!("Accessor for the `{}` JSONB field with default fallback support", jsonb_name);
215                let doc_get = format!("Get a string value from `{}` by key (falls back to default if not found)", jsonb_name);
216                let doc_get_as = format!("Get a typed value from `{}` by key (falls back to default if not found)", jsonb_name);
217                let doc_get_or = format!("Get a value from `{}` with a runtime default fallback", jsonb_name);
218                let doc_has = format!("Check if key exists in `{}` (checks both DB and defaults)", jsonb_name);
219                let doc_set = format!("Set a value in `{}` by key (creates/updates the key)", jsonb_name);
220
221                quote! {
222                    #default_json_init
223
224                    #[doc = #doc_struct]
225                    #[derive(Clone)]
226                    pub struct #sub_accessor_struct {
227                        client: Arc<PgClient>,
228                    }
229
230                    impl std::fmt::Debug for #sub_accessor_struct {
231                        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
232                            f.debug_struct(stringify!(#sub_accessor_struct))
233                                .field("client", &"<PgClient>")
234                                .finish()
235                        }
236                    }
237
238                    impl #sub_accessor_struct {
239                        pub fn new(client: Arc<PgClient>) -> Self {
240                            Self { client }
241                        }
242
243                        #[doc = #doc_get]
244                        ///
245                        /// # Example
246                        /// ```
247                        /// let value = client.guild_settings.settings.get(guild_id, "settingsLang").await?;
248                        /// ```
249                        pub async fn get(&self, id: #pk_type, key: &str)
250                            -> Result<String, Box<dyn std::error::Error + Send + Sync>>
251                        {
252
253                            match #query_builder::new()
254                                .#pk_field_name(id)
255                                .first(&self.client)
256                                .await
257                            {
258                                Ok(Some(record)) => {
259
260                                    match record.#jsonb_field_ident.get_string(key) {
261                                        Ok(value) => Ok(value),
262                                        Err(_) => {
263
264                                            #defaults_const.get_string(key)
265                                        }
266                                    }
267                                },
268                                Ok(None) => {
269
270                                    #defaults_const.get_string(key)
271                                },
272                                Err(e) => Err(e),
273                            }
274                        }
275
276                        #[doc = #doc_get_as]
277                        ///
278                        /// # Example
279                        /// ```
280                        /// let count: i64 = client.guild_settings.settings.get_as(guild_id, "messageCount").await?;
281                        /// let tags: Vec<String> = client.guild_settings.settings.get_as(guild_id, "tags").await?;
282                        /// ```
283                        pub async fn get_as<T>(&self, id: #pk_type, key: &str)
284                            -> Result<T, Box<dyn std::error::Error + Send + Sync>>
285                        where
286                            T: serde::de::DeserializeOwned,
287                        {
288                            match #query_builder::new()
289                                .#pk_field_name(id)
290                                .first(&self.client)
291                                .await
292                            {
293                                Ok(Some(record)) => {
294                                    match record.#jsonb_field_ident.get_value(key) {
295                                        Ok(value) => Ok(value),
296                                        Err(_) => #defaults_const.get_value(key),
297                                    }
298                                },
299                                Ok(None) => #defaults_const.get_value(key),
300                                Err(e) => Err(e),
301                            }
302                        }
303
304                        #[doc = #doc_get_or]
305                        ///
306                        /// # Example
307                        /// ```
308                        /// let prefix = client.guild_settings.settings.get_or(guild_id, "prefix", "!".to_string()).await?;
309                        /// ```
310                        pub async fn get_or<T>(&self, id: #pk_type, key: &str, default: T)
311                            -> Result<T, Box<dyn std::error::Error + Send + Sync>>
312                        where
313                            T: serde::de::DeserializeOwned,
314                        {
315                            match self.get_as(id, key).await {
316                                Ok(value) => Ok(value),
317                                Err(_) => Ok(default),
318                            }
319                        }
320
321                        #[doc = #doc_has]
322                        ///
323                        /// # Example
324                        /// ```
325                        /// if client.guild_settings.settings.has(guild_id, "premium").await? {
326                        ///     // Premium is configured
327                        /// }
328                        /// ```
329                        pub async fn has(&self, id: #pk_type, key: &str)
330                            -> Result<bool, Box<dyn std::error::Error + Send + Sync>>
331                        {
332                            match #query_builder::new()
333                                .#pk_field_name(id)
334                                .first(&self.client)
335                                .await
336                            {
337                                Ok(Some(record)) => Ok(record.#jsonb_field_ident.has_key(key) || #defaults_const.has_key(key)),
338                                Ok(None) => Ok(#defaults_const.has_key(key)),
339                                Err(e) => Err(e),
340                            }
341                        }
342
343                        #[doc = #doc_set]
344                        ///
345                        /// # Example
346                        /// ```
347                        /// client.guild_settings.settings.set(guild_id, "settingsLang", "en").await?;
348                        /// ```
349                        pub async fn set<T>(&self, id: #pk_type, key: &str, value: T)
350                            -> Result<(), Box<dyn std::error::Error + Send + Sync>>
351                        where
352                            T: serde::Serialize + Send + Sync,
353                        {
354                            let value_json = serde_json::to_value(&value)?;
355                            let value_str = value_json.to_string();
356
357                            let sql = format!(
358                                "INSERT INTO {} ({}, {}, updated_at) VALUES ($1, jsonb_build_object($2, $3), NOW()) \
359                                 ON CONFLICT ({}) DO UPDATE SET {} = jsonb_set(COALESCE({}.{}, '{{}}'::jsonb), $4, $5, true), updated_at = NOW()",
360                                #table_name,
361                                #pk_col_name,
362                                #jsonb_snake,
363                                #pk_col_name,
364                                #jsonb_snake,
365                                #table_name,
366                                #jsonb_snake
367                            );
368
369                            let key_path = format!("{{{}}}", key);
370                            self.client.execute(
371                                &sql,
372                                &[&id, &key, &value_str, &key_path, &value_str]
373                            ).await?;
374
375                            Ok(())
376                        }
377                    }
378                }
379            }).collect::<Vec<_>>()
380        } else {
381            vec![]
382        };
383
384        // Fields for JSONB sub-accessors in main accessor
385        let jsonb_accessor_fields = jsonb_fields.iter().map(|jsonb| {
386            let jsonb_snake = to_snake_case(&jsonb.name);
387            let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(&jsonb.name));
388            let sub_accessor_field = format_ident!("{}", jsonb_snake);
389
390            quote! {
391                pub #sub_accessor_field: #sub_accessor_struct
392            }
393        });
394
395        // Initialize JSONB sub-accessors
396        let jsonb_accessor_inits = jsonb_fields.iter().map(|jsonb| {
397            let jsonb_snake = to_snake_case(&jsonb.name);
398            let sub_accessor_struct = format_ident!("{}{}Accessor", model.name, capitalize_first(&jsonb.name));
399            let sub_accessor_field = format_ident!("{}", jsonb_snake);
400
401            quote! {
402                #sub_accessor_field: #sub_accessor_struct::new(client.clone())
403            }
404        });
405
406        let jsonb_debug_fields = jsonb_fields.iter().map(|jsonb| {
407            let jsonb_snake = to_snake_case(&jsonb.name);
408            let sub_accessor_field = format_ident!("{}", jsonb_snake);
409
410            quote! {
411                .field(stringify!(#sub_accessor_field), &self.#sub_accessor_field)
412            }
413        });
414
415        quote! {
416            #(#jsonb_sub_accessors)*
417
418            #[derive(Clone)]
419            pub struct #accessor_struct {
420                client: Arc<PgClient>,
421                #(#jsonb_accessor_fields),*
422            }
423
424            impl std::fmt::Debug for #accessor_struct {
425                fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
426                    f.debug_struct(stringify!(#accessor_struct))
427                        .field("client", &"<PgClient>")
428                        #(#jsonb_debug_fields)*
429                        .finish()
430                }
431            }
432
433            impl #accessor_struct {
434                pub fn new(client: Arc<PgClient>) -> Self {
435                    Self {
436                        client: client.clone(),
437                        #(#jsonb_accessor_inits),*
438                    }
439                }
440
441                pub fn find_many(&self) -> #query_builder {
442                    #query_builder::new()
443                }
444
445                #find_unique
446
447                pub async fn find_first(&self) -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>> {
448                    #query_builder::new().first(&self.client).await
449                }
450
451                pub async fn count(&self) -> Result<i64, Box<dyn std::error::Error + Send + Sync>> {
452                    #query_builder::new().count(&self.client).await
453                }
454
455                pub fn client(&self) -> &PgClient {
456                    &self.client
457                }
458            }
459        }
460    });
461
462    let accessor_inits = schema.models.iter().map(|model| {
463        let accessor_name = format_ident!("{}", to_snake_case(&model.name));
464        let accessor_struct = format_ident!("{}Accessor", model.name);
465
466        quote! {
467            #accessor_name: #accessor_struct::new(client.clone())
468        }
469    });
470
471    let debug_accessor_fields = schema.models.iter().map(|model| {
472        let accessor_name = to_snake_case(&model.name);
473        let accessor_name_ident = format_ident!("{}", accessor_name);
474
475        quote! {
476            .field(#accessor_name, &self.#accessor_name_ident)
477        }
478    });
479
480    quote! {
481        #(#accessor_structs)*
482
483        pub struct Client {
484            client: Arc<PgClient>,
485            #(#model_accessors),*
486        }
487
488        impl std::fmt::Debug for Client {
489            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490                f.debug_struct("Client")
491                    .field("client", &"<PgClient>")
492                    #(#debug_accessor_fields)*
493                    .finish()
494            }
495        }
496
497        impl Client {
498            pub async fn new(connection_string: &str) -> Result<Self, Error> {
499                let (client, connection) = tokio_postgres::connect(connection_string, NoTls).await?;
500
501                tokio::spawn(async move {
502                    if let Err(e) = connection.await {
503                        eprintln!("connection error: {}", e);
504                    }
505                });
506
507                let client = Arc::new(client);
508
509                Ok(Self {
510                    client: client.clone(),
511                    #(#accessor_inits),*
512                })
513            }
514
515            pub fn client(&self) -> &PgClient {
516                &self.client
517            }
518        }
519    }
520}
521
522fn generate_model_with_query_builder(model: &Model) -> TokenStream {
523    let model_struct = generate_model_struct(model);
524    let query_builder_struct = generate_query_builder_struct(model);
525    let query_builder_impl = generate_query_builder_impl(model);
526    let model_impl = generate_model_impl(model);
527
528    quote! {
529        #model_struct
530        #query_builder_struct
531        #model_impl
532        #query_builder_impl
533    }
534}
535
536fn generate_model_struct(model: &Model) -> TokenStream {
537    let name = format_ident!("{}", model.name);
538    let fields = model.fields.iter().map(|field| {
539        let field_name = format_ident!("{}", field.name);
540        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
541        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
542
543        quote! {
544            pub #field_name: #field_type
545        }
546    });
547
548    quote! {
549        #[derive(Debug, Clone, Serialize, Deserialize)]
550        pub struct #name {
551            #(#fields),*
552        }
553    }
554}
555
556fn generate_model_impl(model: &Model) -> TokenStream {
557    let model_name = format_ident!("{}", model.name);
558    let builder_name = format_ident!("{}Query", model.name);
559
560    let pk_field = model.fields.iter()
561        .find(|f| f.modifiers.iter().any(|m| matches!(m, Modifier::PrimaryKey)));
562
563    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
564        let field_name = format_ident!("{}", field.name);
565        quote! { #field_name: row.get(#idx) }
566    });
567
568    let find_by_id_impl = if let Some(pk) = pk_field {
569        let is_nullable = pk.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
570        let pk_type = rust_type_from_schema(&pk.type_name, is_nullable);
571
572        let pk_name = to_snake_case(&pk.name);
573
574        quote! {
575            pub async fn find_by_id(client: &PgClient, id: #pk_type)
576                -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
577            {
578                let sql = format!("SELECT * FROM {} WHERE {} = $1", stringify!(#model_name).to_lowercase(), #pk_name);
579                let row_opt = client.query_opt(&sql, &[&id]).await?;
580                Ok(row_opt.map(|row| #model_name {
581                    #(#field_gets),*
582                }))
583            }
584        }
585    } else {
586        quote! {}
587    };
588
589    quote! {
590        impl #model_name {
591            pub fn query() -> #builder_name {
592                #builder_name::new()
593            }
594            #find_by_id_impl
595        }
596    }
597}
598
599fn generate_query_builder_struct(model: &Model) -> TokenStream {
600    let builder_name = format_ident!("{}Query", model.name);
601
602    quote! {
603        pub struct #builder_name {
604            table: String,
605            where_fragments: Vec<(&'static str, usize)>,
606            args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
607            limit: Option<usize>,
608            offset: Option<usize>,
609            order_by: Vec<(String, String)>,
610        }
611
612        // Safety: All types that implement ToSql for common Rust types (i64, String, etc.) are Send
613        unsafe impl Send for #builder_name {}
614
615        impl Clone for #builder_name {
616            fn clone(&self) -> Self {
617                Self {
618                    table: self.table.clone(),
619                    where_fragments: self.where_fragments.clone(),
620                    args: Vec::new(),
621                    limit: self.limit,
622                    offset: self.offset,
623                    order_by: self.order_by.clone(),
624                }
625            }
626        }
627    }
628}
629
630fn generate_query_builder_impl(model: &Model) -> TokenStream {
631    let builder_name = format_ident!("{}Query", model.name);
632    let model_name = format_ident!("{}", model.name);
633    let table_name = model.name.to_lowercase();
634
635    let field_methods = model.fields.iter().map(|field| {
636        let method_name = format_ident!("where_{}", to_snake_case(&field.name));
637        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
638        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
639        let field_col = to_snake_case(&field.name);
640
641        quote! {
642            pub fn #method_name(mut self, value: #field_type) -> Self {
643                self.args.push(Box::new(value));
644                self.where_fragments.push((#field_col, self.args.len()));
645                self
646            }
647        }
648    });
649
650    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
651        let field_name = format_ident!("{}", field.name);
652        quote! { #field_name: row.get(#idx) }
653    });
654
655    quote! {
656        impl #builder_name {
657            pub fn new() -> Self {
658                Self {
659                    table: #table_name.to_string(),
660                    where_fragments: vec![],
661                    args: vec![],
662                    limit: None,
663                    offset: None,
664                    order_by: vec![],
665                }
666            }
667            #(#field_methods)*
668            pub fn limit(mut self, limit: usize) -> Self {
669                self.limit = Some(limit);
670                self
671            }
672            pub fn offset(mut self, offset: usize) -> Self {
673                self.offset = Some(offset);
674                self
675            }
676            pub fn order_by(mut self, column: &str, direction: &str) -> Self {
677                self.order_by.push((column.to_string(), direction.to_string()));
678                self
679            }
680
681            pub async fn select(&self, client: &PgClient)
682                -> Result<Vec<#model_name>, Box<dyn std::error::Error + Send + Sync>>
683            {
684                let (sql, params) = self.build_select();
685                let rows = client.query(&sql, &params[..]).await?;
686                let mut results = Vec::new();
687                for row in rows {
688                    results.push(#model_name { #(#field_gets),* });
689                }
690                Ok(results)
691            }
692
693            pub async fn first(&self, client: &PgClient)
694                -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
695            {
696                let mut query = #builder_name::new();
697                query.table = self.table.clone();
698                query.where_fragments = self.where_fragments.clone();
699                query.args = Vec::new();
700                query.limit = Some(1);
701                query.offset = self.offset;
702                query.order_by = self.order_by.clone();
703
704                let results = query.select(client).await?;
705                Ok(results.into_iter().next())
706            }
707
708            pub async fn count(&self, client: &PgClient)
709                -> Result<i64, Box<dyn std::error::Error + Send + Sync>>
710            {
711                let (sql, params) = self.build_count();
712                let row = client.query_one(&sql, &params[..]).await?;
713                Ok(row.get(0))
714            }
715
716            fn build_select(&self) -> (String, Vec<&(dyn tokio_postgres::types::ToSql + Sync)>) {
717                let mut sql = format!("SELECT * FROM {}", self.table);
718                let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
719
720                if !self.where_fragments.is_empty() {
721                    let conds: Vec<String> = self.where_fragments.iter()
722                        .enumerate()
723                        .map(|(i, &(col, _idx))| format!("{} = ${}", col, i + 1))
724                        .collect();
725                    sql.push_str(" WHERE ");
726                    sql.push_str(&conds.join(" AND "));
727                    for arg in &self.args {
728                        params.push(arg.as_ref());
729                    }
730                }
731                if !self.order_by.is_empty() {
732                    sql.push_str(" ORDER BY ");
733                    let order_clauses: Vec<String> = self.order_by.iter()
734                        .map(|(col, dir)| format!("{} {}", col, dir))
735                        .collect();
736                    sql.push_str(&order_clauses.join(", "));
737                }
738                if let Some(limit) = self.limit {
739                    sql.push_str(&format!(" LIMIT {}", limit));
740                }
741                if let Some(offset) = self.offset {
742                    sql.push_str(&format!(" OFFSET {}", offset));
743                }
744                (sql, params)
745            }
746            fn build_count(&self) -> (String, Vec<&(dyn tokio_postgres::types::ToSql + Sync)>) {
747                let mut sql = format!("SELECT COUNT(*) FROM {}", self.table);
748                let mut params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> = vec![];
749                if !self.where_fragments.is_empty() {
750                    let conds: Vec<String> = self.where_fragments.iter()
751                        .enumerate()
752                        .map(|(i, &(col, _idx))| format!("{} = ${}", col, i + 1))
753                        .collect();
754                    sql.push_str(" WHERE ");
755                    sql.push_str(&conds.join(" AND "));
756                    for arg in &self.args {
757                        params.push(arg.as_ref());
758                    }
759                }
760                (sql, params)
761            }
762        }
763    }
764}
765
766fn rust_type_from_schema(type_name: &str, nullable: bool) -> TokenStream {
767    let base_type = match type_name {
768        "BigInt" => quote! { i64 },
769        "Int" => quote! { i32 },
770        "String" => quote! { String },
771        "JsonB" => quote! { serde_json::Value },
772        "TimestamptZ" | "Timestamp" => quote! { DateTime<Utc> },
773        "Boolean" => quote! { bool },
774        "Float" => quote! { f64 },
775        "Serial" => quote! { i32 },
776        "Real" => quote! { f32 },
777        _ => quote! { String },
778    };
779
780    if nullable {
781        quote! { Option<#base_type> }
782    } else {
783        base_type
784    }
785}
786
787fn to_snake_case(s: &str) -> String {
788    let mut result = String::new();
789    for (i, ch) in s.chars().enumerate() {
790        if ch.is_uppercase() && i > 0 {
791            result.push('_');
792        }
793        result.push(ch.to_lowercase().next().unwrap_or(ch));
794    }
795    result
796}
797
798fn capitalize_first(s: &str) -> String {
799    let mut chars = s.chars();
800    match chars.next() {
801        None => String::new(),
802        Some(first) => first.to_uppercase().chain(chars).collect(),
803    }
804}