Skip to main content

nodedb_query/msgpack_scan/
aggregate.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Zero-deserialization aggregate computation on raw MessagePack documents.
4//!
5//! Replaces `compute_aggregate(op, field, docs: &[serde_json::Value])` with
6//! direct binary field extraction. Each document is `&[u8]` (MessagePack map).
7//! When an expression is provided, decodes msgpack → `nodedb_types::Value`
8//! directly (no JSON intermediate) and evaluates the expression once per document.
9
10use std::cmp::Ordering;
11use std::collections::HashSet;
12
13use nodedb_types::Value;
14
15use crate::msgpack_scan::compare::compare_field_bytes;
16use crate::msgpack_scan::field::extract_field;
17use crate::msgpack_scan::reader::{read_f64, read_null, read_str};
18use crate::value_ops;
19
20/// Compute an aggregate function over raw MessagePack documents.
21///
22/// Each entry in `docs` is a complete MessagePack map (the raw bytes from storage).
23/// Returns the result as `Value` — conversion to JSON happens at the
24/// response boundary only.
25pub fn compute_aggregate_binary(
26    op: &str,
27    field: &str,
28    expr: Option<&crate::expr::SqlExpr>,
29    docs: &[&[u8]],
30) -> Value {
31    match op {
32        "count" => {
33            if field == "*" && expr.is_none() {
34                Value::Integer(docs.len() as i64)
35            } else {
36                let count = docs
37                    .iter()
38                    .filter_map(|d| extract_as_value(d, field, expr))
39                    .filter(|v| !v.is_null())
40                    .count();
41                Value::Integer(count as i64)
42            }
43        }
44
45        "sum" => {
46            let total: f64 = docs
47                .iter()
48                .filter_map(|d| extract_f64_val(d, field, expr))
49                .sum();
50            Value::Float(total)
51        }
52
53        "avg" => {
54            let (sum, count) = docs
55                .iter()
56                .filter_map(|d| extract_f64_val(d, field, expr))
57                .fold((0.0f64, 0u64), |(s, c), v| (s + v, c + 1));
58            if count == 0 {
59                Value::Null
60            } else {
61                Value::Float(sum / count as f64)
62            }
63        }
64
65        "min" => find_minmax(docs, field, expr, false),
66        "max" => find_minmax(docs, field, expr, true),
67
68        "count_distinct" => {
69            let mut seen = HashSet::new();
70            for doc in docs {
71                if let Some(bytes) = extract_value_bytes(doc, field, expr)
72                    && !value_bytes_are_null(&bytes)
73                {
74                    seen.insert(bytes);
75                }
76            }
77            Value::Integer(seen.len() as i64)
78        }
79
80        "stddev" | "stddev_pop" => {
81            stat_aggregate(docs, field, expr, |variance, _n| variance.sqrt(), true)
82        }
83
84        "stddev_samp" => stat_aggregate(docs, field, expr, |variance, _n| variance.sqrt(), false),
85
86        "variance" | "var_pop" => stat_aggregate(docs, field, expr, |variance, _n| variance, true),
87
88        "var_samp" => stat_aggregate(docs, field, expr, |variance, _n| variance, false),
89
90        "array_agg" => {
91            let values: Vec<Value> = docs
92                .iter()
93                .filter_map(|d| extract_as_value(d, field, expr))
94                .filter(|v| !v.is_null())
95                .collect();
96            Value::Array(values)
97        }
98
99        "array_agg_distinct" => {
100            let mut seen_bytes = HashSet::new();
101            let mut values = Vec::new();
102            for doc in docs {
103                // When expr is present, evaluate once and derive both bytes and value
104                // from the result to avoid double-decoding the document.
105                if let Some(expr) = expr {
106                    let Some(val) = eval_expr_on_doc(doc, expr) else {
107                        continue;
108                    };
109                    if val.is_null() {
110                        continue;
111                    }
112                    let bytes = zerompk::to_msgpack_vec(&val).unwrap_or_default();
113                    if seen_bytes.insert(bytes) {
114                        values.push(val);
115                    }
116                } else if let Some(bytes) = extract_value_bytes(doc, field, None)
117                    && !value_bytes_are_null(&bytes)
118                    && seen_bytes.insert(bytes)
119                    && let Some(v) = value_from_field(doc, field)
120                {
121                    values.push(v);
122                }
123            }
124            Value::Array(values)
125        }
126
127        "string_agg" | "group_concat" => {
128            let values: Vec<String> = docs
129                .iter()
130                .filter_map(|d| extract_str_val(d, field, expr))
131                .collect();
132            Value::String(values.join(","))
133        }
134
135        "approx_count_distinct" => {
136            let mut hll = nodedb_types::approx::HyperLogLog::new();
137            for doc in docs {
138                if let Some(bytes) = extract_value_bytes(doc, field, expr)
139                    && !value_bytes_are_null(&bytes)
140                {
141                    // Hash the raw bytes for HLL.
142                    let hash = hash_bytes(&bytes);
143                    hll.add(hash);
144                }
145            }
146            Value::Integer(hll.estimate().round() as i64)
147        }
148
149        "approx_percentile" => {
150            // Format: field is "quantile:actual_field" (e.g. "0.95:latency").
151            let (pct, actual_field) = if let Some(idx) = field.find(':') {
152                match field[..idx].parse::<f64>() {
153                    Ok(p) => (p, &field[idx + 1..]),
154                    Err(_) => return Value::Null, // invalid quantile
155                }
156            } else {
157                (0.5, field)
158            };
159            let mut digest = nodedb_types::approx::TDigest::new();
160            for doc in docs {
161                if let Some(v) = extract_f64_val(doc, actual_field, expr) {
162                    digest.add(v);
163                }
164            }
165            let result = digest.quantile(pct);
166            if result.is_nan() {
167                Value::Null
168            } else {
169                Value::Float(result)
170            }
171        }
172
173        "approx_topk" => {
174            // Format: field is "k:actual_field" (e.g. "10:region").
175            let (k, actual_field) = if let Some(idx) = field.find(':') {
176                match field[..idx].parse::<usize>() {
177                    Ok(k) => (k, &field[idx + 1..]),
178                    Err(_) => return Value::Null, // invalid k
179                }
180            } else {
181                (10, field)
182            };
183            let mut ss = nodedb_types::approx::SpaceSaving::new(k);
184            for doc in docs {
185                if let Some(bytes) = extract_value_bytes(doc, actual_field, expr)
186                    && !value_bytes_are_null(&bytes)
187                {
188                    ss.add(hash_bytes(&bytes));
189                }
190            }
191            // Return as array of [hash, count, error] tuples.
192            let top = ss.top_k();
193            let arr: Vec<Value> = top
194                .into_iter()
195                .map(|(item, count, error)| {
196                    Value::Object(
197                        [
198                            ("item".to_string(), Value::Integer(item as i64)),
199                            ("count".to_string(), Value::Integer(count as i64)),
200                            ("error".to_string(), Value::Integer(error as i64)),
201                        ]
202                        .into_iter()
203                        .collect(),
204                    )
205                })
206                .collect();
207            Value::Array(arr)
208        }
209
210        "percentile_cont" => {
211            let (pct, actual_field) = if let Some(idx) = field.find(':') {
212                match field[..idx].parse::<f64>() {
213                    Ok(p) => (p, &field[idx + 1..]),
214                    Err(_) => return Value::Null, // invalid quantile
215                }
216            } else {
217                (0.5, field)
218            };
219            let mut values: Vec<f64> = docs
220                .iter()
221                .filter_map(|d| extract_f64_val(d, actual_field, expr))
222                .collect();
223            if values.is_empty() {
224                return Value::Null;
225            }
226            values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
227            let idx = (pct * (values.len() - 1) as f64).clamp(0.0, (values.len() - 1) as f64);
228            let lower = idx.floor() as usize;
229            let upper = idx.ceil() as usize;
230            let frac = idx - lower as f64;
231            let result = values[lower] * (1.0 - frac) + values[upper] * frac;
232            Value::Float(result)
233        }
234
235        _ => Value::Null,
236    }
237}
238
239// ── Internal helpers ───────────────────────────────────────────────────
240
241/// Decode a msgpack document directly to `nodedb_types::Value` and evaluate
242/// the expression. No JSON intermediate — msgpack → Value → eval → Value.
243#[inline]
244fn eval_expr_on_doc(doc: &[u8], expr: &crate::expr::SqlExpr) -> Option<Value> {
245    let doc_val = nodedb_types::json_msgpack::value_from_msgpack(doc).ok()?;
246    Some(expr.eval(&doc_val))
247}
248
249/// Extract a numeric value from a field or expression result.
250#[inline]
251fn extract_f64_val(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<f64> {
252    if let Some(expr) = expr {
253        return value_ops::value_to_f64(&eval_expr_on_doc(doc, expr)?, false);
254    }
255    let (start, _end) = extract_field(doc, 0, field)?;
256    read_f64(doc, start)
257}
258
259/// Extract a string from a field or expression result.
260fn extract_str_val(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<String> {
261    if let Some(expr) = expr {
262        return Some(value_ops::value_to_display_string(&eval_expr_on_doc(
263            doc, expr,
264        )?));
265    }
266    let (start, _end) = extract_field(doc, 0, field)?;
267    read_str(doc, start).map(|s| s.to_string())
268}
269
270/// Extract a field as `Value`. Uses direct msgpack→Value for scalars;
271/// falls back to full decode only for complex types.
272fn extract_as_value(doc: &[u8], field: &str, expr: Option<&crate::expr::SqlExpr>) -> Option<Value> {
273    if let Some(expr) = expr {
274        return eval_expr_on_doc(doc, expr);
275    }
276    value_from_field(doc, field)
277}
278
279#[inline]
280fn value_from_field(doc: &[u8], field: &str) -> Option<Value> {
281    let (start, end) = extract_field(doc, 0, field)?;
282    // Fast path: scalar types (null, bool, int, float, string).
283    if let Some(v) = crate::msgpack_scan::reader::read_value(doc, start) {
284        return Some(v);
285    }
286    // Slow path: complex types (array, map, bin) — decode field bytes directly.
287    let field_bytes = &doc[start..end];
288    nodedb_types::json_msgpack::value_from_msgpack(field_bytes).ok()
289}
290
291/// Find min or max across docs by comparing raw field bytes.
292fn find_minmax(
293    docs: &[&[u8]],
294    field: &str,
295    expr: Option<&crate::expr::SqlExpr>,
296    want_max: bool,
297) -> Value {
298    if let Some(expr) = expr {
299        // Evaluate expression once per doc; compare on Value
300        // since the result may be any type (not a raw field).
301        let mut best: Option<Value> = None;
302        for doc in docs {
303            let Some(value) = eval_expr_on_doc(doc, expr) else {
304                continue;
305            };
306            if value.is_null() {
307                continue;
308            }
309            let replace = match &best {
310                None => true,
311                Some(current) => {
312                    let ord = value_ops::compare_values(&value, current);
313                    if want_max {
314                        ord == Ordering::Greater
315                    } else {
316                        ord == Ordering::Less
317                    }
318                }
319            };
320            if replace {
321                best = Some(value);
322            }
323        }
324        return best.unwrap_or(Value::Null);
325    }
326
327    let mut best_doc: Option<&[u8]> = None;
328    let mut best_range: Option<(usize, usize)> = None;
329
330    for doc in docs {
331        if let Some(range) = extract_field(doc, 0, field) {
332            if read_null(doc, range.0) {
333                continue;
334            }
335            match best_range {
336                None => {
337                    best_doc = Some(doc);
338                    best_range = Some(range);
339                }
340                Some(br) => {
341                    let Some(bd) = best_doc else { continue };
342                    let cmp = compare_field_bytes(doc, range, bd, br);
343                    let replace = if want_max {
344                        cmp == Ordering::Greater
345                    } else {
346                        cmp == Ordering::Less
347                    };
348                    if replace {
349                        best_doc = Some(doc);
350                        best_range = Some(range);
351                    }
352                }
353            }
354        }
355    }
356
357    match (best_doc, best_range) {
358        (Some(doc), Some((start, end))) => {
359            if let Some(v) = crate::msgpack_scan::reader::read_value(doc, start) {
360                return v;
361            }
362            let bytes = &doc[start..end];
363            nodedb_types::json_msgpack::value_from_msgpack(bytes).unwrap_or(Value::Null)
364        }
365        _ => Value::Null,
366    }
367}
368
369/// Compute stddev or variance. `population` = true for population variant.
370/// `finalize` transforms the variance into the final result.
371fn stat_aggregate(
372    docs: &[&[u8]],
373    field: &str,
374    expr: Option<&crate::expr::SqlExpr>,
375    finalize: fn(f64, usize) -> f64,
376    population: bool,
377) -> Value {
378    let values: Vec<f64> = docs
379        .iter()
380        .filter_map(|d| extract_f64_val(d, field, expr))
381        .collect();
382    if values.len() < 2 {
383        return Value::Null;
384    }
385    let mean = values.iter().sum::<f64>() / values.len() as f64;
386    let divisor = if population {
387        values.len() as f64
388    } else {
389        (values.len() - 1) as f64
390    };
391    let variance = values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / divisor;
392    Value::Float(finalize(variance, values.len()))
393}
394
395fn extract_value_bytes(
396    doc: &[u8],
397    field: &str,
398    expr: Option<&crate::expr::SqlExpr>,
399) -> Option<Vec<u8>> {
400    if let Some(expr) = expr {
401        let val = eval_expr_on_doc(doc, expr)?;
402        return nodedb_types::json_msgpack::value_to_msgpack(&val).ok();
403    }
404    let (start, end) = extract_field(doc, 0, field)?;
405    Some(doc[start..end].to_vec())
406}
407
408/// Check if msgpack bytes represent null. Msgpack null is the single byte 0xc0.
409fn value_bytes_are_null(bytes: &[u8]) -> bool {
410    bytes == [0xc0]
411}
412
413/// FNV-1a hash for raw bytes (used by approx aggregates to feed HLL/SpaceSaving).
414fn hash_bytes(bytes: &[u8]) -> u64 {
415    let mut h: u64 = 0xcbf29ce484222325;
416    for &b in bytes {
417        h ^= b as u64;
418        h = h.wrapping_mul(0x100000001b3);
419    }
420    h
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use serde_json::json;
427
428    fn encode(v: &serde_json::Value) -> Vec<u8> {
429        nodedb_types::json_msgpack::json_to_msgpack(v).expect("encode")
430    }
431
432    #[test]
433    fn count() {
434        let d1 = encode(&json!({"x": 1}));
435        let d2 = encode(&json!({"x": 2}));
436        let d3 = encode(&json!({"x": 3}));
437        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
438        assert_eq!(
439            compute_aggregate_binary("count", "x", None, &docs),
440            Value::Integer(3)
441        );
442    }
443
444    #[test]
445    fn sum() {
446        let d1 = encode(&json!({"v": 10}));
447        let d2 = encode(&json!({"v": 20}));
448        let d3 = encode(&json!({"v": 30}));
449        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
450        assert_eq!(
451            compute_aggregate_binary("sum", "v", None, &docs),
452            Value::Float(60.0)
453        );
454    }
455
456    #[test]
457    fn avg() {
458        let d1 = encode(&json!({"v": 10}));
459        let d2 = encode(&json!({"v": 20}));
460        let docs: Vec<&[u8]> = vec![&d1, &d2];
461        assert_eq!(
462            compute_aggregate_binary("avg", "v", None, &docs),
463            Value::Float(15.0)
464        );
465    }
466
467    #[test]
468    fn avg_empty() {
469        let d1 = encode(&json!({"other": 1}));
470        let docs: Vec<&[u8]> = vec![&d1];
471        assert_eq!(
472            compute_aggregate_binary("avg", "v", None, &docs),
473            Value::Null
474        );
475    }
476
477    #[test]
478    fn min_max() {
479        let d1 = encode(&json!({"v": 5}));
480        let d2 = encode(&json!({"v": 1}));
481        let d3 = encode(&json!({"v": 9}));
482        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
483
484        let min = compute_aggregate_binary("min", "v", None, &docs);
485        let max = compute_aggregate_binary("max", "v", None, &docs);
486        assert_eq!(min, Value::Integer(1));
487        assert_eq!(max, Value::Integer(9));
488    }
489
490    #[test]
491    fn count_distinct() {
492        let d1 = encode(&json!({"v": "a"}));
493        let d2 = encode(&json!({"v": "b"}));
494        let d3 = encode(&json!({"v": "a"}));
495        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
496        assert_eq!(
497            compute_aggregate_binary("count_distinct", "v", None, &docs),
498            Value::Integer(2)
499        );
500    }
501
502    #[test]
503    fn string_agg() {
504        let d1 = encode(&json!({"n": "alice"}));
505        let d2 = encode(&json!({"n": "bob"}));
506        let docs: Vec<&[u8]> = vec![&d1, &d2];
507        assert_eq!(
508            compute_aggregate_binary("string_agg", "n", None, &docs),
509            Value::String("alice,bob".into())
510        );
511    }
512
513    #[test]
514    fn array_agg() {
515        let d1 = encode(&json!({"v": 1}));
516        let d2 = encode(&json!({"v": 2}));
517        let docs: Vec<&[u8]> = vec![&d1, &d2];
518        let result = compute_aggregate_binary("array_agg", "v", None, &docs);
519        assert_eq!(
520            result,
521            Value::Array(vec![Value::Integer(1), Value::Integer(2),])
522        );
523    }
524
525    #[test]
526    fn stddev_pop() {
527        let d1 = encode(&json!({"v": 2.0}));
528        let d2 = encode(&json!({"v": 4.0}));
529        let d3 = encode(&json!({"v": 4.0}));
530        let d4 = encode(&json!({"v": 4.0}));
531        let d5 = encode(&json!({"v": 5.0}));
532        let d6 = encode(&json!({"v": 5.0}));
533        let d7 = encode(&json!({"v": 7.0}));
534        let d8 = encode(&json!({"v": 9.0}));
535        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3, &d4, &d5, &d6, &d7, &d8];
536        let result = compute_aggregate_binary("stddev_pop", "v", None, &docs);
537        if let Value::Float(v) = result {
538            assert!((v - 2.0).abs() < 0.01);
539        } else {
540            panic!("expected Float");
541        }
542    }
543
544    #[test]
545    fn percentile_cont_median() {
546        let d1 = encode(&json!({"v": 1.0}));
547        let d2 = encode(&json!({"v": 2.0}));
548        let d3 = encode(&json!({"v": 3.0}));
549        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
550        assert_eq!(
551            compute_aggregate_binary("percentile_cont", "v", None, &docs),
552            Value::Float(2.0)
553        );
554    }
555
556    #[test]
557    fn missing_field_skipped() {
558        let d1 = encode(&json!({"v": 10}));
559        let d2 = encode(&json!({"other": 99}));
560        let d3 = encode(&json!({"v": 30}));
561        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
562        assert_eq!(
563            compute_aggregate_binary("sum", "v", None, &docs),
564            Value::Float(40.0)
565        );
566    }
567
568    #[test]
569    fn null_field_skipped_in_count_distinct() {
570        let d1 = encode(&json!({"v": "a"}));
571        let d2 = encode(&json!({"v": null}));
572        let d3 = encode(&json!({"v": "a"}));
573        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
574        assert_eq!(
575            compute_aggregate_binary("count_distinct", "v", None, &docs),
576            Value::Integer(1)
577        );
578    }
579
580    #[test]
581    fn array_agg_distinct() {
582        let d1 = encode(&json!({"v": 1}));
583        let d2 = encode(&json!({"v": 2}));
584        let d3 = encode(&json!({"v": 1}));
585        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
586        let result = compute_aggregate_binary("array_agg_distinct", "v", None, &docs);
587        assert_eq!(
588            result,
589            Value::Array(vec![Value::Integer(1), Value::Integer(2),])
590        );
591    }
592
593    #[test]
594    fn sum_case_when_expression() {
595        let d1 = encode(&json!({"category": "tools"}));
596        let d2 = encode(&json!({"category": "books"}));
597        let d3 = encode(&json!({"category": "tools"}));
598        let docs: Vec<&[u8]> = vec![&d1, &d2, &d3];
599        let expr = crate::expr::SqlExpr::Case {
600            operand: None,
601            when_thens: vec![(
602                crate::expr::SqlExpr::BinaryOp {
603                    left: Box::new(crate::expr::SqlExpr::Column("category".into())),
604                    op: crate::expr::BinaryOp::Eq,
605                    right: Box::new(crate::expr::SqlExpr::Literal(Value::String("tools".into()))),
606                },
607                crate::expr::SqlExpr::Literal(Value::Integer(1)),
608            )],
609            else_expr: Some(Box::new(crate::expr::SqlExpr::Literal(Value::Integer(0)))),
610        };
611
612        assert_eq!(
613            compute_aggregate_binary("sum", "*", Some(&expr), &docs),
614            Value::Float(2.0)
615        );
616    }
617
618    #[test]
619    fn approx_count_distinct_basic() {
620        let docs: Vec<Vec<u8>> = vec![
621            encode(&json!({"region": "us"})),
622            encode(&json!({"region": "eu"})),
623            encode(&json!({"region": "us"})),
624            encode(&json!({"region": "ap"})),
625        ];
626        let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
627        let result = compute_aggregate_binary("approx_count_distinct", "region", None, &refs);
628        // HLL may not be exactly 3 but should be close.
629        if let Value::Integer(n) = result {
630            assert!((2..=4).contains(&n), "expected ~3 distinct, got {n}");
631        } else {
632            panic!("expected Integer, got {result:?}");
633        }
634    }
635
636    #[test]
637    fn approx_percentile_basic() {
638        let docs: Vec<Vec<u8>> = (1..=100).map(|i| encode(&json!({"val": i}))).collect();
639        let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
640        let result = compute_aggregate_binary("approx_percentile", "0.5:val", None, &refs);
641        if let Value::Float(f) = result {
642            assert!(
643                (f - 50.0).abs() < 10.0,
644                "p50 of 1..100 should be ~50, got {f}"
645            );
646        } else {
647            panic!("expected Float, got {result:?}");
648        }
649    }
650
651    #[test]
652    fn approx_topk_basic() {
653        let mut docs: Vec<Vec<u8>> = Vec::new();
654        for _ in 0..10 {
655            docs.push(encode(&json!({"cat": "a"})));
656        }
657        for _ in 0..5 {
658            docs.push(encode(&json!({"cat": "b"})));
659        }
660        for _ in 0..1 {
661            docs.push(encode(&json!({"cat": "c"})));
662        }
663        let refs: Vec<&[u8]> = docs.iter().map(|d| d.as_slice()).collect();
664        let result = compute_aggregate_binary("approx_topk", "3:cat", None, &refs);
665        if let Value::Array(arr) = result {
666            assert!(!arr.is_empty(), "should have top-k results");
667        } else {
668            panic!("expected Array, got {result:?}");
669        }
670    }
671}