1use std::cmp::Ordering;
11use std::collections::HashSet;
12
13use nodedb_types::Value;
14
15use crate::msgpack_scan::compare::compare_field_bytes;
16use crate::msgpack_scan::field::extract_field;
17use crate::msgpack_scan::reader::{read_f64, read_null, read_str};
18use crate::value_ops;
19
20pub fn compute_aggregate_binary(
26 op: &str,
27 field: &str,
28 expr: Option<&crate::expr::SqlExpr>,
29 docs: &[&[u8]],
30) -> Value {
31 match op {
32 "count" => {
33 if field == "*" && expr.is_none() {
34 Value::Integer(docs.len() as i64)
35 } else {
36 let count = docs
37 .iter()
38 .filter_map(|d| extract_as_value(d, field, expr))
39 .filter(|v| !v.is_null())
40 .count();
41 Value::Integer(count as i64)
42 }
43 }
44
45 "sum" => {
46 let total: f64 = docs
47 .iter()
48 .filter_map(|d| extract_f64_val(d, field, expr))
49 .sum();
50 Value::Float(total)
51 }
52
53 "avg" => {
54 let (sum, count) = docs
55 .iter()
56 .filter_map(|d| extract_f64_val(d, field, expr))
57 .fold((0.0f64, 0u64), |(s, c), v| (s + v, c + 1));
58 if count == 0 {
59 Value::Null
60 } else {
61 Value::Float(sum / count as f64)
62 }
63 }
64
65 "min" => find_minmax(docs, field, expr, false),
66 "max" => find_minmax(docs, field, expr, true),
67
68 "count_distinct" => {
69 let mut seen = HashSet::new();
70 for doc in docs {
71 if let Some(bytes) = extract_value_bytes(doc, field, expr)
72 && !value_bytes_are_null(&bytes)
73 {
74 seen.insert(bytes);
75 }
76 }
77 Value::Integer(seen.len() as i64)
78 }
79
80 "stddev" | "stddev_pop" => {
81 stat_aggregate(docs, field, expr, |variance, _n| variance.sqrt(), true)
82 }
83
84 "stddev_samp" => stat_aggregate(docs, field, expr, |variance, _n| variance.sqrt(), false),
85
86 "variance" | "var_pop" => stat_aggregate(docs, field, expr, |variance, _n| variance, true),
87
88 "var_samp" => stat_aggregate(docs, field, expr, |variance, _n| variance, false),
89
90 "array_agg" => {
91 let values: Vec<Value> = docs
92 .iter()
93 .filter_map(|d| extract_as_value(d, field, expr))
94 .filter(|v| !v.is_null())
95 .collect();
96 Value::Array(values)
97 }
98
99 "array_agg_distinct" => {
100 let mut seen_bytes = HashSet::new();
101 let mut values = Vec::new();
102 for doc in docs {
103 if let Some(expr) = expr {
106 let Some(val) = eval_expr_on_doc(doc, expr) else {
107 continue;
108 };
109 if val.is_null() {
110 continue;
111 }
112 let bytes = zerompk::to_msgpack_vec(&val).unwrap_or_default();
113 if seen_bytes.insert(bytes) {
114 values.push(val);
115 }
116 } else if let Some(bytes) = extract_value_bytes(doc, field, None)
117 && !value_bytes_are_null(&bytes)
118 && seen_bytes.insert(bytes)
119 && let Some(v) = value_from_field(doc, field)
120 {
121 values.push(v);
122 }
123 }
124 Value::Array(values)
125 }
126
127 "string_agg" | "group_concat" => {
128 let values: Vec<String> = docs
129 .iter()
130 .filter_map(|d| extract_str_val(d, field, expr))
131 .collect();
132 Value::String(values.join(","))
133 }
134
135 "approx_count_distinct" => {
136 let mut hll = nodedb_types::approx::HyperLogLog::new();
137 for doc in docs {
138 if let Some(bytes) = extract_value_bytes(doc, field, expr)
139 && !value_bytes_are_null(&bytes)
140 {
141 let hash = hash_bytes(&bytes);
143 hll.add(hash);
144 }
145 }
146 Value::Integer(hll.estimate().round() as i64)
147 }
148
149 "approx_percentile" => {
150 let (pct, actual_field) = if let Some(idx) = field.find(':') {
152 match field[..idx].parse::<f64>() {
153 Ok(p) => (p, &field[idx + 1..]),
154 Err(_) => return Value::Null, }
156 } else {
157 (0.5, field)
158 };
159 let mut digest = nodedb_types::approx::TDigest::new();
160 for doc in docs {
161 if let Some(v) = extract_f64_val(doc, actual_field, expr) {
162 digest.add(v);
163 }
164 }
165 let result = digest.quantile(pct);
166 if result.is_nan() {
167 Value::Null
168 } else {
169 Value::Float(result)
170 }
171 }
172
173 "approx_topk" => {
174 let (k, actual_field) = if let Some(idx) = field.find(':') {
176 match field[..idx].parse::<usize>() {
177 Ok(k) => (k, &field[idx + 1..]),
178 Err(_) => return Value::Null, }
180 } else {
181 (10, field)
182 };
183 let mut ss = nodedb_types::approx::SpaceSaving::new(k);
184 for doc in docs {
185 if let Some(bytes) = extract_value_bytes(doc, actual_field, expr)
186 && !value_bytes_are_null(&bytes)
187 {
188 ss.add(hash_bytes(&bytes));
189 }
190 }
191 let top = ss.top_k();
193 let arr: Vec<Value> = top
194 .into_iter()
195 .map(|(item, count, error)| {
196 Value::Object(
197 [
198 ("item".to_string(), Value::Integer(item as i64)),
199 ("count".to_string(), Value::Integer(count as i64)),
200 ("error".to_string(), Value::Integer(error as i64)),
201 ]
202 .into_iter()
203 .collect(),
204 )
205 })
206 .collect();
207 Value::Array(arr)
208 }
209
210 "percentile_cont" => {
211 let (pct, actual_field) = if let Some(idx) = field.find(':') {
212 match field[..idx].parse::<f64>() {
213 Ok(p) => (p, &field[idx + 1..]),
214 Err(_) => return Value::Null, }
216 } else {
217 (0.5, field)
218 };
219 let mut values: Vec<f64> = docs
220 .iter()
221 .filter_map(|d| extract_f64_val(d, actual_field, expr))
222 .collect();
223 if values.is_empty() {
224 return Value::Null;
225 }
226 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
227 let idx = (pct * (values.len() - 1) as f64).clamp(0.0, (values.len() - 1) as f64);
228 let lower = idx.floor() as usize;
229 let upper = idx.ceil() as usize;
230 let frac = idx - lower as f64;
231 let result = values[lower] * (1.0 - frac) + values[upper] * frac;
232 Value::Float(result)
233 }
234
235 _ => Value::Null,
236 }
237}
238
239#[inline]
244fn eval_expr_on_doc(doc: &[u8], expr: &crate::expr::SqlExpr) -> Option<Value> {
245 let doc_val = nodedb_types::json_msgpack::value_from_msgpack(doc).ok()?;
246 Some(expr.eval(&doc_val))
247}
248
249#[inline]
251fn extract_f64_val(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<f64> {
252 if let Some(expr) = expr {
253 return value_ops::value_to_f64(&eval_expr_on_doc(doc, expr)?, false);
254 }
255 let (start, _end) = extract_field(doc, 0, field)?;
256 read_f64(doc, start)
257}
258
259fn extract_str_val(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<String> {
261 if let Some(expr) = expr {
262 return Some(value_ops::value_to_display_string(&eval_expr_on_doc(
263 doc, expr,
264 )?));
265 }
266 let (start, _end) = extract_field(doc, 0, field)?;
267 read_str(doc, start).map(|s| s.to_string())
268}
269
270fn extract_as_value(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<Value> {
273 if let Some(expr) = expr {
274 return eval_expr_on_doc(doc, expr);
275 }
276 value_from_field(doc, field)
277}
278
279#[inline]
280fn value_from_field(doc: &[u8], field: &str) -> Option<Value> {
281 let (start, end) = extract_field(doc, 0, field)?;
282 if let Some(v) = crate::msgpack_scan::reader::read_value(doc, start) {
284 return Some(v);
285 }
286 let field_bytes = &doc[start..end];
288 nodedb_types::json_msgpack::value_from_msgpack(field_bytes).ok()
289}
290
291fn find_minmax(
293 docs: &[&[u8]],
294 field: &str,
295 expr: Option<&crate::expr::SqlExpr>,
296 want_max: bool,
297) -> Value {
298 if let Some(expr) = expr {
299 let mut best: Option<Value> = None;
302 for doc in docs {
303 let Some(value) = eval_expr_on_doc(doc, expr) else {
304 continue;
305 };
306 if value.is_null() {
307 continue;
308 }
309 let replace = match &best {
310 None => true,
311 Some(current) => {
312 let ord = value_ops::compare_values(&value, current);
313 if want_max {
314 ord == Ordering::Greater
315 } else {
316 ord == Ordering::Less
317 }
318 }
319 };
320 if replace {
321 best = Some(value);
322 }
323 }
324 return best.unwrap_or(Value::Null);
325 }
326
327 let mut best_doc: Option<&[u8]> = None;
328 let mut best_range: Option<(usize, usize)> = None;
329
330 for doc in docs {
331 if let Some(range) = extract_field(doc, 0, field) {
332 if read_null(doc, range.0) {
333 continue;
334 }
335 match best_range {
336 None => {
337 best_doc = Some(doc);
338 best_range = Some(range);
339 }
340 Some(br) => {
341 let Some(bd) = best_doc else { continue };
342 let cmp = compare_field_bytes(doc, range, bd, br);
343 let replace = if want_max {
344 cmp == Ordering::Greater
345 } else {
346 cmp == Ordering::Less
347 };
348 if replace {
349 best_doc = Some(doc);
350 best_range = Some(range);
351 }
352 }
353 }
354 }
355 }
356
357 match (best_doc, best_range) {
358 (Some(doc), Some((start, end))) => {
359 if let Some(v) = crate::msgpack_scan::reader::read_value(doc, start) {
360 return v;
361 }
362 let bytes = &doc[start..end];
363 nodedb_types::json_msgpack::value_from_msgpack(bytes).unwrap_or(Value::Null)
364 }
365 _ => Value::Null,
366 }
367}
368
369fn stat_aggregate(
372 docs: &[&[u8]],
373 field: &str,
374 expr: Option<&crate::expr::SqlExpr>,
375 finalize: fn(f64, usize) -> f64,
376 population: bool,
377) -> Value {
378 let values: Vec<f64> = docs
379 .iter()
380 .filter_map(|d| extract_f64_val(d, field, expr))
381 .collect();
382 if values.len() < 2 {
383 return Value::Null;
384 }
385 let mean = values.iter().sum::<f64>() / values.len() as f64;
386 let divisor = if population {
387 values.len() as f64
388 } else {
389 (values.len() - 1) as f64
390 };
391 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / divisor;
392 Value::Float(finalize(variance, values.len()))
393}
394
395fn extract_value_bytes(
396 doc: &[u8],
397 field: &str,
398 expr: Option<&crate::expr::SqlExpr>,
399) -> Option<Vec<u8>> {
400 if let Some(expr) = expr {
401 let val = eval_expr_on_doc(doc, expr)?;
402 return nodedb_types::json_msgpack::value_to_msgpack(&val).ok();
403 }
404 let (start, end) = extract_field(doc, 0, field)?;
405 Some(doc[start..end].to_vec())
406}
407
408fn value_bytes_are_null(bytes: &[u8]) -> bool {
410 bytes == [0xc0]
411}
412
413fn hash_bytes(bytes: &[u8]) -> u64 {
415 let mut h: u64 = 0xcbf29ce484222325;
416 for &b in bytes {
417 h ^= b as u64;
418 h = h.wrapping_mul(0x100000001b3);
419 }
420 h
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use serde_json::json;
427
428 fn encode(v: &serde_json::Value) -> Vec<u8> {
429 nodedb_types::json_msgpack::json_to_msgpack(v).expect("encode")
430 }
431
432 #[test]
433 fn count() {
434 let d1 = encode(&json!({"x": 1}));
435 let d2 = encode(&json!({"x": 2}));
436 let d3 = encode(&json!({"x": 3}));
437 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
438 assert_eq!(
439 compute_aggregate_binary("count", "x", None, &docs),
440 Value::Integer(3)
441 );
442 }
443
444 #[test]
445 fn sum() {
446 let d1 = encode(&json!({"v": 10}));
447 let d2 = encode(&json!({"v": 20}));
448 let d3 = encode(&json!({"v": 30}));
449 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
450 assert_eq!(
451 compute_aggregate_binary("sum", "v", None, &docs),
452 Value::Float(60.0)
453 );
454 }
455
456 #[test]
457 fn avg() {
458 let d1 = encode(&json!({"v": 10}));
459 let d2 = encode(&json!({"v": 20}));
460 let docs: Vec<&[u8]> = vec![&d1, &d2];
461 assert_eq!(
462 compute_aggregate_binary("avg", "v", None, &docs),
463 Value::Float(15.0)
464 );
465 }
466
467 #[test]
468 fn avg_empty() {
469 let d1 = encode(&json!({"other": 1}));
470 let docs: Vec<&[u8]> = vec![&d1];
471 assert_eq!(
472 compute_aggregate_binary("avg", "v", None, &docs),
473 Value::Null
474 );
475 }
476
477 #[test]
478 fn min_max() {
479 let d1 = encode(&json!({"v": 5}));
480 let d2 = encode(&json!({"v": 1}));
481 let d3 = encode(&json!({"v": 9}));
482 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
483
484 let min = compute_aggregate_binary("min", "v", None, &docs);
485 let max = compute_aggregate_binary("max", "v", None, &docs);
486 assert_eq!(min, Value::Integer(1));
487 assert_eq!(max, Value::Integer(9));
488 }
489
490 #[test]
491 fn count_distinct() {
492 let d1 = encode(&json!({"v": "a"}));
493 let d2 = encode(&json!({"v": "b"}));
494 let d3 = encode(&json!({"v": "a"}));
495 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
496 assert_eq!(
497 compute_aggregate_binary("count_distinct", "v", None, &docs),
498 Value::Integer(2)
499 );
500 }
501
502 #[test]
503 fn string_agg() {
504 let d1 = encode(&json!({"n": "alice"}));
505 let d2 = encode(&json!({"n": "bob"}));
506 let docs: Vec<&[u8]> = vec![&d1, &d2];
507 assert_eq!(
508 compute_aggregate_binary("string_agg", "n", None, &docs),
509 Value::String("alice,bob".into())
510 );
511 }
512
513 #[test]
514 fn array_agg() {
515 let d1 = encode(&json!({"v": 1}));
516 let d2 = encode(&json!({"v": 2}));
517 let docs: Vec<&[u8]> = vec![&d1, &d2];
518 let result = compute_aggregate_binary("array_agg", "v", None, &docs);
519 assert_eq!(
520 result,
521 Value::Array(vec![Value::Integer(1), Value::Integer(2),])
522 );
523 }
524
525 #[test]
526 fn stddev_pop() {
527 let d1 = encode(&json!({"v": 2.0}));
528 let d2 = encode(&json!({"v": 4.0}));
529 let d3 = encode(&json!({"v": 4.0}));
530 let d4 = encode(&json!({"v": 4.0}));
531 let d5 = encode(&json!({"v": 5.0}));
532 let d6 = encode(&json!({"v": 5.0}));
533 let d7 = encode(&json!({"v": 7.0}));
534 let d8 = encode(&json!({"v": 9.0}));
535 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3, &d4, &d5, &d6, &d7, &d8];
536 let result = compute_aggregate_binary("stddev_pop", "v", None, &docs);
537 if let Value::Float(v) = result {
538 assert!((v - 2.0).abs() < 0.01);
539 } else {
540 panic!("expected Float");
541 }
542 }
543
544 #[test]
545 fn percentile_cont_median() {
546 let d1 = encode(&json!({"v": 1.0}));
547 let d2 = encode(&json!({"v": 2.0}));
548 let d3 = encode(&json!({"v": 3.0}));
549 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
550 assert_eq!(
551 compute_aggregate_binary("percentile_cont", "v", None, &docs),
552 Value::Float(2.0)
553 );
554 }
555
556 #[test]
557 fn missing_field_skipped() {
558 let d1 = encode(&json!({"v": 10}));
559 let d2 = encode(&json!({"other": 99}));
560 let d3 = encode(&json!({"v": 30}));
561 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
562 assert_eq!(
563 compute_aggregate_binary("sum", "v", None, &docs),
564 Value::Float(40.0)
565 );
566 }
567
568 #[test]
569 fn null_field_skipped_in_count_distinct() {
570 let d1 = encode(&json!({"v": "a"}));
571 let d2 = encode(&json!({"v": null}));
572 let d3 = encode(&json!({"v": "a"}));
573 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
574 assert_eq!(
575 compute_aggregate_binary("count_distinct", "v", None, &docs),
576 Value::Integer(1)
577 );
578 }
579
580 #[test]
581 fn array_agg_distinct() {
582 let d1 = encode(&json!({"v": 1}));
583 let d2 = encode(&json!({"v": 2}));
584 let d3 = encode(&json!({"v": 1}));
585 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
586 let result = compute_aggregate_binary("array_agg_distinct", "v", None, &docs);
587 assert_eq!(
588 result,
589 Value::Array(vec![Value::Integer(1), Value::Integer(2),])
590 );
591 }
592
593 #[test]
594 fn sum_case_when_expression() {
595 let d1 = encode(&json!({"category": "tools"}));
596 let d2 = encode(&json!({"category": "books"}));
597 let d3 = encode(&json!({"category": "tools"}));
598 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
599 let expr = crate::expr::SqlExpr::Case {
600 operand: None,
601 when_thens: vec![(
602 crate::expr::SqlExpr::BinaryOp {
603 left: Box::new(crate::expr::SqlExpr::Column("category".into())),
604 op: crate::expr::BinaryOp::Eq,
605 right: Box::new(crate::expr::SqlExpr::Literal(Value::String("tools".into()))),
606 },
607 crate::expr::SqlExpr::Literal(Value::Integer(1)),
608 )],
609 else_expr: Some(Box::new(crate::expr::SqlExpr::Literal(Value::Integer(0)))),
610 };
611
612 assert_eq!(
613 compute_aggregate_binary("sum", "*", Some(&expr), &docs),
614 Value::Float(2.0)
615 );
616 }
617
618 #[test]
619 fn approx_count_distinct_basic() {
620 let docs: Vec<Vec<u8>> = vec![
621 encode(&json!({"region": "us"})),
622 encode(&json!({"region": "eu"})),
623 encode(&json!({"region": "us"})),
624 encode(&json!({"region": "ap"})),
625 ];
626 let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
627 let result = compute_aggregate_binary("approx_count_distinct", "region", None, &refs);
628 if let Value::Integer(n) = result {
630 assert!((2..=4).contains(&n), "expected ~3 distinct, got {n}");
631 } else {
632 panic!("expected Integer, got {result:?}");
633 }
634 }
635
636 #[test]
637 fn approx_percentile_basic() {
638 let docs: Vec<Vec<u8>> = (1..=100).map(|i| encode(&json!({"val": i}))).collect();
639 let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
640 let result = compute_aggregate_binary("approx_percentile", "0.5:val", None, &refs);
641 if let Value::Float(f) = result {
642 assert!(
643 (f - 50.0).abs() < 10.0,
644 "p50 of 1..100 should be ~50, got {f}"
645 );
646 } else {
647 panic!("expected Float, got {result:?}");
648 }
649 }
650
651 #[test]
652 fn approx_topk_basic() {
653 let mut docs: Vec<Vec<u8>> = Vec::new();
654 for _ in 0..10 {
655 docs.push(encode(&json!({"cat": "a"})));
656 }
657 for _ in 0..5 {
658 docs.push(encode(&json!({"cat": "b"})));
659 }
660 for _ in 0..1 {
661 docs.push(encode(&json!({"cat": "c"})));
662 }
663 let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
664 let result = compute_aggregate_binary("approx_topk", "3:cat", None, &refs);
665 if let Value::Array(arr) = result {
666 assert!(!arr.is_empty(), "should have top-k results");
667 } else {
668 panic!("expected Array, got {result:?}");
669 }
670 }
671}