mik_sql/builder/
types.rs

1//! Core types for the SQL query builder.
2
3use crate::validate::assert_valid_sql_identifier;
4
5/// SQL comparison operators.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7#[non_exhaustive]
8pub enum Operator {
9    /// Equal: `=`
10    Eq,
11    /// Not equal: `!=`
12    Ne,
13    /// Greater than: `>`
14    Gt,
15    /// Greater than or equal: `>=`
16    Gte,
17    /// Less than: `<`
18    Lt,
19    /// Less than or equal: `<=`
20    Lte,
21    /// In array: `IN` or `= ANY`
22    In,
23    /// Not in array: `NOT IN` or `!= ALL`
24    NotIn,
25    /// Regex match: `~` (Postgres) or `LIKE` (`SQLite`)
26    Regex,
27    /// Pattern match: `LIKE`
28    Like,
29    /// Case-insensitive pattern match: `ILIKE` (Postgres) or `LIKE` (`SQLite`)
30    ILike,
31    /// String starts with: `LIKE $1 || '%'`
32    StartsWith,
33    /// String ends with: `LIKE '%' || $1`
34    EndsWith,
35    /// String contains: `LIKE '%' || $1 || '%'`
36    Contains,
37    /// Between two values: `BETWEEN $1 AND $2`
38    Between,
39}
40
41/// Logical operators for compound filters.
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43#[non_exhaustive]
44pub enum LogicalOp {
45    /// All conditions must match: `AND`
46    And,
47    /// At least one condition must match: `OR`
48    Or,
49    /// Negate the condition: `NOT`
50    Not,
51}
52
53/// A filter expression that can be simple or compound.
54#[derive(Debug, Clone, PartialEq)]
55#[non_exhaustive]
56pub enum FilterExpr {
57    /// A simple field comparison.
58    Simple(Filter),
59    /// A compound filter with logical operator.
60    Compound(CompoundFilter),
61}
62
63/// A compound filter combining multiple expressions with a logical operator.
64#[derive(Debug, Clone, PartialEq)]
65#[non_exhaustive]
66pub struct CompoundFilter {
67    /// The logical operator (AND, OR, NOT) to combine filters.
68    pub op: LogicalOp,
69    /// The filter expressions to combine.
70    pub filters: Vec<FilterExpr>,
71}
72
73impl CompoundFilter {
74    /// Create an AND compound filter.
75    #[must_use]
76    pub const fn and(filters: Vec<FilterExpr>) -> Self {
77        Self {
78            op: LogicalOp::And,
79            filters,
80        }
81    }
82
83    /// Create an OR compound filter.
84    #[must_use]
85    pub const fn or(filters: Vec<FilterExpr>) -> Self {
86        Self {
87            op: LogicalOp::Or,
88            filters,
89        }
90    }
91
92    /// Create a NOT compound filter (wraps a single filter).
93    #[must_use]
94    pub fn not(filter: FilterExpr) -> Self {
95        Self {
96            op: LogicalOp::Not,
97            filters: vec![filter],
98        }
99    }
100}
101
102/// Aggregation functions.
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104#[non_exhaustive]
105pub enum AggregateFunc {
106    /// Count rows: `COUNT(*)`
107    Count,
108    /// Count distinct values: `COUNT(DISTINCT field)`
109    CountDistinct,
110    /// Sum values: `SUM(field)`
111    Sum,
112    /// Average value: `AVG(field)`
113    Avg,
114    /// Minimum value: `MIN(field)`
115    Min,
116    /// Maximum value: `MAX(field)`
117    Max,
118}
119
120/// An aggregation expression.
121#[derive(Debug, Clone, PartialEq, Eq)]
122#[non_exhaustive]
123pub struct Aggregate {
124    /// The aggregation function to apply.
125    pub func: AggregateFunc,
126    /// Field to aggregate, None for COUNT(*).
127    pub field: Option<String>,
128    /// Optional alias for the result.
129    pub alias: Option<String>,
130}
131
132impl Aggregate {
133    /// Create a COUNT(*) aggregation.
134    #[must_use]
135    pub fn count() -> Self {
136        Self {
137            func: AggregateFunc::Count,
138            field: None,
139            alias: Some("count".to_string()),
140        }
141    }
142
143    /// Create a COUNT(field) aggregation.
144    ///
145    /// # Panics
146    ///
147    /// Panics if the field name is not a valid SQL identifier.
148    pub fn count_field(field: impl Into<String>) -> Self {
149        let field = field.into();
150        assert_valid_sql_identifier(&field, "aggregate field");
151        Self {
152            func: AggregateFunc::Count,
153            field: Some(field),
154            alias: None,
155        }
156    }
157
158    /// Create a COUNT(DISTINCT field) aggregation.
159    ///
160    /// # Panics
161    ///
162    /// Panics if the field name is not a valid SQL identifier.
163    pub fn count_distinct(field: impl Into<String>) -> Self {
164        let field = field.into();
165        assert_valid_sql_identifier(&field, "aggregate field");
166        Self {
167            func: AggregateFunc::CountDistinct,
168            field: Some(field),
169            alias: None,
170        }
171    }
172
173    /// Create a SUM(field) aggregation.
174    ///
175    /// # Panics
176    ///
177    /// Panics if the field name is not a valid SQL identifier.
178    pub fn sum(field: impl Into<String>) -> Self {
179        let field = field.into();
180        assert_valid_sql_identifier(&field, "aggregate field");
181        Self {
182            func: AggregateFunc::Sum,
183            field: Some(field),
184            alias: None,
185        }
186    }
187
188    /// Create an AVG(field) aggregation.
189    ///
190    /// # Panics
191    ///
192    /// Panics if the field name is not a valid SQL identifier.
193    pub fn avg(field: impl Into<String>) -> Self {
194        let field = field.into();
195        assert_valid_sql_identifier(&field, "aggregate field");
196        Self {
197            func: AggregateFunc::Avg,
198            field: Some(field),
199            alias: None,
200        }
201    }
202
203    /// Create a MIN(field) aggregation.
204    ///
205    /// # Panics
206    ///
207    /// Panics if the field name is not a valid SQL identifier.
208    pub fn min(field: impl Into<String>) -> Self {
209        let field = field.into();
210        assert_valid_sql_identifier(&field, "aggregate field");
211        Self {
212            func: AggregateFunc::Min,
213            field: Some(field),
214            alias: None,
215        }
216    }
217
218    /// Create a MAX(field) aggregation.
219    ///
220    /// # Panics
221    ///
222    /// Panics if the field name is not a valid SQL identifier.
223    pub fn max(field: impl Into<String>) -> Self {
224        let field = field.into();
225        assert_valid_sql_identifier(&field, "aggregate field");
226        Self {
227            func: AggregateFunc::Max,
228            field: Some(field),
229            alias: None,
230        }
231    }
232
233    /// Set an alias for the aggregation result.
234    ///
235    /// # Panics
236    ///
237    /// Panics if the alias is not a valid SQL identifier.
238    pub fn as_alias(mut self, alias: impl Into<String>) -> Self {
239        let alias = alias.into();
240        assert_valid_sql_identifier(&alias, "aggregate alias");
241        self.alias = Some(alias);
242        self
243    }
244
245    /// Generate SQL for this aggregation.
246    #[must_use]
247    pub fn to_sql(&self) -> String {
248        let expr = match (&self.func, &self.field) {
249            (AggregateFunc::Count, None) => "COUNT(*)".to_string(),
250            (AggregateFunc::Count, Some(f)) => format!("COUNT({f})"),
251            (AggregateFunc::CountDistinct, Some(f)) => format!("COUNT(DISTINCT {f})"),
252            (AggregateFunc::Sum, Some(f)) => format!("SUM({f})"),
253            (AggregateFunc::Avg, Some(f)) => format!("AVG({f})"),
254            (AggregateFunc::Min, Some(f)) => format!("MIN({f})"),
255            (AggregateFunc::Max, Some(f)) => format!("MAX({f})"),
256            _ => "COUNT(*)".to_string(),
257        };
258
259        match &self.alias {
260            Some(a) => format!("{expr} AS {a}"),
261            None => expr,
262        }
263    }
264}
265
266/// SQL parameter values.
267#[derive(Debug, Clone, PartialEq)]
268#[non_exhaustive]
269pub enum Value {
270    /// SQL NULL value.
271    Null,
272    /// Boolean value (true/false).
273    Bool(bool),
274    /// 64-bit signed integer.
275    Int(i64),
276    /// 64-bit floating point number.
277    Float(f64),
278    /// UTF-8 string value.
279    String(String),
280    /// Array of values (for IN, BETWEEN operators).
281    Array(Vec<Self>),
282}
283
284/// Sort direction.
285#[derive(Debug, Clone, Copy, PartialEq, Eq)]
286#[non_exhaustive]
287pub enum SortDir {
288    /// Ascending order (smallest to largest, A to Z).
289    Asc,
290    /// Descending order (largest to smallest, Z to A).
291    Desc,
292}
293
294/// Sort field with direction.
295#[derive(Debug, Clone, PartialEq, Eq)]
296#[non_exhaustive]
297pub struct SortField {
298    /// The column name to sort by.
299    pub field: String,
300    /// The sort direction (ascending or descending).
301    pub dir: SortDir,
302}
303
304impl SortField {
305    /// Create a new sort field.
306    pub fn new(field: impl Into<String>, dir: SortDir) -> Self {
307        Self {
308            field: field.into(),
309            dir,
310        }
311    }
312
313    /// Parse a sort string like "name,-created_at" into sort fields.
314    ///
315    /// Fields prefixed with `-` are sorted descending.
316    /// Validates against allowed fields list.
317    ///
318    /// # Security Note
319    ///
320    /// If `allowed` is empty, ALL fields are allowed. For user input, always
321    /// provide an explicit whitelist to prevent sorting by sensitive columns.
322    pub fn parse_sort_string(sort: &str, allowed: &[&str]) -> Result<Vec<Self>, String> {
323        let mut result = Vec::new();
324
325        for part in sort.split(',') {
326            let part = part.trim();
327            if part.is_empty() {
328                continue;
329            }
330
331            let (field, dir) = part
332                .strip_prefix('-')
333                .map_or((part, SortDir::Asc), |stripped| (stripped, SortDir::Desc));
334
335            // Validate against whitelist (empty = allow all, consistent with FilterValidator)
336            if !allowed.is_empty() && !allowed.contains(&field) {
337                return Err(format!(
338                    "Sort field '{field}' not allowed. Allowed: {allowed:?}"
339                ));
340            }
341
342            result.push(Self::new(field, dir));
343        }
344
345        Ok(result)
346    }
347}
348
349/// Filter condition.
350#[derive(Debug, Clone, PartialEq)]
351#[non_exhaustive]
352pub struct Filter {
353    /// The column name to filter on.
354    pub field: String,
355    /// The comparison operator to use.
356    pub op: Operator,
357    /// The value to compare against.
358    pub value: Value,
359}
360
361impl Filter {
362    /// Create a new filter condition.
363    #[must_use]
364    pub fn new(field: impl Into<String>, op: Operator, value: Value) -> Self {
365        Self {
366            field: field.into(),
367            op,
368            value,
369        }
370    }
371}
372
373/// Query result with SQL string and parameters.
374#[derive(Debug, Clone, PartialEq)]
375#[non_exhaustive]
376#[must_use = "QueryResult must be used to execute the query"]
377pub struct QueryResult {
378    /// The generated SQL query string.
379    pub sql: String,
380    /// The parameter values to bind to the query.
381    pub params: Vec<Value>,
382}
383
384impl QueryResult {
385    /// Create a new query result.
386    #[must_use]
387    pub fn new(sql: impl Into<String>, params: Vec<Value>) -> Self {
388        Self {
389            sql: sql.into(),
390            params,
391        }
392    }
393}
394
395/// A computed field expression with alias.
396#[derive(Debug, Clone, PartialEq, Eq)]
397#[non_exhaustive]
398pub struct ComputedField {
399    /// The alias for the computed field.
400    pub alias: String,
401    /// The SQL expression (e.g., "`first_name` || ' ' || `last_name`").
402    pub expression: String,
403}
404
405impl ComputedField {
406    /// Create a new computed field.
407    pub fn new(alias: impl Into<String>, expression: impl Into<String>) -> Self {
408        Self {
409            alias: alias.into(),
410            expression: expression.into(),
411        }
412    }
413
414    /// Generate the SQL for this computed field.
415    #[must_use]
416    pub fn to_sql(&self) -> String {
417        format!("({}) AS {}", self.expression, self.alias)
418    }
419}
420
421/// Cursor pagination direction.
422#[derive(Debug, Clone, Copy, PartialEq, Eq)]
423#[non_exhaustive]
424pub enum CursorDirection {
425    /// Paginate forward (after the cursor).
426    After,
427    /// Paginate backward (before the cursor).
428    Before,
429}
430
431/// Helper function to create a simple filter expression.
432///
433/// # Panics
434///
435/// Panics if the field name is not a valid SQL identifier.
436pub fn simple(field: impl Into<String>, op: Operator, value: Value) -> FilterExpr {
437    let field = field.into();
438    assert_valid_sql_identifier(&field, "filter field");
439    FilterExpr::Simple(Filter { field, op, value })
440}
441
442/// Helper function to create an AND compound filter.
443#[must_use]
444pub const fn and(filters: Vec<FilterExpr>) -> FilterExpr {
445    FilterExpr::Compound(CompoundFilter::and(filters))
446}
447
448/// Helper function to create an OR compound filter.
449#[must_use]
450pub const fn or(filters: Vec<FilterExpr>) -> FilterExpr {
451    FilterExpr::Compound(CompoundFilter::or(filters))
452}
453
454/// Helper function to create a NOT filter.
455#[must_use]
456pub fn not(filter: FilterExpr) -> FilterExpr {
457    FilterExpr::Compound(CompoundFilter::not(filter))
458}