mik_sql/builder/
types.rs

1//! Core types for the SQL query builder.
2
3use crate::validate::assert_valid_sql_identifier;
4
5/// SQL comparison operators.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[non_exhaustive]
8pub enum Operator {
9    /// Equal: `=`
10    Eq,
11    /// Not equal: `!=`
12    Ne,
13    /// Greater than: `>`
14    Gt,
15    /// Greater than or equal: `>=`
16    Gte,
17    /// Less than: `<`
18    Lt,
19    /// Less than or equal: `<=`
20    Lte,
21    /// In array: `IN` or `= ANY`
22    In,
23    /// Not in array: `NOT IN` or `!= ALL`
24    NotIn,
25    /// Regex match: `~` (Postgres) or `LIKE` (`SQLite`)
26    Regex,
27    /// Pattern match: `LIKE`
28    Like,
29    /// Case-insensitive pattern match: `ILIKE` (Postgres) or `LIKE` (`SQLite`)
30    ILike,
31    /// String starts with: `LIKE $1 || '%'`
32    StartsWith,
33    /// String ends with: `LIKE '%' || $1`
34    EndsWith,
35    /// String contains: `LIKE '%' || $1 || '%'`
36    Contains,
37    /// Between two values: `BETWEEN $1 AND $2`
38    Between,
39}
40
41/// Logical operators for compound filters.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[non_exhaustive]
44pub enum LogicalOp {
45    /// All conditions must match: `AND`
46    And,
47    /// At least one condition must match: `OR`
48    Or,
49    /// Negate the condition: `NOT`
50    Not,
51}
52
53/// A filter expression that can be simple or compound.
54#[derive(Debug, Clone, PartialEq)]
55#[non_exhaustive]
56pub enum FilterExpr {
57    /// A simple field comparison.
58    Simple(Filter),
59    /// A compound filter with logical operator.
60    Compound(CompoundFilter),
61}
62
63/// A compound filter combining multiple expressions with a logical operator.
64#[derive(Debug, Clone, PartialEq)]
65#[non_exhaustive]
66pub struct CompoundFilter {
67    /// The logical operator (AND, OR, NOT) to combine filters.
68    pub op: LogicalOp,
69    /// The filter expressions to combine.
70    pub filters: Vec<FilterExpr>,
71}
72
73impl CompoundFilter {
74    /// Create an AND compound filter.
75    #[must_use]
76    pub const fn and(filters: Vec<FilterExpr>) -> Self {
77        Self {
78            op: LogicalOp::And,
79            filters,
80        }
81    }
82
83    /// Create an OR compound filter.
84    #[must_use]
85    pub const fn or(filters: Vec<FilterExpr>) -> Self {
86        Self {
87            op: LogicalOp::Or,
88            filters,
89        }
90    }
91
92    /// Create a NOT compound filter (wraps a single filter).
93    #[must_use]
94    pub fn not(filter: FilterExpr) -> Self {
95        Self {
96            op: LogicalOp::Not,
97            filters: vec![filter],
98        }
99    }
100}
101
102impl FilterExpr {
103    /// Collect all simple filters from this expression into a vector.
104    ///
105    /// This flattens compound filters, extracting all individual `Filter` items.
106    /// Used by the `merge:` option in `sql_read!` to iterate over user filters.
107    #[must_use]
108    pub fn collect_filters(&self) -> Vec<Filter> {
109        let mut result = Vec::new();
110        self.collect_filters_into(&mut result);
111        result
112    }
113
114    fn collect_filters_into(&self, result: &mut Vec<Filter>) {
115        match self {
116            Self::Simple(f) => result.push(f.clone()),
117            Self::Compound(c) => {
118                for expr in &c.filters {
119                    expr.collect_filters_into(result);
120                }
121            },
122        }
123    }
124
125    /// Returns an iterator over all simple filters in this expression.
126    ///
127    /// This flattens compound filters, yielding all individual `Filter` items.
128    #[must_use]
129    pub fn iter(&self) -> FilterExprIter {
130        self.into_iter()
131    }
132}
133
134/// Iterator over filters in a `FilterExpr`.
135#[derive(Debug)]
136pub struct FilterExprIter {
137    filters: std::vec::IntoIter<Filter>,
138}
139
140impl Iterator for FilterExprIter {
141    type Item = Filter;
142
143    fn next(&mut self) -> Option<Self::Item> {
144        self.filters.next()
145    }
146
147    fn size_hint(&self) -> (usize, Option<usize>) {
148        self.filters.size_hint()
149    }
150}
151
152impl IntoIterator for &FilterExpr {
153    type Item = Filter;
154    type IntoIter = FilterExprIter;
155
156    fn into_iter(self) -> Self::IntoIter {
157        FilterExprIter {
158            filters: self.collect_filters().into_iter(),
159        }
160    }
161}
162
163/// Aggregation functions.
164#[derive(Debug, Clone, Copy, PartialEq, Eq)]
165#[non_exhaustive]
166pub enum AggregateFunc {
167    /// Count rows: `COUNT(*)`
168    Count,
169    /// Count distinct values: `COUNT(DISTINCT field)`
170    CountDistinct,
171    /// Sum values: `SUM(field)`
172    Sum,
173    /// Average value: `AVG(field)`
174    Avg,
175    /// Minimum value: `MIN(field)`
176    Min,
177    /// Maximum value: `MAX(field)`
178    Max,
179}
180
181/// An aggregation expression.
182#[derive(Debug, Clone, PartialEq, Eq)]
183#[non_exhaustive]
184pub struct Aggregate {
185    /// The aggregation function to apply.
186    pub func: AggregateFunc,
187    /// Field to aggregate, None for COUNT(*).
188    pub field: Option<String>,
189    /// Optional alias for the result.
190    pub alias: Option<String>,
191}
192
193impl Aggregate {
194    /// Create a COUNT(*) aggregation.
195    #[must_use]
196    pub fn count() -> Self {
197        Self {
198            func: AggregateFunc::Count,
199            field: None,
200            alias: Some("count".to_string()),
201        }
202    }
203
204    /// Create a COUNT(field) aggregation.
205    ///
206    /// # Panics
207    ///
208    /// Panics if the field name is not a valid SQL identifier.
209    pub fn count_field(field: impl Into<String>) -> Self {
210        let field = field.into();
211        assert_valid_sql_identifier(&field, "aggregate field");
212        Self {
213            func: AggregateFunc::Count,
214            field: Some(field),
215            alias: None,
216        }
217    }
218
219    /// Create a COUNT(DISTINCT field) aggregation.
220    ///
221    /// # Panics
222    ///
223    /// Panics if the field name is not a valid SQL identifier.
224    pub fn count_distinct(field: impl Into<String>) -> Self {
225        let field = field.into();
226        assert_valid_sql_identifier(&field, "aggregate field");
227        Self {
228            func: AggregateFunc::CountDistinct,
229            field: Some(field),
230            alias: None,
231        }
232    }
233
234    /// Create a SUM(field) aggregation.
235    ///
236    /// # Panics
237    ///
238    /// Panics if the field name is not a valid SQL identifier.
239    pub fn sum(field: impl Into<String>) -> Self {
240        let field = field.into();
241        assert_valid_sql_identifier(&field, "aggregate field");
242        Self {
243            func: AggregateFunc::Sum,
244            field: Some(field),
245            alias: None,
246        }
247    }
248
249    /// Create an AVG(field) aggregation.
250    ///
251    /// # Panics
252    ///
253    /// Panics if the field name is not a valid SQL identifier.
254    pub fn avg(field: impl Into<String>) -> Self {
255        let field = field.into();
256        assert_valid_sql_identifier(&field, "aggregate field");
257        Self {
258            func: AggregateFunc::Avg,
259            field: Some(field),
260            alias: None,
261        }
262    }
263
264    /// Create a MIN(field) aggregation.
265    ///
266    /// # Panics
267    ///
268    /// Panics if the field name is not a valid SQL identifier.
269    pub fn min(field: impl Into<String>) -> Self {
270        let field = field.into();
271        assert_valid_sql_identifier(&field, "aggregate field");
272        Self {
273            func: AggregateFunc::Min,
274            field: Some(field),
275            alias: None,
276        }
277    }
278
279    /// Create a MAX(field) aggregation.
280    ///
281    /// # Panics
282    ///
283    /// Panics if the field name is not a valid SQL identifier.
284    pub fn max(field: impl Into<String>) -> Self {
285        let field = field.into();
286        assert_valid_sql_identifier(&field, "aggregate field");
287        Self {
288            func: AggregateFunc::Max,
289            field: Some(field),
290            alias: None,
291        }
292    }
293
294    /// Set an alias for the aggregation result.
295    ///
296    /// # Panics
297    ///
298    /// Panics if the alias is not a valid SQL identifier.
299    pub fn as_alias(mut self, alias: impl Into<String>) -> Self {
300        let alias = alias.into();
301        assert_valid_sql_identifier(&alias, "aggregate alias");
302        self.alias = Some(alias);
303        self
304    }
305
306    /// Generate SQL for this aggregation.
307    #[must_use]
308    pub fn to_sql(&self) -> String {
309        let expr = match (&self.func, &self.field) {
310            (AggregateFunc::Count, None) => "COUNT(*)".to_string(),
311            (AggregateFunc::Count, Some(f)) => format!("COUNT({f})"),
312            (AggregateFunc::CountDistinct, Some(f)) => format!("COUNT(DISTINCT {f})"),
313            (AggregateFunc::Sum, Some(f)) => format!("SUM({f})"),
314            (AggregateFunc::Avg, Some(f)) => format!("AVG({f})"),
315            (AggregateFunc::Min, Some(f)) => format!("MIN({f})"),
316            (AggregateFunc::Max, Some(f)) => format!("MAX({f})"),
317            _ => "COUNT(*)".to_string(),
318        };
319
320        match &self.alias {
321            Some(a) => format!("{expr} AS {a}"),
322            None => expr,
323        }
324    }
325}
326
327/// SQL parameter values.
328#[derive(Debug, Clone, PartialEq)]
329#[non_exhaustive]
330pub enum Value {
331    /// SQL NULL value.
332    Null,
333    /// Boolean value (true/false).
334    Bool(bool),
335    /// 64-bit signed integer.
336    Int(i64),
337    /// 64-bit floating point number.
338    Float(f64),
339    /// UTF-8 string value.
340    String(String),
341    /// Array of values (for IN, BETWEEN operators).
342    Array(Vec<Self>),
343}
344
345/// Sort direction.
346#[derive(Debug, Clone, Copy, PartialEq, Eq)]
347#[non_exhaustive]
348pub enum SortDir {
349    /// Ascending order (smallest to largest, A to Z).
350    Asc,
351    /// Descending order (largest to smallest, Z to A).
352    Desc,
353}
354
355/// Sort field with direction.
356#[derive(Debug, Clone, PartialEq, Eq)]
357#[non_exhaustive]
358pub struct SortField {
359    /// The column name to sort by.
360    pub field: String,
361    /// The sort direction (ascending or descending).
362    pub dir: SortDir,
363}
364
365impl SortField {
366    /// Create a new sort field.
367    pub fn new(field: impl Into<String>, dir: SortDir) -> Self {
368        Self {
369            field: field.into(),
370            dir,
371        }
372    }
373
374    /// Parse a sort string like "name,-created_at" into sort fields.
375    ///
376    /// Fields prefixed with `-` are sorted descending.
377    /// Validates against allowed fields list.
378    ///
379    /// # Security Note
380    ///
381    /// If `allowed` is empty, ALL fields are allowed. For user input, always
382    /// provide an explicit whitelist to prevent sorting by sensitive columns.
383    pub fn parse_sort_string(sort: &str, allowed: &[&str]) -> Result<Vec<Self>, String> {
384        let mut result = Vec::new();
385
386        for part in sort.split(',') {
387            let part = part.trim();
388            if part.is_empty() {
389                continue;
390            }
391
392            let (field, dir) = part
393                .strip_prefix('-')
394                .map_or((part, SortDir::Asc), |stripped| (stripped, SortDir::Desc));
395
396            // Validate against whitelist (empty = allow all, consistent with FilterValidator)
397            if !allowed.is_empty() && !allowed.contains(&field) {
398                return Err(format!(
399                    "Sort field '{field}' not allowed. Allowed: {allowed:?}"
400                ));
401            }
402
403            result.push(Self::new(field, dir));
404        }
405
406        Ok(result)
407    }
408}
409
410/// Filter condition.
411#[derive(Debug, Clone, PartialEq)]
412#[non_exhaustive]
413pub struct Filter {
414    /// The column name to filter on.
415    pub field: String,
416    /// The comparison operator to use.
417    pub op: Operator,
418    /// The value to compare against.
419    pub value: Value,
420}
421
422impl Filter {
423    /// Create a new filter condition.
424    #[must_use]
425    pub fn new(field: impl Into<String>, op: Operator, value: Value) -> Self {
426        Self {
427            field: field.into(),
428            op,
429            value,
430        }
431    }
432}
433
434/// Query result with SQL string and parameters.
435#[derive(Debug, Clone, PartialEq)]
436#[non_exhaustive]
437#[must_use = "QueryResult must be used to execute the query"]
438pub struct QueryResult {
439    /// The generated SQL query string.
440    pub sql: String,
441    /// The parameter values to bind to the query.
442    pub params: Vec<Value>,
443}
444
445impl QueryResult {
446    /// Create a new query result.
447    #[must_use]
448    pub fn new(sql: impl Into<String>, params: Vec<Value>) -> Self {
449        Self {
450            sql: sql.into(),
451            params,
452        }
453    }
454}
455
456/// A computed field expression with alias.
457#[derive(Debug, Clone, PartialEq, Eq)]
458#[non_exhaustive]
459pub struct ComputedField {
460    /// The alias for the computed field.
461    pub alias: String,
462    /// The SQL expression (e.g., "`first_name` || ' ' || `last_name`").
463    pub expression: String,
464}
465
466impl ComputedField {
467    /// Create a new computed field.
468    pub fn new(alias: impl Into<String>, expression: impl Into<String>) -> Self {
469        Self {
470            alias: alias.into(),
471            expression: expression.into(),
472        }
473    }
474
475    /// Generate the SQL for this computed field.
476    #[must_use]
477    pub fn to_sql(&self) -> String {
478        format!("({}) AS {}", self.expression, self.alias)
479    }
480}
481
482/// Cursor pagination direction.
483#[derive(Debug, Clone, Copy, PartialEq, Eq)]
484#[non_exhaustive]
485pub enum CursorDirection {
486    /// Paginate forward (after the cursor).
487    After,
488    /// Paginate backward (before the cursor).
489    Before,
490}
491
492/// Helper function to create a simple filter expression.
493///
494/// # Panics
495///
496/// Panics if the field name is not a valid SQL identifier.
497pub fn simple(field: impl Into<String>, op: Operator, value: Value) -> FilterExpr {
498    let field = field.into();
499    assert_valid_sql_identifier(&field, "filter field");
500    FilterExpr::Simple(Filter { field, op, value })
501}
502
503/// Helper function to create an AND compound filter.
504#[must_use]
505pub const fn and(filters: Vec<FilterExpr>) -> FilterExpr {
506    FilterExpr::Compound(CompoundFilter::and(filters))
507}
508
509/// Helper function to create an OR compound filter.
510#[must_use]
511pub const fn or(filters: Vec<FilterExpr>) -> FilterExpr {
512    FilterExpr::Compound(CompoundFilter::or(filters))
513}
514
515/// Helper function to create a NOT filter.
516#[must_use]
517pub fn not(filter: FilterExpr) -> FilterExpr {
518    FilterExpr::Compound(CompoundFilter::not(filter))
519}