1use std::collections::HashSet;
17
18use crate::sql::db::table::Value;
19use crate::sql::parser::select::{AggregateArg, AggregateCall, AggregateFn};
20
21#[derive(Debug, Clone)]
25pub enum SumAcc {
26 Int(i64),
27 Real(f64),
28}
29
30impl SumAcc {
31 fn add_int(&mut self, j: i64) {
32 match *self {
33 SumAcc::Int(i) => match i.checked_add(j) {
34 Some(s) => *self = SumAcc::Int(s),
35 None => *self = SumAcc::Real(i as f64 + j as f64),
36 },
37 SumAcc::Real(r) => *self = SumAcc::Real(r + j as f64),
38 }
39 }
40 fn add_real(&mut self, r: f64) {
41 match *self {
42 SumAcc::Int(i) => *self = SumAcc::Real(i as f64 + r),
43 SumAcc::Real(x) => *self = SumAcc::Real(x + r),
44 }
45 }
46 fn as_value(&self) -> Value {
47 match self {
48 SumAcc::Int(i) => Value::Integer(*i),
49 SumAcc::Real(r) => Value::Real(*r),
50 }
51 }
52 fn as_f64(&self) -> f64 {
53 match self {
54 SumAcc::Int(i) => *i as f64,
55 SumAcc::Real(r) => *r,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
63pub enum AggState {
64 CountStar(i64),
66 Count {
68 non_null: i64,
69 distinct: Option<HashSet<DistinctKey>>,
70 },
71 Sum {
74 acc: SumAcc,
75 all_null: bool,
76 },
77 Avg {
79 acc: SumAcc,
80 n: i64,
81 },
82 Min(Option<Value>),
85 Max(Option<Value>),
86}
87
88impl AggState {
89 pub fn new(call: &AggregateCall) -> Self {
91 match call.func {
92 AggregateFn::Count => match &call.arg {
93 AggregateArg::Star => AggState::CountStar(0),
94 AggregateArg::Column(_) => AggState::Count {
95 non_null: 0,
96 distinct: if call.distinct {
97 Some(HashSet::new())
98 } else {
99 None
100 },
101 },
102 },
103 AggregateFn::Sum => AggState::Sum {
104 acc: SumAcc::Int(0),
105 all_null: true,
106 },
107 AggregateFn::Avg => AggState::Avg {
108 acc: SumAcc::Int(0),
109 n: 0,
110 },
111 AggregateFn::Min => AggState::Min(None),
112 AggregateFn::Max => AggState::Max(None),
113 }
114 }
115
116 pub fn update(&mut self, value: &Value) -> crate::error::Result<()> {
119 match self {
120 AggState::CountStar(c) => *c += 1,
121 AggState::Count { non_null, distinct } => {
122 if !matches!(value, Value::Null) {
123 if let Some(set) = distinct {
124 set.insert(DistinctKey::from_value(value));
125 } else {
126 *non_null += 1;
127 }
128 }
129 }
130 AggState::Sum { acc, all_null } => match value {
131 Value::Null => {}
132 Value::Integer(i) => {
133 *all_null = false;
134 acc.add_int(*i);
135 }
136 Value::Real(r) => {
137 *all_null = false;
138 acc.add_real(*r);
139 }
140 Value::Bool(b) => {
141 *all_null = false;
142 acc.add_int(if *b { 1 } else { 0 });
143 }
144 other => {
145 return Err(crate::error::SQLRiteError::Internal(format!(
146 "SUM expects a numeric column, got {}",
147 other.to_display_string()
148 )));
149 }
150 },
151 AggState::Avg { acc, n } => match value {
152 Value::Null => {}
153 Value::Integer(i) => {
154 acc.add_int(*i);
155 *n += 1;
156 }
157 Value::Real(r) => {
158 acc.add_real(*r);
159 *n += 1;
160 }
161 Value::Bool(b) => {
162 acc.add_int(if *b { 1 } else { 0 });
163 *n += 1;
164 }
165 other => {
166 return Err(crate::error::SQLRiteError::Internal(format!(
167 "AVG expects a numeric column, got {}",
168 other.to_display_string()
169 )));
170 }
171 },
172 AggState::Min(cur) => {
173 if !matches!(value, Value::Null) {
174 match cur {
175 None => *cur = Some(value.clone()),
176 Some(c) => {
177 if compare_values_total(value, c).is_lt() {
178 *cur = Some(value.clone());
179 }
180 }
181 }
182 }
183 }
184 AggState::Max(cur) => {
185 if !matches!(value, Value::Null) {
186 match cur {
187 None => *cur = Some(value.clone()),
188 Some(c) => {
189 if compare_values_total(value, c).is_gt() {
190 *cur = Some(value.clone());
191 }
192 }
193 }
194 }
195 }
196 }
197 Ok(())
198 }
199
200 pub fn finalize(&self) -> Value {
202 match self {
203 AggState::CountStar(c) => Value::Integer(*c),
204 AggState::Count { non_null, distinct } => match distinct {
205 Some(set) => Value::Integer(set.len() as i64),
206 None => Value::Integer(*non_null),
207 },
208 AggState::Sum { acc, all_null } => {
209 if *all_null {
210 Value::Null
211 } else {
212 acc.as_value()
213 }
214 }
215 AggState::Avg { acc, n } => {
216 if *n == 0 {
217 Value::Null
218 } else {
219 Value::Real(acc.as_f64() / (*n as f64))
220 }
221 }
222 AggState::Min(v) | AggState::Max(v) => v.clone().unwrap_or(Value::Null),
223 }
224 }
225}
226
227#[derive(Debug, Clone, Hash, PartialEq, Eq)]
235pub enum DistinctKey {
236 Null,
237 Bool(bool),
238 Int(i64),
239 Real(u64),
240 Text(String),
241 Vector(Vec<u8>),
242}
243
244impl DistinctKey {
245 pub fn from_value(v: &Value) -> Self {
246 match v {
247 Value::Null => DistinctKey::Null,
248 Value::Bool(b) => DistinctKey::Bool(*b),
249 Value::Integer(i) => DistinctKey::Int(*i),
250 Value::Real(r) => DistinctKey::Real(r.to_bits()),
251 Value::Text(s) => DistinctKey::Text(s.clone()),
252 Value::Vector(v) => {
253 let mut bytes = Vec::with_capacity(v.len() * 4);
254 for f in v {
255 bytes.extend_from_slice(&f.to_le_bytes());
256 }
257 DistinctKey::Vector(bytes)
258 }
259 }
260 }
261}
262
263fn compare_values_total(a: &Value, b: &Value) -> std::cmp::Ordering {
268 use std::cmp::Ordering;
269 match (a, b) {
270 (Value::Null, Value::Null) => Ordering::Equal,
271 (Value::Null, _) => Ordering::Less,
272 (_, Value::Null) => Ordering::Greater,
273 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
274 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
275 (Value::Integer(x), Value::Real(y)) => {
276 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
277 }
278 (Value::Real(x), Value::Integer(y)) => {
279 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
280 }
281 (Value::Text(x), Value::Text(y)) => x.cmp(y),
282 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
283 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
284 }
285}
286
287pub fn like_match(text: &str, pattern: &str, case_insensitive: bool) -> bool {
299 let text: Vec<char> = text.chars().collect();
300 let pat: Vec<char> = pattern.chars().collect();
301 let n = text.len();
302 let m = pat.len();
303
304 let mut ti = 0usize;
305 let mut pi = 0usize;
306 let mut star_ti: Option<usize> = None;
309 let mut star_pi: Option<usize> = None;
310
311 while ti < n {
312 if pi < m {
313 let pc = pat[pi];
314 if pc == '%' {
315 star_pi = Some(pi);
316 star_ti = Some(ti);
317 pi += 1;
318 continue;
319 }
320 if pc == '_' {
321 pi += 1;
322 ti += 1;
323 continue;
324 }
325 let (effective_pat, advance) = if pc == '\\' && pi + 1 < m {
329 let nxt = pat[pi + 1];
330 if nxt == '%' || nxt == '_' || nxt == '\\' {
331 (nxt, 2)
332 } else {
333 (pc, 1)
334 }
335 } else {
336 (pc, 1)
337 };
338 if char_eq(text[ti], effective_pat, case_insensitive) {
339 pi += advance;
340 ti += 1;
341 continue;
342 }
343 }
344 if let (Some(spi), Some(sti)) = (star_pi, star_ti) {
347 pi = spi + 1;
348 star_ti = Some(sti + 1);
349 ti = sti + 1;
350 } else {
351 return false;
352 }
353 }
354 while pi < m && pat[pi] == '%' {
356 pi += 1;
357 }
358 pi == m
359}
360
361fn char_eq(a: char, b: char, case_insensitive: bool) -> bool {
362 if !case_insensitive {
363 return a == b;
364 }
365 if a.is_ascii() && b.is_ascii() {
366 a.eq_ignore_ascii_case(&b)
367 } else {
368 a == b
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn like_simple_literal() {
378 assert!(like_match("apple", "apple", true));
379 assert!(!like_match("apple", "apples", true));
380 }
381
382 #[test]
383 fn like_percent_wildcard() {
384 assert!(like_match("apple", "a%", true));
385 assert!(like_match("apple", "%le", true));
386 assert!(like_match("apple", "%pp%", true));
387 assert!(!like_match("banana", "a%", true));
388 }
389
390 #[test]
391 fn like_underscore_wildcard() {
392 assert!(like_match("abc", "a_c", true));
393 assert!(!like_match("abbc", "a_c", true));
394 }
395
396 #[test]
397 fn like_case_insensitive_default() {
398 assert!(like_match("Apple", "a%", true));
399 assert!(like_match("APPLE", "%le", true));
400 assert!(
401 !like_match("Apple", "a%", false),
402 "case-sensitive should fail"
403 );
404 }
405
406 #[test]
407 fn like_escape_percent_literal() {
408 assert!(like_match("100%", "100\\%", true));
410 assert!(!like_match("100x", "100\\%", true));
411 }
412
413 #[test]
414 fn like_no_pathological_recursion() {
415 let text = "a".repeat(40);
417 let pat = "a%a%a%a%a%a%a%a%b";
418 assert!(!like_match(&text, pat, true));
421 }
422
423 #[test]
424 fn distinct_key_real_distinguishes_from_int() {
425 let a = DistinctKey::from_value(&Value::Integer(1));
426 let b = DistinctKey::from_value(&Value::Real(1.0));
427 assert_ne!(a, b, "Integer(1) vs Real(1.0) must hash differently");
428 }
429
430 #[test]
431 fn count_star_includes_nulls() {
432 let call = AggregateCall {
433 func: AggregateFn::Count,
434 arg: AggregateArg::Star,
435 distinct: false,
436 };
437 let mut s = AggState::new(&call);
438 s.update(&Value::Null).unwrap();
439 s.update(&Value::Integer(7)).unwrap();
440 s.update(&Value::Null).unwrap();
441 assert_eq!(s.finalize(), Value::Integer(3));
442 }
443
444 #[test]
445 fn count_col_skips_nulls() {
446 let call = AggregateCall {
447 func: AggregateFn::Count,
448 arg: AggregateArg::Column("x".into()),
449 distinct: false,
450 };
451 let mut s = AggState::new(&call);
452 s.update(&Value::Null).unwrap();
453 s.update(&Value::Integer(7)).unwrap();
454 s.update(&Value::Null).unwrap();
455 assert_eq!(s.finalize(), Value::Integer(1));
456 }
457
458 #[test]
459 fn count_distinct_dedupes() {
460 let call = AggregateCall {
461 func: AggregateFn::Count,
462 arg: AggregateArg::Column("x".into()),
463 distinct: true,
464 };
465 let mut s = AggState::new(&call);
466 for v in [1, 1, 2, 2, 3, 3] {
467 s.update(&Value::Integer(v)).unwrap();
468 }
469 s.update(&Value::Null).unwrap();
470 assert_eq!(s.finalize(), Value::Integer(3));
471 }
472
473 #[test]
474 fn sum_int_stays_int_until_real() {
475 let call = AggregateCall {
476 func: AggregateFn::Sum,
477 arg: AggregateArg::Column("x".into()),
478 distinct: false,
479 };
480 let mut s = AggState::new(&call);
481 s.update(&Value::Integer(2)).unwrap();
482 s.update(&Value::Integer(3)).unwrap();
483 assert_eq!(s.finalize(), Value::Integer(5));
484
485 s.update(&Value::Real(0.5)).unwrap();
486 match s.finalize() {
487 Value::Real(r) => assert!((r - 5.5).abs() < 1e-9),
488 v => panic!("expected Real, got {:?}", v),
489 }
490 }
491
492 #[test]
493 fn sum_all_null_is_null() {
494 let call = AggregateCall {
495 func: AggregateFn::Sum,
496 arg: AggregateArg::Column("x".into()),
497 distinct: false,
498 };
499 let mut s = AggState::new(&call);
500 s.update(&Value::Null).unwrap();
501 s.update(&Value::Null).unwrap();
502 assert_eq!(s.finalize(), Value::Null);
503 }
504
505 #[test]
506 fn avg_always_real() {
507 let call = AggregateCall {
508 func: AggregateFn::Avg,
509 arg: AggregateArg::Column("x".into()),
510 distinct: false,
511 };
512 let mut s = AggState::new(&call);
513 s.update(&Value::Integer(2)).unwrap();
514 s.update(&Value::Integer(4)).unwrap();
515 match s.finalize() {
516 Value::Real(r) => assert!((r - 3.0).abs() < 1e-9),
517 v => panic!("expected Real, got {:?}", v),
518 }
519 }
520
521 #[test]
522 fn min_max_skip_nulls() {
523 let mk = |f| AggregateCall {
524 func: f,
525 arg: AggregateArg::Column("x".into()),
526 distinct: false,
527 };
528 let mut mn = AggState::new(&mk(AggregateFn::Min));
529 let mut mx = AggState::new(&mk(AggregateFn::Max));
530 for v in [
531 Value::Null,
532 Value::Integer(7),
533 Value::Integer(3),
534 Value::Integer(9),
535 Value::Null,
536 ] {
537 mn.update(&v).unwrap();
538 mx.update(&v).unwrap();
539 }
540 assert_eq!(mn.finalize(), Value::Integer(3));
541 assert_eq!(mx.finalize(), Value::Integer(9));
542 }
543}