Skip to main content

nodedb_query/
window.rs

1//! Window function specification and evaluation.
2//!
3//! Evaluated after sort, before projection. Each spec produces a
4//! new column appended to every row (e.g., ROW_NUMBER, RANK, SUM OVER).
5
6use crate::expr::SqlExpr;
7
8/// A window function specification.
9#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
10pub struct WindowFuncSpec {
11    /// Output column name (e.g., "row_num", "running_sum").
12    pub alias: String,
13    /// Function name: row_number, rank, dense_rank, lag, lead, sum, count, avg, min, max.
14    pub func_name: String,
15    /// Function arguments (e.g., `salary` for SUM(salary)). Empty for ROW_NUMBER.
16    pub args: Vec<SqlExpr>,
17    /// PARTITION BY column names. Empty = single partition (entire result set).
18    pub partition_by: Vec<String>,
19    /// ORDER BY within each partition: [(field, ascending)].
20    pub order_by: Vec<(String, bool)>,
21    /// Window frame specification.
22    pub frame: WindowFrame,
23}
24
25/// Window frame: defines which rows within the partition are visible to the function.
26#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
27pub struct WindowFrame {
28    /// Frame mode: "rows" or "range".
29    pub mode: String,
30    /// Start bound.
31    pub start: FrameBound,
32    /// End bound.
33    pub end: FrameBound,
34}
35
36impl Default for WindowFrame {
37    fn default() -> Self {
38        Self {
39            mode: "range".into(),
40            start: FrameBound::UnboundedPreceding,
41            end: FrameBound::CurrentRow,
42        }
43    }
44}
45
46/// Window frame boundary.
47#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
48pub enum FrameBound {
49    UnboundedPreceding,
50    Preceding(u64),
51    CurrentRow,
52    Following(u64),
53    UnboundedFollowing,
54}
55
56/// Evaluate window functions over sorted, partitioned rows.
57///
58/// `rows` is the sorted result set. Each row is a `(doc_id, serde_json::Value)`.
59/// Returns the same rows with window columns appended to each document.
60pub fn evaluate_window_functions(
61    rows: &mut [(String, serde_json::Value)],
62    specs: &[WindowFuncSpec],
63) {
64    for spec in specs {
65        let partitions = build_partitions(rows, &spec.partition_by);
66
67        for partition_indices in &partitions {
68            match spec.func_name.as_str() {
69                "row_number" => apply_row_number(rows, partition_indices, &spec.alias),
70                "rank" => apply_rank(rows, partition_indices, &spec.alias, &spec.order_by),
71                "dense_rank" => {
72                    apply_dense_rank(rows, partition_indices, &spec.alias, &spec.order_by)
73                }
74                "lag" => apply_lag(rows, partition_indices, spec),
75                "lead" => apply_lead(rows, partition_indices, spec),
76                "ntile" => apply_ntile(rows, partition_indices, spec),
77                "sum" | "count" | "avg" | "min" | "max" | "first_value" | "last_value" => {
78                    apply_aggregate_window(rows, partition_indices, spec)
79                }
80                _ => {}
81            }
82        }
83    }
84}
85
86fn build_partitions(
87    rows: &[(String, serde_json::Value)],
88    partition_by: &[String],
89) -> Vec<Vec<usize>> {
90    if partition_by.is_empty() {
91        return vec![(0..rows.len()).collect()];
92    }
93
94    let mut groups: std::collections::HashMap<String, Vec<usize>> =
95        std::collections::HashMap::new();
96    let mut order = Vec::new();
97
98    for (i, (_id, doc)) in rows.iter().enumerate() {
99        let key: String = partition_by
100            .iter()
101            .map(|col| {
102                doc.get(col)
103                    .map(|v| v.to_string())
104                    .unwrap_or_else(|| "null".to_string())
105            })
106            .collect::<Vec<_>>()
107            .join("\x00");
108        let entry = groups.entry(key.clone()).or_default();
109        if entry.is_empty() {
110            order.push(key);
111        }
112        entry.push(i);
113    }
114
115    order.iter().filter_map(|k| groups.remove(k)).collect()
116}
117
118fn set_window_col(row: &mut serde_json::Value, alias: &str, val: serde_json::Value) {
119    if let serde_json::Value::Object(map) = row {
120        map.insert(alias.to_string(), val);
121    }
122}
123
124fn get_field(doc: &serde_json::Value, field: &str) -> serde_json::Value {
125    doc.get(field).cloned().unwrap_or(serde_json::Value::Null)
126}
127
128fn as_f64(v: &serde_json::Value) -> Option<f64> {
129    match v {
130        serde_json::Value::Number(n) => n.as_f64(),
131        serde_json::Value::String(s) => s.parse().ok(),
132        _ => None,
133    }
134}
135
136fn apply_row_number(rows: &mut [(String, serde_json::Value)], indices: &[usize], alias: &str) {
137    for (rank, &i) in indices.iter().enumerate() {
138        set_window_col(&mut rows[i].1, alias, serde_json::json!(rank + 1));
139    }
140}
141
142fn apply_rank(
143    rows: &mut [(String, serde_json::Value)],
144    indices: &[usize],
145    alias: &str,
146    order_by: &[(String, bool)],
147) {
148    if indices.is_empty() {
149        return;
150    }
151    let mut current_rank = 1;
152    set_window_col(&mut rows[indices[0]].1, alias, serde_json::json!(1));
153
154    for pos in 1..indices.len() {
155        let prev = &rows[indices[pos - 1]].1;
156        let curr = &rows[indices[pos]].1;
157        let same = order_by
158            .iter()
159            .all(|(col, _)| get_field(prev, col) == get_field(curr, col));
160        if !same {
161            current_rank = pos + 1;
162        }
163        set_window_col(
164            &mut rows[indices[pos]].1,
165            alias,
166            serde_json::json!(current_rank),
167        );
168    }
169}
170
171fn apply_dense_rank(
172    rows: &mut [(String, serde_json::Value)],
173    indices: &[usize],
174    alias: &str,
175    order_by: &[(String, bool)],
176) {
177    if indices.is_empty() {
178        return;
179    }
180    let mut current_rank = 1;
181    set_window_col(&mut rows[indices[0]].1, alias, serde_json::json!(1));
182
183    for pos in 1..indices.len() {
184        let prev = &rows[indices[pos - 1]].1;
185        let curr = &rows[indices[pos]].1;
186        let same = order_by
187            .iter()
188            .all(|(col, _)| get_field(prev, col) == get_field(curr, col));
189        if !same {
190            current_rank += 1;
191        }
192        set_window_col(
193            &mut rows[indices[pos]].1,
194            alias,
195            serde_json::json!(current_rank),
196        );
197    }
198}
199
200fn apply_ntile(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
201    let n = spec
202        .args
203        .first()
204        .and_then(|e| {
205            if let SqlExpr::Literal(v) = e {
206                as_f64(v).map(|x| x as usize)
207            } else {
208                None
209            }
210        })
211        .unwrap_or(1)
212        .max(1);
213    let total = indices.len();
214    for (pos, &i) in indices.iter().enumerate() {
215        let bucket = (pos * n / total) + 1;
216        set_window_col(&mut rows[i].1, &spec.alias, serde_json::json!(bucket));
217    }
218}
219
220fn apply_lag(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
221    let field = spec
222        .args
223        .first()
224        .and_then(|e| {
225            if let SqlExpr::Column(c) = e {
226                Some(c.as_str())
227            } else {
228                None
229            }
230        })
231        .unwrap_or("*");
232    let offset = spec
233        .args
234        .get(1)
235        .and_then(|e| {
236            if let SqlExpr::Literal(v) = e {
237                as_f64(v).map(|n| n as usize)
238            } else {
239                None
240            }
241        })
242        .unwrap_or(1);
243    let default = spec
244        .args
245        .get(2)
246        .and_then(|e| {
247            if let SqlExpr::Literal(v) = e {
248                Some(v.clone())
249            } else {
250                None
251            }
252        })
253        .unwrap_or(serde_json::Value::Null);
254
255    for (pos, &i) in indices.iter().enumerate() {
256        let val = if pos >= offset {
257            get_field(&rows[indices[pos - offset]].1, field)
258        } else {
259            default.clone()
260        };
261        set_window_col(&mut rows[i].1, &spec.alias, val);
262    }
263}
264
265fn apply_lead(rows: &mut [(String, serde_json::Value)], indices: &[usize], spec: &WindowFuncSpec) {
266    let field = spec
267        .args
268        .first()
269        .and_then(|e| {
270            if let SqlExpr::Column(c) = e {
271                Some(c.as_str())
272            } else {
273                None
274            }
275        })
276        .unwrap_or("*");
277    let offset = spec
278        .args
279        .get(1)
280        .and_then(|e| {
281            if let SqlExpr::Literal(v) = e {
282                as_f64(v).map(|n| n as usize)
283            } else {
284                None
285            }
286        })
287        .unwrap_or(1);
288    let default = spec
289        .args
290        .get(2)
291        .and_then(|e| {
292            if let SqlExpr::Literal(v) = e {
293                Some(v.clone())
294            } else {
295                None
296            }
297        })
298        .unwrap_or(serde_json::Value::Null);
299
300    for (pos, &i) in indices.iter().enumerate() {
301        let val = if pos + offset < indices.len() {
302            get_field(&rows[indices[pos + offset]].1, field)
303        } else {
304            default.clone()
305        };
306        set_window_col(&mut rows[i].1, &spec.alias, val);
307    }
308}
309
310fn apply_aggregate_window(
311    rows: &mut [(String, serde_json::Value)],
312    indices: &[usize],
313    spec: &WindowFuncSpec,
314) {
315    let field = spec
316        .args
317        .first()
318        .and_then(|e| {
319            if let SqlExpr::Column(c) = e {
320                Some(c.as_str())
321            } else {
322                None
323            }
324        })
325        .unwrap_or("*");
326
327    let use_running = spec.frame.mode == "range"
328        && matches!(spec.frame.start, FrameBound::UnboundedPreceding)
329        && matches!(spec.frame.end, FrameBound::CurrentRow);
330
331    if use_running {
332        let mut running_sum = 0.0f64;
333        let mut running_count = 0u64;
334        let mut running_min: Option<f64> = None;
335        let mut running_max: Option<f64> = None;
336
337        for (pos, &i) in indices.iter().enumerate() {
338            let val = get_field(&rows[i].1, field);
339            if let Some(n) = as_f64(&val) {
340                running_sum += n;
341                running_count += 1;
342                running_min = Some(running_min.map_or(n, |m: f64| m.min(n)));
343                running_max = Some(running_max.map_or(n, |m: f64| m.max(n)));
344            } else if spec.func_name == "count" {
345                running_count += 1;
346            }
347
348            let result = match spec.func_name.as_str() {
349                "sum" => serde_json::json!(running_sum),
350                "count" => serde_json::json!(running_count),
351                "avg" => {
352                    if running_count > 0 {
353                        serde_json::json!(running_sum / running_count as f64)
354                    } else {
355                        serde_json::Value::Null
356                    }
357                }
358                "min" => running_min
359                    .map(|m| serde_json::json!(m))
360                    .unwrap_or(serde_json::Value::Null),
361                "max" => running_max
362                    .map(|m| serde_json::json!(m))
363                    .unwrap_or(serde_json::Value::Null),
364                "first_value" => get_field(&rows[indices[0]].1, field),
365                "last_value" => get_field(&rows[indices[pos]].1, field),
366                _ => serde_json::Value::Null,
367            };
368            set_window_col(&mut rows[i].1, &spec.alias, result);
369        }
370    } else {
371        let values: Vec<f64> = indices
372            .iter()
373            .filter_map(|&i| as_f64(&get_field(&rows[i].1, field)))
374            .collect();
375
376        let result = match spec.func_name.as_str() {
377            "sum" => serde_json::json!(values.iter().sum::<f64>()),
378            "count" => serde_json::json!(indices.len()),
379            "avg" => {
380                if values.is_empty() {
381                    serde_json::Value::Null
382                } else {
383                    serde_json::json!(values.iter().sum::<f64>() / values.len() as f64)
384                }
385            }
386            "min" => values
387                .iter()
388                .copied()
389                .reduce(f64::min)
390                .map(|m| serde_json::json!(m))
391                .unwrap_or(serde_json::Value::Null),
392            "max" => values
393                .iter()
394                .copied()
395                .reduce(f64::max)
396                .map(|m| serde_json::json!(m))
397                .unwrap_or(serde_json::Value::Null),
398            "first_value" => get_field(&rows[indices[0]].1, field),
399            "last_value" => get_field(&rows[*indices.last().unwrap()].1, field),
400            _ => serde_json::Value::Null,
401        };
402
403        for &i in indices {
404            set_window_col(&mut rows[i].1, &spec.alias, result.clone());
405        }
406    }
407}
408
409#[cfg(test)]
410mod tests {
411    use super::*;
412    use serde_json::json;
413
414    fn make_rows() -> Vec<(String, serde_json::Value)> {
415        vec![
416            (
417                "1".into(),
418                json!({"dept": "eng", "salary": 100, "name": "Alice"}),
419            ),
420            (
421                "2".into(),
422                json!({"dept": "eng", "salary": 120, "name": "Bob"}),
423            ),
424            (
425                "3".into(),
426                json!({"dept": "eng", "salary": 90, "name": "Carol"}),
427            ),
428            (
429                "4".into(),
430                json!({"dept": "sales", "salary": 80, "name": "Dave"}),
431            ),
432            (
433                "5".into(),
434                json!({"dept": "sales", "salary": 110, "name": "Eve"}),
435            ),
436        ]
437    }
438
439    #[test]
440    fn row_number_single_partition() {
441        let mut rows = make_rows();
442        let spec = WindowFuncSpec {
443            alias: "rn".into(),
444            func_name: "row_number".into(),
445            args: vec![],
446            partition_by: vec![],
447            order_by: vec![],
448            frame: WindowFrame::default(),
449        };
450        evaluate_window_functions(&mut rows, &[spec]);
451        assert_eq!(rows[0].1["rn"], json!(1));
452        assert_eq!(rows[4].1["rn"], json!(5));
453    }
454
455    #[test]
456    fn row_number_partitioned() {
457        let mut rows = make_rows();
458        let spec = WindowFuncSpec {
459            alias: "rn".into(),
460            func_name: "row_number".into(),
461            args: vec![],
462            partition_by: vec!["dept".into()],
463            order_by: vec![],
464            frame: WindowFrame::default(),
465        };
466        evaluate_window_functions(&mut rows, &[spec]);
467        assert_eq!(rows[0].1["rn"], json!(1));
468        assert_eq!(rows[2].1["rn"], json!(3));
469        assert_eq!(rows[3].1["rn"], json!(1));
470        assert_eq!(rows[4].1["rn"], json!(2));
471    }
472
473    #[test]
474    fn running_sum() {
475        let mut rows = make_rows();
476        let spec = WindowFuncSpec {
477            alias: "running_total".into(),
478            func_name: "sum".into(),
479            args: vec![SqlExpr::Column("salary".into())],
480            partition_by: vec!["dept".into()],
481            order_by: vec![("salary".into(), true)],
482            frame: WindowFrame::default(),
483        };
484        evaluate_window_functions(&mut rows, &[spec]);
485        assert_eq!(rows[0].1["running_total"], json!(100.0));
486        assert_eq!(rows[1].1["running_total"], json!(220.0));
487        assert_eq!(rows[2].1["running_total"], json!(310.0));
488        assert_eq!(rows[3].1["running_total"], json!(80.0));
489        assert_eq!(rows[4].1["running_total"], json!(190.0));
490    }
491}