1use std::cmp::Ordering;
31
32use crate::error::{QueryError, Result};
33use crate::executor::{QueryResult, Row};
34use crate::parser::ParsedWindowFn;
35use crate::schema::ColumnName;
36use crate::value::Value;
37
38#[derive(Debug, Clone, PartialEq, Eq)]
45pub enum WindowFunction {
46 RowNumber,
47 Rank,
48 DenseRank,
49 Lag { column: ColumnName, offset: usize },
52 Lead { column: ColumnName, offset: usize },
55 FirstValue { column: ColumnName },
58 LastValue { column: ColumnName },
64}
65
66impl WindowFunction {
67 pub fn default_alias(&self) -> &'static str {
69 match self {
70 Self::RowNumber => "row_number",
71 Self::Rank => "rank",
72 Self::DenseRank => "dense_rank",
73 Self::Lag { .. } => "lag",
74 Self::Lead { .. } => "lead",
75 Self::FirstValue { .. } => "first_value",
76 Self::LastValue { .. } => "last_value",
77 }
78 }
79}
80
81pub fn apply_window_fns(
89 base: QueryResult,
90 window_fns: &[ParsedWindowFn],
91) -> Result<QueryResult> {
92 if window_fns.is_empty() {
93 return Ok(base);
94 }
95
96 let columns_idx = build_column_index(&base.columns);
99
100 let QueryResult { columns, rows } = base;
101 let mut out_columns = columns.clone();
102
103 let mut work_rows = rows;
109 let original_index_col = work_rows.len(); let _ = original_index_col;
111
112 let mut indexed: Vec<(usize, Row)> = work_rows.drain(..).enumerate().collect();
116
117 for win in window_fns {
118 let fn_col = compute_window_column(win, &mut indexed, &columns_idx)?;
119 out_columns.push(ColumnName::new(
120 win.alias
121 .clone()
122 .unwrap_or_else(|| win.function.default_alias().to_string()),
123 ));
124 for ((_, row), val) in indexed.iter_mut().zip(fn_col.into_iter()) {
125 row.push(val);
126 }
127 }
128
129 indexed.sort_by_key(|(idx, _)| *idx);
133 let final_rows = indexed.into_iter().map(|(_, r)| r).collect();
134
135 Ok(QueryResult {
136 columns: out_columns,
137 rows: final_rows,
138 })
139}
140
141fn build_column_index(columns: &[ColumnName]) -> Vec<(String, usize)> {
143 columns
144 .iter()
145 .enumerate()
146 .map(|(i, c)| (c.as_str().to_string(), i))
147 .collect()
148}
149
150fn lookup_col(idx: &[(String, usize)], name: &str) -> Result<usize> {
151 idx.iter()
152 .find(|(n, _)| n == name)
153 .map(|(_, i)| *i)
154 .ok_or_else(|| {
155 QueryError::ParseError(format!(
156 "window function references unknown column '{name}'"
157 ))
158 })
159}
160
161fn compute_window_column(
166 win: &ParsedWindowFn,
167 indexed_rows: &mut [(usize, Row)],
168 columns_idx: &[(String, usize)],
169) -> Result<Vec<Value>> {
170 let partition_idx: Vec<usize> = win
172 .partition_by
173 .iter()
174 .map(|c| lookup_col(columns_idx, c.as_str()))
175 .collect::<Result<_>>()?;
176 let order_idx: Vec<(usize, bool)> = win
177 .order_by
178 .iter()
179 .map(|c| Ok((lookup_col(columns_idx, c.column.as_str())?, c.ascending)))
180 .collect::<Result<_>>()?;
181
182 indexed_rows.sort_by(|(_, a), (_, b)| {
183 compare_partition_then_order(a, b, &partition_idx, &order_idx)
184 });
185
186 let n = indexed_rows.len();
187 let mut out = vec![Value::Null; n];
188
189 let mut row_num: i64 = 0;
190 let mut rank: i64 = 0;
191 let mut dense_rank: i64 = 0;
192 let mut last_partition_key: Option<Vec<Value>> = None;
193 let mut last_order_key: Option<Vec<Value>> = None;
194
195 for i in 0..n {
196 let row = &indexed_rows[i].1;
197 let part_key: Vec<Value> = partition_idx.iter().map(|&j| row[j].clone()).collect();
198 let ord_key: Vec<Value> = order_idx.iter().map(|&(j, _)| row[j].clone()).collect();
199
200 let new_partition = last_partition_key.as_ref() != Some(&part_key);
201 if new_partition {
202 row_num = 0;
203 rank = 0;
204 dense_rank = 0;
205 last_partition_key = Some(part_key.clone());
206 last_order_key = None;
207 }
208
209 row_num += 1;
210 let order_changed = last_order_key.as_ref() != Some(&ord_key);
211 if order_changed {
212 rank = row_num;
213 dense_rank += 1;
214 last_order_key = Some(ord_key.clone());
215 }
216
217 out[i] = match &win.function {
218 WindowFunction::RowNumber => Value::BigInt(row_num),
219 WindowFunction::Rank => Value::BigInt(rank),
220 WindowFunction::DenseRank => Value::BigInt(dense_rank),
221 WindowFunction::Lag { column, offset } => lookup_offset(
222 indexed_rows,
223 columns_idx,
224 &partition_idx,
225 column,
226 i,
227 -(*offset as isize),
228 )?,
229 WindowFunction::Lead { column, offset } => lookup_offset(
230 indexed_rows,
231 columns_idx,
232 &partition_idx,
233 column,
234 i,
235 *offset as isize,
236 )?,
237 WindowFunction::FirstValue { column } => {
238 first_in_partition(indexed_rows, columns_idx, column, i, &partition_idx)?
239 }
240 WindowFunction::LastValue { column } => {
241 let col_i = lookup_col(columns_idx, column.as_str())?;
244 indexed_rows[i].1[col_i].clone()
245 }
246 };
247 }
248 Ok(out)
249}
250
251fn compare_partition_then_order(
252 a: &Row,
253 b: &Row,
254 partition_idx: &[usize],
255 order_idx: &[(usize, bool)],
256) -> Ordering {
257 for &j in partition_idx {
258 match cmp_values(&a[j], &b[j]) {
259 Ordering::Equal => continue,
260 other => return other,
261 }
262 }
263 for &(j, asc) in order_idx {
264 let ord = cmp_values(&a[j], &b[j]);
265 match ord {
266 Ordering::Equal => continue,
267 other => return if asc { other } else { other.reverse() },
268 }
269 }
270 Ordering::Equal
271}
272
273fn cmp_values(a: &Value, b: &Value) -> Ordering {
276 use Value::*;
277 match (a, b) {
278 (Null, Null) => Ordering::Equal,
279 (Null, _) => Ordering::Less,
280 (_, Null) => Ordering::Greater,
281 (BigInt(x), BigInt(y)) => x.cmp(y),
282 (Integer(x), Integer(y)) => x.cmp(y),
283 (SmallInt(x), SmallInt(y)) => x.cmp(y),
284 (TinyInt(x), TinyInt(y)) => x.cmp(y),
285 (Real(x), Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
286 (Text(x), Text(y)) => x.cmp(y),
287 (Boolean(x), Boolean(y)) => x.cmp(y),
288 (Date(x), Date(y)) => x.cmp(y),
289 (Time(x), Time(y)) => x.cmp(y),
290 (lhs, rhs) => format!("{lhs:?}").cmp(&format!("{rhs:?}")),
294 }
295}
296
297fn lookup_offset(
298 indexed: &[(usize, Row)],
299 columns_idx: &[(String, usize)],
300 partition_idx: &[usize],
301 column: &ColumnName,
302 i: usize,
303 delta: isize,
304) -> Result<Value> {
305 let col_i = lookup_col(columns_idx, column.as_str())?;
306 let target_pos = i as isize + delta;
307 if target_pos < 0 || (target_pos as usize) >= indexed.len() {
308 return Ok(Value::Null);
309 }
310 let target = target_pos as usize;
311 if !same_partition(&indexed[i].1, &indexed[target].1, partition_idx) {
315 return Ok(Value::Null);
316 }
317 Ok(indexed[target].1[col_i].clone())
318}
319
320fn same_partition(a: &Row, b: &Row, partition_idx: &[usize]) -> bool {
321 partition_idx.iter().all(|&j| a[j] == b[j])
322}
323
324fn first_in_partition(
325 indexed: &[(usize, Row)],
326 columns_idx: &[(String, usize)],
327 column: &ColumnName,
328 i: usize,
329 partition_idx: &[usize],
330) -> Result<Value> {
331 let col_i = lookup_col(columns_idx, column.as_str())?;
332 let current_part: Vec<Value> = partition_idx
333 .iter()
334 .map(|&j| indexed[i].1[j].clone())
335 .collect();
336 let mut start = i;
338 while start > 0 {
339 let prev_part: Vec<Value> = partition_idx
340 .iter()
341 .map(|&j| indexed[start - 1].1[j].clone())
342 .collect();
343 if prev_part != current_part {
344 break;
345 }
346 start -= 1;
347 }
348 Ok(indexed[start].1[col_i].clone())
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use crate::parser::OrderByClause;
355 use crate::schema::ColumnName;
356
357 fn cols(names: &[&str]) -> Vec<ColumnName> {
358 names.iter().map(|n| ColumnName::new(*n)).collect()
359 }
360
361 fn row(vals: Vec<Value>) -> Row {
362 vals
363 }
364
365 fn order_asc(name: &str) -> OrderByClause {
366 OrderByClause {
367 column: ColumnName::new(name),
368 ascending: true,
369 }
370 }
371
372 #[test]
373 fn row_number_no_partition_no_order_assigns_1_to_n_in_input_order() {
374 let qr = QueryResult {
375 columns: cols(&["id"]),
376 rows: vec![row(vec![Value::BigInt(10)]), row(vec![Value::BigInt(20)])],
377 };
378 let win = ParsedWindowFn {
379 function: WindowFunction::RowNumber,
380 partition_by: vec![],
381 order_by: vec![],
382 alias: None,
383 };
384 let out = apply_window_fns(qr, &[win]).expect("apply");
385 assert_eq!(out.columns.len(), 2);
386 assert_eq!(out.rows[0][1], Value::BigInt(1));
387 assert_eq!(out.rows[1][1], Value::BigInt(2));
388 }
389
390 #[test]
391 fn row_number_resets_per_partition() {
392 let qr = QueryResult {
393 columns: cols(&["dept", "salary"]),
394 rows: vec![
395 row(vec![Value::Text("A".into()), Value::BigInt(100)]),
396 row(vec![Value::Text("B".into()), Value::BigInt(200)]),
397 row(vec![Value::Text("A".into()), Value::BigInt(150)]),
398 row(vec![Value::Text("B".into()), Value::BigInt(250)]),
399 ],
400 };
401 let win = ParsedWindowFn {
402 function: WindowFunction::RowNumber,
403 partition_by: vec![ColumnName::new("dept")],
404 order_by: vec![order_asc("salary")],
405 alias: Some("rn".into()),
406 };
407 let out = apply_window_fns(qr, &[win]).expect("apply");
408 let map: std::collections::HashMap<(String, i64), i64> = out
410 .rows
411 .iter()
412 .map(|r| {
413 let dept = match &r[0] {
414 Value::Text(s) => s.clone(),
415 _ => panic!(),
416 };
417 let salary = match &r[1] {
418 Value::BigInt(i) => *i,
419 _ => panic!(),
420 };
421 let rn = match &r[2] {
422 Value::BigInt(i) => *i,
423 _ => panic!(),
424 };
425 ((dept, salary), rn)
426 })
427 .collect();
428 assert_eq!(map.get(&("A".into(), 100)), Some(&1));
430 assert_eq!(map.get(&("A".into(), 150)), Some(&2));
431 assert_eq!(map.get(&("B".into(), 200)), Some(&1));
432 assert_eq!(map.get(&("B".into(), 250)), Some(&2));
433 }
434
435 #[test]
436 fn rank_and_dense_rank_distinguish_ties() {
437 let qr = QueryResult {
440 columns: cols(&["salary"]),
441 rows: vec![
442 row(vec![Value::BigInt(100)]),
443 row(vec![Value::BigInt(100)]),
444 row(vec![Value::BigInt(200)]),
445 ],
446 };
447 let win_rank = ParsedWindowFn {
448 function: WindowFunction::Rank,
449 partition_by: vec![],
450 order_by: vec![order_asc("salary")],
451 alias: Some("r".into()),
452 };
453 let win_dense = ParsedWindowFn {
454 function: WindowFunction::DenseRank,
455 partition_by: vec![],
456 order_by: vec![order_asc("salary")],
457 alias: Some("dr".into()),
458 };
459 let out = apply_window_fns(qr, &[win_rank, win_dense]).expect("apply");
460 for r in &out.rows {
463 let salary = match &r[0] {
464 Value::BigInt(i) => *i,
465 _ => panic!(),
466 };
467 let rank = match &r[1] {
468 Value::BigInt(i) => *i,
469 _ => panic!(),
470 };
471 let dense = match &r[2] {
472 Value::BigInt(i) => *i,
473 _ => panic!(),
474 };
475 if salary == 100 {
476 assert_eq!(rank, 1, "rank ties");
477 assert_eq!(dense, 1, "dense_rank ties");
478 } else {
479 assert_eq!(rank, 3, "rank skips after ties");
480 assert_eq!(dense, 2, "dense_rank does not skip");
481 }
482 }
483 }
484
485 #[test]
486 fn first_value_returns_partition_start_value() {
487 let qr = QueryResult {
488 columns: cols(&["dept", "salary"]),
489 rows: vec![
490 row(vec![Value::Text("A".into()), Value::BigInt(300)]),
491 row(vec![Value::Text("A".into()), Value::BigInt(100)]),
492 row(vec![Value::Text("A".into()), Value::BigInt(200)]),
493 ],
494 };
495 let win = ParsedWindowFn {
496 function: WindowFunction::FirstValue {
497 column: ColumnName::new("salary"),
498 },
499 partition_by: vec![ColumnName::new("dept")],
500 order_by: vec![order_asc("salary")],
501 alias: Some("first".into()),
502 };
503 let out = apply_window_fns(qr, &[win]).expect("apply");
504 for r in &out.rows {
506 assert_eq!(r[2], Value::BigInt(100));
507 }
508 }
509
510 #[test]
511 fn lag_returns_null_at_partition_start() {
512 let qr = QueryResult {
513 columns: cols(&["id"]),
514 rows: vec![
515 row(vec![Value::BigInt(10)]),
516 row(vec![Value::BigInt(20)]),
517 row(vec![Value::BigInt(30)]),
518 ],
519 };
520 let win = ParsedWindowFn {
521 function: WindowFunction::Lag {
522 column: ColumnName::new("id"),
523 offset: 1,
524 },
525 partition_by: vec![],
526 order_by: vec![order_asc("id")],
527 alias: Some("prev".into()),
528 };
529 let out = apply_window_fns(qr, &[win]).expect("apply");
530 let map: std::collections::HashMap<i64, Value> = out
534 .rows
535 .iter()
536 .map(|r| {
537 let id = match &r[0] {
538 Value::BigInt(i) => *i,
539 _ => panic!(),
540 };
541 (id, r[1].clone())
542 })
543 .collect();
544 assert_eq!(map[&10], Value::Null, "first row lag is NULL");
545 assert_eq!(map[&20], Value::BigInt(10));
546 assert_eq!(map[&30], Value::BigInt(20));
547 }
548}