kaccy_db/
query_builder.rs

1//! Dynamic query builder with type-safe construction and SQL injection prevention.
2//!
3//! This module provides:
4//! - Type-safe query construction
5//! - Runtime filter composition
6//! - SQL injection prevention
7
8use std::fmt::Write;
9
10/// SQL operator for filters
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Operator {
13    /// Equal (=)
14    Eq,
15    /// Not equal (<>)
16    Ne,
17    /// Greater than (>)
18    Gt,
19    /// Greater than or equal (>=)
20    Gte,
21    /// Less than (<)
22    Lt,
23    /// Less than or equal (<=)
24    Lte,
25    /// LIKE
26    Like,
27    /// ILIKE (case-insensitive)
28    ILike,
29    /// IN
30    In,
31    /// NOT IN
32    NotIn,
33    /// IS NULL
34    IsNull,
35    /// IS NOT NULL
36    IsNotNull,
37}
38
39impl Operator {
40    /// Get the SQL representation
41    pub fn as_sql(&self) -> &str {
42        match self {
43            Operator::Eq => "=",
44            Operator::Ne => "<>",
45            Operator::Gt => ">",
46            Operator::Gte => ">=",
47            Operator::Lt => "<",
48            Operator::Lte => "<=",
49            Operator::Like => "LIKE",
50            Operator::ILike => "ILIKE",
51            Operator::In => "IN",
52            Operator::NotIn => "NOT IN",
53            Operator::IsNull => "IS NULL",
54            Operator::IsNotNull => "IS NOT NULL",
55        }
56    }
57}
58
59/// Filter condition
60#[derive(Debug, Clone)]
61pub struct Filter {
62    /// Column name
63    pub column: String,
64    /// Operator
65    pub operator: Operator,
66    /// Value (None for IS NULL/IS NOT NULL)
67    pub value: Option<FilterValue>,
68}
69
70/// Filter value type
71#[derive(Debug, Clone)]
72pub enum FilterValue {
73    /// String value
74    String(String),
75    /// Integer value
76    Int(i64),
77    /// Float value
78    Float(f64),
79    /// Boolean value
80    Bool(bool),
81    /// Array of strings
82    StringArray(Vec<String>),
83    /// Array of integers
84    IntArray(Vec<i64>),
85}
86
87impl Filter {
88    /// Create a new filter
89    pub fn new(column: String, operator: Operator, value: Option<FilterValue>) -> Self {
90        Self {
91            column,
92            operator,
93            value,
94        }
95    }
96
97    /// Create an equality filter
98    pub fn eq<T: Into<FilterValue>>(column: String, value: T) -> Self {
99        Self::new(column, Operator::Eq, Some(value.into()))
100    }
101
102    /// Create a not-equal filter
103    pub fn ne<T: Into<FilterValue>>(column: String, value: T) -> Self {
104        Self::new(column, Operator::Ne, Some(value.into()))
105    }
106
107    /// Create a greater-than filter
108    pub fn gt<T: Into<FilterValue>>(column: String, value: T) -> Self {
109        Self::new(column, Operator::Gt, Some(value.into()))
110    }
111
112    /// Create an IN filter
113    pub fn in_values(column: String, values: Vec<String>) -> Self {
114        Self::new(column, Operator::In, Some(FilterValue::StringArray(values)))
115    }
116
117    /// Create an IS NULL filter
118    pub fn is_null(column: String) -> Self {
119        Self::new(column, Operator::IsNull, None)
120    }
121
122    /// Convert to SQL with parameterized value
123    pub fn to_sql(&self, param_index: &mut usize) -> String {
124        let mut sql = format!("{} {}", self.column, self.operator.as_sql());
125
126        match (&self.operator, &self.value) {
127            (Operator::IsNull | Operator::IsNotNull, _) => {
128                // No value needed
129            }
130            (Operator::In | Operator::NotIn, Some(FilterValue::StringArray(values))) => {
131                let placeholders: Vec<String> = values
132                    .iter()
133                    .map(|_| {
134                        let placeholder = format!("${}", param_index);
135                        *param_index += 1;
136                        placeholder
137                    })
138                    .collect();
139                write!(sql, " ({})", placeholders.join(", ")).unwrap();
140            }
141            (Operator::In | Operator::NotIn, Some(FilterValue::IntArray(values))) => {
142                let placeholders: Vec<String> = values
143                    .iter()
144                    .map(|_| {
145                        let placeholder = format!("${}", param_index);
146                        *param_index += 1;
147                        placeholder
148                    })
149                    .collect();
150                write!(sql, " ({})", placeholders.join(", ")).unwrap();
151            }
152            (_, Some(_)) => {
153                write!(sql, " ${}", param_index).unwrap();
154                *param_index += 1;
155            }
156            _ => {}
157        }
158
159        sql
160    }
161}
162
163impl From<String> for FilterValue {
164    fn from(s: String) -> Self {
165        FilterValue::String(s)
166    }
167}
168
169impl From<&str> for FilterValue {
170    fn from(s: &str) -> Self {
171        FilterValue::String(s.to_string())
172    }
173}
174
175impl From<i64> for FilterValue {
176    fn from(i: i64) -> Self {
177        FilterValue::Int(i)
178    }
179}
180
181impl From<i32> for FilterValue {
182    fn from(i: i32) -> Self {
183        FilterValue::Int(i as i64)
184    }
185}
186
187impl From<bool> for FilterValue {
188    fn from(b: bool) -> Self {
189        FilterValue::Bool(b)
190    }
191}
192
193impl From<f64> for FilterValue {
194    fn from(f: f64) -> Self {
195        FilterValue::Float(f)
196    }
197}
198
199/// Logical operator for combining filters
200#[derive(Debug, Clone, Copy, PartialEq, Eq)]
201pub enum LogicalOp {
202    /// AND
203    And,
204    /// OR
205    Or,
206}
207
208impl LogicalOp {
209    /// Get the SQL representation
210    pub fn as_sql(&self) -> &str {
211        match self {
212            LogicalOp::And => "AND",
213            LogicalOp::Or => "OR",
214        }
215    }
216}
217
218/// Order direction
219#[derive(Debug, Clone, Copy, PartialEq, Eq)]
220pub enum OrderDirection {
221    /// Ascending
222    Asc,
223    /// Descending
224    Desc,
225}
226
227impl OrderDirection {
228    /// Get the SQL representation
229    pub fn as_sql(&self) -> &str {
230        match self {
231            OrderDirection::Asc => "ASC",
232            OrderDirection::Desc => "DESC",
233        }
234    }
235}
236
237/// Order by clause
238#[derive(Debug, Clone)]
239pub struct OrderBy {
240    /// Column name
241    pub column: String,
242    /// Direction
243    pub direction: OrderDirection,
244}
245
246impl OrderBy {
247    /// Create a new order by clause
248    pub fn new(column: String, direction: OrderDirection) -> Self {
249        Self { column, direction }
250    }
251
252    /// Create ascending order
253    pub fn asc(column: String) -> Self {
254        Self::new(column, OrderDirection::Asc)
255    }
256
257    /// Create descending order
258    pub fn desc(column: String) -> Self {
259        Self::new(column, OrderDirection::Desc)
260    }
261
262    /// Convert to SQL
263    pub fn to_sql(&self) -> String {
264        format!("{} {}", self.column, self.direction.as_sql())
265    }
266}
267
268/// Dynamic query builder
269#[derive(Debug, Clone)]
270pub struct QueryBuilder {
271    table: String,
272    columns: Vec<String>,
273    filters: Vec<Filter>,
274    logical_op: LogicalOp,
275    order_by: Vec<OrderBy>,
276    limit: Option<i64>,
277    offset: Option<i64>,
278}
279
280impl QueryBuilder {
281    /// Create a new query builder for a table
282    pub fn new(table: String) -> Self {
283        Self {
284            table,
285            columns: vec!["*".to_string()],
286            filters: Vec::new(),
287            logical_op: LogicalOp::And,
288            order_by: Vec::new(),
289            limit: None,
290            offset: None,
291        }
292    }
293
294    /// Select specific columns
295    pub fn select(mut self, columns: Vec<String>) -> Self {
296        self.columns = columns;
297        self
298    }
299
300    /// Add a filter
301    pub fn filter(mut self, filter: Filter) -> Self {
302        self.filters.push(filter);
303        self
304    }
305
306    /// Add multiple filters
307    pub fn filters(mut self, filters: Vec<Filter>) -> Self {
308        self.filters.extend(filters);
309        self
310    }
311
312    /// Set the logical operator for combining filters (AND/OR)
313    pub fn logical_op(mut self, op: LogicalOp) -> Self {
314        self.logical_op = op;
315        self
316    }
317
318    /// Add an order by clause
319    pub fn order_by(mut self, order: OrderBy) -> Self {
320        self.order_by.push(order);
321        self
322    }
323
324    /// Set limit
325    pub fn limit(mut self, limit: i64) -> Self {
326        self.limit = Some(limit);
327        self
328    }
329
330    /// Set offset
331    pub fn offset(mut self, offset: i64) -> Self {
332        self.offset = Some(offset);
333        self
334    }
335
336    /// Build the SQL query
337    pub fn build(&self) -> String {
338        let mut sql = format!("SELECT {} FROM {}", self.columns.join(", "), self.table);
339
340        if !self.filters.is_empty() {
341            sql.push_str(" WHERE ");
342            let mut param_index = 1;
343
344            for (i, filter) in self.filters.iter().enumerate() {
345                if i > 0 {
346                    write!(sql, " {} ", self.logical_op.as_sql()).unwrap();
347                }
348                write!(sql, "{}", filter.to_sql(&mut param_index)).unwrap();
349            }
350        }
351
352        if !self.order_by.is_empty() {
353            sql.push_str(" ORDER BY ");
354            let order_clauses: Vec<String> = self.order_by.iter().map(|o| o.to_sql()).collect();
355            sql.push_str(&order_clauses.join(", "));
356        }
357
358        if let Some(limit) = self.limit {
359            write!(sql, " LIMIT {}", limit).unwrap();
360        }
361
362        if let Some(offset) = self.offset {
363            write!(sql, " OFFSET {}", offset).unwrap();
364        }
365
366        sql
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn test_operator_as_sql() {
376        assert_eq!(Operator::Eq.as_sql(), "=");
377        assert_eq!(Operator::Ne.as_sql(), "<>");
378        assert_eq!(Operator::Gt.as_sql(), ">");
379        assert_eq!(Operator::Like.as_sql(), "LIKE");
380        assert_eq!(Operator::In.as_sql(), "IN");
381        assert_eq!(Operator::IsNull.as_sql(), "IS NULL");
382    }
383
384    #[test]
385    fn test_filter_eq() {
386        let filter = Filter::eq("name".to_string(), "John");
387        assert_eq!(filter.column, "name");
388        assert_eq!(filter.operator, Operator::Eq);
389    }
390
391    #[test]
392    fn test_filter_is_null() {
393        let filter = Filter::is_null("deleted_at".to_string());
394        assert_eq!(filter.column, "deleted_at");
395        assert_eq!(filter.operator, Operator::IsNull);
396        assert!(filter.value.is_none());
397    }
398
399    #[test]
400    fn test_filter_to_sql() {
401        let mut param_index = 1;
402        let filter = Filter::eq("age".to_string(), 25);
403        let sql = filter.to_sql(&mut param_index);
404        assert_eq!(sql, "age = $1");
405        assert_eq!(param_index, 2);
406    }
407
408    #[test]
409    fn test_filter_to_sql_is_null() {
410        let mut param_index = 1;
411        let filter = Filter::is_null("deleted_at".to_string());
412        let sql = filter.to_sql(&mut param_index);
413        assert_eq!(sql, "deleted_at IS NULL");
414        assert_eq!(param_index, 1); // No parameter added
415    }
416
417    #[test]
418    fn test_order_by() {
419        let order = OrderBy::asc("created_at".to_string());
420        assert_eq!(order.to_sql(), "created_at ASC");
421
422        let order = OrderBy::desc("updated_at".to_string());
423        assert_eq!(order.to_sql(), "updated_at DESC");
424    }
425
426    #[test]
427    fn test_query_builder_simple() {
428        let query = QueryBuilder::new("users".to_string()).build();
429        assert_eq!(query, "SELECT * FROM users");
430    }
431
432    #[test]
433    fn test_query_builder_with_filter() {
434        let query = QueryBuilder::new("users".to_string())
435            .filter(Filter::eq("email".to_string(), "test@example.com"))
436            .build();
437
438        assert!(query.contains("SELECT * FROM users"));
439        assert!(query.contains("WHERE email = $1"));
440    }
441
442    #[test]
443    fn test_query_builder_with_multiple_filters() {
444        let query = QueryBuilder::new("users".to_string())
445            .filter(Filter::eq("active".to_string(), true))
446            .filter(Filter::gt("age".to_string(), 18))
447            .build();
448
449        assert!(query.contains("WHERE active = $1 AND age > $2"));
450    }
451
452    #[test]
453    fn test_query_builder_with_order() {
454        let query = QueryBuilder::new("users".to_string())
455            .order_by(OrderBy::desc("created_at".to_string()))
456            .build();
457
458        assert!(query.contains("ORDER BY created_at DESC"));
459    }
460
461    #[test]
462    fn test_query_builder_with_limit_offset() {
463        let query = QueryBuilder::new("users".to_string())
464            .limit(10)
465            .offset(20)
466            .build();
467
468        assert!(query.contains("LIMIT 10"));
469        assert!(query.contains("OFFSET 20"));
470    }
471
472    #[test]
473    fn test_query_builder_full() {
474        let query = QueryBuilder::new("users".to_string())
475            .select(vec![
476                "id".to_string(),
477                "name".to_string(),
478                "email".to_string(),
479            ])
480            .filter(Filter::eq("active".to_string(), true))
481            .filter(Filter::gt("age".to_string(), 18))
482            .order_by(OrderBy::desc("created_at".to_string()))
483            .limit(10)
484            .offset(0)
485            .build();
486
487        assert!(query.contains("SELECT id, name, email FROM users"));
488        assert!(query.contains("WHERE active = $1 AND age > $2"));
489        assert!(query.contains("ORDER BY created_at DESC"));
490        assert!(query.contains("LIMIT 10"));
491        assert!(query.contains("OFFSET 0"));
492    }
493}