Skip to main content

kimberlite_query/
window.rs

1//! AUDIT-2026-04 S3.2 — SQL window functions.
2//!
3//! Supports `ROW_NUMBER()`, `RANK()`, `DENSE_RANK()`, `LAG()`,
4//! `LEAD()`, `FIRST_VALUE()`, and `LAST_VALUE()` with `PARTITION BY`
5//! and `ORDER BY`. No frame clauses (the default frame for ranking
6//! functions is `RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW`,
7//! which is what these implementations apply).
8//!
9//! # Execution model
10//!
11//! Window functions execute as a *post-pass* over the rows produced
12//! by the underlying SELECT. This sidesteps the need for a
13//! `Plan::Window` node and keeps the change additive — see
14//! `apply_window_fns` in `lib.rs`.
15//!
16//! Pseudo-code:
17//!
18//! ```text
19//! for fn in window_fns:
20//!     sort rows by (partition_keys ++ order_keys)
21//!     iterate rows once:
22//!         on partition boundary: reset rank counters
23//!         compute fn value, append to row
24//! ```
25//!
26//! Determinism: the sort uses a stable comparator over typed values
27//! (the `sort_rows` helper from executor.rs), so two equal rows
28//! retain their original order — a property `LAG`/`LEAD` rely on.
29
30use std::cmp::Ordering;
31
32use crate::error::{QueryError, Result};
33use crate::executor::{QueryResult, Row};
34use crate::parser::ParsedWindowFn;
35use crate::schema::ColumnName;
36use crate::value::Value;
37
38/// Window function operations supported by the engine.
39///
40/// `ROW_NUMBER` / `RANK` / `DENSE_RANK` are pure ranking functions
41/// (no args). `LAG` / `LEAD` look at a sibling row offset
42/// (default 1). `FIRST_VALUE` / `LAST_VALUE` return the column at
43/// the partition boundary.
44#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum WindowFunction {
46    RowNumber,
47    Rank,
48    DenseRank,
49    /// `LAG(column, offset = 1)` — value `offset` rows back, NULL
50    /// if before partition start.
51    Lag { column: ColumnName, offset: usize },
52    /// `LEAD(column, offset = 1)` — value `offset` rows forward,
53    /// NULL if past partition end.
54    Lead { column: ColumnName, offset: usize },
55    /// `FIRST_VALUE(column)` — value of `column` at the first row
56    /// of the current partition (under the ORDER BY).
57    FirstValue { column: ColumnName },
58    /// `LAST_VALUE(column)` — value of `column` at the last row
59    /// of the current partition. Per ANSI default frame, "last
60    /// row" here means the *current* row — so we treat
61    /// `LAST_VALUE` with no explicit frame as "value of column on
62    /// the current row". Postgres parity in `tests/`.
63    LastValue { column: ColumnName },
64}
65
66impl WindowFunction {
67    /// Output column name when no alias is present.
68    pub fn default_alias(&self) -> &'static str {
69        match self {
70            Self::RowNumber => "row_number",
71            Self::Rank => "rank",
72            Self::DenseRank => "dense_rank",
73            Self::Lag { .. } => "lag",
74            Self::Lead { .. } => "lead",
75            Self::FirstValue { .. } => "first_value",
76            Self::LastValue { .. } => "last_value",
77        }
78    }
79}
80
81/// Apply each window function in order to the rows produced by the
82/// underlying SELECT. Returns a new [`QueryResult`] with the
83/// window-function output columns appended in left-to-right order.
84///
85/// `result.rows` is consumed. The base columns are preserved at
86/// their original positions; window output columns are appended in
87/// the order the parser saw them.
88pub fn apply_window_fns(
89    base: QueryResult,
90    window_fns: &[ParsedWindowFn],
91) -> Result<QueryResult> {
92    if window_fns.is_empty() {
93        return Ok(base);
94    }
95
96    // Resolve column indices for each window fn's
97    // partition_by + order_by + arg references against base.columns.
98    let columns_idx = build_column_index(&base.columns);
99
100    let QueryResult { columns, rows } = base;
101    let mut out_columns = columns.clone();
102
103    // Each window fn produces one new column. Compute one fn at a
104    // time; the algorithm needs the rows sorted by that fn's
105    // (partition_by ++ order_by), so re-sort per fn. For
106    // partition_by = [] and order_by = [] (whole-table frame) the
107    // sort is a no-op.
108    let mut work_rows = rows;
109    let original_index_col = work_rows.len(); // sentinel marker (unused)
110    let _ = original_index_col;
111
112    // Stamp each row with its original index so we can restore
113    // ordering at the end. The original column positions stay
114    // unchanged; we only append window-fn output columns.
115    let mut indexed: Vec<(usize, Row)> = work_rows.drain(..).enumerate().collect();
116
117    for win in window_fns {
118        let fn_col = compute_window_column(win, &mut indexed, &columns_idx)?;
119        out_columns.push(ColumnName::new(
120            win.alias
121                .clone()
122                .unwrap_or_else(|| win.function.default_alias().to_string()),
123        ));
124        for ((_, row), val) in indexed.iter_mut().zip(fn_col.into_iter()) {
125            row.push(val);
126        }
127    }
128
129    // Restore original input order so callers see rows in the
130    // pre-window position (the SELECT's own ORDER BY, if any, ran
131    // before this point).
132    indexed.sort_by_key(|(idx, _)| *idx);
133    let final_rows = indexed.into_iter().map(|(_, r)| r).collect();
134
135    Ok(QueryResult {
136        columns: out_columns,
137        rows: final_rows,
138    })
139}
140
141/// Resolve column name → row index for the base columns.
142fn build_column_index(columns: &[ColumnName]) -> Vec<(String, usize)> {
143    columns
144        .iter()
145        .enumerate()
146        .map(|(i, c)| (c.as_str().to_string(), i))
147        .collect()
148}
149
150fn lookup_col(idx: &[(String, usize)], name: &str) -> Result<usize> {
151    idx.iter()
152        .find(|(n, _)| n == name)
153        .map(|(_, i)| *i)
154        .ok_or_else(|| {
155            QueryError::ParseError(format!(
156                "window function references unknown column '{name}'"
157            ))
158        })
159}
160
161/// Compute the window-function column for `win` over `indexed_rows`.
162///
163/// Mutates `indexed_rows` (re-sorts by partition + order) so the
164/// caller can append the resulting Vec<Value> column-wise.
165fn compute_window_column(
166    win: &ParsedWindowFn,
167    indexed_rows: &mut [(usize, Row)],
168    columns_idx: &[(String, usize)],
169) -> Result<Vec<Value>> {
170    // Resolve indices once.
171    let partition_idx: Vec<usize> = win
172        .partition_by
173        .iter()
174        .map(|c| lookup_col(columns_idx, c.as_str()))
175        .collect::<Result<_>>()?;
176    let order_idx: Vec<(usize, bool)> = win
177        .order_by
178        .iter()
179        .map(|c| Ok((lookup_col(columns_idx, c.column.as_str())?, c.ascending)))
180        .collect::<Result<_>>()?;
181
182    indexed_rows.sort_by(|(_, a), (_, b)| {
183        compare_partition_then_order(a, b, &partition_idx, &order_idx)
184    });
185
186    let n = indexed_rows.len();
187    let mut out = vec![Value::Null; n];
188
189    let mut row_num: i64 = 0;
190    let mut rank: i64 = 0;
191    let mut dense_rank: i64 = 0;
192    let mut last_partition_key: Option<Vec<Value>> = None;
193    let mut last_order_key: Option<Vec<Value>> = None;
194
195    for i in 0..n {
196        let row = &indexed_rows[i].1;
197        let part_key: Vec<Value> = partition_idx.iter().map(|&j| row[j].clone()).collect();
198        let ord_key: Vec<Value> = order_idx.iter().map(|&(j, _)| row[j].clone()).collect();
199
200        let new_partition = last_partition_key.as_ref() != Some(&part_key);
201        if new_partition {
202            row_num = 0;
203            rank = 0;
204            dense_rank = 0;
205            last_partition_key = Some(part_key.clone());
206            last_order_key = None;
207        }
208
209        row_num += 1;
210        let order_changed = last_order_key.as_ref() != Some(&ord_key);
211        if order_changed {
212            rank = row_num;
213            dense_rank += 1;
214            last_order_key = Some(ord_key.clone());
215        }
216
217        out[i] = match &win.function {
218            WindowFunction::RowNumber => Value::BigInt(row_num),
219            WindowFunction::Rank => Value::BigInt(rank),
220            WindowFunction::DenseRank => Value::BigInt(dense_rank),
221            WindowFunction::Lag { column, offset } => lookup_offset(
222                indexed_rows,
223                columns_idx,
224                &partition_idx,
225                column,
226                i,
227                -(*offset as isize),
228            )?,
229            WindowFunction::Lead { column, offset } => lookup_offset(
230                indexed_rows,
231                columns_idx,
232                &partition_idx,
233                column,
234                i,
235                *offset as isize,
236            )?,
237            WindowFunction::FirstValue { column } => {
238                first_in_partition(indexed_rows, columns_idx, column, i, &partition_idx)?
239            }
240            WindowFunction::LastValue { column } => {
241                // ANSI default frame for LAST_VALUE without an
242                // explicit frame is the current row — mirror that.
243                let col_i = lookup_col(columns_idx, column.as_str())?;
244                indexed_rows[i].1[col_i].clone()
245            }
246        };
247    }
248    Ok(out)
249}
250
251fn compare_partition_then_order(
252    a: &Row,
253    b: &Row,
254    partition_idx: &[usize],
255    order_idx: &[(usize, bool)],
256) -> Ordering {
257    for &j in partition_idx {
258        match cmp_values(&a[j], &b[j]) {
259            Ordering::Equal => continue,
260            other => return other,
261        }
262    }
263    for &(j, asc) in order_idx {
264        let ord = cmp_values(&a[j], &b[j]);
265        match ord {
266            Ordering::Equal => continue,
267            other => return if asc { other } else { other.reverse() },
268        }
269    }
270    Ordering::Equal
271}
272
273/// Best-effort total order over `Value`. NULLs sort first, mirroring
274/// PostgreSQL's `NULLS FIRST` ascending default.
275fn cmp_values(a: &Value, b: &Value) -> Ordering {
276    use Value::*;
277    match (a, b) {
278        (Null, Null) => Ordering::Equal,
279        (Null, _) => Ordering::Less,
280        (_, Null) => Ordering::Greater,
281        (BigInt(x), BigInt(y)) => x.cmp(y),
282        (Integer(x), Integer(y)) => x.cmp(y),
283        (SmallInt(x), SmallInt(y)) => x.cmp(y),
284        (TinyInt(x), TinyInt(y)) => x.cmp(y),
285        (Real(x), Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
286        (Text(x), Text(y)) => x.cmp(y),
287        (Boolean(x), Boolean(y)) => x.cmp(y),
288        (Date(x), Date(y)) => x.cmp(y),
289        (Time(x), Time(y)) => x.cmp(y),
290        // Cross-type or unhandled: fall back to debug-string compare so
291        // sort is total. Real-world window queries don't hit this since
292        // schema enforces typed columns; the fallback exists for safety.
293        (lhs, rhs) => format!("{lhs:?}").cmp(&format!("{rhs:?}")),
294    }
295}
296
297fn lookup_offset(
298    indexed: &[(usize, Row)],
299    columns_idx: &[(String, usize)],
300    partition_idx: &[usize],
301    column: &ColumnName,
302    i: usize,
303    delta: isize,
304) -> Result<Value> {
305    let col_i = lookup_col(columns_idx, column.as_str())?;
306    let target_pos = i as isize + delta;
307    if target_pos < 0 || (target_pos as usize) >= indexed.len() {
308        return Ok(Value::Null);
309    }
310    let target = target_pos as usize;
311    // Partition-boundary check: LAG/LEAD must NOT cross partition
312    // boundaries — a row at the start of partition B should report
313    // NULL even though indexed[i-1] holds the last row of A.
314    if !same_partition(&indexed[i].1, &indexed[target].1, partition_idx) {
315        return Ok(Value::Null);
316    }
317    Ok(indexed[target].1[col_i].clone())
318}
319
320fn same_partition(a: &Row, b: &Row, partition_idx: &[usize]) -> bool {
321    partition_idx.iter().all(|&j| a[j] == b[j])
322}
323
324fn first_in_partition(
325    indexed: &[(usize, Row)],
326    columns_idx: &[(String, usize)],
327    column: &ColumnName,
328    i: usize,
329    partition_idx: &[usize],
330) -> Result<Value> {
331    let col_i = lookup_col(columns_idx, column.as_str())?;
332    let current_part: Vec<Value> = partition_idx
333        .iter()
334        .map(|&j| indexed[i].1[j].clone())
335        .collect();
336    // Walk back to the partition start.
337    let mut start = i;
338    while start > 0 {
339        let prev_part: Vec<Value> = partition_idx
340            .iter()
341            .map(|&j| indexed[start - 1].1[j].clone())
342            .collect();
343        if prev_part != current_part {
344            break;
345        }
346        start -= 1;
347    }
348    Ok(indexed[start].1[col_i].clone())
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use crate::parser::OrderByClause;
355    use crate::schema::ColumnName;
356
357    fn cols(names: &[&str]) -> Vec<ColumnName> {
358        names.iter().map(|n| ColumnName::new(*n)).collect()
359    }
360
361    fn row(vals: Vec<Value>) -> Row {
362        vals
363    }
364
365    fn order_asc(name: &str) -> OrderByClause {
366        OrderByClause {
367            column: ColumnName::new(name),
368            ascending: true,
369        }
370    }
371
372    #[test]
373    fn row_number_no_partition_no_order_assigns_1_to_n_in_input_order() {
374        let qr = QueryResult {
375            columns: cols(&["id"]),
376            rows: vec![row(vec![Value::BigInt(10)]), row(vec![Value::BigInt(20)])],
377        };
378        let win = ParsedWindowFn {
379            function: WindowFunction::RowNumber,
380            partition_by: vec![],
381            order_by: vec![],
382            alias: None,
383        };
384        let out = apply_window_fns(qr, &[win]).expect("apply");
385        assert_eq!(out.columns.len(), 2);
386        assert_eq!(out.rows[0][1], Value::BigInt(1));
387        assert_eq!(out.rows[1][1], Value::BigInt(2));
388    }
389
390    #[test]
391    fn row_number_resets_per_partition() {
392        let qr = QueryResult {
393            columns: cols(&["dept", "salary"]),
394            rows: vec![
395                row(vec![Value::Text("A".into()), Value::BigInt(100)]),
396                row(vec![Value::Text("B".into()), Value::BigInt(200)]),
397                row(vec![Value::Text("A".into()), Value::BigInt(150)]),
398                row(vec![Value::Text("B".into()), Value::BigInt(250)]),
399            ],
400        };
401        let win = ParsedWindowFn {
402            function: WindowFunction::RowNumber,
403            partition_by: vec![ColumnName::new("dept")],
404            order_by: vec![order_asc("salary")],
405            alias: Some("rn".into()),
406        };
407        let out = apply_window_fns(qr, &[win]).expect("apply");
408        // Rows preserved in input order; locate by (dept, salary).
409        let map: std::collections::HashMap<(String, i64), i64> = out
410            .rows
411            .iter()
412            .map(|r| {
413                let dept = match &r[0] {
414                    Value::Text(s) => s.clone(),
415                    _ => panic!(),
416                };
417                let salary = match &r[1] {
418                    Value::BigInt(i) => *i,
419                    _ => panic!(),
420                };
421                let rn = match &r[2] {
422                    Value::BigInt(i) => *i,
423                    _ => panic!(),
424                };
425                ((dept, salary), rn)
426            })
427            .collect();
428        // A's lowest salary (100) → rn=1; A's next (150) → rn=2.
429        assert_eq!(map.get(&("A".into(), 100)), Some(&1));
430        assert_eq!(map.get(&("A".into(), 150)), Some(&2));
431        assert_eq!(map.get(&("B".into(), 200)), Some(&1));
432        assert_eq!(map.get(&("B".into(), 250)), Some(&2));
433    }
434
435    #[test]
436    fn rank_and_dense_rank_distinguish_ties() {
437        // Three rows with salaries 100, 100, 200 — RANK = 1, 1, 3;
438        // DENSE_RANK = 1, 1, 2. PostgreSQL parity.
439        let qr = QueryResult {
440            columns: cols(&["salary"]),
441            rows: vec![
442                row(vec![Value::BigInt(100)]),
443                row(vec![Value::BigInt(100)]),
444                row(vec![Value::BigInt(200)]),
445            ],
446        };
447        let win_rank = ParsedWindowFn {
448            function: WindowFunction::Rank,
449            partition_by: vec![],
450            order_by: vec![order_asc("salary")],
451            alias: Some("r".into()),
452        };
453        let win_dense = ParsedWindowFn {
454            function: WindowFunction::DenseRank,
455            partition_by: vec![],
456            order_by: vec![order_asc("salary")],
457            alias: Some("dr".into()),
458        };
459        let out = apply_window_fns(qr, &[win_rank, win_dense]).expect("apply");
460        // After post-pass the rows are restored to input order.
461        // Salary 100 (twice) → r=1, dr=1; salary 200 → r=3, dr=2.
462        for r in &out.rows {
463            let salary = match &r[0] {
464                Value::BigInt(i) => *i,
465                _ => panic!(),
466            };
467            let rank = match &r[1] {
468                Value::BigInt(i) => *i,
469                _ => panic!(),
470            };
471            let dense = match &r[2] {
472                Value::BigInt(i) => *i,
473                _ => panic!(),
474            };
475            if salary == 100 {
476                assert_eq!(rank, 1, "rank ties");
477                assert_eq!(dense, 1, "dense_rank ties");
478            } else {
479                assert_eq!(rank, 3, "rank skips after ties");
480                assert_eq!(dense, 2, "dense_rank does not skip");
481            }
482        }
483    }
484
485    #[test]
486    fn first_value_returns_partition_start_value() {
487        let qr = QueryResult {
488            columns: cols(&["dept", "salary"]),
489            rows: vec![
490                row(vec![Value::Text("A".into()), Value::BigInt(300)]),
491                row(vec![Value::Text("A".into()), Value::BigInt(100)]),
492                row(vec![Value::Text("A".into()), Value::BigInt(200)]),
493            ],
494        };
495        let win = ParsedWindowFn {
496            function: WindowFunction::FirstValue {
497                column: ColumnName::new("salary"),
498            },
499            partition_by: vec![ColumnName::new("dept")],
500            order_by: vec![order_asc("salary")],
501            alias: Some("first".into()),
502        };
503        let out = apply_window_fns(qr, &[win]).expect("apply");
504        // All three rows must report the partition's lowest salary.
505        for r in &out.rows {
506            assert_eq!(r[2], Value::BigInt(100));
507        }
508    }
509
510    #[test]
511    fn lag_returns_null_at_partition_start() {
512        let qr = QueryResult {
513            columns: cols(&["id"]),
514            rows: vec![
515                row(vec![Value::BigInt(10)]),
516                row(vec![Value::BigInt(20)]),
517                row(vec![Value::BigInt(30)]),
518            ],
519        };
520        let win = ParsedWindowFn {
521            function: WindowFunction::Lag {
522                column: ColumnName::new("id"),
523                offset: 1,
524            },
525            partition_by: vec![],
526            order_by: vec![order_asc("id")],
527            alias: Some("prev".into()),
528        };
529        let out = apply_window_fns(qr, &[win]).expect("apply");
530        // After sorting by id asc and reapplying input order, the
531        // row with id=10 (first by id) gets NULL, id=20 gets 10,
532        // id=30 gets 20.
533        let map: std::collections::HashMap<i64, Value> = out
534            .rows
535            .iter()
536            .map(|r| {
537                let id = match &r[0] {
538                    Value::BigInt(i) => *i,
539                    _ => panic!(),
540                };
541                (id, r[1].clone())
542            })
543            .collect();
544        assert_eq!(map[&10], Value::Null, "first row lag is NULL");
545        assert_eq!(map[&20], Value::BigInt(10));
546        assert_eq!(map[&30], Value::BigInt(20));
547    }
548}