Skip to main content

kimberlite_query/
window.rs

1//! AUDIT-2026-04 S3.2 — SQL window functions.
2//!
3//! Supports `ROW_NUMBER()`, `RANK()`, `DENSE_RANK()`, `LAG()`,
4//! `LEAD()`, `FIRST_VALUE()`, and `LAST_VALUE()` with `PARTITION BY`
5//! and `ORDER BY`. No frame clauses (the default frame for ranking
6//! functions is `RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`,
7//! which is what these implementations apply).
8//!
9//! # Execution model
10//!
11//! Window functions execute as a *post-pass* over the rows produced
12//! by the underlying SELECT. This sidesteps the need for a
13//! `Plan::Window` node and keeps the change additive — see
14//! `apply_window_fns` in `lib.rs`.
15//!
16//! Pseudo-code:
17//!
18//! ```text
19//! for fn in window_fns:
20//!     sort rows by (partition_keys ++ order_keys)
21//!     iterate rows once:
22//!         on partition boundary: reset rank counters
23//!         compute fn value, append to row
24//! ```
25//!
26//! Determinism: the sort uses a stable comparator over typed values
27//! (the `sort_rows` helper from executor.rs), so two equal rows
28//! retain their original order — a property `LAG`/`LEAD` rely on.
29
30use std::cmp::Ordering;
31
32use crate::error::{QueryError, Result};
33use crate::executor::{QueryResult, Row};
34use crate::parser::ParsedWindowFn;
35use crate::schema::ColumnName;
36use crate::value::Value;
37
38/// Window function operations supported by the engine.
39///
40/// `ROW_NUMBER` / `RANK` / `DENSE_RANK` are pure ranking functions
41/// (no args). `LAG` / `LEAD` look at a sibling row offset
42/// (default 1). `FIRST_VALUE` / `LAST_VALUE` return the column at
43/// the partition boundary.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum WindowFunction {
46    RowNumber,
47    Rank,
48    DenseRank,
49    /// `LAG(column, offset = 1)` — value `offset` rows back, NULL
50    /// if before partition start.
51    Lag {
52        column: ColumnName,
53        offset: usize,
54    },
55    /// `LEAD(column, offset = 1)` — value `offset` rows forward,
56    /// NULL if past partition end.
57    Lead {
58        column: ColumnName,
59        offset: usize,
60    },
61    /// `FIRST_VALUE(column)` — value of `column` at the first row
62    /// of the current partition (under the ORDER BY).
63    FirstValue {
64        column: ColumnName,
65    },
66    /// `LAST_VALUE(column)` — value of `column` at the last row
67    /// of the current partition. Per ANSI default frame, "last
68    /// row" here means the *current* row — so we treat
69    /// `LAST_VALUE` with no explicit frame as "value of column on
70    /// the current row". Postgres parity in `tests/`.
71    LastValue {
72        column: ColumnName,
73    },
74}
75
76impl WindowFunction {
77    /// Output column name when no alias is present.
78    pub fn default_alias(&self) -> &'static str {
79        match self {
80            Self::RowNumber => "row_number",
81            Self::Rank => "rank",
82            Self::DenseRank => "dense_rank",
83            Self::Lag { .. } => "lag",
84            Self::Lead { .. } => "lead",
85            Self::FirstValue { .. } => "first_value",
86            Self::LastValue { .. } => "last_value",
87        }
88    }
89}
90
91/// Apply each window function in order to the rows produced by the
92/// underlying SELECT. Returns a new [`QueryResult`] with the
93/// window-function output columns appended in left-to-right order.
94///
95/// `result.rows` is consumed. The base columns are preserved at
96/// their original positions; window output columns are appended in
97/// the order the parser saw them.
98pub fn apply_window_fns(base: QueryResult, window_fns: &[ParsedWindowFn]) -> Result<QueryResult> {
99    if window_fns.is_empty() {
100        return Ok(base);
101    }
102
103    // Resolve column indices for each window fn's
104    // partition_by + order_by + arg references against base.columns.
105    let columns_idx = build_column_index(&base.columns);
106
107    let QueryResult { columns, rows } = base;
108    let mut out_columns = columns.clone();
109
110    // Each window fn produces one new column. Compute one fn at a
111    // time; the algorithm needs the rows sorted by that fn's
112    // (partition_by ++ order_by), so re-sort per fn. For
113    // partition_by = [] and order_by = [] (whole-table frame) the
114    // sort is a no-op.
115    let mut work_rows = rows;
116    let original_index_col = work_rows.len(); // sentinel marker (unused)
117    let _ = original_index_col;
118
119    // Stamp each row with its original index so we can restore
120    // ordering at the end. The original column positions stay
121    // unchanged; we only append window-fn output columns.
122    let mut indexed: Vec<(usize, Row)> = work_rows.drain(..).enumerate().collect();
123
124    for win in window_fns {
125        let fn_col = compute_window_column(win, &mut indexed, &columns_idx)?;
126        out_columns.push(ColumnName::new(
127            win.alias
128                .clone()
129                .unwrap_or_else(|| win.function.default_alias().to_string()),
130        ));
131        for ((_, row), val) in indexed.iter_mut().zip(fn_col.into_iter()) {
132            row.push(val);
133        }
134    }
135
136    // Restore original input order so callers see rows in the
137    // pre-window position (the SELECT's own ORDER BY, if any, ran
138    // before this point).
139    indexed.sort_by_key(|(idx, _)| *idx);
140    let final_rows = indexed.into_iter().map(|(_, r)| r).collect();
141
142    Ok(QueryResult {
143        columns: out_columns,
144        rows: final_rows,
145    })
146}
147
148/// Resolve column name → row index for the base columns.
149fn build_column_index(columns: &[ColumnName]) -> Vec<(String, usize)> {
150    columns
151        .iter()
152        .enumerate()
153        .map(|(i, c)| (c.as_str().to_string(), i))
154        .collect()
155}
156
157fn lookup_col(idx: &[(String, usize)], name: &str) -> Result<usize> {
158    idx.iter()
159        .find(|(n, _)| n == name)
160        .map(|(_, i)| *i)
161        .ok_or_else(|| {
162            QueryError::ParseError(format!(
163                "window function references unknown column '{name}'"
164            ))
165        })
166}
167
168/// Compute the window-function column for `win` over `indexed_rows`.
169///
170/// Mutates `indexed_rows` (re-sorts by partition + order) so the
171/// caller can append the resulting Vec<Value> column-wise.
172fn compute_window_column(
173    win: &ParsedWindowFn,
174    indexed_rows: &mut [(usize, Row)],
175    columns_idx: &[(String, usize)],
176) -> Result<Vec<Value>> {
177    // Resolve indices once.
178    let partition_idx: Vec<usize> = win
179        .partition_by
180        .iter()
181        .map(|c| lookup_col(columns_idx, c.as_str()))
182        .collect::<Result<_>>()?;
183    let order_idx: Vec<(usize, bool)> = win
184        .order_by
185        .iter()
186        .map(|c| Ok((lookup_col(columns_idx, c.column.as_str())?, c.ascending)))
187        .collect::<Result<_>>()?;
188
189    indexed_rows
190        .sort_by(|(_, a), (_, b)| compare_partition_then_order(a, b, &partition_idx, &order_idx));
191
192    let n = indexed_rows.len();
193    let mut out = vec![Value::Null; n];
194
195    let mut row_num: i64 = 0;
196    let mut rank: i64 = 0;
197    let mut dense_rank: i64 = 0;
198    let mut last_partition_key: Option<Vec<Value>> = None;
199    let mut last_order_key: Option<Vec<Value>> = None;
200
201    for i in 0..n {
202        let row = &indexed_rows[i].1;
203        let part_key: Vec<Value> = partition_idx.iter().map(|&j| row[j].clone()).collect();
204        let ord_key: Vec<Value> = order_idx.iter().map(|&(j, _)| row[j].clone()).collect();
205
206        let new_partition = last_partition_key.as_ref() != Some(&part_key);
207        if new_partition {
208            row_num = 0;
209            rank = 0;
210            dense_rank = 0;
211            last_partition_key = Some(part_key.clone());
212            last_order_key = None;
213        }
214
215        row_num += 1;
216        let order_changed = last_order_key.as_ref() != Some(&ord_key);
217        if order_changed {
218            rank = row_num;
219            dense_rank += 1;
220            last_order_key = Some(ord_key.clone());
221        }
222
223        out[i] = match &win.function {
224            WindowFunction::RowNumber => Value::BigInt(row_num),
225            WindowFunction::Rank => Value::BigInt(rank),
226            WindowFunction::DenseRank => Value::BigInt(dense_rank),
227            WindowFunction::Lag { column, offset } => lookup_offset(
228                indexed_rows,
229                columns_idx,
230                &partition_idx,
231                column,
232                i,
233                -(*offset as isize),
234            )?,
235            WindowFunction::Lead { column, offset } => lookup_offset(
236                indexed_rows,
237                columns_idx,
238                &partition_idx,
239                column,
240                i,
241                *offset as isize,
242            )?,
243            WindowFunction::FirstValue { column } => {
244                first_in_partition(indexed_rows, columns_idx, column, i, &partition_idx)?
245            }
246            WindowFunction::LastValue { column } => {
247                // ANSI default frame for LAST_VALUE without an
248                // explicit frame is the current row — mirror that.
249                let col_i = lookup_col(columns_idx, column.as_str())?;
250                indexed_rows[i].1[col_i].clone()
251            }
252        };
253    }
254    Ok(out)
255}
256
257fn compare_partition_then_order(
258    a: &Row,
259    b: &Row,
260    partition_idx: &[usize],
261    order_idx: &[(usize, bool)],
262) -> Ordering {
263    for &j in partition_idx {
264        match cmp_values(&a[j], &b[j]) {
265            Ordering::Equal => {}
266            other => return other,
267        }
268    }
269    for &(j, asc) in order_idx {
270        let ord = cmp_values(&a[j], &b[j]);
271        match ord {
272            Ordering::Equal => {}
273            other => return if asc { other } else { other.reverse() },
274        }
275    }
276    Ordering::Equal
277}
278
279/// Best-effort total order over `Value`. NULLs sort first, mirroring
280/// PostgreSQL's `NULLS FIRST` ascending default.
281///
282/// Arms look identical (`x.cmp(y)`) across integer-family variants but
283/// each is intentionally its own arm — the binding type differs
284/// (`i64` / `i32` / `i16` / `i8` / timestamp-ns), and merging them
285/// would tangle semantic equality. Silence the lint here.
286#[allow(clippy::match_same_arms)]
287fn cmp_values(a: &Value, b: &Value) -> Ordering {
288    use Value::{BigInt, Boolean, Date, Integer, Null, Real, SmallInt, Text, Time, TinyInt};
289    match (a, b) {
290        (Null, Null) => Ordering::Equal,
291        (Null, _) => Ordering::Less,
292        (_, Null) => Ordering::Greater,
293        (BigInt(x), BigInt(y)) => x.cmp(y),
294        (Integer(x), Integer(y)) => x.cmp(y),
295        (SmallInt(x), SmallInt(y)) => x.cmp(y),
296        (TinyInt(x), TinyInt(y)) => x.cmp(y),
297        (Real(x), Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
298        (Text(x), Text(y)) => x.cmp(y),
299        (Boolean(x), Boolean(y)) => x.cmp(y),
300        (Date(x), Date(y)) => x.cmp(y),
301        (Time(x), Time(y)) => x.cmp(y),
302        // Cross-type or unhandled: fall back to debug-string compare so
303        // sort is total. Real-world window queries don't hit this since
304        // schema enforces typed columns; the fallback exists for safety.
305        (lhs, rhs) => format!("{lhs:?}").cmp(&format!("{rhs:?}")),
306    }
307}
308
309fn lookup_offset(
310    indexed: &[(usize, Row)],
311    columns_idx: &[(String, usize)],
312    partition_idx: &[usize],
313    column: &ColumnName,
314    i: usize,
315    delta: isize,
316) -> Result<Value> {
317    let col_i = lookup_col(columns_idx, column.as_str())?;
318    let target_pos = i as isize + delta;
319    if target_pos < 0 || (target_pos as usize) >= indexed.len() {
320        return Ok(Value::Null);
321    }
322    let target = target_pos as usize;
323    // Partition-boundary check: LAG/LEAD must NOT cross partition
324    // boundaries — a row at the start of partition B should report
325    // NULL even though indexed[i-1] holds the last row of A.
326    if !same_partition(&indexed[i].1, &indexed[target].1, partition_idx) {
327        return Ok(Value::Null);
328    }
329    Ok(indexed[target].1[col_i].clone())
330}
331
332fn same_partition(a: &Row, b: &Row, partition_idx: &[usize]) -> bool {
333    partition_idx.iter().all(|&j| a[j] == b[j])
334}
335
336fn first_in_partition(
337    indexed: &[(usize, Row)],
338    columns_idx: &[(String, usize)],
339    column: &ColumnName,
340    i: usize,
341    partition_idx: &[usize],
342) -> Result<Value> {
343    let col_i = lookup_col(columns_idx, column.as_str())?;
344    let current_part: Vec<Value> = partition_idx
345        .iter()
346        .map(|&j| indexed[i].1[j].clone())
347        .collect();
348    // Walk back to the partition start.
349    let mut start = i;
350    while start > 0 {
351        let prev_part: Vec<Value> = partition_idx
352            .iter()
353            .map(|&j| indexed[start - 1].1[j].clone())
354            .collect();
355        if prev_part != current_part {
356            break;
357        }
358        start -= 1;
359    }
360    Ok(indexed[start].1[col_i].clone())
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use crate::parser::OrderByClause;
367    use crate::schema::ColumnName;
368
369    fn cols(names: &[&str]) -> Vec<ColumnName> {
370        names.iter().map(|n| ColumnName::new(*n)).collect()
371    }
372
373    fn row(vals: Vec<Value>) -> Row {
374        vals
375    }
376
377    fn order_asc(name: &str) -> OrderByClause {
378        OrderByClause {
379            column: ColumnName::new(name),
380            ascending: true,
381        }
382    }
383
384    #[test]
385    fn row_number_no_partition_no_order_assigns_1_to_n_in_input_order() {
386        let qr = QueryResult {
387            columns: cols(&["id"]),
388            rows: vec![row(vec![Value::BigInt(10)]), row(vec![Value::BigInt(20)])],
389        };
390        let win = ParsedWindowFn {
391            function: WindowFunction::RowNumber,
392            partition_by: vec![],
393            order_by: vec![],
394            alias: None,
395        };
396        let out = apply_window_fns(qr, &[win]).expect("apply");
397        assert_eq!(out.columns.len(), 2);
398        assert_eq!(out.rows[0][1], Value::BigInt(1));
399        assert_eq!(out.rows[1][1], Value::BigInt(2));
400    }
401
402    #[test]
403    fn row_number_resets_per_partition() {
404        let qr = QueryResult {
405            columns: cols(&["dept", "salary"]),
406            rows: vec![
407                row(vec![Value::Text("A".into()), Value::BigInt(100)]),
408                row(vec![Value::Text("B".into()), Value::BigInt(200)]),
409                row(vec![Value::Text("A".into()), Value::BigInt(150)]),
410                row(vec![Value::Text("B".into()), Value::BigInt(250)]),
411            ],
412        };
413        let win = ParsedWindowFn {
414            function: WindowFunction::RowNumber,
415            partition_by: vec![ColumnName::new("dept")],
416            order_by: vec![order_asc("salary")],
417            alias: Some("rn".into()),
418        };
419        let out = apply_window_fns(qr, &[win]).expect("apply");
420        // Rows preserved in input order; locate by (dept, salary).
421        let map: std::collections::HashMap<(String, i64), i64> = out
422            .rows
423            .iter()
424            .map(|r| {
425                let dept = match &r[0] {
426                    Value::Text(s) => s.clone(),
427                    _ => panic!(),
428                };
429                let salary = match &r[1] {
430                    Value::BigInt(i) => *i,
431                    _ => panic!(),
432                };
433                let rn = match &r[2] {
434                    Value::BigInt(i) => *i,
435                    _ => panic!(),
436                };
437                ((dept, salary), rn)
438            })
439            .collect();
440        // A's lowest salary (100) → rn=1; A's next (150) → rn=2.
441        assert_eq!(map.get(&("A".into(), 100)), Some(&1));
442        assert_eq!(map.get(&("A".into(), 150)), Some(&2));
443        assert_eq!(map.get(&("B".into(), 200)), Some(&1));
444        assert_eq!(map.get(&("B".into(), 250)), Some(&2));
445    }
446
447    #[test]
448    fn rank_and_dense_rank_distinguish_ties() {
449        // Three rows with salaries 100, 100, 200 — RANK = 1, 1, 3;
450        // DENSE_RANK = 1, 1, 2. PostgreSQL parity.
451        let qr = QueryResult {
452            columns: cols(&["salary"]),
453            rows: vec![
454                row(vec![Value::BigInt(100)]),
455                row(vec![Value::BigInt(100)]),
456                row(vec![Value::BigInt(200)]),
457            ],
458        };
459        let win_rank = ParsedWindowFn {
460            function: WindowFunction::Rank,
461            partition_by: vec![],
462            order_by: vec![order_asc("salary")],
463            alias: Some("r".into()),
464        };
465        let win_dense = ParsedWindowFn {
466            function: WindowFunction::DenseRank,
467            partition_by: vec![],
468            order_by: vec![order_asc("salary")],
469            alias: Some("dr".into()),
470        };
471        let out = apply_window_fns(qr, &[win_rank, win_dense]).expect("apply");
472        // After post-pass the rows are restored to input order.
473        // Salary 100 (twice) → r=1, dr=1; salary 200 → r=3, dr=2.
474        for r in &out.rows {
475            let salary = match &r[0] {
476                Value::BigInt(i) => *i,
477                _ => panic!(),
478            };
479            let rank = match &r[1] {
480                Value::BigInt(i) => *i,
481                _ => panic!(),
482            };
483            let dense = match &r[2] {
484                Value::BigInt(i) => *i,
485                _ => panic!(),
486            };
487            if salary == 100 {
488                assert_eq!(rank, 1, "rank ties");
489                assert_eq!(dense, 1, "dense_rank ties");
490            } else {
491                assert_eq!(rank, 3, "rank skips after ties");
492                assert_eq!(dense, 2, "dense_rank does not skip");
493            }
494        }
495    }
496
497    #[test]
498    fn first_value_returns_partition_start_value() {
499        let qr = QueryResult {
500            columns: cols(&["dept", "salary"]),
501            rows: vec![
502                row(vec![Value::Text("A".into()), Value::BigInt(300)]),
503                row(vec![Value::Text("A".into()), Value::BigInt(100)]),
504                row(vec![Value::Text("A".into()), Value::BigInt(200)]),
505            ],
506        };
507        let win = ParsedWindowFn {
508            function: WindowFunction::FirstValue {
509                column: ColumnName::new("salary"),
510            },
511            partition_by: vec![ColumnName::new("dept")],
512            order_by: vec![order_asc("salary")],
513            alias: Some("first".into()),
514        };
515        let out = apply_window_fns(qr, &[win]).expect("apply");
516        // All three rows must report the partition's lowest salary.
517        for r in &out.rows {
518            assert_eq!(r[2], Value::BigInt(100));
519        }
520    }
521
522    #[test]
523    fn lag_returns_null_at_partition_start() {
524        let qr = QueryResult {
525            columns: cols(&["id"]),
526            rows: vec![
527                row(vec![Value::BigInt(10)]),
528                row(vec![Value::BigInt(20)]),
529                row(vec![Value::BigInt(30)]),
530            ],
531        };
532        let win = ParsedWindowFn {
533            function: WindowFunction::Lag {
534                column: ColumnName::new("id"),
535                offset: 1,
536            },
537            partition_by: vec![],
538            order_by: vec![order_asc("id")],
539            alias: Some("prev".into()),
540        };
541        let out = apply_window_fns(qr, &[win]).expect("apply");
542        // After sorting by id asc and reapplying input order, the
543        // row with id=10 (first by id) gets NULL, id=20 gets 10,
544        // id=30 gets 20.
545        let map: std::collections::HashMap<i64, Value> = out
546            .rows
547            .iter()
548            .map(|r| {
549                let id = match &r[0] {
550                    Value::BigInt(i) => *i,
551                    _ => panic!(),
552                };
553                (id, r[1].clone())
554            })
555            .collect();
556        assert_eq!(map[&10], Value::Null, "first row lag is NULL");
557        assert_eq!(map[&20], Value::BigInt(10));
558        assert_eq!(map[&30], Value::BigInt(20));
559    }
560}