1use std::cmp::Ordering;
9use std::collections::HashSet;
10
11use nodedb_types::Value;
12
13use crate::msgpack_scan::compare::compare_field_bytes;
14use crate::msgpack_scan::field::extract_field;
15use crate::msgpack_scan::reader::{read_f64, read_null, read_str};
16use crate::value_ops;
17
18pub fn compute_aggregate_binary(
24 op: &str,
25 field: &str,
26 expr: Option<&crate::expr::SqlExpr>,
27 docs: &[&[u8]],
28) -> Value {
29 match op {
30 "count" => {
31 if field == "*" && expr.is_none() {
32 Value::Integer(docs.len() as i64)
33 } else {
34 let count = docs
35 .iter()
36 .filter_map(|d| extract_as_value(d, field, expr))
37 .filter(|v| !v.is_null())
38 .count();
39 Value::Integer(count as i64)
40 }
41 }
42
43 "sum" => {
44 let total: f64 = docs
45 .iter()
46 .filter_map(|d| extract_f64_val(d, field, expr))
47 .sum();
48 Value::Float(total)
49 }
50
51 "avg" => {
52 let (sum, count) = docs
53 .iter()
54 .filter_map(|d| extract_f64_val(d, field, expr))
55 .fold((0.0f64, 0u64), |(s, c), v| (s + v, c + 1));
56 if count == 0 {
57 Value::Null
58 } else {
59 Value::Float(sum / count as f64)
60 }
61 }
62
63 "min" => find_minmax(docs, field, expr, false),
64 "max" => find_minmax(docs, field, expr, true),
65
66 "count_distinct" => {
67 let mut seen = HashSet::new();
68 for doc in docs {
69 if let Some(bytes) = extract_value_bytes(doc, field, expr)
70 && !value_bytes_are_null(&bytes)
71 {
72 seen.insert(bytes);
73 }
74 }
75 Value::Integer(seen.len() as i64)
76 }
77
78 "stddev" | "stddev_pop" => {
79 stat_aggregate(docs, field, expr, |variance, _n| variance.sqrt(), true)
80 }
81
82 "stddev_samp" => stat_aggregate(docs, field, expr, |variance, _n| variance.sqrt(), false),
83
84 "variance" | "var_pop" => stat_aggregate(docs, field, expr, |variance, _n| variance, true),
85
86 "var_samp" => stat_aggregate(docs, field, expr, |variance, _n| variance, false),
87
88 "array_agg" => {
89 let values: Vec<Value> = docs
90 .iter()
91 .filter_map(|d| extract_as_value(d, field, expr))
92 .filter(|v| !v.is_null())
93 .collect();
94 Value::Array(values)
95 }
96
97 "array_agg_distinct" => {
98 let mut seen_bytes = HashSet::new();
99 let mut values = Vec::new();
100 for doc in docs {
101 if let Some(expr) = expr {
104 let Some(val) = eval_expr_on_doc(doc, expr) else {
105 continue;
106 };
107 if val.is_null() {
108 continue;
109 }
110 let bytes = zerompk::to_msgpack_vec(&val).unwrap_or_default();
111 if seen_bytes.insert(bytes) {
112 values.push(val);
113 }
114 } else if let Some(bytes) = extract_value_bytes(doc, field, None)
115 && !value_bytes_are_null(&bytes)
116 && seen_bytes.insert(bytes)
117 && let Some(v) = value_from_field(doc, field)
118 {
119 values.push(v);
120 }
121 }
122 Value::Array(values)
123 }
124
125 "string_agg" | "group_concat" => {
126 let values: Vec<String> = docs
127 .iter()
128 .filter_map(|d| extract_str_val(d, field, expr))
129 .collect();
130 Value::String(values.join(","))
131 }
132
133 "approx_count_distinct" => {
134 let mut hll = nodedb_types::approx::HyperLogLog::new();
135 for doc in docs {
136 if let Some(bytes) = extract_value_bytes(doc, field, expr)
137 && !value_bytes_are_null(&bytes)
138 {
139 let hash = hash_bytes(&bytes);
141 hll.add(hash);
142 }
143 }
144 Value::Integer(hll.estimate().round() as i64)
145 }
146
147 "approx_percentile" => {
148 let (pct, actual_field) = if let Some(idx) = field.find(':') {
150 match field[..idx].parse::<f64>() {
151 Ok(p) => (p, &field[idx + 1..]),
152 Err(_) => return Value::Null, }
154 } else {
155 (0.5, field)
156 };
157 let mut digest = nodedb_types::approx::TDigest::new();
158 for doc in docs {
159 if let Some(v) = extract_f64_val(doc, actual_field, expr) {
160 digest.add(v);
161 }
162 }
163 let result = digest.quantile(pct);
164 if result.is_nan() {
165 Value::Null
166 } else {
167 Value::Float(result)
168 }
169 }
170
171 "approx_topk" => {
172 let (k, actual_field) = if let Some(idx) = field.find(':') {
174 match field[..idx].parse::<usize>() {
175 Ok(k) => (k, &field[idx + 1..]),
176 Err(_) => return Value::Null, }
178 } else {
179 (10, field)
180 };
181 let mut ss = nodedb_types::approx::SpaceSaving::new(k);
182 for doc in docs {
183 if let Some(bytes) = extract_value_bytes(doc, actual_field, expr)
184 && !value_bytes_are_null(&bytes)
185 {
186 ss.add(hash_bytes(&bytes));
187 }
188 }
189 let top = ss.top_k();
191 let arr: Vec<Value> = top
192 .into_iter()
193 .map(|(item, count, error)| {
194 Value::Object(
195 [
196 ("item".to_string(), Value::Integer(item as i64)),
197 ("count".to_string(), Value::Integer(count as i64)),
198 ("error".to_string(), Value::Integer(error as i64)),
199 ]
200 .into_iter()
201 .collect(),
202 )
203 })
204 .collect();
205 Value::Array(arr)
206 }
207
208 "percentile_cont" => {
209 let (pct, actual_field) = if let Some(idx) = field.find(':') {
210 match field[..idx].parse::<f64>() {
211 Ok(p) => (p, &field[idx + 1..]),
212 Err(_) => return Value::Null, }
214 } else {
215 (0.5, field)
216 };
217 let mut values: Vec<f64> = docs
218 .iter()
219 .filter_map(|d| extract_f64_val(d, actual_field, expr))
220 .collect();
221 if values.is_empty() {
222 return Value::Null;
223 }
224 values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
225 let idx = (pct * (values.len() - 1) as f64).clamp(0.0, (values.len() - 1) as f64);
226 let lower = idx.floor() as usize;
227 let upper = idx.ceil() as usize;
228 let frac = idx - lower as f64;
229 let result = values[lower] * (1.0 - frac) + values[upper] * frac;
230 Value::Float(result)
231 }
232
233 _ => Value::Null,
234 }
235}
236
237#[inline]
242fn eval_expr_on_doc(doc: &[u8], expr: &crate::expr::SqlExpr) -> Option<Value> {
243 let doc_val = nodedb_types::json_msgpack::value_from_msgpack(doc).ok()?;
244 Some(expr.eval(&doc_val))
245}
246
247#[inline]
249fn extract_f64_val(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<f64> {
250 if let Some(expr) = expr {
251 return value_ops::value_to_f64(&eval_expr_on_doc(doc, expr)?, false);
252 }
253 let (start, _end) = extract_field(doc, 0, field)?;
254 read_f64(doc, start)
255}
256
257fn extract_str_val(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<String> {
259 if let Some(expr) = expr {
260 return Some(value_ops::value_to_display_string(&eval_expr_on_doc(
261 doc, expr,
262 )?));
263 }
264 let (start, _end) = extract_field(doc, 0, field)?;
265 read_str(doc, start).map(|s| s.to_string())
266}
267
268fn extract_as_value(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<Value> {
271 if let Some(expr) = expr {
272 return eval_expr_on_doc(doc, expr);
273 }
274 value_from_field(doc, field)
275}
276
277#[inline]
278fn value_from_field(doc: &[u8], field: &str) -> Option<Value> {
279 let (start, end) = extract_field(doc, 0, field)?;
280 if let Some(v) = crate::msgpack_scan::reader::read_value(doc, start) {
282 return Some(v);
283 }
284 let field_bytes = &doc[start..end];
286 nodedb_types::json_msgpack::value_from_msgpack(field_bytes).ok()
287}
288
289fn find_minmax(
291 docs: &[&[u8]],
292 field: &str,
293 expr: Option<&crate::expr::SqlExpr>,
294 want_max: bool,
295) -> Value {
296 if let Some(expr) = expr {
297 let mut best: Option<Value> = None;
300 for doc in docs {
301 let Some(value) = eval_expr_on_doc(doc, expr) else {
302 continue;
303 };
304 if value.is_null() {
305 continue;
306 }
307 let replace = match &best {
308 None => true,
309 Some(current) => {
310 let ord = value_ops::compare_values(&value, current);
311 if want_max {
312 ord == Ordering::Greater
313 } else {
314 ord == Ordering::Less
315 }
316 }
317 };
318 if replace {
319 best = Some(value);
320 }
321 }
322 return best.unwrap_or(Value::Null);
323 }
324
325 let mut best_doc: Option<&[u8]> = None;
326 let mut best_range: Option<(usize, usize)> = None;
327
328 for doc in docs {
329 if let Some(range) = extract_field(doc, 0, field) {
330 if read_null(doc, range.0) {
331 continue;
332 }
333 match best_range {
334 None => {
335 best_doc = Some(doc);
336 best_range = Some(range);
337 }
338 Some(br) => {
339 let Some(bd) = best_doc else { continue };
340 let cmp = compare_field_bytes(doc, range, bd, br);
341 let replace = if want_max {
342 cmp == Ordering::Greater
343 } else {
344 cmp == Ordering::Less
345 };
346 if replace {
347 best_doc = Some(doc);
348 best_range = Some(range);
349 }
350 }
351 }
352 }
353 }
354
355 match (best_doc, best_range) {
356 (Some(doc), Some((start, end))) => {
357 if let Some(v) = crate::msgpack_scan::reader::read_value(doc, start) {
358 return v;
359 }
360 let bytes = &doc[start..end];
361 nodedb_types::json_msgpack::value_from_msgpack(bytes).unwrap_or(Value::Null)
362 }
363 _ => Value::Null,
364 }
365}
366
367fn stat_aggregate(
370 docs: &[&[u8]],
371 field: &str,
372 expr: Option<&crate::expr::SqlExpr>,
373 finalize: fn(f64, usize) -> f64,
374 population: bool,
375) -> Value {
376 let values: Vec<f64> = docs
377 .iter()
378 .filter_map(|d| extract_f64_val(d, field, expr))
379 .collect();
380 if values.len() < 2 {
381 return Value::Null;
382 }
383 let mean = values.iter().sum::<f64>() / values.len() as f64;
384 let divisor = if population {
385 values.len() as f64
386 } else {
387 (values.len() - 1) as f64
388 };
389 let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / divisor;
390 Value::Float(finalize(variance, values.len()))
391}
392
393fn extract_value_bytes(
394 doc: &[u8],
395 field: &str,
396 expr: Option<&crate::expr::SqlExpr>,
397) -> Option<Vec<u8>> {
398 if let Some(expr) = expr {
399 let val = eval_expr_on_doc(doc, expr)?;
400 return nodedb_types::json_msgpack::value_to_msgpack(&val).ok();
401 }
402 let (start, end) = extract_field(doc, 0, field)?;
403 Some(doc[start..end].to_vec())
404}
405
406fn value_bytes_are_null(bytes: &[u8]) -> bool {
408 bytes == [0xc0]
409}
410
411fn hash_bytes(bytes: &[u8]) -> u64 {
413 let mut h: u64 = 0xcbf29ce484222325;
414 for &b in bytes {
415 h ^= b as u64;
416 h = h.wrapping_mul(0x100000001b3);
417 }
418 h
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use serde_json::json;
425
426 fn encode(v: &serde_json::Value) -> Vec<u8> {
427 nodedb_types::json_msgpack::json_to_msgpack(v).expect("encode")
428 }
429
430 #[test]
431 fn count() {
432 let d1 = encode(&json!({"x": 1}));
433 let d2 = encode(&json!({"x": 2}));
434 let d3 = encode(&json!({"x": 3}));
435 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
436 assert_eq!(
437 compute_aggregate_binary("count", "x", None, &docs),
438 Value::Integer(3)
439 );
440 }
441
442 #[test]
443 fn sum() {
444 let d1 = encode(&json!({"v": 10}));
445 let d2 = encode(&json!({"v": 20}));
446 let d3 = encode(&json!({"v": 30}));
447 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
448 assert_eq!(
449 compute_aggregate_binary("sum", "v", None, &docs),
450 Value::Float(60.0)
451 );
452 }
453
454 #[test]
455 fn avg() {
456 let d1 = encode(&json!({"v": 10}));
457 let d2 = encode(&json!({"v": 20}));
458 let docs: Vec<&[u8]> = vec![&d1, &d2];
459 assert_eq!(
460 compute_aggregate_binary("avg", "v", None, &docs),
461 Value::Float(15.0)
462 );
463 }
464
465 #[test]
466 fn avg_empty() {
467 let d1 = encode(&json!({"other": 1}));
468 let docs: Vec<&[u8]> = vec![&d1];
469 assert_eq!(
470 compute_aggregate_binary("avg", "v", None, &docs),
471 Value::Null
472 );
473 }
474
475 #[test]
476 fn min_max() {
477 let d1 = encode(&json!({"v": 5}));
478 let d2 = encode(&json!({"v": 1}));
479 let d3 = encode(&json!({"v": 9}));
480 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
481
482 let min = compute_aggregate_binary("min", "v", None, &docs);
483 let max = compute_aggregate_binary("max", "v", None, &docs);
484 assert_eq!(min, Value::Integer(1));
485 assert_eq!(max, Value::Integer(9));
486 }
487
488 #[test]
489 fn count_distinct() {
490 let d1 = encode(&json!({"v": "a"}));
491 let d2 = encode(&json!({"v": "b"}));
492 let d3 = encode(&json!({"v": "a"}));
493 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
494 assert_eq!(
495 compute_aggregate_binary("count_distinct", "v", None, &docs),
496 Value::Integer(2)
497 );
498 }
499
500 #[test]
501 fn string_agg() {
502 let d1 = encode(&json!({"n": "alice"}));
503 let d2 = encode(&json!({"n": "bob"}));
504 let docs: Vec<&[u8]> = vec![&d1, &d2];
505 assert_eq!(
506 compute_aggregate_binary("string_agg", "n", None, &docs),
507 Value::String("alice,bob".into())
508 );
509 }
510
511 #[test]
512 fn array_agg() {
513 let d1 = encode(&json!({"v": 1}));
514 let d2 = encode(&json!({"v": 2}));
515 let docs: Vec<&[u8]> = vec![&d1, &d2];
516 let result = compute_aggregate_binary("array_agg", "v", None, &docs);
517 assert_eq!(
518 result,
519 Value::Array(vec![Value::Integer(1), Value::Integer(2),])
520 );
521 }
522
523 #[test]
524 fn stddev_pop() {
525 let d1 = encode(&json!({"v": 2.0}));
526 let d2 = encode(&json!({"v": 4.0}));
527 let d3 = encode(&json!({"v": 4.0}));
528 let d4 = encode(&json!({"v": 4.0}));
529 let d5 = encode(&json!({"v": 5.0}));
530 let d6 = encode(&json!({"v": 5.0}));
531 let d7 = encode(&json!({"v": 7.0}));
532 let d8 = encode(&json!({"v": 9.0}));
533 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3, &d4, &d5, &d6, &d7, &d8];
534 let result = compute_aggregate_binary("stddev_pop", "v", None, &docs);
535 if let Value::Float(v) = result {
536 assert!((v - 2.0).abs() < 0.01);
537 } else {
538 panic!("expected Float");
539 }
540 }
541
542 #[test]
543 fn percentile_cont_median() {
544 let d1 = encode(&json!({"v": 1.0}));
545 let d2 = encode(&json!({"v": 2.0}));
546 let d3 = encode(&json!({"v": 3.0}));
547 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
548 assert_eq!(
549 compute_aggregate_binary("percentile_cont", "v", None, &docs),
550 Value::Float(2.0)
551 );
552 }
553
554 #[test]
555 fn missing_field_skipped() {
556 let d1 = encode(&json!({"v": 10}));
557 let d2 = encode(&json!({"other": 99}));
558 let d3 = encode(&json!({"v": 30}));
559 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
560 assert_eq!(
561 compute_aggregate_binary("sum", "v", None, &docs),
562 Value::Float(40.0)
563 );
564 }
565
566 #[test]
567 fn null_field_skipped_in_count_distinct() {
568 let d1 = encode(&json!({"v": "a"}));
569 let d2 = encode(&json!({"v": null}));
570 let d3 = encode(&json!({"v": "a"}));
571 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
572 assert_eq!(
573 compute_aggregate_binary("count_distinct", "v", None, &docs),
574 Value::Integer(1)
575 );
576 }
577
578 #[test]
579 fn array_agg_distinct() {
580 let d1 = encode(&json!({"v": 1}));
581 let d2 = encode(&json!({"v": 2}));
582 let d3 = encode(&json!({"v": 1}));
583 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
584 let result = compute_aggregate_binary("array_agg_distinct", "v", None, &docs);
585 assert_eq!(
586 result,
587 Value::Array(vec![Value::Integer(1), Value::Integer(2),])
588 );
589 }
590
591 #[test]
592 fn sum_case_when_expression() {
593 let d1 = encode(&json!({"category": "tools"}));
594 let d2 = encode(&json!({"category": "books"}));
595 let d3 = encode(&json!({"category": "tools"}));
596 let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
597 let expr = crate::expr::SqlExpr::Case {
598 operand: None,
599 when_thens: vec![(
600 crate::expr::SqlExpr::BinaryOp {
601 left: Box::new(crate::expr::SqlExpr::Column("category".into())),
602 op: crate::expr::BinaryOp::Eq,
603 right: Box::new(crate::expr::SqlExpr::Literal(Value::String("tools".into()))),
604 },
605 crate::expr::SqlExpr::Literal(Value::Integer(1)),
606 )],
607 else_expr: Some(Box::new(crate::expr::SqlExpr::Literal(Value::Integer(0)))),
608 };
609
610 assert_eq!(
611 compute_aggregate_binary("sum", "*", Some(&expr), &docs),
612 Value::Float(2.0)
613 );
614 }
615
616 #[test]
617 fn approx_count_distinct_basic() {
618 let docs: Vec<Vec<u8>> = vec![
619 encode(&json!({"region": "us"})),
620 encode(&json!({"region": "eu"})),
621 encode(&json!({"region": "us"})),
622 encode(&json!({"region": "ap"})),
623 ];
624 let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
625 let result = compute_aggregate_binary("approx_count_distinct", "region", None, &refs);
626 if let Value::Integer(n) = result {
628 assert!((2..=4).contains(&n), "expected ~3 distinct, got {n}");
629 } else {
630 panic!("expected Integer, got {result:?}");
631 }
632 }
633
634 #[test]
635 fn approx_percentile_basic() {
636 let docs: Vec<Vec<u8>> = (1..=100).map(|i| encode(&json!({"val": i}))).collect();
637 let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
638 let result = compute_aggregate_binary("approx_percentile", "0.5:val", None, &refs);
639 if let Value::Float(f) = result {
640 assert!(
641 (f - 50.0).abs() < 10.0,
642 "p50 of 1..100 should be ~50, got {f}"
643 );
644 } else {
645 panic!("expected Float, got {result:?}");
646 }
647 }
648
649 #[test]
650 fn approx_topk_basic() {
651 let mut docs: Vec<Vec<u8>> = Vec::new();
652 for _ in 0..10 {
653 docs.push(encode(&json!({"cat": "a"})));
654 }
655 for _ in 0..5 {
656 docs.push(encode(&json!({"cat": "b"})));
657 }
658 for _ in 0..1 {
659 docs.push(encode(&json!({"cat": "c"})));
660 }
661 let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
662 let result = compute_aggregate_binary("approx_topk", "3:cat", None, &refs);
663 if let Value::Array(arr) = result {
664 assert!(!arr.is_empty(), "should have top-k results");
665 } else {
666 panic!("expected Array, got {result:?}");
667 }
668 }
669}