1use std::collections::HashSet;
17
18use crate::sql::db::table::Value;
19use crate::sql::parser::select::{AggregateArg, AggregateCall, AggregateFn};
20
21#[derive(Debug, Clone)]
25pub enum SumAcc {
26 Int(i64),
27 Real(f64),
28}
29
30impl SumAcc {
31 fn add_int(&mut self, j: i64) {
32 match *self {
33 SumAcc::Int(i) => match i.checked_add(j) {
34 Some(s) => *self = SumAcc::Int(s),
35 None => *self = SumAcc::Real(i as f64 + j as f64),
36 },
37 SumAcc::Real(r) => *self = SumAcc::Real(r + j as f64),
38 }
39 }
40 fn add_real(&mut self, r: f64) {
41 match *self {
42 SumAcc::Int(i) => *self = SumAcc::Real(i as f64 + r),
43 SumAcc::Real(x) => *self = SumAcc::Real(x + r),
44 }
45 }
46 fn as_value(&self) -> Value {
47 match self {
48 SumAcc::Int(i) => Value::Integer(*i),
49 SumAcc::Real(r) => Value::Real(*r),
50 }
51 }
52 fn as_f64(&self) -> f64 {
53 match self {
54 SumAcc::Int(i) => *i as f64,
55 SumAcc::Real(r) => *r,
56 }
57 }
58}
59
60#[derive(Debug, Clone)]
63pub enum AggState {
64 CountStar(i64),
66 Count {
68 non_null: i64,
69 distinct: Option<HashSet<DistinctKey>>,
70 },
71 Sum {
74 acc: SumAcc,
75 all_null: bool,
76 },
77 Avg {
79 acc: SumAcc,
80 n: i64,
81 },
82 Min(Option<Value>),
85 Max(Option<Value>),
86}
87
88impl AggState {
89 pub fn new(call: &AggregateCall) -> Self {
91 match call.func {
92 AggregateFn::Count => match &call.arg {
93 AggregateArg::Star => AggState::CountStar(0),
94 AggregateArg::Column { .. } => AggState::Count {
95 non_null: 0,
96 distinct: if call.distinct {
97 Some(HashSet::new())
98 } else {
99 None
100 },
101 },
102 },
103 AggregateFn::Sum => AggState::Sum {
104 acc: SumAcc::Int(0),
105 all_null: true,
106 },
107 AggregateFn::Avg => AggState::Avg {
108 acc: SumAcc::Int(0),
109 n: 0,
110 },
111 AggregateFn::Min => AggState::Min(None),
112 AggregateFn::Max => AggState::Max(None),
113 }
114 }
115
116 pub fn update(&mut self, value: &Value) -> crate::error::Result<()> {
119 match self {
120 AggState::CountStar(c) => *c += 1,
121 AggState::Count { non_null, distinct } => {
122 if !matches!(value, Value::Null) {
123 if let Some(set) = distinct {
124 set.insert(DistinctKey::from_value(value));
125 } else {
126 *non_null += 1;
127 }
128 }
129 }
130 AggState::Sum { acc, all_null } => match value {
131 Value::Null => {}
132 Value::Integer(i) => {
133 *all_null = false;
134 acc.add_int(*i);
135 }
136 Value::Real(r) => {
137 *all_null = false;
138 acc.add_real(*r);
139 }
140 Value::Bool(b) => {
141 *all_null = false;
142 acc.add_int(if *b { 1 } else { 0 });
143 }
144 other => {
145 return Err(crate::error::SQLRiteError::Internal(format!(
146 "SUM expects a numeric column, got {}",
147 other.to_display_string()
148 )));
149 }
150 },
151 AggState::Avg { acc, n } => match value {
152 Value::Null => {}
153 Value::Integer(i) => {
154 acc.add_int(*i);
155 *n += 1;
156 }
157 Value::Real(r) => {
158 acc.add_real(*r);
159 *n += 1;
160 }
161 Value::Bool(b) => {
162 acc.add_int(if *b { 1 } else { 0 });
163 *n += 1;
164 }
165 other => {
166 return Err(crate::error::SQLRiteError::Internal(format!(
167 "AVG expects a numeric column, got {}",
168 other.to_display_string()
169 )));
170 }
171 },
172 AggState::Min(cur) => {
173 if !matches!(value, Value::Null) {
174 match cur {
175 None => *cur = Some(value.clone()),
176 Some(c) => {
177 if compare_values_total(value, c).is_lt() {
178 *cur = Some(value.clone());
179 }
180 }
181 }
182 }
183 }
184 AggState::Max(cur) => {
185 if !matches!(value, Value::Null) {
186 match cur {
187 None => *cur = Some(value.clone()),
188 Some(c) => {
189 if compare_values_total(value, c).is_gt() {
190 *cur = Some(value.clone());
191 }
192 }
193 }
194 }
195 }
196 }
197 Ok(())
198 }
199
200 pub fn finalize(&self) -> Value {
202 match self {
203 AggState::CountStar(c) => Value::Integer(*c),
204 AggState::Count { non_null, distinct } => match distinct {
205 Some(set) => Value::Integer(set.len() as i64),
206 None => Value::Integer(*non_null),
207 },
208 AggState::Sum { acc, all_null } => {
209 if *all_null {
210 Value::Null
211 } else {
212 acc.as_value()
213 }
214 }
215 AggState::Avg { acc, n } => {
216 if *n == 0 {
217 Value::Null
218 } else {
219 Value::Real(acc.as_f64() / (*n as f64))
220 }
221 }
222 AggState::Min(v) | AggState::Max(v) => v.clone().unwrap_or(Value::Null),
223 }
224 }
225}
226
227#[derive(Debug, Clone, Hash, PartialEq, Eq)]
235pub enum DistinctKey {
236 Null,
237 Bool(bool),
238 Int(i64),
239 Real(u64),
240 Text(String),
241 Vector(Vec<u8>),
242}
243
244impl DistinctKey {
245 pub fn from_value(v: &Value) -> Self {
246 match v {
247 Value::Null => DistinctKey::Null,
248 Value::Bool(b) => DistinctKey::Bool(*b),
249 Value::Integer(i) => DistinctKey::Int(*i),
250 Value::Real(r) => DistinctKey::Real(r.to_bits()),
251 Value::Text(s) => DistinctKey::Text(s.clone()),
252 Value::Vector(v) => {
253 let mut bytes = Vec::with_capacity(v.len() * 4);
254 for f in v {
255 bytes.extend_from_slice(&f.to_le_bytes());
256 }
257 DistinctKey::Vector(bytes)
258 }
259 }
260 }
261}
262
263fn compare_values_total(a: &Value, b: &Value) -> std::cmp::Ordering {
268 use std::cmp::Ordering;
269 match (a, b) {
270 (Value::Null, Value::Null) => Ordering::Equal,
271 (Value::Null, _) => Ordering::Less,
272 (_, Value::Null) => Ordering::Greater,
273 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
274 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
275 (Value::Integer(x), Value::Real(y)) => {
276 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
277 }
278 (Value::Real(x), Value::Integer(y)) => {
279 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
280 }
281 (Value::Text(x), Value::Text(y)) => x.cmp(y),
282 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
283 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
284 }
285}
286
287pub fn like_match(text: &str, pattern: &str, case_insensitive: bool) -> bool {
299 let text: Vec<char> = text.chars().collect();
300 let pat: Vec<char> = pattern.chars().collect();
301 let n = text.len();
302 let m = pat.len();
303
304 let mut ti = 0usize;
305 let mut pi = 0usize;
306 let mut star_ti: Option<usize> = None;
309 let mut star_pi: Option<usize> = None;
310
311 while ti < n {
312 if pi < m {
313 let pc = pat[pi];
314 if pc == '%' {
315 star_pi = Some(pi);
316 star_ti = Some(ti);
317 pi += 1;
318 continue;
319 }
320 if pc == '_' {
321 pi += 1;
322 ti += 1;
323 continue;
324 }
325 let (effective_pat, advance) = if pc == '\\' && pi + 1 < m {
329 let nxt = pat[pi + 1];
330 if nxt == '%' || nxt == '_' || nxt == '\\' {
331 (nxt, 2)
332 } else {
333 (pc, 1)
334 }
335 } else {
336 (pc, 1)
337 };
338 if char_eq(text[ti], effective_pat, case_insensitive) {
339 pi += advance;
340 ti += 1;
341 continue;
342 }
343 }
344 if let (Some(spi), Some(sti)) = (star_pi, star_ti) {
347 pi = spi + 1;
348 star_ti = Some(sti + 1);
349 ti = sti + 1;
350 } else {
351 return false;
352 }
353 }
354 while pi < m && pat[pi] == '%' {
356 pi += 1;
357 }
358 pi == m
359}
360
361fn char_eq(a: char, b: char, case_insensitive: bool) -> bool {
362 if !case_insensitive {
363 return a == b;
364 }
365 if a.is_ascii() && b.is_ascii() {
366 a.eq_ignore_ascii_case(&b)
367 } else {
368 a == b
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn like_simple_literal() {
378 assert!(like_match("apple", "apple", true));
379 assert!(!like_match("apple", "apples", true));
380 }
381
382 #[test]
383 fn like_percent_wildcard() {
384 assert!(like_match("apple", "a%", true));
385 assert!(like_match("apple", "%le", true));
386 assert!(like_match("apple", "%pp%", true));
387 assert!(!like_match("banana", "a%", true));
388 }
389
390 #[test]
391 fn like_underscore_wildcard() {
392 assert!(like_match("abc", "a_c", true));
393 assert!(!like_match("abbc", "a_c", true));
394 }
395
396 #[test]
397 fn like_case_insensitive_default() {
398 assert!(like_match("Apple", "a%", true));
399 assert!(like_match("APPLE", "%le", true));
400 assert!(
401 !like_match("Apple", "a%", false),
402 "case-sensitive should fail"
403 );
404 }
405
406 #[test]
407 fn like_escape_percent_literal() {
408 assert!(like_match("100%", "100\\%", true));
410 assert!(!like_match("100x", "100\\%", true));
411 }
412
413 #[test]
414 fn like_no_pathological_recursion() {
415 let text = "a".repeat(40);
417 let pat = "a%a%a%a%a%a%a%a%b";
418 assert!(!like_match(&text, pat, true));
421 }
422
423 #[test]
424 fn distinct_key_real_distinguishes_from_int() {
425 let a = DistinctKey::from_value(&Value::Integer(1));
426 let b = DistinctKey::from_value(&Value::Real(1.0));
427 assert_ne!(a, b, "Integer(1) vs Real(1.0) must hash differently");
428 }
429
430 #[test]
431 fn count_star_includes_nulls() {
432 let call = AggregateCall {
433 func: AggregateFn::Count,
434 arg: AggregateArg::Star,
435 distinct: false,
436 };
437 let mut s = AggState::new(&call);
438 s.update(&Value::Null).unwrap();
439 s.update(&Value::Integer(7)).unwrap();
440 s.update(&Value::Null).unwrap();
441 assert_eq!(s.finalize(), Value::Integer(3));
442 }
443
444 #[test]
445 fn count_col_skips_nulls() {
446 let call = AggregateCall {
447 func: AggregateFn::Count,
448 arg: AggregateArg::Column {
449 qualifier: None,
450 name: "x".into(),
451 },
452 distinct: false,
453 };
454 let mut s = AggState::new(&call);
455 s.update(&Value::Null).unwrap();
456 s.update(&Value::Integer(7)).unwrap();
457 s.update(&Value::Null).unwrap();
458 assert_eq!(s.finalize(), Value::Integer(1));
459 }
460
461 #[test]
462 fn count_distinct_dedupes() {
463 let call = AggregateCall {
464 func: AggregateFn::Count,
465 arg: AggregateArg::Column {
466 qualifier: None,
467 name: "x".into(),
468 },
469 distinct: true,
470 };
471 let mut s = AggState::new(&call);
472 for v in [1, 1, 2, 2, 3, 3] {
473 s.update(&Value::Integer(v)).unwrap();
474 }
475 s.update(&Value::Null).unwrap();
476 assert_eq!(s.finalize(), Value::Integer(3));
477 }
478
479 #[test]
480 fn sum_int_stays_int_until_real() {
481 let call = AggregateCall {
482 func: AggregateFn::Sum,
483 arg: AggregateArg::Column {
484 qualifier: None,
485 name: "x".into(),
486 },
487 distinct: false,
488 };
489 let mut s = AggState::new(&call);
490 s.update(&Value::Integer(2)).unwrap();
491 s.update(&Value::Integer(3)).unwrap();
492 assert_eq!(s.finalize(), Value::Integer(5));
493
494 s.update(&Value::Real(0.5)).unwrap();
495 match s.finalize() {
496 Value::Real(r) => assert!((r - 5.5).abs() < 1e-9),
497 v => panic!("expected Real, got {:?}", v),
498 }
499 }
500
501 #[test]
502 fn sum_all_null_is_null() {
503 let call = AggregateCall {
504 func: AggregateFn::Sum,
505 arg: AggregateArg::Column {
506 qualifier: None,
507 name: "x".into(),
508 },
509 distinct: false,
510 };
511 let mut s = AggState::new(&call);
512 s.update(&Value::Null).unwrap();
513 s.update(&Value::Null).unwrap();
514 assert_eq!(s.finalize(), Value::Null);
515 }
516
517 #[test]
518 fn avg_always_real() {
519 let call = AggregateCall {
520 func: AggregateFn::Avg,
521 arg: AggregateArg::Column {
522 qualifier: None,
523 name: "x".into(),
524 },
525 distinct: false,
526 };
527 let mut s = AggState::new(&call);
528 s.update(&Value::Integer(2)).unwrap();
529 s.update(&Value::Integer(4)).unwrap();
530 match s.finalize() {
531 Value::Real(r) => assert!((r - 3.0).abs() < 1e-9),
532 v => panic!("expected Real, got {:?}", v),
533 }
534 }
535
536 #[test]
537 fn min_max_skip_nulls() {
538 let mk = |f| AggregateCall {
539 func: f,
540 arg: AggregateArg::Column {
541 qualifier: None,
542 name: "x".into(),
543 },
544 distinct: false,
545 };
546 let mut mn = AggState::new(&mk(AggregateFn::Min));
547 let mut mx = AggState::new(&mk(AggregateFn::Max));
548 for v in [
549 Value::Null,
550 Value::Integer(7),
551 Value::Integer(3),
552 Value::Integer(9),
553 Value::Null,
554 ] {
555 mn.update(&v).unwrap();
556 mx.update(&v).unwrap();
557 }
558 assert_eq!(mn.finalize(), Value::Integer(3));
559 assert_eq!(mx.finalize(), Value::Integer(9));
560 }
561}