byteorm_lib/rustgen/
query.rs

1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use crate::{Model, Modifier};
4use crate::rustgen::{rust_type_from_schema, to_snake_case};
5
6pub fn generate_query_builder_struct(model: &Model) -> TokenStream {
7    let model_name = format_ident!("{}", model.name);
8    let builder_name = format_ident!("{}Query", model.name);
9    let table_name = model.name.to_lowercase();
10
11    let where_methods = model.fields.iter().map(|field| {
12        let method_name = format_ident!("where_{}", to_snake_case(&field.name));
13        let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
14        let field_type = rust_type_from_schema(&field.type_name, is_nullable);
15        let field_col = to_snake_case(&field.name);
16
17        quote! {
18            pub fn #method_name(mut self, value: #field_type) -> Self {
19                self.args.push(Box::new(value));
20                self.where_fragments.push((#field_col, self.args.len()));
21                self
22            }
23        }
24    });
25
26    let order_by_methods = model.fields.iter().map(|field| {
27        let asc_method = format_ident!("order_by_{}_asc", to_snake_case(&field.name));
28        let desc_method = format_ident!("order_by_{}_desc", to_snake_case(&field.name));
29        let field_col = to_snake_case(&field.name);
30
31        quote! {
32            pub fn #asc_method(mut self) -> Self {
33                self.order_by.push((#field_col.to_string(), "ASC".to_string()));
34                self
35            }
36            pub fn #desc_method(mut self) -> Self {
37                self.order_by.push((#field_col.to_string(), "DESC".to_string()));
38                self
39            }
40        }
41    });
42
43    let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
44        let field_name = format_ident!("{}", field.name);
45        quote! { #field_name: row.get(#idx) }
46    });
47
48    quote! {
49        pub struct #builder_name {
50            table: String,
51            where_fragments: Vec<(&'static str, usize)>,
52            args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
53            limit: Option<usize>,
54            offset: Option<usize>,
55            order_by: Vec<(String, String)>,
56        }
57
58        unsafe impl Send for #builder_name {}
59
60        impl Clone for #builder_name {
61            fn clone(&self) -> Self {
62                Self {
63                    table: self.table.clone(),
64                    where_fragments: self.where_fragments.clone(),
65                    args: Vec::new(),
66                    limit: self.limit,
67                    offset: self.offset,
68                    order_by: self.order_by.clone(),
69                }
70            }
71        }
72
73        impl #builder_name {
74            pub fn new() -> Self {
75                Self {
76                    table: #table_name.to_string(),
77                    where_fragments: vec![],
78                    args: vec![],
79                    limit: None,
80                    offset: None,
81                    order_by: vec![],
82                }
83            }
84
85            #(#where_methods)*
86            #(#order_by_methods)*
87
88            pub fn limit(mut self, limit: usize) -> Self {
89                self.limit = Some(limit);
90                self
91            }
92
93            pub fn offset(mut self, offset: usize) -> Self {
94                self.offset = Some(offset);
95                self
96            }
97
98            pub async fn select(self, client: &PgClient)
99                -> Result<Vec<#model_name>, Box<dyn std::error::Error + Send + Sync>>
100            {
101                let mut sql = format!("SELECT * FROM {}", self.table);
102
103                let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
104                    self.args.iter().map(|b| b.as_ref()).collect();
105
106                if !self.where_fragments.is_empty() {
107                    let where_clauses: Vec<String> = self.where_fragments.iter()
108                        .map(|&(col, idx)| format!("{} = ${}", col, idx))
109                        .collect();
110                    sql.push_str(" WHERE ");
111                    sql.push_str(&where_clauses.join(" AND "));
112                }
113
114                if !self.order_by.is_empty() {
115                    let order_clauses: Vec<String> = self.order_by.iter()
116                        .map(|(col, dir)| format!("{} {}", col, dir))
117                        .collect();
118                    sql.push_str(" ORDER BY ");
119                    sql.push_str(&order_clauses.join(", "));
120                }
121
122                if let Some(limit) = self.limit {
123                    sql.push_str(&format!(" LIMIT {}", limit));
124                }
125
126                if let Some(offset) = self.offset {
127                    sql.push_str(&format!(" OFFSET {}", offset));
128                }
129
130                let rows = client.query(&sql, &params[..]).await?;
131                Ok(rows.into_iter().map(|row| #model_name {
132                    #(#field_gets),*
133                }).collect())
134            }
135
136            pub async fn first(self, client: &PgClient)
137                -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
138            {
139                let result = self.limit(1).select(client).await?;
140                Ok(result.into_iter().next())
141            }
142
143            pub async fn count(self, client: &PgClient)
144                -> Result<i64, Box<dyn std::error::Error + Send + Sync>>
145            {
146                let mut sql = format!("SELECT COUNT(*) FROM {}", self.table);
147
148                let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
149                    self.args.iter().map(|b| b.as_ref()).collect();
150
151                if !self.where_fragments.is_empty() {
152                    let where_clauses: Vec<String> = self.where_fragments.iter()
153                        .map(|&(col, idx)| format!("{} = ${}", col, idx))
154                        .collect();
155                    sql.push_str(" WHERE ");
156                    sql.push_str(&where_clauses.join(" AND "));
157                }
158
159                let row = client.query_one(&sql, &params[..]).await?;
160                Ok(row.get(0))
161            }
162        }
163    }
164}