1use crate::{
2 database::{Database, Drivers},
3 model::{ColumnInfo, Model},
4 Error,
5};
6use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
7use heck::ToSnakeCase;
8use sqlx::{
9 any::{AnyArguments, AnyRow},
10 Any, Arguments, Encode, FromRow, Type,
11};
12use std::marker::PhantomData;
13
14pub type FilterFn = Box<dyn Fn(&mut String, &mut AnyArguments<'_>, &Drivers, &mut usize) + Send + Sync>;
22
23pub struct QueryBuilder<'a, T> {
27 pub(crate) db: &'a Database,
28 pub(crate) table_name: &'static str,
29 pub(crate) columns_info: Vec<ColumnInfo>,
30 pub(crate) columns: Vec<String>,
31 pub(crate) select_columns: Vec<String>,
32 pub(crate) where_clauses: Vec<FilterFn>,
33 pub(crate) order_clauses: Vec<String>,
34 pub(crate) limit: Option<usize>,
35 pub(crate) offset: Option<usize>,
36 pub(crate) _marker: PhantomData<T>,
37}
38
39impl<'a, T: Model + Send + Sync + Unpin> QueryBuilder<'a, T> {
40 pub fn new(
44 db: &'a Database,
45 table_name: &'static str,
46 columns_info: Vec<ColumnInfo>,
47 columns: Vec<String>,
48 ) -> Self {
49 Self {
50 db,
51 table_name,
52 columns_info,
53 columns,
54 select_columns: Vec::new(),
55 where_clauses: Vec::new(),
56 order_clauses: Vec::new(),
57 limit: None,
58 offset: None,
59 _marker: PhantomData,
60 }
61 }
62
63 pub fn filter<V>(mut self, col: &'static str, op: &'static str, value: V) -> Self
77 where
78 V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
79 {
80 let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
81 query.push_str(" AND \"");
82 query.push_str(col);
83 query.push_str("\" ");
84 query.push_str(op);
85 query.push(' ');
86
87 match driver {
88 Drivers::Postgres => {
89 query.push_str(&format!("${}", arg_counter));
90 *arg_counter += 1;
91 }
92 _ => query.push('?'),
93 }
94
95 args.add(value.clone());
96 });
97 self.where_clauses.push(clause);
98 self
99 }
100
101 pub fn order(mut self, order: &str) -> Self {
102 self.order_clauses.push(order.to_string());
103 self
104 }
105
106 pub fn preload(mut self) -> Self {
107 self
108 }
109
110 pub fn join(mut self) -> Self {
111 self
112 }
113
114 pub fn pagination(
115 mut self,
116 max_value: usize,
117 default: usize,
118 page: usize,
119 value: isize,
120 ) -> Result<Self, Error> {
121 if value < 0 {
122 return Err(Error::InvalidArgument("value cannot be negative".into()));
123 }
124 let mut f_value = value as usize;
125
126 if f_value > max_value {
127 f_value = default;
128 }
129 self = self.offset(f_value * page);
130 self = self.limit(f_value);
131 Ok(self)
132 }
133
134 pub fn select(mut self, columns: &str) -> Self {
138 self.select_columns.push(columns.to_string());
139 self
140 }
141
142 pub fn offset(mut self, offset: usize) -> Self {
144 self.offset = Some(offset);
145 self
146 }
147
148 pub fn limit(mut self, limit: usize) -> Self {
150 self.limit = Some(limit);
151 self
152 }
153
154 pub async fn insert(&self, model: &T) -> Result<&Self, sqlx::Error> {
158 let data_map = model.to_map();
159
160 if data_map.is_empty() {
161 return Ok(&self);
162 }
163
164 let table_name = self.table_name.to_snake_case();
165 let columns_info = T::columns();
166
167 let mut target_columns = Vec::new();
168 let mut bindings: Vec<(String, &str)> = Vec::new();
169
170 for (col_name, value) in data_map {
171 let col_name_clean = col_name.strip_prefix("r#").unwrap_or(&col_name).to_snake_case();
172 target_columns.push(format!("\"{}\"", col_name_clean));
173
174 let sql_type = columns_info.iter().find(|c| c.name == col_name).map(|c| c.sql_type).unwrap_or("TEXT");
175
176 bindings.push((value, sql_type));
177 }
178
179 let placeholders: Vec<String> = bindings
180 .iter()
181 .enumerate()
182 .map(|(i, (_, sql_type))| match self.db.driver {
183 Drivers::Postgres => {
184 let idx = i + 1;
185 match *sql_type {
186 "TIMESTAMPTZ" | "DateTime" => format!("${}::TIMESTAMPTZ", idx),
187 "TIMESTAMP" | "NaiveDateTime" => format!("${}::TIMESTAMP", idx),
188 "DATE" | "NaiveDate" => format!("${}::DATE", idx),
189 "TIME" | "NaiveTime" => format!("${}::TIME", idx),
190 _ => format!("${}", idx),
191 }
192 }
193 _ => "?".to_string(),
194 })
195 .collect();
196
197 let query_str = format!(
198 "INSERT INTO \"{}\" ({}) VALUES ({})",
199 table_name,
200 target_columns.join(", "),
201 placeholders.join(", ")
202 );
203
204 let mut query = sqlx::query::<sqlx::Any>(&query_str);
206
207 for (val_str, sql_type) in bindings {
209 match sql_type {
210 "INTEGER" | "INT" | "SERIAL" | "serial" | "int4" => {
211 let val: i32 = val_str.parse().unwrap_or_default();
212 query = query.bind(val);
213 }
214 "BIGINT" | "INT8" | "int8" => {
215 let val: i64 = val_str.parse().unwrap_or_default();
216 query = query.bind(val);
217 }
218 "BOOLEAN" | "BOOL" | "bool" => {
219 let val: bool = val_str.parse().unwrap_or(false);
220 query = query.bind(val);
221 }
222 "DOUBLE PRECISION" | "FLOAT" | "float8" => {
223 let val: f64 = val_str.parse().unwrap_or_default();
224 query = query.bind(val);
225 }
226 "TIMESTAMP" | "NaiveDateTime" => {
227 if let Ok(val) = val_str.parse::<NaiveDateTime>() {
228 query = query.bind(val.to_string());
229 } else {
230 query = query.bind(val_str);
231 }
232 }
233 "TIMESTAMPTZ" | "DateTime" => {
234 if let Ok(val) = val_str.parse::<DateTime<Utc>>() {
235 query = query.bind(val.to_string());
236 } else {
237 query = query.bind(val_str);
238 }
239 }
240 "DATE" | "NaiveDate" => {
241 if let Ok(val) = val_str.parse::<NaiveDate>() {
242 query = query.bind(val.to_string());
243 } else {
244 query = query.bind(val_str);
245 }
246 }
247 "TIME" | "NaiveTime" => {
248 if let Ok(val) = val_str.parse::<NaiveTime>() {
249 query = query.bind(val.to_string());
250 } else {
251 query = query.bind(val_str);
252 }
253 }
254 _ => query = query.bind(val_str),
255 }
256 }
257
258 query.execute(&self.db.pool).await?;
259 Ok(&self)
260 }
261
262 pub fn to_sql(&self) -> String {
264 let mut query = String::from("SELECT ");
265 if self.select_columns.is_empty() {
266 query.push('*');
267 } else {
268 query.push_str(&self.select_columns.join(", "));
269 }
270 query.push_str(" FROM \"");
271 query.push_str(&self.table_name.to_snake_case());
272 query.push_str("\" WHERE 1=1");
273
274 let mut dummy_args = AnyArguments::default();
275 let mut dummy_counter = 1;
276
277 for clause in &self.where_clauses {
278 clause(&mut query, &mut dummy_args, &self.db.driver, &mut dummy_counter);
279 }
280
281 if !self.order_clauses.is_empty() {
282 query.push_str(&format!(" ORDER BY {}", &self.order_clauses.join(", ")));
283 }
284
285 query
286 }
287
288 pub async fn scan<R>(self) -> Result<Vec<R>, sqlx::Error>
296 where
297 R: for<'r> FromRow<'r, AnyRow> + Send + Unpin,
298 {
299 let mut query = String::from("SELECT ");
300 if self.select_columns.is_empty() {
301 query.push('*');
302 } else {
303 query.push_str(&self.select_columns.join(", "));
304 }
305 query.push_str(" FROM \"");
306 query.push_str(&self.table_name.to_snake_case());
307 query.push_str("\" WHERE 1=1");
308
309 let mut args = AnyArguments::default();
310 let mut arg_counter = 1;
311
312 for clause in &self.where_clauses {
313 clause(&mut query, &mut args, &self.db.driver, &mut arg_counter);
314 }
315
316 if let Some(limit) = self.limit {
317 query.push_str(" LIMIT ");
318 match self.db.driver {
319 Drivers::Postgres => {
320 query.push_str(&format!("${}", arg_counter));
321 arg_counter += 1;
322 }
323 _ => query.push('?'),
324 }
325 args.add(limit as i64);
326 }
327
328 if let Some(offset) = self.offset {
329 query.push_str(" OFFSET ");
330 match self.db.driver {
331 Drivers::Postgres => {
332 query.push_str(&format!("${}", arg_counter));
333 }
335 _ => query.push('?'),
336 }
337 args.add(offset as i64);
338 }
339
340 sqlx::query_as_with::<_, R, _>(&query, args).fetch_all(&self.db.pool).await
341 }
342
343 pub async fn first<R>(self) -> Result<R, sqlx::Error>
353 where
354 R: for<'r> FromRow<'r, AnyRow> + Send + Unpin,
355 {
356 let mut query = String::from("SELECT ");
357 if self.select_columns.is_empty() {
358 query.push('*');
359 } else {
360 query.push_str(&self.select_columns.join(", "));
361 }
362 query.push_str(" FROM \"");
363 query.push_str(&self.table_name.to_snake_case());
364 query.push_str("\" WHERE 1=1");
365
366 let mut args = AnyArguments::default();
367 let mut arg_counter = 1;
368
369 for clause in &self.where_clauses {
370 clause(&mut query, &mut args, &self.db.driver, &mut arg_counter);
371 }
372
373 let pk_column = T::columns()
374 .iter()
375 .find(|c| c.is_primary_key)
376 .map(|c| c.name.strip_prefix("r#").unwrap_or(c.name).to_snake_case());
377
378 if let Some(pk) = pk_column {
379 query.push_str(" ORDER BY \"");
380 query.push_str(&pk);
381 query.push_str("\" ASC");
382 }
383
384 query.push_str(" LIMIT 1");
385
386 sqlx::query_as_with::<_, R, _>(&query, args).fetch_one(&self.db.pool).await
387 }
388}