Skip to main content

nodedb_query/window/
value_agg.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Aggregate window functions (sum, count, avg, min, max, first_value, last_value)
4//! and frame-bound resolution for the Value-native evaluator.
5
6use std::collections::HashMap;
7
8use nodedb_types::Value;
9
10use super::spec::{FrameBound, WindowFrame, WindowFuncSpec};
11use super::value_eval::{cmp_values, eval_arg_for_row, order_keys_equal_v, set_cell};
12use crate::simd_agg;
13
14pub(super) fn apply_v_aggregate(
15    rows: &mut [Vec<Value>],
16    indices: &[usize],
17    column_index: &HashMap<String, usize>,
18    spec: &WindowFuncSpec,
19    write_col: usize,
20) {
21    let use_running = spec.frame.mode == "range"
22        && matches!(spec.frame.start, FrameBound::UnboundedPreceding)
23        && matches!(spec.frame.end, FrameBound::CurrentRow);
24
25    if use_running {
26        apply_v_running_aggregate(rows, indices, column_index, spec, write_col);
27    } else {
28        apply_v_per_row_aggregate(rows, indices, column_index, spec, write_col);
29    }
30}
31
32fn eval_arg(spec: &WindowFuncSpec, row: &[Value], column_index: &HashMap<String, usize>) -> Value {
33    spec.args
34        .first()
35        .map(|expr| eval_arg_for_row(expr, row, column_index))
36        .unwrap_or(Value::Null)
37}
38
39fn apply_v_running_aggregate(
40    rows: &mut [Vec<Value>],
41    indices: &[usize],
42    column_index: &HashMap<String, usize>,
43    spec: &WindowFuncSpec,
44    write_col: usize,
45) {
46    let len = indices.len();
47    if len == 0 {
48        return;
49    }
50
51    let mut running_sum = 0.0f64;
52    let mut running_count = 0u64;
53    let mut running_min: Option<f64> = None;
54    let mut running_max: Option<f64> = None;
55    let mut peer_start = 0usize;
56
57    for pos in 0..len {
58        let i = indices[pos];
59        let val = rows
60            .get(i)
61            .map(|row| eval_arg(spec, row, column_index))
62            .unwrap_or(Value::Null);
63
64        if let Some(n) = val.as_f64() {
65            running_sum += n;
66            running_count += 1;
67            running_min = Some(running_min.map_or(n, |m: f64| m.min(n)));
68            running_max = Some(running_max.map_or(n, |m: f64| m.max(n)));
69        } else if spec.func_name == "count" {
70            running_count += 1;
71        }
72
73        let is_last_in_group = pos + 1 == len
74            || !order_keys_equal_v(rows, i, indices[pos + 1], column_index, &spec.order_by);
75
76        if is_last_in_group {
77            let first_val = rows
78                .get(indices[0])
79                .map(|row| eval_arg(spec, row, column_index))
80                .unwrap_or(Value::Null);
81            let last_val = rows
82                .get(indices[pos])
83                .map(|row| eval_arg(spec, row, column_index))
84                .unwrap_or(Value::Null);
85
86            let result = match spec.func_name.as_str() {
87                "sum" => Value::Float(running_sum),
88                "count" => Value::Integer(running_count as i64),
89                "avg" => {
90                    if running_count > 0 {
91                        Value::Float(running_sum / running_count as f64)
92                    } else {
93                        Value::Null
94                    }
95                }
96                "min" => running_min.map(Value::Float).unwrap_or(Value::Null),
97                "max" => running_max.map(Value::Float).unwrap_or(Value::Null),
98                "first_value" => first_val,
99                "last_value" => last_val,
100                _ => Value::Null,
101            };
102
103            for &peer_idx in &indices[peer_start..=pos] {
104                set_cell(rows, peer_idx, write_col, result.clone());
105            }
106            peer_start = pos + 1;
107        }
108    }
109}
110
111fn apply_v_per_row_aggregate(
112    rows: &mut [Vec<Value>],
113    indices: &[usize],
114    column_index: &HashMap<String, usize>,
115    spec: &WindowFuncSpec,
116    write_col: usize,
117) {
118    let len = indices.len();
119    if len == 0 {
120        return;
121    }
122
123    let order_expr = spec.order_by.first().map(|(expr, _)| expr);
124    let order_values: Vec<Value> = indices
125        .iter()
126        .map(|&i| {
127            order_expr
128                .and_then(|expr| {
129                    rows.get(i)
130                        .map(|row| eval_arg_for_row(expr, row, column_index))
131                })
132                .unwrap_or(Value::Null)
133        })
134        .collect();
135
136    let peer_groups: Vec<usize> = if spec.frame.mode == "groups" {
137        build_v_peer_groups(&order_values)
138    } else {
139        Vec::new()
140    };
141
142    let all_vals: Vec<Option<f64>> = indices
143        .iter()
144        .map(|&i| {
145            rows.get(i)
146                .map(|row| eval_arg(spec, row, column_index).as_f64())
147                .unwrap_or(None)
148        })
149        .collect();
150
151    let results: Vec<Value> = (0..len)
152        .map(|pos| {
153            let (start_idx, end_idx) =
154                evaluate_v_frame_bounds(&spec.frame, pos, len, &order_values, &peer_groups);
155            aggregate_v_slice(
156                &all_vals,
157                indices,
158                rows,
159                column_index,
160                spec,
161                start_idx,
162                end_idx,
163            )
164        })
165        .collect();
166
167    for (pos, result) in results.into_iter().enumerate() {
168        set_cell(rows, indices[pos], write_col, result);
169    }
170}
171
172fn aggregate_v_slice(
173    all_vals: &[Option<f64>],
174    indices: &[usize],
175    rows: &[Vec<Value>],
176    column_index: &HashMap<String, usize>,
177    spec: &WindowFuncSpec,
178    start_idx: usize,
179    end_idx: usize,
180) -> Value {
181    let slice_vals: Vec<f64> = all_vals[start_idx..=end_idx]
182        .iter()
183        .filter_map(|v| *v)
184        .collect();
185    let slice_count = end_idx - start_idx + 1;
186
187    match spec.func_name.as_str() {
188        "sum" => {
189            let rt = simd_agg::ts_runtime();
190            Value::Float((rt.sum_f64)(&slice_vals))
191        }
192        "count" => Value::Integer(slice_count as i64),
193        "avg" => {
194            if slice_vals.is_empty() {
195                Value::Null
196            } else {
197                let rt = simd_agg::ts_runtime();
198                Value::Float((rt.sum_f64)(&slice_vals) / slice_vals.len() as f64)
199            }
200        }
201        "min" => {
202            if slice_vals.is_empty() {
203                Value::Null
204            } else {
205                let rt = simd_agg::ts_runtime();
206                Value::Float((rt.min_f64)(&slice_vals))
207            }
208        }
209        "max" => {
210            if slice_vals.is_empty() {
211                Value::Null
212            } else {
213                let rt = simd_agg::ts_runtime();
214                Value::Float((rt.max_f64)(&slice_vals))
215            }
216        }
217        "first_value" => indices
218            .get(start_idx)
219            .and_then(|&i| rows.get(i))
220            .map(|row| {
221                eval_arg_for_row(
222                    spec.args
223                        .first()
224                        .unwrap_or(&crate::expr::types::SqlExpr::Literal(Value::Null)),
225                    row,
226                    column_index,
227                )
228            })
229            .unwrap_or(Value::Null),
230        "last_value" => indices
231            .get(end_idx)
232            .and_then(|&i| rows.get(i))
233            .map(|row| {
234                eval_arg_for_row(
235                    spec.args
236                        .first()
237                        .unwrap_or(&crate::expr::types::SqlExpr::Literal(Value::Null)),
238                    row,
239                    column_index,
240                )
241            })
242            .unwrap_or(Value::Null),
243        _ => Value::Null,
244    }
245}
246
247fn build_v_peer_groups(order_values: &[Value]) -> Vec<usize> {
248    let mut groups = Vec::with_capacity(order_values.len());
249    let mut current_group = 0usize;
250    for (i, val) in order_values.iter().enumerate() {
251        if i > 0
252            && !matches!(
253                cmp_values(val, &order_values[i - 1]),
254                std::cmp::Ordering::Equal
255            )
256        {
257            current_group += 1;
258        }
259        groups.push(current_group);
260    }
261    groups
262}
263
264pub(super) fn evaluate_v_frame_bounds(
265    frame: &WindowFrame,
266    pos: usize,
267    len: usize,
268    order_values: &[Value],
269    peer_groups: &[usize],
270) -> (usize, usize) {
271    match frame.mode.as_str() {
272        "rows" => v_rows_bounds(&frame.start, &frame.end, pos, len),
273        "range" => v_range_bounds(&frame.start, &frame.end, pos, len, order_values),
274        "groups" => v_groups_bounds(&frame.start, &frame.end, pos, len, peer_groups),
275        _ => (0, len.saturating_sub(1)),
276    }
277}
278
279fn v_rows_bounds(start: &FrameBound, end: &FrameBound, pos: usize, len: usize) -> (usize, usize) {
280    let s = v_rows_bound_to_idx(start, pos, len);
281    let e = v_rows_bound_to_idx(end, pos, len);
282    (s.min(e), s.max(e))
283}
284
285fn v_rows_bound_to_idx(bound: &FrameBound, pos: usize, len: usize) -> usize {
286    match bound {
287        FrameBound::UnboundedPreceding => 0,
288        FrameBound::Preceding(n) => pos.saturating_sub(*n as usize),
289        FrameBound::CurrentRow => pos,
290        FrameBound::Following(n) => (pos + *n as usize).min(len.saturating_sub(1)),
291        FrameBound::UnboundedFollowing => len.saturating_sub(1),
292    }
293}
294
295fn v_range_bounds(
296    start: &FrameBound,
297    end: &FrameBound,
298    pos: usize,
299    len: usize,
300    order_values: &[Value],
301) -> (usize, usize) {
302    let current_val = order_values.get(pos).and_then(|v| v.as_f64());
303    let s = v_range_bound_to_idx(start, pos, len, order_values, current_val, true);
304    let e = v_range_bound_to_idx(end, pos, len, order_values, current_val, false);
305    (s.min(e), s.max(e))
306}
307
308fn v_range_bound_to_idx(
309    bound: &FrameBound,
310    pos: usize,
311    len: usize,
312    order_values: &[Value],
313    current_val: Option<f64>,
314    is_start: bool,
315) -> usize {
316    match bound {
317        FrameBound::UnboundedPreceding => 0,
318        FrameBound::UnboundedFollowing => len.saturating_sub(1),
319        FrameBound::CurrentRow => {
320            if is_start {
321                let mut idx = pos;
322                while idx > 0
323                    && matches!(
324                        cmp_values(
325                            order_values.get(idx - 1).unwrap_or(&Value::Null),
326                            order_values.get(pos).unwrap_or(&Value::Null),
327                        ),
328                        std::cmp::Ordering::Equal
329                    )
330                {
331                    idx -= 1;
332                }
333                idx
334            } else {
335                let mut idx = pos;
336                while idx + 1 < len
337                    && matches!(
338                        cmp_values(
339                            order_values.get(idx + 1).unwrap_or(&Value::Null),
340                            order_values.get(pos).unwrap_or(&Value::Null),
341                        ),
342                        std::cmp::Ordering::Equal
343                    )
344                {
345                    idx += 1;
346                }
347                idx
348            }
349        }
350        FrameBound::Preceding(n) => {
351            let threshold = match current_val {
352                Some(cv) => cv - *n as f64,
353                None => return pos,
354            };
355            let mut idx = 0;
356            for (i, v) in order_values.iter().enumerate() {
357                if v.as_f64().is_some_and(|fv| fv >= threshold) {
358                    idx = i;
359                    break;
360                }
361                idx = i + 1;
362            }
363            idx.min(len.saturating_sub(1))
364        }
365        FrameBound::Following(n) => {
366            let threshold = match current_val {
367                Some(cv) => cv + *n as f64,
368                None => return pos,
369            };
370            let mut idx = pos;
371            for (i, v) in order_values.iter().enumerate().skip(pos) {
372                if v.as_f64().is_none_or(|fv| fv > threshold) {
373                    break;
374                }
375                idx = i;
376            }
377            idx.min(len.saturating_sub(1))
378        }
379    }
380}
381
382fn v_groups_bounds(
383    start: &FrameBound,
384    end: &FrameBound,
385    pos: usize,
386    len: usize,
387    peer_groups: &[usize],
388) -> (usize, usize) {
389    let current_group = peer_groups.get(pos).copied().unwrap_or(0);
390    let max_group = peer_groups.last().copied().unwrap_or(0);
391    let start_group = v_groups_bound_to_group(start, current_group, max_group);
392    let end_group = v_groups_bound_to_group(end, current_group, max_group);
393    let start_idx = peer_groups
394        .iter()
395        .position(|&g| g == start_group)
396        .unwrap_or(0);
397    let end_idx = peer_groups
398        .iter()
399        .rposition(|&g| g == end_group)
400        .unwrap_or(len.saturating_sub(1));
401    (start_idx, end_idx)
402}
403
404fn v_groups_bound_to_group(bound: &FrameBound, current_group: usize, max_group: usize) -> usize {
405    match bound {
406        FrameBound::UnboundedPreceding => 0,
407        FrameBound::UnboundedFollowing => max_group,
408        FrameBound::CurrentRow => current_group,
409        FrameBound::Preceding(n) => current_group.saturating_sub(*n as usize),
410        FrameBound::Following(n) => (current_group + *n as usize).min(max_group),
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417    use crate::expr::types::SqlExpr;
418
419    fn col(name: &str) -> SqlExpr {
420        SqlExpr::Column(name.to_string())
421    }
422
423    fn ci(names: &[&str]) -> HashMap<String, usize> {
424        names
425            .iter()
426            .enumerate()
427            .map(|(i, n)| (n.to_string(), i))
428            .collect()
429    }
430
431    fn rows_v(vals: &[i64]) -> Vec<Vec<Value>> {
432        vals.iter().map(|&v| vec![Value::Integer(v)]).collect()
433    }
434
435    fn agg_spec(func: &str, frame: WindowFrame, order_by: Vec<(SqlExpr, bool)>) -> WindowFuncSpec {
436        WindowFuncSpec {
437            alias: format!("w_{func}"),
438            func_name: func.to_string(),
439            args: vec![col("v")],
440            partition_by: vec![],
441            order_by,
442            frame,
443        }
444    }
445
446    /// Drive `apply_v_aggregate` over the whole single-partition row set the
447    /// same way `evaluate_window_functions_value` does (push a Null cell, then
448    /// fill it), returning the produced column.
449    fn run_agg(rows: &mut [Vec<Value>], cols: &HashMap<String, usize>, spec: &WindowFuncSpec) {
450        let write_col = rows.first().map(|r| r.len()).unwrap_or(0);
451        for row in rows.iter_mut() {
452            row.push(Value::Null);
453        }
454        let indices: Vec<usize> = (0..rows.len()).collect();
455        apply_v_aggregate(rows, &indices, cols, spec, write_col);
456    }
457
458    fn frame(mode: &str, start: FrameBound, end: FrameBound) -> WindowFrame {
459        WindowFrame {
460            mode: mode.into(),
461            start,
462            end,
463        }
464    }
465
466    #[test]
467    fn running_sum_is_cumulative() {
468        // Default frame (range, unbounded preceding → current row) with a
469        // strictly increasing order key → cumulative sum.
470        let cols = ci(&["v"]);
471        let mut rows = rows_v(&[1, 2, 3]);
472        let s = agg_spec("sum", WindowFrame::default(), vec![(col("v"), true)]);
473        run_agg(&mut rows, &cols, &s);
474        let got: Vec<f64> = rows.iter().map(|r| r[1].as_f64().unwrap()).collect();
475        assert_eq!(got, vec![1.0, 3.0, 6.0]);
476    }
477
478    #[test]
479    fn running_sum_shares_value_across_peers() {
480        // Tied order keys form a peer group; the running aggregate assigns the
481        // group's running total to every peer.
482        let cols = ci(&["v"]);
483        let mut rows = rows_v(&[5, 5, 9]);
484        let s = agg_spec("sum", WindowFrame::default(), vec![(col("v"), true)]);
485        run_agg(&mut rows, &cols, &s);
486        let got: Vec<f64> = rows.iter().map(|r| r[1].as_f64().unwrap()).collect();
487        assert_eq!(got, vec![10.0, 10.0, 19.0]);
488    }
489
490    #[test]
491    fn rows_frame_sliding_sum() {
492        let cols = ci(&["v"]);
493        let mut rows = rows_v(&[10, 20, 30]);
494        let s = agg_spec(
495            "sum",
496            frame("rows", FrameBound::Preceding(1), FrameBound::CurrentRow),
497            vec![(col("v"), true)],
498        );
499        run_agg(&mut rows, &cols, &s);
500        let got: Vec<f64> = rows.iter().map(|r| r[1].as_f64().unwrap()).collect();
501        assert_eq!(got, vec![10.0, 30.0, 50.0]);
502    }
503
504    #[test]
505    fn rows_frame_count_and_avg() {
506        let cols = ci(&["v"]);
507        let f = frame(
508            "rows",
509            FrameBound::UnboundedPreceding,
510            FrameBound::CurrentRow,
511        );
512
513        let mut rows = rows_v(&[4, 8, 12]);
514        let cnt = agg_spec("count", f.clone(), vec![(col("v"), true)]);
515        run_agg(&mut rows, &cols, &cnt);
516        let counts: Vec<i64> = rows
517            .iter()
518            .map(|r| match r[1] {
519                Value::Integer(n) => n,
520                _ => panic!("count must be integer"),
521            })
522            .collect();
523        assert_eq!(counts, vec![1, 2, 3]);
524
525        let mut rows = rows_v(&[4, 8, 12]);
526        let avg = agg_spec("avg", f, vec![(col("v"), true)]);
527        run_agg(&mut rows, &cols, &avg);
528        let avgs: Vec<f64> = rows.iter().map(|r| r[1].as_f64().unwrap()).collect();
529        assert_eq!(avgs, vec![4.0, 6.0, 8.0]);
530    }
531
532    #[test]
533    fn rows_frame_min_max() {
534        let cols = ci(&["v"]);
535        let f = frame(
536            "rows",
537            FrameBound::UnboundedPreceding,
538            FrameBound::UnboundedFollowing,
539        );
540
541        let mut rows = rows_v(&[3, 1, 2]);
542        let mn = agg_spec("min", f.clone(), vec![]);
543        run_agg(&mut rows, &cols, &mn);
544        assert!((rows[0][1].as_f64().unwrap() - 1.0).abs() < 1e-9);
545
546        let mut rows = rows_v(&[3, 1, 2]);
547        let mx = agg_spec("max", f, vec![]);
548        run_agg(&mut rows, &cols, &mx);
549        assert!((rows[0][1].as_f64().unwrap() - 3.0).abs() < 1e-9);
550    }
551
552    #[test]
553    fn first_and_last_value() {
554        let cols = ci(&["v"]);
555        let f = frame(
556            "rows",
557            FrameBound::UnboundedPreceding,
558            FrameBound::UnboundedFollowing,
559        );
560
561        let mut rows = rows_v(&[7, 8, 9]);
562        let fv = agg_spec("first_value", f.clone(), vec![]);
563        run_agg(&mut rows, &cols, &fv);
564        assert_eq!(rows[2][1].as_f64().unwrap() as i64, 7);
565
566        let mut rows = rows_v(&[7, 8, 9]);
567        let lv = agg_spec("last_value", f, vec![]);
568        run_agg(&mut rows, &cols, &lv);
569        assert_eq!(rows[0][1].as_f64().unwrap() as i64, 9);
570    }
571
572    #[test]
573    fn rows_bounds_resolution() {
574        let order = vec![];
575        let groups = vec![];
576        let f = frame("rows", FrameBound::Preceding(1), FrameBound::Following(1));
577        // Middle of a 5-row partition → window [pos-1, pos+1].
578        assert_eq!(evaluate_v_frame_bounds(&f, 2, 5, &order, &groups), (1, 3));
579        // First row clamps the preceding bound to 0.
580        assert_eq!(evaluate_v_frame_bounds(&f, 0, 5, &order, &groups), (0, 1));
581        // Last row clamps the following bound to len-1.
582        assert_eq!(evaluate_v_frame_bounds(&f, 4, 5, &order, &groups), (3, 4));
583    }
584
585    #[test]
586    fn range_bounds_expand_over_peers() {
587        // RANGE with CURRENT ROW spans the whole peer group of equal keys.
588        let order = vec![Value::Integer(10), Value::Integer(10), Value::Integer(20)];
589        let f = frame(
590            "range",
591            FrameBound::UnboundedPreceding,
592            FrameBound::CurrentRow,
593        );
594        // pos 0 (key 10) → end extends across both 10s.
595        assert_eq!(evaluate_v_frame_bounds(&f, 0, 3, &order, &[]), (0, 1));
596        // pos 2 (key 20) → spans everything up to and including itself.
597        assert_eq!(evaluate_v_frame_bounds(&f, 2, 3, &order, &[]), (0, 2));
598    }
599
600    #[test]
601    fn groups_bounds_resolution() {
602        // Peer groups: [g0, g0, g1, g2]. GROUPS 1 PRECEDING → CURRENT ROW at
603        // pos 2 (group 1) covers groups 0..=1 → indices 0..=1.
604        let peer_groups = vec![0usize, 0, 1, 2];
605        let order = vec![
606            Value::Integer(1),
607            Value::Integer(1),
608            Value::Integer(2),
609            Value::Integer(3),
610        ];
611        let f = frame("groups", FrameBound::Preceding(1), FrameBound::CurrentRow);
612        assert_eq!(
613            evaluate_v_frame_bounds(&f, 2, 4, &order, &peer_groups),
614            (0, 2)
615        );
616    }
617}