mik_sql/builder/
select.rs

1//! SELECT query builder.
2
3use crate::dialect::Dialect;
4use crate::pagination::{Cursor, IntoCursor};
5use crate::validate::{assert_valid_sql_expression, assert_valid_sql_identifier};
6
7use super::filter::{build_condition_impl, build_filter_expr_impl};
8use super::types::{
9    Aggregate, CompoundFilter, ComputedField, CursorDirection, Filter, FilterExpr, Operator,
10    QueryResult, SortDir, SortField, Value,
11};
12
13/// SQL query builder with dialect support.
14#[derive(Debug)]
15#[must_use = "builder does nothing until .build() is called"]
16pub struct QueryBuilder<D: Dialect> {
17    dialect: D,
18    table: String,
19    fields: Vec<String>,
20    computed: Vec<ComputedField>,
21    aggregates: Vec<Aggregate>,
22    filters: Vec<Filter>,
23    filter_expr: Option<FilterExpr>,
24    group_by: Vec<String>,
25    having: Option<FilterExpr>,
26    sorts: Vec<SortField>,
27    limit: Option<u32>,
28    offset: Option<u32>,
29    cursor: Option<Cursor>,
30    cursor_direction: Option<CursorDirection>,
31}
32
33impl<D: Dialect> QueryBuilder<D> {
34    /// Create a new query builder for the given table.
35    ///
36    /// # Panics
37    ///
38    /// Panics if the table name is not a valid SQL identifier.
39    pub fn new(dialect: D, table: impl Into<String>) -> Self {
40        let table = table.into();
41        assert_valid_sql_identifier(&table, "table");
42        Self {
43            dialect,
44            table,
45            fields: Vec::new(),
46            computed: Vec::new(),
47            aggregates: Vec::new(),
48            filters: Vec::new(),
49            filter_expr: None,
50            group_by: Vec::new(),
51            having: None,
52            sorts: Vec::new(),
53            limit: None,
54            offset: None,
55            cursor: None,
56            cursor_direction: None,
57        }
58    }
59
60    /// Set the fields to SELECT.
61    ///
62    /// # Panics
63    ///
64    /// Panics if any field name is not a valid SQL identifier.
65    pub fn fields(mut self, fields: &[&str]) -> Self {
66        for field in fields {
67            assert_valid_sql_identifier(field, "field");
68        }
69        self.fields = fields.iter().map(|s| (*s).to_string()).collect();
70        self
71    }
72
73    /// Add a computed field to the SELECT clause.
74    ///
75    /// # Example
76    ///
77    /// ```
78    /// # use mik_sql::prelude::*;
79    /// let result = postgres("orders")
80    ///     .computed("line_total", "quantity * price")
81    ///     .build();
82    /// assert!(result.sql.contains("(quantity * price) AS line_total"));
83    /// ```
84    ///
85    /// # Panics
86    ///
87    /// Panics if alias is not a valid SQL identifier or expression contains
88    /// dangerous patterns (comments, semicolons, SQL keywords).
89    ///
90    /// # Security
91    ///
92    /// **WARNING**: Only use with trusted expressions from code, never with user input.
93    pub fn computed(mut self, alias: impl Into<String>, expression: impl Into<String>) -> Self {
94        let alias = alias.into();
95        let expression = expression.into();
96        assert_valid_sql_identifier(&alias, "computed field alias");
97        assert_valid_sql_expression(&expression, "computed field");
98        self.computed.push(ComputedField::new(alias, expression));
99        self
100    }
101
102    /// Add an aggregation to the SELECT clause.
103    pub fn aggregate(mut self, agg: Aggregate) -> Self {
104        self.aggregates.push(agg);
105        self
106    }
107
108    /// Add a COUNT(*) aggregation.
109    pub fn count(mut self) -> Self {
110        self.aggregates.push(Aggregate::count());
111        self
112    }
113
114    /// Add a SUM(field) aggregation.
115    pub fn sum(mut self, field: impl Into<String>) -> Self {
116        self.aggregates.push(Aggregate::sum(field));
117        self
118    }
119
120    /// Add an AVG(field) aggregation.
121    pub fn avg(mut self, field: impl Into<String>) -> Self {
122        self.aggregates.push(Aggregate::avg(field));
123        self
124    }
125
126    /// Add a MIN(field) aggregation.
127    pub fn min(mut self, field: impl Into<String>) -> Self {
128        self.aggregates.push(Aggregate::min(field));
129        self
130    }
131
132    /// Add a MAX(field) aggregation.
133    pub fn max(mut self, field: impl Into<String>) -> Self {
134        self.aggregates.push(Aggregate::max(field));
135        self
136    }
137
138    /// Add a filter condition.
139    ///
140    /// # Panics
141    ///
142    /// Panics if the field name is not a valid SQL identifier.
143    pub fn filter(mut self, field: impl Into<String>, op: Operator, value: Value) -> Self {
144        let field = field.into();
145        assert_valid_sql_identifier(&field, "filter field");
146        self.filters.push(Filter { field, op, value });
147        self
148    }
149
150    /// Set a compound filter expression (replaces simple filters for WHERE clause).
151    pub fn filter_expr(mut self, expr: FilterExpr) -> Self {
152        self.filter_expr = Some(expr);
153        self
154    }
155
156    /// Add an AND compound filter.
157    pub fn and(mut self, filters: Vec<FilterExpr>) -> Self {
158        self.filter_expr = Some(FilterExpr::Compound(CompoundFilter::and(filters)));
159        self
160    }
161
162    /// Add an OR compound filter.
163    pub fn or(mut self, filters: Vec<FilterExpr>) -> Self {
164        self.filter_expr = Some(FilterExpr::Compound(CompoundFilter::or(filters)));
165        self
166    }
167
168    /// Add GROUP BY fields.
169    ///
170    /// # Panics
171    ///
172    /// Panics if any field name is not a valid SQL identifier.
173    pub fn group_by(mut self, fields: &[&str]) -> Self {
174        for field in fields {
175            assert_valid_sql_identifier(field, "group by field");
176        }
177        self.group_by = fields.iter().map(|s| (*s).to_string()).collect();
178        self
179    }
180
181    /// Add a HAVING clause (for filtering aggregated results).
182    pub fn having(mut self, expr: FilterExpr) -> Self {
183        self.having = Some(expr);
184        self
185    }
186
187    /// Add a sort field.
188    ///
189    /// # Panics
190    ///
191    /// Panics if the field name is not a valid SQL identifier.
192    pub fn sort(mut self, field: impl Into<String>, dir: SortDir) -> Self {
193        let field = field.into();
194        assert_valid_sql_identifier(&field, "sort field");
195        self.sorts.push(SortField::new(field, dir));
196        self
197    }
198
199    /// Add multiple sort fields.
200    pub fn sorts(mut self, sorts: &[SortField]) -> Self {
201        self.sorts.extend(sorts.iter().cloned());
202        self
203    }
204
205    /// Set pagination with page number (1-indexed) and limit.
206    pub const fn page(mut self, page: u32, limit: u32) -> Self {
207        self.limit = Some(limit);
208        self.offset = Some(page.saturating_sub(1).saturating_mul(limit));
209        self
210    }
211
212    /// Set explicit limit and offset.
213    pub const fn limit_offset(mut self, limit: u32, offset: u32) -> Self {
214        self.limit = Some(limit);
215        self.offset = Some(offset);
216        self
217    }
218
219    /// Set a limit without offset.
220    pub const fn limit(mut self, limit: u32) -> Self {
221        self.limit = Some(limit);
222        self
223    }
224
225    /// Paginate after this cursor (forward pagination).
226    ///
227    /// This method accepts flexible input types for great DX:
228    /// - `&Cursor` - when you have an already-parsed cursor
229    /// - `&str` - automatically decodes the base64 cursor
230    /// - `Option<&str>` - perfect for `req.query("after")` results
231    ///
232    /// If the cursor is invalid or None, it's silently ignored.
233    /// This makes it safe to pass `req.query("after")` directly.
234    pub fn after_cursor(mut self, cursor: impl IntoCursor) -> Self {
235        if let Some(c) = cursor.into_cursor() {
236            self.cursor = Some(c);
237            self.cursor_direction = Some(CursorDirection::After);
238        }
239        self
240    }
241
242    /// Paginate before this cursor (backward pagination).
243    ///
244    /// This method accepts flexible input types for great DX:
245    /// - `&Cursor` - when you have an already-parsed cursor
246    /// - `&str` - automatically decodes the base64 cursor
247    /// - `Option<&str>` - perfect for `req.query("before")` results
248    ///
249    /// If the cursor is invalid or None, it's silently ignored.
250    pub fn before_cursor(mut self, cursor: impl IntoCursor) -> Self {
251        if let Some(c) = cursor.into_cursor() {
252            self.cursor = Some(c);
253            self.cursor_direction = Some(CursorDirection::Before);
254        }
255        self
256    }
257
258    /// Build the SQL query and parameters.
259    pub fn build(self) -> QueryResult {
260        let mut sql = String::new();
261        let mut params = Vec::new();
262        let mut param_idx = 1usize;
263
264        // SELECT clause
265        let mut select_parts = Vec::new();
266
267        // Add regular fields
268        if !self.fields.is_empty() {
269            select_parts.extend(self.fields.clone());
270        }
271
272        // Add computed fields
273        for comp in &self.computed {
274            select_parts.push(comp.to_sql());
275        }
276
277        // Add aggregations
278        for agg in &self.aggregates {
279            select_parts.push(agg.to_sql());
280        }
281
282        let select_str = if select_parts.is_empty() {
283            "*".to_string()
284        } else {
285            select_parts.join(", ")
286        };
287
288        sql.push_str(&format!("SELECT {} FROM {}", select_str, self.table));
289
290        // WHERE clause - combine filter_expr, simple filters, and cursor conditions
291        let has_filter_expr = self.filter_expr.is_some();
292        let has_simple_filters = !self.filters.is_empty();
293        let has_cursor = self.cursor.is_some() && self.cursor_direction.is_some();
294
295        if has_filter_expr || has_simple_filters || has_cursor {
296            sql.push_str(" WHERE ");
297            let mut all_conditions = Vec::new();
298
299            // Add filter_expr conditions first
300            if let Some(ref expr) = self.filter_expr {
301                let (condition, new_params, new_idx) =
302                    build_filter_expr_impl(&self.dialect, expr, param_idx);
303                all_conditions.push(condition);
304                params.extend(new_params);
305                param_idx = new_idx;
306            }
307
308            // Add simple filters (from merge or direct .filter() calls)
309            for filter in &self.filters {
310                let (condition, new_params, new_idx) =
311                    build_condition_impl(&self.dialect, filter, param_idx);
312                all_conditions.push(condition);
313                params.extend(new_params);
314                param_idx = new_idx;
315            }
316
317            // Add cursor pagination conditions
318            if let (Some(cursor), Some(direction)) = (&self.cursor, self.cursor_direction) {
319                let (condition, new_params, new_idx) =
320                    self.build_cursor_condition(cursor, direction, param_idx);
321                if !condition.is_empty() {
322                    all_conditions.push(condition);
323                    params.extend(new_params);
324                    param_idx = new_idx;
325                }
326            }
327
328            sql.push_str(&all_conditions.join(" AND "));
329        }
330
331        // GROUP BY clause
332        if !self.group_by.is_empty() {
333            sql.push_str(&format!(" GROUP BY {}", self.group_by.join(", ")));
334        }
335
336        // HAVING clause
337        // Note: _new_idx intentionally unused - ORDER BY/LIMIT/OFFSET don't use parameters
338        if let Some(ref expr) = self.having {
339            let (condition, new_params, _new_idx) =
340                build_filter_expr_impl(&self.dialect, expr, param_idx);
341            sql.push_str(&format!(" HAVING {condition}"));
342            params.extend(new_params);
343        }
344
345        // ORDER BY clause
346        if !self.sorts.is_empty() {
347            sql.push_str(" ORDER BY ");
348            let sort_parts: Vec<String> = self
349                .sorts
350                .iter()
351                .map(|s| {
352                    let dir = match s.dir {
353                        SortDir::Asc => "ASC",
354                        SortDir::Desc => "DESC",
355                    };
356                    format!("{} {}", s.field, dir)
357                })
358                .collect();
359            sql.push_str(&sort_parts.join(", "));
360        }
361
362        // LIMIT/OFFSET clause
363        if let Some(limit) = self.limit {
364            sql.push_str(&format!(" LIMIT {limit}"));
365        }
366        if let Some(offset) = self.offset {
367            sql.push_str(&format!(" OFFSET {offset}"));
368        }
369
370        QueryResult { sql, params }
371    }
372
373    /// Build cursor pagination condition.
374    ///
375    /// Generates keyset-style WHERE conditions based on sort fields and cursor values.
376    /// For single field: `field > $1` (or `<` for DESC)
377    /// For multiple fields: `(a, b) > ($1, $2)` using row comparison.
378    fn build_cursor_condition(
379        &self,
380        cursor: &Cursor,
381        direction: CursorDirection,
382        start_idx: usize,
383    ) -> (String, Vec<Value>, usize) {
384        // If no sorts defined, try using cursor fields directly with ascending order
385        let sort_fields: Vec<SortField> = if self.sorts.is_empty() {
386            cursor
387                .fields
388                .iter()
389                .map(|(name, _)| SortField::new(name.clone(), SortDir::Asc))
390                .collect()
391        } else {
392            self.sorts.clone()
393        };
394
395        if sort_fields.is_empty() {
396            return (String::new(), vec![], start_idx);
397        }
398
399        // Collect values for each sort field from cursor
400        let mut cursor_values: Vec<(&str, &Value)> = Vec::new();
401        for sort in &sort_fields {
402            if let Some((_, value)) = cursor.fields.iter().find(|(name, _)| name == &sort.field) {
403                cursor_values.push((&sort.field, value));
404            }
405        }
406
407        if cursor_values.is_empty() {
408            return (String::new(), vec![], start_idx);
409        }
410
411        let mut idx = start_idx;
412        let mut params = Vec::new();
413
414        if cursor_values.len() == 1 {
415            // Single field: simple comparison
416            let (field, value) = cursor_values[0];
417            let sort = &sort_fields[0];
418            let op = match (direction, sort.dir) {
419                (CursorDirection::After, SortDir::Asc) => ">",
420                (CursorDirection::After, SortDir::Desc) => "<",
421                (CursorDirection::Before, SortDir::Asc) => "<",
422                (CursorDirection::Before, SortDir::Desc) => ">",
423            };
424
425            let sql = format!("{} {} {}", field, op, self.dialect.param(idx));
426            params.push(value.clone());
427            idx += 1;
428
429            (sql, params, idx)
430        } else {
431            // Multiple fields: use row/tuple comparison for efficiency
432            // (a, b, c) > ($1, $2, $3) handles lexicographic ordering correctly
433            let fields: Vec<&str> = cursor_values.iter().map(|(f, _)| *f).collect();
434            let placeholders: Vec<String> = cursor_values
435                .iter()
436                .enumerate()
437                .map(|(i, (_, value))| {
438                    params.push((*value).clone());
439                    self.dialect.param(idx + i)
440                })
441                .collect();
442            idx += cursor_values.len();
443
444            // Determine comparison operator based on primary sort direction
445            let primary_dir = sort_fields[0].dir;
446            let op = match (direction, primary_dir) {
447                (CursorDirection::After, SortDir::Asc) => ">",
448                (CursorDirection::After, SortDir::Desc) => "<",
449                (CursorDirection::Before, SortDir::Asc) => "<",
450                (CursorDirection::Before, SortDir::Desc) => ">",
451            };
452
453            let sql = format!(
454                "({}) {} ({})",
455                fields.join(", "),
456                op,
457                placeholders.join(", ")
458            );
459
460            (sql, params, idx)
461        }
462    }
463}