Skip to main content

sqlrite/sql/
agg.rs

1//! SQLR-3 aggregate runtime.
2//!
3//! Three concerns live here:
4//!   1. `AggState` — per-group accumulator state for COUNT/SUM/AVG/MIN/MAX,
5//!      with SQLite-style numeric type rules (Sum stays Integer until a
6//!      Real input or i64 overflow forces a one-time promotion to f64).
7//!   2. `DistinctKey` — a hashable typed wrapper around `Value`, used both
8//!      as the per-row key for GROUP BY and as the dedupe key for
9//!      `COUNT(DISTINCT col)` and `SELECT DISTINCT`.
10//!   3. `like_match` — the iterative two-pointer LIKE matcher (case
11//!      insensitive ASCII to match SQLite's default).
12//!
13//! All of this is pure-functional in the sense that nothing here touches
14//! the `Database`/`Table`. The executor walks rows and feeds values in.
15
16use std::collections::HashSet;
17
18use crate::sql::db::table::Value;
19use crate::sql::parser::select::{AggregateArg, AggregateCall, AggregateFn};
20
21/// SQLite-style numeric accumulator: stays `Int` while every input is
22/// Integer and the running total fits in i64, otherwise promotes once to
23/// `Real` and never demotes back.
24#[derive(Debug, Clone)]
25pub enum SumAcc {
26    Int(i64),
27    Real(f64),
28}
29
30impl SumAcc {
31    fn add_int(&mut self, j: i64) {
32        match *self {
33            SumAcc::Int(i) => match i.checked_add(j) {
34                Some(s) => *self = SumAcc::Int(s),
35                None => *self = SumAcc::Real(i as f64 + j as f64),
36            },
37            SumAcc::Real(r) => *self = SumAcc::Real(r + j as f64),
38        }
39    }
40    fn add_real(&mut self, r: f64) {
41        match *self {
42            SumAcc::Int(i) => *self = SumAcc::Real(i as f64 + r),
43            SumAcc::Real(x) => *self = SumAcc::Real(x + r),
44        }
45    }
46    fn as_value(&self) -> Value {
47        match self {
48            SumAcc::Int(i) => Value::Integer(*i),
49            SumAcc::Real(r) => Value::Real(*r),
50        }
51    }
52    fn as_f64(&self) -> f64 {
53        match self {
54            SumAcc::Int(i) => *i as f64,
55            SumAcc::Real(r) => *r,
56        }
57    }
58}
59
60/// Per-aggregate accumulator. One instance per (group, projection-slot)
61/// pair lives for the duration of the SELECT.
62#[derive(Debug, Clone)]
63pub enum AggState {
64    /// `COUNT(*)` — counts every row, including all-NULL rows.
65    CountStar(i64),
66    /// `COUNT(col)` — counts non-NULL values, optionally with DISTINCT.
67    Count {
68        non_null: i64,
69        distinct: Option<HashSet<DistinctKey>>,
70    },
71    /// `SUM(col)` — skips NULLs; `all_null` tracks the SQL semantic that
72    /// SUM over an all-NULL or empty set yields NULL (not 0).
73    Sum {
74        acc: SumAcc,
75        all_null: bool,
76    },
77    /// `AVG(col)` — always returns Real (or NULL on empty / all-NULL).
78    Avg {
79        acc: SumAcc,
80        n: i64,
81    },
82    /// `MIN(col)` / `MAX(col)` — track the running winner (or None until
83    /// the first non-NULL input).
84    Min(Option<Value>),
85    Max(Option<Value>),
86}
87
88impl AggState {
89    /// Construct the initial accumulator for an aggregate call.
90    pub fn new(call: &AggregateCall) -> Self {
91        match call.func {
92            AggregateFn::Count => match &call.arg {
93                AggregateArg::Star => AggState::CountStar(0),
94                AggregateArg::Column(_) => AggState::Count {
95                    non_null: 0,
96                    distinct: if call.distinct {
97                        Some(HashSet::new())
98                    } else {
99                        None
100                    },
101                },
102            },
103            AggregateFn::Sum => AggState::Sum {
104                acc: SumAcc::Int(0),
105                all_null: true,
106            },
107            AggregateFn::Avg => AggState::Avg {
108                acc: SumAcc::Int(0),
109                n: 0,
110            },
111            AggregateFn::Min => AggState::Min(None),
112            AggregateFn::Max => AggState::Max(None),
113        }
114    }
115
116    /// Fold one row's value into the accumulator.
117    /// For `COUNT(*)`, the value is irrelevant — pass anything.
118    pub fn update(&mut self, value: &Value) -> crate::error::Result<()> {
119        match self {
120            AggState::CountStar(c) => *c += 1,
121            AggState::Count { non_null, distinct } => {
122                if !matches!(value, Value::Null) {
123                    if let Some(set) = distinct {
124                        set.insert(DistinctKey::from_value(value));
125                    } else {
126                        *non_null += 1;
127                    }
128                }
129            }
130            AggState::Sum { acc, all_null } => match value {
131                Value::Null => {}
132                Value::Integer(i) => {
133                    *all_null = false;
134                    acc.add_int(*i);
135                }
136                Value::Real(r) => {
137                    *all_null = false;
138                    acc.add_real(*r);
139                }
140                Value::Bool(b) => {
141                    *all_null = false;
142                    acc.add_int(if *b { 1 } else { 0 });
143                }
144                other => {
145                    return Err(crate::error::SQLRiteError::Internal(format!(
146                        "SUM expects a numeric column, got {}",
147                        other.to_display_string()
148                    )));
149                }
150            },
151            AggState::Avg { acc, n } => match value {
152                Value::Null => {}
153                Value::Integer(i) => {
154                    acc.add_int(*i);
155                    *n += 1;
156                }
157                Value::Real(r) => {
158                    acc.add_real(*r);
159                    *n += 1;
160                }
161                Value::Bool(b) => {
162                    acc.add_int(if *b { 1 } else { 0 });
163                    *n += 1;
164                }
165                other => {
166                    return Err(crate::error::SQLRiteError::Internal(format!(
167                        "AVG expects a numeric column, got {}",
168                        other.to_display_string()
169                    )));
170                }
171            },
172            AggState::Min(cur) => {
173                if !matches!(value, Value::Null) {
174                    match cur {
175                        None => *cur = Some(value.clone()),
176                        Some(c) => {
177                            if compare_values_total(value, c).is_lt() {
178                                *cur = Some(value.clone());
179                            }
180                        }
181                    }
182                }
183            }
184            AggState::Max(cur) => {
185                if !matches!(value, Value::Null) {
186                    match cur {
187                        None => *cur = Some(value.clone()),
188                        Some(c) => {
189                            if compare_values_total(value, c).is_gt() {
190                                *cur = Some(value.clone());
191                            }
192                        }
193                    }
194                }
195            }
196        }
197        Ok(())
198    }
199
200    /// Produce the final SQL value emitted for this group.
201    pub fn finalize(&self) -> Value {
202        match self {
203            AggState::CountStar(c) => Value::Integer(*c),
204            AggState::Count { non_null, distinct } => match distinct {
205                Some(set) => Value::Integer(set.len() as i64),
206                None => Value::Integer(*non_null),
207            },
208            AggState::Sum { acc, all_null } => {
209                if *all_null {
210                    Value::Null
211                } else {
212                    acc.as_value()
213                }
214            }
215            AggState::Avg { acc, n } => {
216                if *n == 0 {
217                    Value::Null
218                } else {
219                    Value::Real(acc.as_f64() / (*n as f64))
220                }
221            }
222            AggState::Min(v) | AggState::Max(v) => v.clone().unwrap_or(Value::Null),
223        }
224    }
225}
226
227/// A hashable typed wrapper around `Value`, used as the GROUP BY key
228/// element and as the `COUNT(DISTINCT col)` set entry. We can't `impl
229/// Hash for Value` because Value has a `Real(f64)` variant and `f64`
230/// isn't `Hash + Eq`. Round-trip via `f64::to_bits` to keep the
231/// canonical bit-pattern as the key — NaN keys remain distinguishable
232/// by exact bit pattern, which is the safer choice for grouping (we
233/// don't try to be cute about NaN==NaN).
234#[derive(Debug, Clone, Hash, PartialEq, Eq)]
235pub enum DistinctKey {
236    Null,
237    Bool(bool),
238    Int(i64),
239    Real(u64),
240    Text(String),
241    Vector(Vec<u8>),
242}
243
244impl DistinctKey {
245    pub fn from_value(v: &Value) -> Self {
246        match v {
247            Value::Null => DistinctKey::Null,
248            Value::Bool(b) => DistinctKey::Bool(*b),
249            Value::Integer(i) => DistinctKey::Int(*i),
250            Value::Real(r) => DistinctKey::Real(r.to_bits()),
251            Value::Text(s) => DistinctKey::Text(s.clone()),
252            Value::Vector(v) => {
253                let mut bytes = Vec::with_capacity(v.len() * 4);
254                for f in v {
255                    bytes.extend_from_slice(&f.to_le_bytes());
256                }
257                DistinctKey::Vector(bytes)
258            }
259        }
260    }
261}
262
263/// Total-order comparison used by MIN/MAX. Mirrors the executor's
264/// `compare_values` semantics (Int↔Real cross-coerce; otherwise stringify).
265/// Kept separate to avoid a dependency from this module back into
266/// executor.rs's private comparator.
267fn compare_values_total(a: &Value, b: &Value) -> std::cmp::Ordering {
268    use std::cmp::Ordering;
269    match (a, b) {
270        (Value::Null, Value::Null) => Ordering::Equal,
271        (Value::Null, _) => Ordering::Less,
272        (_, Value::Null) => Ordering::Greater,
273        (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
274        (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
275        (Value::Integer(x), Value::Real(y)) => {
276            (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
277        }
278        (Value::Real(x), Value::Integer(y)) => {
279            x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
280        }
281        (Value::Text(x), Value::Text(y)) => x.cmp(y),
282        (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
283        (x, y) => x.to_display_string().cmp(&y.to_display_string()),
284    }
285}
286
287/// SQL `LIKE` matcher.
288///
289/// Wildcards: `%` matches any (possibly empty) char sequence; `_`
290/// matches exactly one char. `\` escapes the next char (so `\%` matches
291/// a literal percent). When `case_insensitive` is true, ASCII letters
292/// fold; non-ASCII characters compare by code-point (we don't pull in
293/// Unicode case folding for v1).
294///
295/// Iterative two-pointer with backtracking — no recursion, so adversarial
296/// patterns like `%a%a%a%a%a%b` against `aaaa…aa` can't blow the stack.
297/// Worst case is O(|text| · |pattern|).
298pub fn like_match(text: &str, pattern: &str, case_insensitive: bool) -> bool {
299    let text: Vec<char> = text.chars().collect();
300    let pat: Vec<char> = pattern.chars().collect();
301    let n = text.len();
302    let m = pat.len();
303
304    let mut ti = 0usize;
305    let mut pi = 0usize;
306    // Backtrack point: the last position where we saw `%` and committed to
307    // matching zero characters with it.
308    let mut star_ti: Option<usize> = None;
309    let mut star_pi: Option<usize> = None;
310
311    while ti < n {
312        if pi < m {
313            let pc = pat[pi];
314            if pc == '%' {
315                star_pi = Some(pi);
316                star_ti = Some(ti);
317                pi += 1;
318                continue;
319            }
320            if pc == '_' {
321                pi += 1;
322                ti += 1;
323                continue;
324            }
325            // Escape support: `\X` matches a literal X for X in {%, _, \}.
326            // Outside that set the backslash is itself literal (matches
327            // SQLite's loose default).
328            let (effective_pat, advance) = if pc == '\\' && pi + 1 < m {
329                let nxt = pat[pi + 1];
330                if nxt == '%' || nxt == '_' || nxt == '\\' {
331                    (nxt, 2)
332                } else {
333                    (pc, 1)
334                }
335            } else {
336                (pc, 1)
337            };
338            if char_eq(text[ti], effective_pat, case_insensitive) {
339                pi += advance;
340                ti += 1;
341                continue;
342            }
343        }
344        // Mismatch (or pattern exhausted before text). If a backtrack point
345        // exists, expand the last `%` to absorb one more char and retry.
346        if let (Some(spi), Some(sti)) = (star_pi, star_ti) {
347            pi = spi + 1;
348            star_ti = Some(sti + 1);
349            ti = sti + 1;
350        } else {
351            return false;
352        }
353    }
354    // Text exhausted; pattern must be done (or all that's left is `%`).
355    while pi < m && pat[pi] == '%' {
356        pi += 1;
357    }
358    pi == m
359}
360
361fn char_eq(a: char, b: char, case_insensitive: bool) -> bool {
362    if !case_insensitive {
363        return a == b;
364    }
365    if a.is_ascii() && b.is_ascii() {
366        a.eq_ignore_ascii_case(&b)
367    } else {
368        a == b
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    #[test]
377    fn like_simple_literal() {
378        assert!(like_match("apple", "apple", true));
379        assert!(!like_match("apple", "apples", true));
380    }
381
382    #[test]
383    fn like_percent_wildcard() {
384        assert!(like_match("apple", "a%", true));
385        assert!(like_match("apple", "%le", true));
386        assert!(like_match("apple", "%pp%", true));
387        assert!(!like_match("banana", "a%", true));
388    }
389
390    #[test]
391    fn like_underscore_wildcard() {
392        assert!(like_match("abc", "a_c", true));
393        assert!(!like_match("abbc", "a_c", true));
394    }
395
396    #[test]
397    fn like_case_insensitive_default() {
398        assert!(like_match("Apple", "a%", true));
399        assert!(like_match("APPLE", "%le", true));
400        assert!(
401            !like_match("Apple", "a%", false),
402            "case-sensitive should fail"
403        );
404    }
405
406    #[test]
407    fn like_escape_percent_literal() {
408        // pattern `100\%` should match literal "100%"
409        assert!(like_match("100%", "100\\%", true));
410        assert!(!like_match("100x", "100\\%", true));
411    }
412
413    #[test]
414    fn like_no_pathological_recursion() {
415        // The classic "exponential naive matcher" stress case.
416        let text = "a".repeat(40);
417        let pat = "a%a%a%a%a%a%a%a%b";
418        // Should return false in linear time; if we recurse we'd stack-OOM
419        // or hang; this test is mostly a smoke test.
420        assert!(!like_match(&text, pat, true));
421    }
422
423    #[test]
424    fn distinct_key_real_distinguishes_from_int() {
425        let a = DistinctKey::from_value(&Value::Integer(1));
426        let b = DistinctKey::from_value(&Value::Real(1.0));
427        assert_ne!(a, b, "Integer(1) vs Real(1.0) must hash differently");
428    }
429
430    #[test]
431    fn count_star_includes_nulls() {
432        let call = AggregateCall {
433            func: AggregateFn::Count,
434            arg: AggregateArg::Star,
435            distinct: false,
436        };
437        let mut s = AggState::new(&call);
438        s.update(&Value::Null).unwrap();
439        s.update(&Value::Integer(7)).unwrap();
440        s.update(&Value::Null).unwrap();
441        assert_eq!(s.finalize(), Value::Integer(3));
442    }
443
444    #[test]
445    fn count_col_skips_nulls() {
446        let call = AggregateCall {
447            func: AggregateFn::Count,
448            arg: AggregateArg::Column("x".into()),
449            distinct: false,
450        };
451        let mut s = AggState::new(&call);
452        s.update(&Value::Null).unwrap();
453        s.update(&Value::Integer(7)).unwrap();
454        s.update(&Value::Null).unwrap();
455        assert_eq!(s.finalize(), Value::Integer(1));
456    }
457
458    #[test]
459    fn count_distinct_dedupes() {
460        let call = AggregateCall {
461            func: AggregateFn::Count,
462            arg: AggregateArg::Column("x".into()),
463            distinct: true,
464        };
465        let mut s = AggState::new(&call);
466        for v in [1, 1, 2, 2, 3, 3] {
467            s.update(&Value::Integer(v)).unwrap();
468        }
469        s.update(&Value::Null).unwrap();
470        assert_eq!(s.finalize(), Value::Integer(3));
471    }
472
473    #[test]
474    fn sum_int_stays_int_until_real() {
475        let call = AggregateCall {
476            func: AggregateFn::Sum,
477            arg: AggregateArg::Column("x".into()),
478            distinct: false,
479        };
480        let mut s = AggState::new(&call);
481        s.update(&Value::Integer(2)).unwrap();
482        s.update(&Value::Integer(3)).unwrap();
483        assert_eq!(s.finalize(), Value::Integer(5));
484
485        s.update(&Value::Real(0.5)).unwrap();
486        match s.finalize() {
487            Value::Real(r) => assert!((r - 5.5).abs() < 1e-9),
488            v => panic!("expected Real, got {:?}", v),
489        }
490    }
491
492    #[test]
493    fn sum_all_null_is_null() {
494        let call = AggregateCall {
495            func: AggregateFn::Sum,
496            arg: AggregateArg::Column("x".into()),
497            distinct: false,
498        };
499        let mut s = AggState::new(&call);
500        s.update(&Value::Null).unwrap();
501        s.update(&Value::Null).unwrap();
502        assert_eq!(s.finalize(), Value::Null);
503    }
504
505    #[test]
506    fn avg_always_real() {
507        let call = AggregateCall {
508            func: AggregateFn::Avg,
509            arg: AggregateArg::Column("x".into()),
510            distinct: false,
511        };
512        let mut s = AggState::new(&call);
513        s.update(&Value::Integer(2)).unwrap();
514        s.update(&Value::Integer(4)).unwrap();
515        match s.finalize() {
516            Value::Real(r) => assert!((r - 3.0).abs() < 1e-9),
517            v => panic!("expected Real, got {:?}", v),
518        }
519    }
520
521    #[test]
522    fn min_max_skip_nulls() {
523        let mk = |f| AggregateCall {
524            func: f,
525            arg: AggregateArg::Column("x".into()),
526            distinct: false,
527        };
528        let mut mn = AggState::new(&mk(AggregateFn::Min));
529        let mut mx = AggState::new(&mk(AggregateFn::Max));
530        for v in [
531            Value::Null,
532            Value::Integer(7),
533            Value::Integer(3),
534            Value::Integer(9),
535            Value::Null,
536        ] {
537            mn.update(&v).unwrap();
538            mx.update(&v).unwrap();
539        }
540        assert_eq!(mn.finalize(), Value::Integer(3));
541        assert_eq!(mx.finalize(), Value::Integer(9));
542    }
543}