Skip to main content

nodedb_query/
window.rs

1//! Window function specification and evaluation.
2//!
3//! Evaluated after sort, before projection. Each spec produces a
4//! new column appended to every row (e.g., ROW_NUMBER, RANK, SUM OVER).
5
6use crate::expr::SqlExpr;
7
8/// A window function specification.
9#[derive(
10    Debug,
11    Clone,
12    serde::Serialize,
13    serde::Deserialize,
14    zerompk::ToMessagePack,
15    zerompk::FromMessagePack,
16)]
17pub struct WindowFuncSpec {
18    /// Output column name (e.g., "row_num", "running_sum").
19    pub alias: String,
20    /// Function name: row_number, rank, dense_rank, lag, lead, sum, count, avg, min, max.
21    pub func_name: String,
22    /// Function arguments (e.g., `salary` for SUM(salary)). Empty for ROW_NUMBER.
23    pub args: Vec<SqlExpr>,
24    /// PARTITION BY column names. Empty = single partition (entire result set).
25    pub partition_by: Vec<String>,
26    /// ORDER BY within each partition: [(field, ascending)].
27    pub order_by: Vec<(String, bool)>,
28    /// Window frame specification.
29    pub frame: WindowFrame,
30}
31
32/// Window frame: defines which rows within the partition are visible to the function.
33#[derive(
34    Debug,
35    Clone,
36    serde::Serialize,
37    serde::Deserialize,
38    zerompk::ToMessagePack,
39    zerompk::FromMessagePack,
40)]
41pub struct WindowFrame {
42    /// Frame mode: "rows" or "range".
43    pub mode: String,
44    /// Start bound.
45    pub start: FrameBound,
46    /// End bound.
47    pub end: FrameBound,
48}
49
50impl Default for WindowFrame {
51    fn default() -> Self {
52        Self {
53            mode: "range".into(),
54            start: FrameBound::UnboundedPreceding,
55            end: FrameBound::CurrentRow,
56        }
57    }
58}
59
60/// Window frame boundary.
61#[derive(
62    Debug,
63    Clone,
64    serde::Serialize,
65    serde::Deserialize,
66    zerompk::ToMessagePack,
67    zerompk::FromMessagePack,
68)]
69pub enum FrameBound {
70    UnboundedPreceding,
71    Preceding(u64),
72    CurrentRow,
73    Following(u64),
74    UnboundedFollowing,
75}
76
77/// Evaluate window functions over sorted, partitioned rows.
78///
79/// `rows` is the sorted result set. Each row is a `(doc_id, serde_json::Value)`.
80/// Returns the same rows with window columns appended to each document.
81pub fn evaluate_window_functions(
82    rows: &mut [(String, serde_json::Value)],
83    specs: &[WindowFuncSpec],
84) {
85    for spec in specs {
86        let partitions = build_partitions(rows, &spec.partition_by);
87
88        for partition_indices in &partitions {
89            match spec.func_name.as_str() {
90                "row_number" => apply_row_number(rows, partition_indices, &spec.alias),
91                "rank" => apply_rank(rows, partition_indices, &spec.alias, &spec.order_by),
92                "dense_rank" => {
93                    apply_dense_rank(rows, partition_indices, &spec.alias, &spec.order_by)
94                }
95                "lag" => apply_lag(rows, partition_indices, spec),
96                "lead" => apply_lead(rows, partition_indices, spec),
97                "ntile" => apply_ntile(rows, partition_indices, spec),
98                "sum" | "count" | "avg" | "min" | "max" | "first_value" | "last_value" => {
99                    apply_aggregate_window(rows, partition_indices, spec)
100                }
101                _ => {}
102            }
103        }
104    }
105}
106
107fn build_partitions(
108    rows: &[(String, serde_json::Value)],
109    partition_by: &[String],
110) -> Vec<Vec<usize>> {
111    if partition_by.is_empty() {
112        return vec![(0..rows.len()).collect()];
113    }
114
115    let mut groups: std::collections::HashMap<String, Vec<usize>> =
116        std::collections::HashMap::new();
117    let mut order = Vec::new();
118
119    for (i, (_id, doc)) in rows.iter().enumerate() {
120        let key: String = partition_by
121            .iter()
122            .map(|col| {
123                doc.get(col)
124                    .map(|v| v.to_string())
125                    .unwrap_or_else(|| "null".to_string())
126            })
127            .collect::<Vec<_>>()
128            .join("\x00");
129        let entry = groups.entry(key.clone()).or_default();
130        if entry.is_empty() {
131            order.push(key);
132        }
133        entry.push(i);
134    }
135
136    order.iter().filter_map(|k| groups.remove(k)).collect()
137}
138
139fn set_window_col(row: &mut serde_json::Value, alias: &str, val: serde_json::Value) {
140    if let serde_json::Value::Object(map) = row {
141        map.insert(alias.to_string(), val);
142    }
143}
144
145fn get_field(doc: &serde_json::Value, field: &str) -> serde_json::Value {
146    doc.get(field).cloned().unwrap_or(serde_json::Value::Null)
147}
148
149fn as_f64(v: &serde_json::Value) -> Option<f64> {
150    match v {
151        serde_json::Value::Number(n) => n.as_f64(),
152        serde_json::Value::String(s) => s.parse().ok(),
153        _ => None,
154    }
155}
156
157fn apply_row_number(rows: &mut [(String, serde_json::Value)], indices: &[usize], alias: &str) {
158    for (rank, &i) in indices.iter().enumerate() {
159        set_window_col(&mut rows[i].1, alias, serde_json::json!(rank + 1));
160    }
161}
162
163fn apply_rank(
164    rows: &mut [(String, serde_json::Value)],
165    indices: &[usize],
166    alias: &str,
167    order_by: &[(String, bool)],
168) {
169    if indices.is_empty() {
170        return;
171    }
172    let mut current_rank = 1;
173    set_window_col(&mut rows[indices[0]].1, alias, serde_json::json!(1));
174
175    for pos in 1..indices.len() {
176        let prev = &rows[indices[pos - 1]].1;
177        let curr = &rows[indices[pos]].1;
178        let same = order_by
179            .iter()
180            .all(|(col, _)| get_field(prev, col) == get_field(curr, col));
181        if !same {
182            current_rank = pos + 1;
183        }
184        set_window_col(
185            &mut rows[indices[pos]].1,
186            alias,
187            serde_json::json!(current_rank),
188        );
189    }
190}
191
192fn apply_dense_rank(
193    rows: &mut [(String, serde_json::Value)],
194    indices: &[usize],
195    alias: &str,
196    order_by: &[(String, bool)],
197) {
198    if indices.is_empty() {
199        return;
200    }
201    let mut current_rank = 1;
202    set_window_col(&mut rows[indices[0]].1, alias, serde_json::json!(1));
203
204    for pos in 1..indices.len() {
205        let prev = &rows[indices[pos - 1]].1;
206        let curr = &rows[indices[pos]].1;
207        let same = order_by
208            .iter()
209            .all(|(col, _)| get_field(prev, col) == get_field(curr, col));
210        if !same {
211            current_rank += 1;
212        }
213        set_window_col(
214            &mut rows[indices[pos]].1,
215            alias,
216            serde_json::json!(current_rank),
217        );
218    }
219}
220
221fn apply_ntile(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
222    let n = spec
223        .args
224        .first()
225        .and_then(|e| {
226            if let SqlExpr::Literal(v) = e {
227                v.as_f64().map(|x| x as usize)
228            } else {
229                None
230            }
231        })
232        .unwrap_or(1)
233        .max(1);
234    let total = indices.len();
235    for (pos, &i) in indices.iter().enumerate() {
236        let bucket = (pos * n / total) + 1;
237        set_window_col(&mut rows[i].1, &spec.alias, serde_json::json!(bucket));
238    }
239}
240
241fn apply_lag(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
242    let field = spec
243        .args
244        .first()
245        .and_then(|e| {
246            if let SqlExpr::Column(c) = e {
247                Some(c.as_str())
248            } else {
249                None
250            }
251        })
252        .unwrap_or("*");
253    let offset = spec
254        .args
255        .get(1)
256        .and_then(|e| {
257            if let SqlExpr::Literal(v) = e {
258                v.as_f64().map(|n| n as usize)
259            } else {
260                None
261            }
262        })
263        .unwrap_or(1);
264    let default = spec
265        .args
266        .get(2)
267        .and_then(|e| {
268            if let SqlExpr::Literal(v) = e {
269                Some(serde_json::Value::from(v.clone()))
270            } else {
271                None
272            }
273        })
274        .unwrap_or(serde_json::Value::Null);
275
276    for (pos, &i) in indices.iter().enumerate() {
277        let val = if pos >= offset {
278            get_field(&rows[indices[pos - offset]].1, field)
279        } else {
280            default.clone()
281        };
282        set_window_col(&mut rows[i].1, &spec.alias, val);
283    }
284}
285
286fn apply_lead(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
287    let field = spec
288        .args
289        .first()
290        .and_then(|e| {
291            if let SqlExpr::Column(c) = e {
292                Some(c.as_str())
293            } else {
294                None
295            }
296        })
297        .unwrap_or("*");
298    let offset = spec
299        .args
300        .get(1)
301        .and_then(|e| {
302            if let SqlExpr::Literal(v) = e {
303                v.as_f64().map(|n| n as usize)
304            } else {
305                None
306            }
307        })
308        .unwrap_or(1);
309    let default = spec
310        .args
311        .get(2)
312        .and_then(|e| {
313            if let SqlExpr::Literal(v) = e {
314                Some(serde_json::Value::from(v.clone()))
315            } else {
316                None
317            }
318        })
319        .unwrap_or(serde_json::Value::Null);
320
321    for (pos, &i) in indices.iter().enumerate() {
322        let val = if pos + offset < indices.len() {
323            get_field(&rows[indices[pos + offset]].1, field)
324        } else {
325            default.clone()
326        };
327        set_window_col(&mut rows[i].1, &spec.alias, val);
328    }
329}
330
331fn apply_aggregate_window(
332    rows: &mut [(String, serde_json::Value)],
333    indices: &[usize],
334    spec: &WindowFuncSpec,
335) {
336    let field = spec
337        .args
338        .first()
339        .and_then(|e| {
340            if let SqlExpr::Column(c) = e {
341                Some(c.as_str())
342            } else {
343                None
344            }
345        })
346        .unwrap_or("*");
347
348    let use_running = spec.frame.mode == "range"
349        && matches!(spec.frame.start, FrameBound::UnboundedPreceding)
350        && matches!(spec.frame.end, FrameBound::CurrentRow);
351
352    if use_running {
353        let mut running_sum = 0.0f64;
354        let mut running_count = 0u64;
355        let mut running_min: Option<f64> = None;
356        let mut running_max: Option<f64> = None;
357
358        for (pos, &i) in indices.iter().enumerate() {
359            let val = get_field(&rows[i].1, field);
360            if let Some(n) = as_f64(&val) {
361                running_sum += n;
362                running_count += 1;
363                running_min = Some(running_min.map_or(n, |m: f64| m.min(n)));
364                running_max = Some(running_max.map_or(n, |m: f64| m.max(n)));
365            } else if spec.func_name == "count" {
366                running_count += 1;
367            }
368
369            let result = match spec.func_name.as_str() {
370                "sum" => serde_json::json!(running_sum),
371                "count" => serde_json::json!(running_count),
372                "avg" => {
373                    if running_count > 0 {
374                        serde_json::json!(running_sum / running_count as f64)
375                    } else {
376                        serde_json::Value::Null
377                    }
378                }
379                "min" => running_min
380                    .map(|m| serde_json::json!(m))
381                    .unwrap_or(serde_json::Value::Null),
382                "max" => running_max
383                    .map(|m| serde_json::json!(m))
384                    .unwrap_or(serde_json::Value::Null),
385                "first_value" => get_field(&rows[indices[0]].1, field),
386                "last_value" => get_field(&rows[indices[pos]].1, field),
387                _ => serde_json::Value::Null,
388            };
389            set_window_col(&mut rows[i].1, &spec.alias, result);
390        }
391    } else {
392        let values: Vec<f64> = indices
393            .iter()
394            .filter_map(|&i| as_f64(&get_field(&rows[i].1, field)))
395            .collect();
396
397        // Dispatch sum/min/max to SIMD kernels when the values slice is available.
398        let rt = crate::simd_agg::ts_runtime();
399        let result = match spec.func_name.as_str() {
400            "sum" => serde_json::json!((rt.sum_f64)(&values)),
401            "count" => serde_json::json!(indices.len()),
402            "avg" => {
403                if values.is_empty() {
404                    serde_json::Value::Null
405                } else {
406                    serde_json::json!((rt.sum_f64)(&values) / values.len() as f64)
407                }
408            }
409            "min" => {
410                if values.is_empty() {
411                    serde_json::Value::Null
412                } else {
413                    serde_json::json!((rt.min_f64)(&values))
414                }
415            }
416            "max" => {
417                if values.is_empty() {
418                    serde_json::Value::Null
419                } else {
420                    serde_json::json!((rt.max_f64)(&values))
421                }
422            }
423            "first_value" => indices
424                .first()
425                .map(|&i| get_field(&rows[i].1, field))
426                .unwrap_or(serde_json::Value::Null),
427            "last_value" => indices
428                .last()
429                .map(|&i| get_field(&rows[i].1, field))
430                .unwrap_or(serde_json::Value::Null),
431            _ => serde_json::Value::Null,
432        };
433
434        for &i in indices {
435            set_window_col(&mut rows[i].1, &spec.alias, result.clone());
436        }
437    }
438}
439
440#[cfg(test)]
441mod tests {
442    use super::*;
443    use serde_json::json;
444
445    fn make_rows() -> Vec<(String, serde_json::Value)> {
446        vec![
447            (
448                "1".into(),
449                json!({"dept": "eng", "salary": 100, "name": "Alice"}),
450            ),
451            (
452                "2".into(),
453                json!({"dept": "eng", "salary": 120, "name": "Bob"}),
454            ),
455            (
456                "3".into(),
457                json!({"dept": "eng", "salary": 90, "name": "Carol"}),
458            ),
459            (
460                "4".into(),
461                json!({"dept": "sales", "salary": 80, "name": "Dave"}),
462            ),
463            (
464                "5".into(),
465                json!({"dept": "sales", "salary": 110, "name": "Eve"}),
466            ),
467        ]
468    }
469
470    #[test]
471    fn row_number_single_partition() {
472        let mut rows = make_rows();
473        let spec = WindowFuncSpec {
474            alias: "rn".into(),
475            func_name: "row_number".into(),
476            args: vec![],
477            partition_by: vec![],
478            order_by: vec![],
479            frame: WindowFrame::default(),
480        };
481        evaluate_window_functions(&mut rows, &[spec]);
482        assert_eq!(rows[0].1["rn"], json!(1));
483        assert_eq!(rows[4].1["rn"], json!(5));
484    }
485
486    #[test]
487    fn row_number_partitioned() {
488        let mut rows = make_rows();
489        let spec = WindowFuncSpec {
490            alias: "rn".into(),
491            func_name: "row_number".into(),
492            args: vec![],
493            partition_by: vec!["dept".into()],
494            order_by: vec![],
495            frame: WindowFrame::default(),
496        };
497        evaluate_window_functions(&mut rows, &[spec]);
498        assert_eq!(rows[0].1["rn"], json!(1));
499        assert_eq!(rows[2].1["rn"], json!(3));
500        assert_eq!(rows[3].1["rn"], json!(1));
501        assert_eq!(rows[4].1["rn"], json!(2));
502    }
503
504    #[test]
505    fn running_sum() {
506        let mut rows = make_rows();
507        let spec = WindowFuncSpec {
508            alias: "running_total".into(),
509            func_name: "sum".into(),
510            args: vec![SqlExpr::Column("salary".into())],
511            partition_by: vec!["dept".into()],
512            order_by: vec![("salary".into(), true)],
513            frame: WindowFrame::default(),
514        };
515        evaluate_window_functions(&mut rows, &[spec]);
516        assert_eq!(rows[0].1["running_total"], json!(100.0));
517        assert_eq!(rows[1].1["running_total"], json!(220.0));
518        assert_eq!(rows[2].1["running_total"], json!(310.0));
519        assert_eq!(rows[3].1["running_total"], json!(80.0));
520        assert_eq!(rows[4].1["running_total"], json!(190.0));
521    }
522}