bottle_orm/
query_builder.rs

1use crate::{
2    database::{Database, Drivers},
3    model::{ColumnInfo, Model},
4    Error,
5};
6use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
7use heck::ToSnakeCase;
8use sqlx::{
9    any::{AnyArguments, AnyRow},
10    Any, Arguments, Encode, FromRow, Type,
11};
12use std::marker::PhantomData;
13
14/// A type alias for filter closures that support manual SQL construction and argument binding.
15///
16/// It receives:
17/// 1. `&mut String`: The SQL query buffer being built.
18/// 2. `&mut AnyArguments`: The argument container for binding values.
19/// 3. `&Drivers`: The current database driver (to decide between `$n` or `?`).
20/// 4. `&mut usize`: The argument counter (for PostgreSQL `$n` placeholders).
21pub type FilterFn = Box<dyn Fn(&mut String, &mut AnyArguments<'_>, &Drivers, &mut usize) + Send + Sync>;
22
23/// A fluent Query Builder for constructing SQL queries.
24///
25/// Handles SELECT, INSERT, filtering (WHERE), pagination (LIMIT/OFFSET), and ordering.
26pub struct QueryBuilder<'a, T> {
27    pub(crate) db: &'a Database,
28    pub(crate) table_name: &'static str,
29    pub(crate) columns_info: Vec<ColumnInfo>,
30    pub(crate) columns: Vec<String>,
31    pub(crate) select_columns: Vec<String>,
32    pub(crate) where_clauses: Vec<FilterFn>,
33    pub(crate) order_clauses: Vec<String>,
34    pub(crate) limit: Option<usize>,
35    pub(crate) offset: Option<usize>,
36    pub(crate) _marker: PhantomData<T>,
37}
38
39impl<'a, T: Model + Send + Sync + Unpin> QueryBuilder<'a, T> {
40    /// Creates a new QueryBuilder instance.
41    ///
42    /// Usually called via `db.model::<T>()`.
43    pub fn new(
44        db: &'a Database,
45        table_name: &'static str,
46        columns_info: Vec<ColumnInfo>,
47        columns: Vec<String>,
48    ) -> Self {
49        Self {
50            db,
51            table_name,
52            columns_info,
53            columns,
54            select_columns: Vec::new(),
55            where_clauses: Vec::new(),
56            order_clauses: Vec::new(),
57            limit: None,
58            offset: None,
59            _marker: PhantomData,
60        }
61    }
62
63    /// Adds a WHERE clause to the query.
64    ///
65    /// # Arguments
66    ///
67    /// * `col` - The column name.
68    /// * `op` - The operator (e.g., "=", ">", "LIKE").
69    /// * `value` - The value to compare against.
70    ///
71    /// # Example
72    ///
73    /// ```rust,ignore
74    /// db.model::<User>().filter("age", ">", 18).scan().await?;
75    /// ```
76    pub fn filter<V>(mut self, col: &'static str, op: &'static str, value: V) -> Self
77    where
78        V: 'static + for<'q> Encode<'q, Any> + Type<Any> + Send + Sync + Clone,
79    {
80        let clause: FilterFn = Box::new(move |query, args, driver, arg_counter| {
81            query.push_str(" AND \"");
82            query.push_str(col);
83            query.push_str("\" ");
84            query.push_str(op);
85            query.push(' ');
86
87            match driver {
88                Drivers::Postgres => {
89                    query.push_str(&format!("${}", arg_counter));
90                    *arg_counter += 1;
91                }
92                _ => query.push('?'),
93            }
94
95            args.add(value.clone());
96        });
97        self.where_clauses.push(clause);
98        self
99    }
100
101    pub fn order(mut self, order: &str) -> Self {
102        self.order_clauses.push(order.to_string());
103        self
104    }
105    
106    pub fn preload(mut self) -> Self {
107    	self
108    }
109
110    pub fn join(mut self) -> Self {
111        self
112    }
113
114    pub fn pagination(
115        mut self,
116        max_value: usize,
117        default: usize,
118        page: usize,
119        value: isize,
120    ) -> Result<Self, Error> {
121        if value < 0 {
122            return Err(Error::InvalidArgument("value cannot be negative".into()));
123        }
124        let mut f_value = value as usize;
125
126        if f_value > max_value {
127            f_value = default;
128        }
129        self = self.offset(f_value * page);
130        self = self.limit(f_value);
131        Ok(self)
132    }
133
134    /// Selects specific columns to return.
135    ///
136    /// By default, all columns (`*`) are selected.
137    pub fn select(mut self, columns: &str) -> Self {
138        self.select_columns.push(columns.to_string());
139        self
140    }
141
142    /// Sets the query offset (pagination).
143    pub fn offset(mut self, offset: usize) -> Self {
144        self.offset = Some(offset);
145        self
146    }
147
148    /// Sets the maximum number of records to return.
149    pub fn limit(mut self, limit: usize) -> Self {
150        self.limit = Some(limit);
151        self
152    }
153
154    /// Inserts a new record into the database based on the model instance.
155    ///
156    /// Uses manual string parsing to bind values (temporary solution until fuller serialization support).
157    pub async fn insert(&self, model: &T) -> Result<&Self, sqlx::Error> {
158        let data_map = model.to_map();
159
160        if data_map.is_empty() {
161            return Ok(&self);
162        }
163
164        let table_name = self.table_name.to_snake_case();
165        let columns_info = T::columns();
166
167        let mut target_columns = Vec::new();
168        let mut bindings: Vec<(String, &str)> = Vec::new();
169
170        for (col_name, value) in data_map {
171            let col_name_clean = col_name.strip_prefix("r#").unwrap_or(&col_name).to_snake_case();
172            target_columns.push(format!("\"{}\"", col_name_clean));
173
174            let sql_type = columns_info.iter().find(|c| c.name == col_name).map(|c| c.sql_type).unwrap_or("TEXT");
175
176            bindings.push((value, sql_type));
177        }
178
179        let placeholders: Vec<String> = bindings
180            .iter()
181            .enumerate()
182            .map(|(i, (_, sql_type))| match self.db.driver {
183                Drivers::Postgres => {
184                    let idx = i + 1;
185                    match *sql_type {
186                        "TIMESTAMPTZ" | "DateTime" => format!("${}::TIMESTAMPTZ", idx),
187                        "TIMESTAMP" | "NaiveDateTime" => format!("${}::TIMESTAMP", idx),
188                        "DATE" | "NaiveDate" => format!("${}::DATE", idx),
189                        "TIME" | "NaiveTime" => format!("${}::TIME", idx),
190                        _ => format!("${}", idx),
191                    }
192                }
193                _ => "?".to_string(),
194            })
195            .collect();
196
197        let query_str = format!(
198            "INSERT INTO \"{}\" ({}) VALUES ({})",
199            table_name,
200            target_columns.join(", "),
201            placeholders.join(", ")
202        );
203
204        // println!("{}", query_str); // Debug if needed
205        let mut query = sqlx::query::<sqlx::Any>(&query_str);
206
207        // Manual binding based on string parsing
208        for (val_str, sql_type) in bindings {
209            match sql_type {
210                "INTEGER" | "INT" | "SERIAL" | "serial" | "int4" => {
211                    let val: i32 = val_str.parse().unwrap_or_default();
212                    query = query.bind(val);
213                }
214                "BIGINT" | "INT8" | "int8" => {
215                    let val: i64 = val_str.parse().unwrap_or_default();
216                    query = query.bind(val);
217                }
218                "BOOLEAN" | "BOOL" | "bool" => {
219                    let val: bool = val_str.parse().unwrap_or(false);
220                    query = query.bind(val);
221                }
222                "DOUBLE PRECISION" | "FLOAT" | "float8" => {
223                    let val: f64 = val_str.parse().unwrap_or_default();
224                    query = query.bind(val);
225                }
226                "TIMESTAMP" | "NaiveDateTime" => {
227                    if let Ok(val) = val_str.parse::<NaiveDateTime>() {
228                        query = query.bind(val.to_string());
229                    } else {
230                        query = query.bind(val_str);
231                    }
232                }
233                "TIMESTAMPTZ" | "DateTime" => {
234                    if let Ok(val) = val_str.parse::<DateTime<Utc>>() {
235                        query = query.bind(val.to_string());
236                    } else {
237                        query = query.bind(val_str);
238                    }
239                }
240                "DATE" | "NaiveDate" => {
241                    if let Ok(val) = val_str.parse::<NaiveDate>() {
242                        query = query.bind(val.to_string());
243                    } else {
244                        query = query.bind(val_str);
245                    }
246                }
247                "TIME" | "NaiveTime" => {
248                    if let Ok(val) = val_str.parse::<NaiveTime>() {
249                        query = query.bind(val.to_string());
250                    } else {
251                        query = query.bind(val_str);
252                    }
253                }
254                _ => query = query.bind(val_str),
255            }
256        }
257
258        query.execute(&self.db.pool).await?;
259        Ok(&self)
260    }
261
262    /// Returns the generated SQL string (for debugging purposes, without arguments).
263    pub fn to_sql(&self) -> String {
264        let mut query = String::from("SELECT ");
265        if self.select_columns.is_empty() {
266            query.push('*');
267        } else {
268            query.push_str(&self.select_columns.join(", "));
269        }
270        query.push_str(" FROM \"");
271        query.push_str(&self.table_name.to_snake_case());
272        query.push_str("\" WHERE 1=1");
273
274        let mut dummy_args = AnyArguments::default();
275        let mut dummy_counter = 1;
276
277        for clause in &self.where_clauses {
278            clause(&mut query, &mut dummy_args, &self.db.driver, &mut dummy_counter);
279        }
280
281        if !self.order_clauses.is_empty() {
282            query.push_str(&format!(" ORDER BY {}", &self.order_clauses.join(", ")));
283        }
284
285        query
286    }
287
288    /// Executes the query and returns a list of results.
289    ///
290    /// # Example
291    ///
292    /// ```rust,ignore
293    /// let users: Vec<User> = db.model::<User>().scan().await?;
294    /// ```
295    pub async fn scan<R>(self) -> Result<Vec<R>, sqlx::Error>
296    where
297        R: for<'r> FromRow<'r, AnyRow> + Send + Unpin,
298    {
299        let mut query = String::from("SELECT ");
300        if self.select_columns.is_empty() {
301            query.push('*');
302        } else {
303            query.push_str(&self.select_columns.join(", "));
304        }
305        query.push_str(" FROM \"");
306        query.push_str(&self.table_name.to_snake_case());
307        query.push_str("\" WHERE 1=1");
308
309        let mut args = AnyArguments::default();
310        let mut arg_counter = 1;
311
312        for clause in &self.where_clauses {
313            clause(&mut query, &mut args, &self.db.driver, &mut arg_counter);
314        }
315
316        if let Some(limit) = self.limit {
317            query.push_str(" LIMIT ");
318            match self.db.driver {
319                Drivers::Postgres => {
320                    query.push_str(&format!("${}", arg_counter));
321                    arg_counter += 1;
322                }
323                _ => query.push('?'),
324            }
325            args.add(limit as i64);
326        }
327
328        if let Some(offset) = self.offset {
329            query.push_str(" OFFSET ");
330            match self.db.driver {
331                Drivers::Postgres => {
332                    query.push_str(&format!("${}", arg_counter));
333                    // arg_counter += 1; // Ignored as it is last usage
334                }
335                _ => query.push('?'),
336            }
337            args.add(offset as i64);
338        }
339
340        sqlx::query_as_with::<_, R, _>(&query, args).fetch_all(&self.db.pool).await
341    }
342
343    /// Executes the query and returns only the first result.
344    ///
345    /// Automatically adds `LIMIT 1` and orders by Primary Key if available.
346    ///
347    /// # Example
348    ///
349    /// ```rust,ignore
350    /// let user: User = db.model::<User>().filter("id", "=", 1).first().await?;
351    /// ```
352    pub async fn first<R>(self) -> Result<R, sqlx::Error>
353    where
354        R: for<'r> FromRow<'r, AnyRow> + Send + Unpin,
355    {
356        let mut query = String::from("SELECT ");
357        if self.select_columns.is_empty() {
358            query.push('*');
359        } else {
360            query.push_str(&self.select_columns.join(", "));
361        }
362        query.push_str(" FROM \"");
363        query.push_str(&self.table_name.to_snake_case());
364        query.push_str("\" WHERE 1=1");
365
366        let mut args = AnyArguments::default();
367        let mut arg_counter = 1;
368
369        for clause in &self.where_clauses {
370            clause(&mut query, &mut args, &self.db.driver, &mut arg_counter);
371        }
372
373        let pk_column = T::columns()
374            .iter()
375            .find(|c| c.is_primary_key)
376            .map(|c| c.name.strip_prefix("r#").unwrap_or(c.name).to_snake_case());
377
378        if let Some(pk) = pk_column {
379            query.push_str(" ORDER BY \"");
380            query.push_str(&pk);
381            query.push_str("\" ASC");
382        }
383
384        query.push_str(" LIMIT 1");
385
386        sqlx::query_as_with::<_, R, _>(&query, args).fetch_one(&self.db.pool).await
387    }
388}