byteorm_lib/rustgen/
query.rs1use proc_macro2::TokenStream;
2use quote::{format_ident, quote};
3use crate::{Model, Modifier};
4use crate::rustgen::{rust_type_from_schema, to_snake_case};
5
6pub fn generate_query_builder_struct(model: &Model) -> TokenStream {
7 let model_name = format_ident!("{}", model.name);
8 let builder_name = format_ident!("{}Query", model.name);
9 let table_name = model.name.to_lowercase();
10
11 let where_methods = model.fields.iter().map(|field| {
12 let method_name = format_ident!("where_{}", to_snake_case(&field.name));
13 let is_nullable = field.modifiers.iter().any(|m| matches!(m, Modifier::Nullable));
14 let field_type = rust_type_from_schema(&field.type_name, is_nullable);
15 let field_col = to_snake_case(&field.name);
16
17 quote! {
18 pub fn #method_name(mut self, value: #field_type) -> Self {
19 self.args.push(Box::new(value));
20 self.where_fragments.push((#field_col, self.args.len()));
21 self
22 }
23 }
24 });
25
26 let order_by_methods = model.fields.iter().map(|field| {
27 let asc_method = format_ident!("order_by_{}_asc", to_snake_case(&field.name));
28 let desc_method = format_ident!("order_by_{}_desc", to_snake_case(&field.name));
29 let field_col = to_snake_case(&field.name);
30
31 quote! {
32 pub fn #asc_method(mut self) -> Self {
33 self.order_by.push((#field_col.to_string(), "ASC".to_string()));
34 self
35 }
36 pub fn #desc_method(mut self) -> Self {
37 self.order_by.push((#field_col.to_string(), "DESC".to_string()));
38 self
39 }
40 }
41 });
42
43 let field_gets = model.fields.iter().enumerate().map(|(idx, field)| {
44 let field_name = format_ident!("{}", field.name);
45 quote! { #field_name: row.get(#idx) }
46 });
47
48 quote! {
49 pub struct #builder_name {
50 table: String,
51 where_fragments: Vec<(&'static str, usize)>,
52 args: Vec<Box<dyn tokio_postgres::types::ToSql + Sync>>,
53 limit: Option<usize>,
54 offset: Option<usize>,
55 order_by: Vec<(String, String)>,
56 }
57
58 unsafe impl Send for #builder_name {}
59
60 impl Clone for #builder_name {
61 fn clone(&self) -> Self {
62 Self {
63 table: self.table.clone(),
64 where_fragments: self.where_fragments.clone(),
65 args: Vec::new(),
66 limit: self.limit,
67 offset: self.offset,
68 order_by: self.order_by.clone(),
69 }
70 }
71 }
72
73 impl #builder_name {
74 pub fn new() -> Self {
75 Self {
76 table: #table_name.to_string(),
77 where_fragments: vec![],
78 args: vec![],
79 limit: None,
80 offset: None,
81 order_by: vec![],
82 }
83 }
84
85 #(#where_methods)*
86 #(#order_by_methods)*
87
88 pub fn limit(mut self, limit: usize) -> Self {
89 self.limit = Some(limit);
90 self
91 }
92
93 pub fn offset(mut self, offset: usize) -> Self {
94 self.offset = Some(offset);
95 self
96 }
97
98 pub async fn select(self, client: &PgClient)
99 -> Result<Vec<#model_name>, Box<dyn std::error::Error + Send + Sync>>
100 {
101 let mut sql = format!("SELECT * FROM {}", self.table);
102
103 let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
104 self.args.iter().map(|b| b.as_ref()).collect();
105
106 if !self.where_fragments.is_empty() {
107 let where_clauses: Vec<String> = self.where_fragments.iter()
108 .map(|&(col, idx)| format!("{} = ${}", col, idx))
109 .collect();
110 sql.push_str(" WHERE ");
111 sql.push_str(&where_clauses.join(" AND "));
112 }
113
114 if !self.order_by.is_empty() {
115 let order_clauses: Vec<String> = self.order_by.iter()
116 .map(|(col, dir)| format!("{} {}", col, dir))
117 .collect();
118 sql.push_str(" ORDER BY ");
119 sql.push_str(&order_clauses.join(", "));
120 }
121
122 if let Some(limit) = self.limit {
123 sql.push_str(&format!(" LIMIT {}", limit));
124 }
125
126 if let Some(offset) = self.offset {
127 sql.push_str(&format!(" OFFSET {}", offset));
128 }
129
130 let rows = client.query(&sql, ¶ms[..]).await?;
131 Ok(rows.into_iter().map(|row| #model_name {
132 #(#field_gets),*
133 }).collect())
134 }
135
136 pub async fn first(self, client: &PgClient)
137 -> Result<Option<#model_name>, Box<dyn std::error::Error + Send + Sync>>
138 {
139 let result = self.limit(1).select(client).await?;
140 Ok(result.into_iter().next())
141 }
142
143 pub async fn count(self, client: &PgClient)
144 -> Result<i64, Box<dyn std::error::Error + Send + Sync>>
145 {
146 let mut sql = format!("SELECT COUNT(*) FROM {}", self.table);
147
148 let params: Vec<&(dyn tokio_postgres::types::ToSql + Sync)> =
149 self.args.iter().map(|b| b.as_ref()).collect();
150
151 if !self.where_fragments.is_empty() {
152 let where_clauses: Vec<String> = self.where_fragments.iter()
153 .map(|&(col, idx)| format!("{} = ${}", col, idx))
154 .collect();
155 sql.push_str(" WHERE ");
156 sql.push_str(&where_clauses.join(" AND "));
157 }
158
159 let row = client.query_one(&sql, ¶ms[..]).await?;
160 Ok(row.get(0))
161 }
162 }
163 }
164}