1use std::cmp::Ordering;
31
32use crate::error::{QueryError, Result};
33use crate::executor::{QueryResult, Row};
34use crate::parser::ParsedWindowFn;
35use crate::schema::ColumnName;
36use crate::value::Value;
37
38#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum WindowFunction {
46 RowNumber,
47 Rank,
48 DenseRank,
49 Lag {
52 column: ColumnName,
53 offset: usize,
54 },
55 Lead {
58 column: ColumnName,
59 offset: usize,
60 },
61 FirstValue {
64 column: ColumnName,
65 },
66 LastValue {
72 column: ColumnName,
73 },
74}
75
76impl WindowFunction {
77 pub fn default_alias(&self) -> &'static str {
79 match self {
80 Self::RowNumber => "row_number",
81 Self::Rank => "rank",
82 Self::DenseRank => "dense_rank",
83 Self::Lag { .. } => "lag",
84 Self::Lead { .. } => "lead",
85 Self::FirstValue { .. } => "first_value",
86 Self::LastValue { .. } => "last_value",
87 }
88 }
89}
90
91pub fn apply_window_fns(base: QueryResult, window_fns: &[ParsedWindowFn]) -> Result<QueryResult> {
99 if window_fns.is_empty() {
100 return Ok(base);
101 }
102
103 let columns_idx = build_column_index(&base.columns);
106
107 let QueryResult { columns, rows } = base;
108 let mut out_columns = columns.clone();
109
110 let mut work_rows = rows;
116 let original_index_col = work_rows.len(); let _ = original_index_col;
118
119 let mut indexed: Vec<(usize, Row)> = work_rows.drain(..).enumerate().collect();
123
124 for win in window_fns {
125 let fn_col = compute_window_column(win, &mut indexed, &columns_idx)?;
126 out_columns.push(ColumnName::new(
127 win.alias
128 .clone()
129 .unwrap_or_else(|| win.function.default_alias().to_string()),
130 ));
131 for ((_, row), val) in indexed.iter_mut().zip(fn_col.into_iter()) {
132 row.push(val);
133 }
134 }
135
136 indexed.sort_by_key(|(idx, _)| *idx);
140 let final_rows = indexed.into_iter().map(|(_, r)| r).collect();
141
142 Ok(QueryResult {
143 columns: out_columns,
144 rows: final_rows,
145 })
146}
147
148fn build_column_index(columns: &[ColumnName]) -> Vec<(String, usize)> {
150 columns
151 .iter()
152 .enumerate()
153 .map(|(i, c)| (c.as_str().to_string(), i))
154 .collect()
155}
156
157fn lookup_col(idx: &[(String, usize)], name: &str) -> Result<usize> {
158 idx.iter()
159 .find(|(n, _)| n == name)
160 .map(|(_, i)| *i)
161 .ok_or_else(|| {
162 QueryError::ParseError(format!(
163 "window function references unknown column '{name}'"
164 ))
165 })
166}
167
168fn compute_window_column(
173 win: &ParsedWindowFn,
174 indexed_rows: &mut [(usize, Row)],
175 columns_idx: &[(String, usize)],
176) -> Result<Vec<Value>> {
177 let partition_idx: Vec<usize> = win
179 .partition_by
180 .iter()
181 .map(|c| lookup_col(columns_idx, c.as_str()))
182 .collect::<Result<_>>()?;
183 let order_idx: Vec<(usize, bool)> = win
184 .order_by
185 .iter()
186 .map(|c| Ok((lookup_col(columns_idx, c.column.as_str())?, c.ascending)))
187 .collect::<Result<_>>()?;
188
189 indexed_rows
190 .sort_by(|(_, a), (_, b)| compare_partition_then_order(a, b, &partition_idx, &order_idx));
191
192 let n = indexed_rows.len();
193 let mut out = vec![Value::Null; n];
194
195 let mut row_num: i64 = 0;
196 let mut rank: i64 = 0;
197 let mut dense_rank: i64 = 0;
198 let mut last_partition_key: Option<Vec<Value>> = None;
199 let mut last_order_key: Option<Vec<Value>> = None;
200
201 for i in 0..n {
202 let row = &indexed_rows[i].1;
203 let part_key: Vec<Value> = partition_idx.iter().map(|&j| row[j].clone()).collect();
204 let ord_key: Vec<Value> = order_idx.iter().map(|&(j, _)| row[j].clone()).collect();
205
206 let new_partition = last_partition_key.as_ref() != Some(&part_key);
207 if new_partition {
208 row_num = 0;
209 rank = 0;
210 dense_rank = 0;
211 last_partition_key = Some(part_key.clone());
212 last_order_key = None;
213 }
214
215 row_num += 1;
216 let order_changed = last_order_key.as_ref() != Some(&ord_key);
217 if order_changed {
218 rank = row_num;
219 dense_rank += 1;
220 last_order_key = Some(ord_key.clone());
221 }
222
223 out[i] = match &win.function {
224 WindowFunction::RowNumber => Value::BigInt(row_num),
225 WindowFunction::Rank => Value::BigInt(rank),
226 WindowFunction::DenseRank => Value::BigInt(dense_rank),
227 WindowFunction::Lag { column, offset } => lookup_offset(
228 indexed_rows,
229 columns_idx,
230 &partition_idx,
231 column,
232 i,
233 -(*offset as isize),
234 )?,
235 WindowFunction::Lead { column, offset } => lookup_offset(
236 indexed_rows,
237 columns_idx,
238 &partition_idx,
239 column,
240 i,
241 *offset as isize,
242 )?,
243 WindowFunction::FirstValue { column } => {
244 first_in_partition(indexed_rows, columns_idx, column, i, &partition_idx)?
245 }
246 WindowFunction::LastValue { column } => {
247 let col_i = lookup_col(columns_idx, column.as_str())?;
250 indexed_rows[i].1[col_i].clone()
251 }
252 };
253 }
254 Ok(out)
255}
256
257fn compare_partition_then_order(
258 a: &Row,
259 b: &Row,
260 partition_idx: &[usize],
261 order_idx: &[(usize, bool)],
262) -> Ordering {
263 for &j in partition_idx {
264 match cmp_values(&a[j], &b[j]) {
265 Ordering::Equal => {}
266 other => return other,
267 }
268 }
269 for &(j, asc) in order_idx {
270 let ord = cmp_values(&a[j], &b[j]);
271 match ord {
272 Ordering::Equal => {}
273 other => return if asc { other } else { other.reverse() },
274 }
275 }
276 Ordering::Equal
277}
278
279#[allow(clippy::match_same_arms)]
287fn cmp_values(a: &Value, b: &Value) -> Ordering {
288 use Value::{BigInt, Boolean, Date, Integer, Null, Real, SmallInt, Text, Time, TinyInt};
289 match (a, b) {
290 (Null, Null) => Ordering::Equal,
291 (Null, _) => Ordering::Less,
292 (_, Null) => Ordering::Greater,
293 (BigInt(x), BigInt(y)) => x.cmp(y),
294 (Integer(x), Integer(y)) => x.cmp(y),
295 (SmallInt(x), SmallInt(y)) => x.cmp(y),
296 (TinyInt(x), TinyInt(y)) => x.cmp(y),
297 (Real(x), Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
298 (Text(x), Text(y)) => x.cmp(y),
299 (Boolean(x), Boolean(y)) => x.cmp(y),
300 (Date(x), Date(y)) => x.cmp(y),
301 (Time(x), Time(y)) => x.cmp(y),
302 (lhs, rhs) => format!("{lhs:?}").cmp(&format!("{rhs:?}")),
306 }
307}
308
309fn lookup_offset(
310 indexed: &[(usize, Row)],
311 columns_idx: &[(String, usize)],
312 partition_idx: &[usize],
313 column: &ColumnName,
314 i: usize,
315 delta: isize,
316) -> Result<Value> {
317 let col_i = lookup_col(columns_idx, column.as_str())?;
318 let target_pos = i as isize + delta;
319 if target_pos < 0 || (target_pos as usize) >= indexed.len() {
320 return Ok(Value::Null);
321 }
322 let target = target_pos as usize;
323 if !same_partition(&indexed[i].1, &indexed[target].1, partition_idx) {
327 return Ok(Value::Null);
328 }
329 Ok(indexed[target].1[col_i].clone())
330}
331
332fn same_partition(a: &Row, b: &Row, partition_idx: &[usize]) -> bool {
333 partition_idx.iter().all(|&j| a[j] == b[j])
334}
335
336fn first_in_partition(
337 indexed: &[(usize, Row)],
338 columns_idx: &[(String, usize)],
339 column: &ColumnName,
340 i: usize,
341 partition_idx: &[usize],
342) -> Result<Value> {
343 let col_i = lookup_col(columns_idx, column.as_str())?;
344 let current_part: Vec<Value> = partition_idx
345 .iter()
346 .map(|&j| indexed[i].1[j].clone())
347 .collect();
348 let mut start = i;
350 while start > 0 {
351 let prev_part: Vec<Value> = partition_idx
352 .iter()
353 .map(|&j| indexed[start - 1].1[j].clone())
354 .collect();
355 if prev_part != current_part {
356 break;
357 }
358 start -= 1;
359 }
360 Ok(indexed[start].1[col_i].clone())
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::parser::OrderByClause;
367 use crate::schema::ColumnName;
368
369 fn cols(names: &[&str]) -> Vec<ColumnName> {
370 names.iter().map(|n| ColumnName::new(*n)).collect()
371 }
372
373 fn row(vals: Vec<Value>) -> Row {
374 vals
375 }
376
377 fn order_asc(name: &str) -> OrderByClause {
378 OrderByClause {
379 column: ColumnName::new(name),
380 ascending: true,
381 }
382 }
383
384 #[test]
385 fn row_number_no_partition_no_order_assigns_1_to_n_in_input_order() {
386 let qr = QueryResult {
387 columns: cols(&["id"]),
388 rows: vec![row(vec![Value::BigInt(10)]), row(vec![Value::BigInt(20)])],
389 };
390 let win = ParsedWindowFn {
391 function: WindowFunction::RowNumber,
392 partition_by: vec![],
393 order_by: vec![],
394 alias: None,
395 };
396 let out = apply_window_fns(qr, &[win]).expect("apply");
397 assert_eq!(out.columns.len(), 2);
398 assert_eq!(out.rows[0][1], Value::BigInt(1));
399 assert_eq!(out.rows[1][1], Value::BigInt(2));
400 }
401
402 #[test]
403 fn row_number_resets_per_partition() {
404 let qr = QueryResult {
405 columns: cols(&["dept", "salary"]),
406 rows: vec![
407 row(vec![Value::Text("A".into()), Value::BigInt(100)]),
408 row(vec![Value::Text("B".into()), Value::BigInt(200)]),
409 row(vec![Value::Text("A".into()), Value::BigInt(150)]),
410 row(vec![Value::Text("B".into()), Value::BigInt(250)]),
411 ],
412 };
413 let win = ParsedWindowFn {
414 function: WindowFunction::RowNumber,
415 partition_by: vec![ColumnName::new("dept")],
416 order_by: vec![order_asc("salary")],
417 alias: Some("rn".into()),
418 };
419 let out = apply_window_fns(qr, &[win]).expect("apply");
420 let map: std::collections::HashMap<(String, i64), i64> = out
422 .rows
423 .iter()
424 .map(|r| {
425 let dept = match &r[0] {
426 Value::Text(s) => s.clone(),
427 _ => panic!(),
428 };
429 let salary = match &r[1] {
430 Value::BigInt(i) => *i,
431 _ => panic!(),
432 };
433 let rn = match &r[2] {
434 Value::BigInt(i) => *i,
435 _ => panic!(),
436 };
437 ((dept, salary), rn)
438 })
439 .collect();
440 assert_eq!(map.get(&("A".into(), 100)), Some(&1));
442 assert_eq!(map.get(&("A".into(), 150)), Some(&2));
443 assert_eq!(map.get(&("B".into(), 200)), Some(&1));
444 assert_eq!(map.get(&("B".into(), 250)), Some(&2));
445 }
446
447 #[test]
448 fn rank_and_dense_rank_distinguish_ties() {
449 let qr = QueryResult {
452 columns: cols(&["salary"]),
453 rows: vec![
454 row(vec![Value::BigInt(100)]),
455 row(vec![Value::BigInt(100)]),
456 row(vec![Value::BigInt(200)]),
457 ],
458 };
459 let win_rank = ParsedWindowFn {
460 function: WindowFunction::Rank,
461 partition_by: vec![],
462 order_by: vec![order_asc("salary")],
463 alias: Some("r".into()),
464 };
465 let win_dense = ParsedWindowFn {
466 function: WindowFunction::DenseRank,
467 partition_by: vec![],
468 order_by: vec![order_asc("salary")],
469 alias: Some("dr".into()),
470 };
471 let out = apply_window_fns(qr, &[win_rank, win_dense]).expect("apply");
472 for r in &out.rows {
475 let salary = match &r[0] {
476 Value::BigInt(i) => *i,
477 _ => panic!(),
478 };
479 let rank = match &r[1] {
480 Value::BigInt(i) => *i,
481 _ => panic!(),
482 };
483 let dense = match &r[2] {
484 Value::BigInt(i) => *i,
485 _ => panic!(),
486 };
487 if salary == 100 {
488 assert_eq!(rank, 1, "rank ties");
489 assert_eq!(dense, 1, "dense_rank ties");
490 } else {
491 assert_eq!(rank, 3, "rank skips after ties");
492 assert_eq!(dense, 2, "dense_rank does not skip");
493 }
494 }
495 }
496
497 #[test]
498 fn first_value_returns_partition_start_value() {
499 let qr = QueryResult {
500 columns: cols(&["dept", "salary"]),
501 rows: vec![
502 row(vec![Value::Text("A".into()), Value::BigInt(300)]),
503 row(vec![Value::Text("A".into()), Value::BigInt(100)]),
504 row(vec![Value::Text("A".into()), Value::BigInt(200)]),
505 ],
506 };
507 let win = ParsedWindowFn {
508 function: WindowFunction::FirstValue {
509 column: ColumnName::new("salary"),
510 },
511 partition_by: vec![ColumnName::new("dept")],
512 order_by: vec![order_asc("salary")],
513 alias: Some("first".into()),
514 };
515 let out = apply_window_fns(qr, &[win]).expect("apply");
516 for r in &out.rows {
518 assert_eq!(r[2], Value::BigInt(100));
519 }
520 }
521
522 #[test]
523 fn lag_returns_null_at_partition_start() {
524 let qr = QueryResult {
525 columns: cols(&["id"]),
526 rows: vec![
527 row(vec![Value::BigInt(10)]),
528 row(vec![Value::BigInt(20)]),
529 row(vec![Value::BigInt(30)]),
530 ],
531 };
532 let win = ParsedWindowFn {
533 function: WindowFunction::Lag {
534 column: ColumnName::new("id"),
535 offset: 1,
536 },
537 partition_by: vec![],
538 order_by: vec![order_asc("id")],
539 alias: Some("prev".into()),
540 };
541 let out = apply_window_fns(qr, &[win]).expect("apply");
542 let map: std::collections::HashMap<i64, Value> = out
546 .rows
547 .iter()
548 .map(|r| {
549 let id = match &r[0] {
550 Value::BigInt(i) => *i,
551 _ => panic!(),
552 };
553 (id, r[1].clone())
554 })
555 .collect();
556 assert_eq!(map[&10], Value::Null, "first row lag is NULL");
557 assert_eq!(map[&20], Value::BigInt(10));
558 assert_eq!(map[&30], Value::BigInt(20));
559 }
560}