Skip to main content

nodedb_query/msgpack_scan/
aggregate.rs

1//! Zero-deserialization aggregate computation on raw MessagePack documents.
2//!
3//! Replaces `compute_aggregate(op, field, docs: &[serde_json::Value])` with
4//! direct binary field extraction. Each document is `&[u8]` (MessagePack map).
5//! When an expression is provided, decodes msgpack → `nodedb_types::Value`
6//! directly (no JSON intermediate) and evaluates the expression once per document.
7
8use std::cmp::Ordering;
9use std::collections::HashSet;
10
11use nodedb_types::Value;
12
13use crate::msgpack_scan::compare::compare_field_bytes;
14use crate::msgpack_scan::field::extract_field;
15use crate::msgpack_scan::reader::{read_f64, read_null, read_str};
16use crate::value_ops;
17
18/// Compute an aggregate function over raw MessagePack documents.
19///
20/// Each entry in `docs` is a complete MessagePack map (the raw bytes from storage).
21/// Returns the result as `Value` — conversion to JSON happens at the
22/// response boundary only.
23pub fn compute_aggregate_binary(
24    op: &str,
25    field: &str,
26    expr: Option<&crate::expr::SqlExpr>,
27    docs: &[&[u8]],
28) -> Value {
29    match op {
30        "count" => {
31            if field == "*" && expr.is_none() {
32                Value::Integer(docs.len() as i64)
33            } else {
34                let count = docs
35                    .iter()
36                    .filter_map(|d| extract_as_value(d, field, expr))
37                    .filter(|v| !v.is_null())
38                    .count();
39                Value::Integer(count as i64)
40            }
41        }
42
43        "sum" => {
44            let total: f64 = docs
45                .iter()
46                .filter_map(|d| extract_f64_val(d, field, expr))
47                .sum();
48            Value::Float(total)
49        }
50
51        "avg" => {
52            let (sum, count) = docs
53                .iter()
54                .filter_map(|d| extract_f64_val(d, field, expr))
55                .fold((0.0f64, 0u64), |(s, c), v| (s + v, c + 1));
56            if count == 0 {
57                Value::Null
58            } else {
59                Value::Float(sum / count as f64)
60            }
61        }
62
63        "min" => find_minmax(docs, field, expr, false),
64        "max" => find_minmax(docs, field, expr, true),
65
66        "count_distinct" => {
67            let mut seen = HashSet::new();
68            for doc in docs {
69                if let Some(bytes) = extract_value_bytes(doc, field, expr)
70                    && !value_bytes_are_null(&bytes)
71                {
72                    seen.insert(bytes);
73                }
74            }
75            Value::Integer(seen.len() as i64)
76        }
77
78        "stddev" | "stddev_pop" => {
79            stat_aggregate(docs, field, expr, |variance, _n| variance.sqrt(), true)
80        }
81
82        "stddev_samp" => stat_aggregate(docs, field, expr, |variance, _n| variance.sqrt(), false),
83
84        "variance" | "var_pop" => stat_aggregate(docs, field, expr, |variance, _n| variance, true),
85
86        "var_samp" => stat_aggregate(docs, field, expr, |variance, _n| variance, false),
87
88        "array_agg" => {
89            let values: Vec<Value> = docs
90                .iter()
91                .filter_map(|d| extract_as_value(d, field, expr))
92                .filter(|v| !v.is_null())
93                .collect();
94            Value::Array(values)
95        }
96
97        "array_agg_distinct" => {
98            let mut seen_bytes = HashSet::new();
99            let mut values = Vec::new();
100            for doc in docs {
101                // When expr is present, evaluate once and derive both bytes and value
102                // from the result to avoid double-decoding the document.
103                if let Some(expr) = expr {
104                    let Some(val) = eval_expr_on_doc(doc, expr) else {
105                        continue;
106                    };
107                    if val.is_null() {
108                        continue;
109                    }
110                    let bytes = zerompk::to_msgpack_vec(&val).unwrap_or_default();
111                    if seen_bytes.insert(bytes) {
112                        values.push(val);
113                    }
114                } else if let Some(bytes) = extract_value_bytes(doc, field, None)
115                    && !value_bytes_are_null(&bytes)
116                    && seen_bytes.insert(bytes)
117                    && let Some(v) = value_from_field(doc, field)
118                {
119                    values.push(v);
120                }
121            }
122            Value::Array(values)
123        }
124
125        "string_agg" | "group_concat" => {
126            let values: Vec<String> = docs
127                .iter()
128                .filter_map(|d| extract_str_val(d, field, expr))
129                .collect();
130            Value::String(values.join(","))
131        }
132
133        "approx_count_distinct" => {
134            let mut hll = nodedb_types::approx::HyperLogLog::new();
135            for doc in docs {
136                if let Some(bytes) = extract_value_bytes(doc, field, expr)
137                    && !value_bytes_are_null(&bytes)
138                {
139                    // Hash the raw bytes for HLL.
140                    let hash = hash_bytes(&bytes);
141                    hll.add(hash);
142                }
143            }
144            Value::Integer(hll.estimate().round() as i64)
145        }
146
147        "approx_percentile" => {
148            // Format: field is "quantile:actual_field" (e.g. "0.95:latency").
149            let (pct, actual_field) = if let Some(idx) = field.find(':') {
150                match field[..idx].parse::<f64>() {
151                    Ok(p) => (p, &field[idx + 1..]),
152                    Err(_) => return Value::Null, // invalid quantile
153                }
154            } else {
155                (0.5, field)
156            };
157            let mut digest = nodedb_types::approx::TDigest::new();
158            for doc in docs {
159                if let Some(v) = extract_f64_val(doc, actual_field, expr) {
160                    digest.add(v);
161                }
162            }
163            let result = digest.quantile(pct);
164            if result.is_nan() {
165                Value::Null
166            } else {
167                Value::Float(result)
168            }
169        }
170
171        "approx_topk" => {
172            // Format: field is "k:actual_field" (e.g. "10:region").
173            let (k, actual_field) = if let Some(idx) = field.find(':') {
174                match field[..idx].parse::<usize>() {
175                    Ok(k) => (k, &field[idx + 1..]),
176                    Err(_) => return Value::Null, // invalid k
177                }
178            } else {
179                (10, field)
180            };
181            let mut ss = nodedb_types::approx::SpaceSaving::new(k);
182            for doc in docs {
183                if let Some(bytes) = extract_value_bytes(doc, actual_field, expr)
184                    && !value_bytes_are_null(&bytes)
185                {
186                    ss.add(hash_bytes(&bytes));
187                }
188            }
189            // Return as array of [hash, count, error] tuples.
190            let top = ss.top_k();
191            let arr: Vec<Value> = top
192                .into_iter()
193                .map(|(item, count, error)| {
194                    Value::Object(
195                        [
196                            ("item".to_string(), Value::Integer(item as i64)),
197                            ("count".to_string(), Value::Integer(count as i64)),
198                            ("error".to_string(), Value::Integer(error as i64)),
199                        ]
200                        .into_iter()
201                        .collect(),
202                    )
203                })
204                .collect();
205            Value::Array(arr)
206        }
207
208        "percentile_cont" => {
209            let (pct, actual_field) = if let Some(idx) = field.find(':') {
210                match field[..idx].parse::<f64>() {
211                    Ok(p) => (p, &field[idx + 1..]),
212                    Err(_) => return Value::Null, // invalid quantile
213                }
214            } else {
215                (0.5, field)
216            };
217            let mut values: Vec<f64> = docs
218                .iter()
219                .filter_map(|d| extract_f64_val(d, actual_field, expr))
220                .collect();
221            if values.is_empty() {
222                return Value::Null;
223            }
224            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
225            let idx = (pct * (values.len() - 1) as f64).clamp(0.0, (values.len() - 1) as f64);
226            let lower = idx.floor() as usize;
227            let upper = idx.ceil() as usize;
228            let frac = idx - lower as f64;
229            let result = values[lower] * (1.0 - frac) + values[upper] * frac;
230            Value::Float(result)
231        }
232
233        _ => Value::Null,
234    }
235}
236
237// ── Internal helpers ───────────────────────────────────────────────────
238
239/// Decode a msgpack document directly to `nodedb_types::Value` and evaluate
240/// the expression. No JSON intermediate — msgpack → Value → eval → Value.
241#[inline]
242fn eval_expr_on_doc(doc: &[u8], expr: &crate::expr::SqlExpr) -> Option<Value> {
243    let doc_val = nodedb_types::json_msgpack::value_from_msgpack(doc).ok()?;
244    Some(expr.eval(&doc_val))
245}
246
247/// Extract a numeric value from a field or expression result.
248#[inline]
249fn extract_f64_val(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<f64> {
250    if let Some(expr) = expr {
251        return value_ops::value_to_f64(&eval_expr_on_doc(doc, expr)?, false);
252    }
253    let (start, _end) = extract_field(doc, 0, field)?;
254    read_f64(doc, start)
255}
256
257/// Extract a string from a field or expression result.
258fn extract_str_val(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<String> {
259    if let Some(expr) = expr {
260        return Some(value_ops::value_to_display_string(&eval_expr_on_doc(
261            doc, expr,
262        )?));
263    }
264    let (start, _end) = extract_field(doc, 0, field)?;
265    read_str(doc, start).map(|s| s.to_string())
266}
267
268/// Extract a field as `Value`. Uses direct msgpack→Value for scalars;
269/// falls back to full decode only for complex types.
270fn extract_as_value(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<Value> {
271    if let Some(expr) = expr {
272        return eval_expr_on_doc(doc, expr);
273    }
274    value_from_field(doc, field)
275}
276
277#[inline]
278fn value_from_field(doc: &[u8], field: &str) -> Option<Value> {
279    let (start, end) = extract_field(doc, 0, field)?;
280    // Fast path: scalar types (null, bool, int, float, string).
281    if let Some(v) = crate::msgpack_scan::reader::read_value(doc, start) {
282        return Some(v);
283    }
284    // Slow path: complex types (array, map, bin) — decode field bytes directly.
285    let field_bytes = &doc[start..end];
286    nodedb_types::json_msgpack::value_from_msgpack(field_bytes).ok()
287}
288
289/// Find min or max across docs by comparing raw field bytes.
290fn find_minmax(
291    docs: &[&[u8]],
292    field: &str,
293    expr: Option<&crate::expr::SqlExpr>,
294    want_max: bool,
295) -> Value {
296    if let Some(expr) = expr {
297        // Evaluate expression once per doc; compare on Value
298        // since the result may be any type (not a raw field).
299        let mut best: Option<Value> = None;
300        for doc in docs {
301            let Some(value) = eval_expr_on_doc(doc, expr) else {
302                continue;
303            };
304            if value.is_null() {
305                continue;
306            }
307            let replace = match &best {
308                None => true,
309                Some(current) => {
310                    let ord = value_ops::compare_values(&value, current);
311                    if want_max {
312                        ord == Ordering::Greater
313                    } else {
314                        ord == Ordering::Less
315                    }
316                }
317            };
318            if replace {
319                best = Some(value);
320            }
321        }
322        return best.unwrap_or(Value::Null);
323    }
324
325    let mut best_doc: Option<&[u8]> = None;
326    let mut best_range: Option<(usize, usize)> = None;
327
328    for doc in docs {
329        if let Some(range) = extract_field(doc, 0, field) {
330            if read_null(doc, range.0) {
331                continue;
332            }
333            match best_range {
334                None => {
335                    best_doc = Some(doc);
336                    best_range = Some(range);
337                }
338                Some(br) => {
339                    let Some(bd) = best_doc else { continue };
340                    let cmp = compare_field_bytes(doc, range, bd, br);
341                    let replace = if want_max {
342                        cmp == Ordering::Greater
343                    } else {
344                        cmp == Ordering::Less
345                    };
346                    if replace {
347                        best_doc = Some(doc);
348                        best_range = Some(range);
349                    }
350                }
351            }
352        }
353    }
354
355    match (best_doc, best_range) {
356        (Some(doc), Some((start, end))) => {
357            if let Some(v) = crate::msgpack_scan::reader::read_value(doc, start) {
358                return v;
359            }
360            let bytes = &doc[start..end];
361            nodedb_types::json_msgpack::value_from_msgpack(bytes).unwrap_or(Value::Null)
362        }
363        _ => Value::Null,
364    }
365}
366
367/// Compute stddev or variance. `population` = true for population variant.
368/// `finalize` transforms the variance into the final result.
369fn stat_aggregate(
370    docs: &[&[u8]],
371    field: &str,
372    expr: Option<&crate::expr::SqlExpr>,
373    finalize: fn(f64, usize) -> f64,
374    population: bool,
375) -> Value {
376    let values: Vec<f64> = docs
377        .iter()
378        .filter_map(|d| extract_f64_val(d, field, expr))
379        .collect();
380    if values.len() < 2 {
381        return Value::Null;
382    }
383    let mean = values.iter().sum::<f64>() / values.len() as f64;
384    let divisor = if population {
385        values.len() as f64
386    } else {
387        (values.len() - 1) as f64
388    };
389    let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / divisor;
390    Value::Float(finalize(variance, values.len()))
391}
392
393fn extract_value_bytes(
394    doc: &[u8],
395    field: &str,
396    expr: Option<&crate::expr::SqlExpr>,
397) -> Option<Vec<u8>> {
398    if let Some(expr) = expr {
399        let val = eval_expr_on_doc(doc, expr)?;
400        return nodedb_types::json_msgpack::value_to_msgpack(&val).ok();
401    }
402    let (start, end) = extract_field(doc, 0, field)?;
403    Some(doc[start..end].to_vec())
404}
405
406/// Check if msgpack bytes represent null. Msgpack null is the single byte 0xc0.
407fn value_bytes_are_null(bytes: &[u8]) -> bool {
408    bytes == [0xc0]
409}
410
411/// FNV-1a hash for raw bytes (used by approx aggregates to feed HLL/SpaceSaving).
412fn hash_bytes(bytes: &[u8]) -> u64 {
413    let mut h: u64 = 0xcbf29ce484222325;
414    for &b in bytes {
415        h ^= b as u64;
416        h = h.wrapping_mul(0x100000001b3);
417    }
418    h
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use serde_json::json;
425
426    fn encode(v: &serde_json::Value) -> Vec<u8> {
427        nodedb_types::json_msgpack::json_to_msgpack(v).expect("encode")
428    }
429
430    #[test]
431    fn count() {
432        let d1 = encode(&json!({"x": 1}));
433        let d2 = encode(&json!({"x": 2}));
434        let d3 = encode(&json!({"x": 3}));
435        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
436        assert_eq!(
437            compute_aggregate_binary("count", "x", None, &docs),
438            Value::Integer(3)
439        );
440    }
441
442    #[test]
443    fn sum() {
444        let d1 = encode(&json!({"v": 10}));
445        let d2 = encode(&json!({"v": 20}));
446        let d3 = encode(&json!({"v": 30}));
447        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
448        assert_eq!(
449            compute_aggregate_binary("sum", "v", None, &docs),
450            Value::Float(60.0)
451        );
452    }
453
454    #[test]
455    fn avg() {
456        let d1 = encode(&json!({"v": 10}));
457        let d2 = encode(&json!({"v": 20}));
458        let docs: Vec<&[u8]> = vec![&d1, &d2];
459        assert_eq!(
460            compute_aggregate_binary("avg", "v", None, &docs),
461            Value::Float(15.0)
462        );
463    }
464
465    #[test]
466    fn avg_empty() {
467        let d1 = encode(&json!({"other": 1}));
468        let docs: Vec<&[u8]> = vec![&d1];
469        assert_eq!(
470            compute_aggregate_binary("avg", "v", None, &docs),
471            Value::Null
472        );
473    }
474
475    #[test]
476    fn min_max() {
477        let d1 = encode(&json!({"v": 5}));
478        let d2 = encode(&json!({"v": 1}));
479        let d3 = encode(&json!({"v": 9}));
480        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
481
482        let min = compute_aggregate_binary("min", "v", None, &docs);
483        let max = compute_aggregate_binary("max", "v", None, &docs);
484        assert_eq!(min, Value::Integer(1));
485        assert_eq!(max, Value::Integer(9));
486    }
487
488    #[test]
489    fn count_distinct() {
490        let d1 = encode(&json!({"v": "a"}));
491        let d2 = encode(&json!({"v": "b"}));
492        let d3 = encode(&json!({"v": "a"}));
493        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
494        assert_eq!(
495            compute_aggregate_binary("count_distinct", "v", None, &docs),
496            Value::Integer(2)
497        );
498    }
499
500    #[test]
501    fn string_agg() {
502        let d1 = encode(&json!({"n": "alice"}));
503        let d2 = encode(&json!({"n": "bob"}));
504        let docs: Vec<&[u8]> = vec![&d1, &d2];
505        assert_eq!(
506            compute_aggregate_binary("string_agg", "n", None, &docs),
507            Value::String("alice,bob".into())
508        );
509    }
510
511    #[test]
512    fn array_agg() {
513        let d1 = encode(&json!({"v": 1}));
514        let d2 = encode(&json!({"v": 2}));
515        let docs: Vec<&[u8]> = vec![&d1, &d2];
516        let result = compute_aggregate_binary("array_agg", "v", None, &docs);
517        assert_eq!(
518            result,
519            Value::Array(vec![Value::Integer(1), Value::Integer(2),])
520        );
521    }
522
523    #[test]
524    fn stddev_pop() {
525        let d1 = encode(&json!({"v": 2.0}));
526        let d2 = encode(&json!({"v": 4.0}));
527        let d3 = encode(&json!({"v": 4.0}));
528        let d4 = encode(&json!({"v": 4.0}));
529        let d5 = encode(&json!({"v": 5.0}));
530        let d6 = encode(&json!({"v": 5.0}));
531        let d7 = encode(&json!({"v": 7.0}));
532        let d8 = encode(&json!({"v": 9.0}));
533        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3, &d4, &d5, &d6, &d7, &d8];
534        let result = compute_aggregate_binary("stddev_pop", "v", None, &docs);
535        if let Value::Float(v) = result {
536            assert!((v - 2.0).abs() < 0.01);
537        } else {
538            panic!("expected Float");
539        }
540    }
541
542    #[test]
543    fn percentile_cont_median() {
544        let d1 = encode(&json!({"v": 1.0}));
545        let d2 = encode(&json!({"v": 2.0}));
546        let d3 = encode(&json!({"v": 3.0}));
547        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
548        assert_eq!(
549            compute_aggregate_binary("percentile_cont", "v", None, &docs),
550            Value::Float(2.0)
551        );
552    }
553
554    #[test]
555    fn missing_field_skipped() {
556        let d1 = encode(&json!({"v": 10}));
557        let d2 = encode(&json!({"other": 99}));
558        let d3 = encode(&json!({"v": 30}));
559        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
560        assert_eq!(
561            compute_aggregate_binary("sum", "v", None, &docs),
562            Value::Float(40.0)
563        );
564    }
565
566    #[test]
567    fn null_field_skipped_in_count_distinct() {
568        let d1 = encode(&json!({"v": "a"}));
569        let d2 = encode(&json!({"v": null}));
570        let d3 = encode(&json!({"v": "a"}));
571        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
572        assert_eq!(
573            compute_aggregate_binary("count_distinct", "v", None, &docs),
574            Value::Integer(1)
575        );
576    }
577
578    #[test]
579    fn array_agg_distinct() {
580        let d1 = encode(&json!({"v": 1}));
581        let d2 = encode(&json!({"v": 2}));
582        let d3 = encode(&json!({"v": 1}));
583        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
584        let result = compute_aggregate_binary("array_agg_distinct", "v", None, &docs);
585        assert_eq!(
586            result,
587            Value::Array(vec![Value::Integer(1), Value::Integer(2),])
588        );
589    }
590
591    #[test]
592    fn sum_case_when_expression() {
593        let d1 = encode(&json!({"category": "tools"}));
594        let d2 = encode(&json!({"category": "books"}));
595        let d3 = encode(&json!({"category": "tools"}));
596        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
597        let expr = crate::expr::SqlExpr::Case {
598            operand: None,
599            when_thens: vec![(
600                crate::expr::SqlExpr::BinaryOp {
601                    left: Box::new(crate::expr::SqlExpr::Column("category".into())),
602                    op: crate::expr::BinaryOp::Eq,
603                    right: Box::new(crate::expr::SqlExpr::Literal(Value::String("tools".into()))),
604                },
605                crate::expr::SqlExpr::Literal(Value::Integer(1)),
606            )],
607            else_expr: Some(Box::new(crate::expr::SqlExpr::Literal(Value::Integer(0)))),
608        };
609
610        assert_eq!(
611            compute_aggregate_binary("sum", "*", Some(&expr), &docs),
612            Value::Float(2.0)
613        );
614    }
615
616    #[test]
617    fn approx_count_distinct_basic() {
618        let docs: Vec<Vec<u8>> = vec![
619            encode(&json!({"region": "us"})),
620            encode(&json!({"region": "eu"})),
621            encode(&json!({"region": "us"})),
622            encode(&json!({"region": "ap"})),
623        ];
624        let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
625        let result = compute_aggregate_binary("approx_count_distinct", "region", None, &refs);
626        // HLL may not be exactly 3 but should be close.
627        if let Value::Integer(n) = result {
628            assert!((2..=4).contains(&n), "expected ~3 distinct, got {n}");
629        } else {
630            panic!("expected Integer, got {result:?}");
631        }
632    }
633
634    #[test]
635    fn approx_percentile_basic() {
636        let docs: Vec<Vec<u8>> = (1..=100).map(|i| encode(&json!({"val": i}))).collect();
637        let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
638        let result = compute_aggregate_binary("approx_percentile", "0.5:val", None, &refs);
639        if let Value::Float(f) = result {
640            assert!(
641                (f - 50.0).abs() < 10.0,
642                "p50 of 1..100 should be ~50, got {f}"
643            );
644        } else {
645            panic!("expected Float, got {result:?}");
646        }
647    }
648
649    #[test]
650    fn approx_topk_basic() {
651        let mut docs: Vec<Vec<u8>> = Vec::new();
652        for _ in 0..10 {
653            docs.push(encode(&json!({"cat": "a"})));
654        }
655        for _ in 0..5 {
656            docs.push(encode(&json!({"cat": "b"})));
657        }
658        for _ in 0..1 {
659            docs.push(encode(&json!({"cat": "c"})));
660        }
661        let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
662        let result = compute_aggregate_binary("approx_topk", "3:cat", None, &refs);
663        if let Value::Array(arr) = result {
664            assert!(!arr.is_empty(), "should have top-k results");
665        } else {
666            panic!("expected Array, got {result:?}");
667        }
668    }
669}