1use crate::executor::Row;
9use crate::expr::{
10 Expr, FrameBound, WindowFrame, WindowFunction, WindowFunctionType, WindowOrderByExpr,
11 WindowSortOrder,
12};
13use featherdb_core::{Result, Value};
14use std::collections::HashMap;
15use std::sync::Arc;
16
17pub struct WindowExecutor;
19
20impl WindowExecutor {
21 pub fn execute(rows: Vec<Row>, window_exprs: &[(Expr, String)]) -> Result<Vec<Row>> {
26 if rows.is_empty() {
27 return Ok(rows);
28 }
29
30 let col_map = build_col_map(&rows[0].columns);
32
33 let mut result_rows = rows;
35
36 for (window_expr, alias) in window_exprs {
37 if let Expr::Window(window_func) = window_expr {
38 result_rows =
39 Self::execute_window_function(result_rows, window_func, alias, &col_map)?;
40 }
41 }
42
43 Ok(result_rows)
44 }
45
46 fn execute_window_function(
48 mut rows: Vec<Row>,
49 window_func: &WindowFunction,
50 alias: &str,
51 col_map: &HashMap<String, usize>,
52 ) -> Result<Vec<Row>> {
53 let partitions = Self::partition_rows(&rows, &window_func.partition_by, col_map)?;
55
56 let mut window_values: Vec<(usize, Value)> = Vec::new();
58
59 for partition_indices in partitions {
60 let mut partition_rows: Vec<(usize, &Row)> =
62 partition_indices.iter().map(|&i| (i, &rows[i])).collect();
63
64 if !window_func.order_by.is_empty() {
66 Self::sort_partition(&mut partition_rows, &window_func.order_by, col_map)?;
67 }
68
69 let values = Self::compute_window_values(
71 &partition_rows,
72 &window_func.function,
73 window_func.frame.as_ref(),
74 &window_func.order_by,
75 col_map,
76 )?;
77
78 for ((orig_idx, _), value) in partition_rows.into_iter().zip(values) {
80 window_values.push((orig_idx, value));
81 }
82 }
83
84 window_values.sort_by_key(|(idx, _)| *idx);
86
87 for (row, (_, value)) in rows.iter_mut().zip(window_values) {
89 row.values.push(value);
90 let cols = Arc::make_mut(&mut row.columns);
92 cols.push(alias.to_string());
93 }
94
95 Ok(rows)
96 }
97
98 fn partition_rows(
100 rows: &[Row],
101 partition_by: &[Expr],
102 col_map: &HashMap<String, usize>,
103 ) -> Result<Vec<Vec<usize>>> {
104 if partition_by.is_empty() {
105 return Ok(vec![(0..rows.len()).collect()]);
107 }
108
109 let mut partitions: HashMap<Vec<Value>, Vec<usize>> = HashMap::new();
110
111 for (idx, row) in rows.iter().enumerate() {
112 let key: Vec<Value> = partition_by
113 .iter()
114 .map(|expr| expr.eval(&row.values, col_map))
115 .collect::<Result<_>>()?;
116
117 partitions.entry(key).or_default().push(idx);
118 }
119
120 Ok(partitions.into_values().collect())
121 }
122
123 fn sort_partition(
125 partition: &mut [(usize, &Row)],
126 order_by: &[WindowOrderByExpr],
127 col_map: &HashMap<String, usize>,
128 ) -> Result<()> {
129 partition.sort_by(|(_, row_a), (_, row_b)| {
130 for order_expr in order_by {
131 let val_a = order_expr
132 .expr
133 .eval(&row_a.values, col_map)
134 .unwrap_or(Value::Null);
135 let val_b = order_expr
136 .expr
137 .eval(&row_b.values, col_map)
138 .unwrap_or(Value::Null);
139
140 let cmp = match order_expr.order {
141 WindowSortOrder::Asc => val_a.cmp(&val_b),
142 WindowSortOrder::Desc => val_b.cmp(&val_a),
143 };
144
145 if cmp != std::cmp::Ordering::Equal {
146 return cmp;
147 }
148 }
149 std::cmp::Ordering::Equal
150 });
151
152 Ok(())
153 }
154
155 fn compute_window_values(
157 partition: &[(usize, &Row)],
158 function: &WindowFunctionType,
159 frame: Option<&WindowFrame>,
160 order_by: &[WindowOrderByExpr],
161 col_map: &HashMap<String, usize>,
162 ) -> Result<Vec<Value>> {
163 let n = partition.len();
164 let mut values = Vec::with_capacity(n);
165
166 for (pos, (_, row)) in partition.iter().enumerate() {
167 let value = match function {
168 WindowFunctionType::RowNumber => Value::Integer((pos + 1) as i64),
169
170 WindowFunctionType::Rank => {
171 Self::compute_rank(partition, pos, order_by, col_map, false)?
172 }
173
174 WindowFunctionType::DenseRank => {
175 Self::compute_rank(partition, pos, order_by, col_map, true)?
176 }
177
178 WindowFunctionType::NTile(num_buckets) => {
179 let bucket = ((pos as u32 * *num_buckets) / n as u32) + 1;
180 Value::Integer(bucket as i64)
181 }
182
183 WindowFunctionType::Lag {
184 expr,
185 offset,
186 default,
187 } => {
188 let target_pos = pos as i64 - *offset;
189 if target_pos >= 0 && (target_pos as usize) < n {
190 let target_row = partition[target_pos as usize].1;
191 expr.eval(&target_row.values, col_map)?
192 } else {
193 default
194 .as_ref()
195 .map(|d| d.eval(&row.values, col_map))
196 .transpose()?
197 .unwrap_or(Value::Null)
198 }
199 }
200
201 WindowFunctionType::Lead {
202 expr,
203 offset,
204 default,
205 } => {
206 let target_pos = pos as i64 + *offset;
207 if target_pos >= 0 && (target_pos as usize) < n {
208 let target_row = partition[target_pos as usize].1;
209 expr.eval(&target_row.values, col_map)?
210 } else {
211 default
212 .as_ref()
213 .map(|d| d.eval(&row.values, col_map))
214 .transpose()?
215 .unwrap_or(Value::Null)
216 }
217 }
218
219 WindowFunctionType::FirstValue(expr) => {
220 let (start, _) = Self::compute_frame_bounds(pos, n, frame);
221 if start < n {
222 let first_row = partition[start].1;
223 expr.eval(&first_row.values, col_map)?
224 } else {
225 Value::Null
226 }
227 }
228
229 WindowFunctionType::LastValue(expr) => {
230 let (_, end) = Self::compute_frame_bounds(pos, n, frame);
231 if end > 0 && end <= n {
232 let last_row = partition[end - 1].1;
233 expr.eval(&last_row.values, col_map)?
234 } else {
235 Value::Null
236 }
237 }
238
239 WindowFunctionType::NthValue(expr, nth) => {
240 let (start, end) = Self::compute_frame_bounds(pos, n, frame);
241 let target = start + (*nth as usize) - 1;
242 if target < end && target < n {
243 let target_row = partition[target].1;
244 expr.eval(&target_row.values, col_map)?
245 } else {
246 Value::Null
247 }
248 }
249
250 WindowFunctionType::Sum(expr) => {
252 Self::compute_aggregate_sum(partition, pos, expr, frame, col_map)?
253 }
254
255 WindowFunctionType::Avg(expr) => {
256 Self::compute_aggregate_avg(partition, pos, expr, frame, col_map)?
257 }
258
259 WindowFunctionType::Count(expr) => {
260 Self::compute_aggregate_count(partition, pos, expr.as_deref(), frame, col_map)?
261 }
262
263 WindowFunctionType::Min(expr) => {
264 Self::compute_aggregate_min(partition, pos, expr, frame, col_map)?
265 }
266
267 WindowFunctionType::Max(expr) => {
268 Self::compute_aggregate_max(partition, pos, expr, frame, col_map)?
269 }
270 };
271
272 values.push(value);
273 }
274
275 Ok(values)
276 }
277
278 fn order_by_values_equal(
280 row_a: &Row,
281 row_b: &Row,
282 order_by: &[WindowOrderByExpr],
283 col_map: &HashMap<String, usize>,
284 ) -> bool {
285 for order_expr in order_by {
286 let val_a = order_expr
287 .expr
288 .eval(&row_a.values, col_map)
289 .unwrap_or(Value::Null);
290 let val_b = order_expr
291 .expr
292 .eval(&row_b.values, col_map)
293 .unwrap_or(Value::Null);
294 if val_a != val_b {
295 return false;
296 }
297 }
298 true
299 }
300
301 fn compute_rank(
303 partition: &[(usize, &Row)],
304 pos: usize,
305 order_by: &[WindowOrderByExpr],
306 col_map: &HashMap<String, usize>,
307 dense: bool,
308 ) -> Result<Value> {
309 if pos == 0 || order_by.is_empty() {
310 return Ok(Value::Integer(1));
311 }
312
313 if dense {
315 let mut dense_rank = 1i64;
317 for i in 1..=pos {
318 if !Self::order_by_values_equal(
319 partition[i].1,
320 partition[i - 1].1,
321 order_by,
322 col_map,
323 ) {
324 dense_rank += 1;
325 }
326 }
327 Ok(Value::Integer(dense_rank))
328 } else {
329 let mut rank_start = pos;
331 while rank_start > 0
332 && Self::order_by_values_equal(
333 partition[rank_start].1,
334 partition[rank_start - 1].1,
335 order_by,
336 col_map,
337 )
338 {
339 rank_start -= 1;
340 }
341 Ok(Value::Integer((rank_start + 1) as i64))
342 }
343 }
344
345 fn compute_frame_bounds(
347 pos: usize,
348 partition_size: usize,
349 frame: Option<&WindowFrame>,
350 ) -> (usize, usize) {
351 let frame = match frame {
352 Some(f) => f,
353 None => {
354 return (0, pos + 1);
356 }
357 };
358
359 let start = match &frame.start {
360 FrameBound::UnboundedPreceding => 0,
361 FrameBound::Preceding(n) => pos.saturating_sub(*n as usize),
362 FrameBound::CurrentRow => pos,
363 FrameBound::Following(n) => (pos + *n as usize).min(partition_size),
364 FrameBound::UnboundedFollowing => partition_size,
365 };
366
367 let end = match &frame.end {
368 FrameBound::UnboundedPreceding => 0,
369 FrameBound::Preceding(n) => pos.saturating_sub(*n as usize),
370 FrameBound::CurrentRow => pos + 1,
371 FrameBound::Following(n) => (pos + *n as usize + 1).min(partition_size),
372 FrameBound::UnboundedFollowing => partition_size,
373 };
374
375 (start, end)
376 }
377
378 fn compute_aggregate_sum(
380 partition: &[(usize, &Row)],
381 pos: usize,
382 expr: &Expr,
383 frame: Option<&WindowFrame>,
384 col_map: &HashMap<String, usize>,
385 ) -> Result<Value> {
386 let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
387 let mut sum = 0.0;
388 let mut has_value = false;
389
390 for (_, row) in partition.iter().skip(start).take(end - start) {
391 let val = expr.eval(&row.values, col_map)?;
392 if let Some(n) = val.as_f64() {
393 sum += n;
394 has_value = true;
395 }
396 }
397
398 if has_value {
399 Ok(Value::Real(sum))
400 } else {
401 Ok(Value::Null)
402 }
403 }
404
405 fn compute_aggregate_avg(
407 partition: &[(usize, &Row)],
408 pos: usize,
409 expr: &Expr,
410 frame: Option<&WindowFrame>,
411 col_map: &HashMap<String, usize>,
412 ) -> Result<Value> {
413 let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
414 let mut sum = 0.0;
415 let mut count = 0;
416
417 for (_, row) in partition.iter().skip(start).take(end - start) {
418 let val = expr.eval(&row.values, col_map)?;
419 if let Some(n) = val.as_f64() {
420 sum += n;
421 count += 1;
422 }
423 }
424
425 if count > 0 {
426 Ok(Value::Real(sum / count as f64))
427 } else {
428 Ok(Value::Null)
429 }
430 }
431
432 fn compute_aggregate_count(
434 partition: &[(usize, &Row)],
435 pos: usize,
436 expr: Option<&Expr>,
437 frame: Option<&WindowFrame>,
438 col_map: &HashMap<String, usize>,
439 ) -> Result<Value> {
440 let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
441 let mut count = 0i64;
442
443 for (_, row) in partition.iter().skip(start).take(end - start) {
444 match expr {
445 Some(e) => {
446 let val = e.eval(&row.values, col_map)?;
447 if !val.is_null() {
448 count += 1;
449 }
450 }
451 None => {
452 count += 1;
454 }
455 }
456 }
457
458 Ok(Value::Integer(count))
459 }
460
461 fn compute_aggregate_min(
463 partition: &[(usize, &Row)],
464 pos: usize,
465 expr: &Expr,
466 frame: Option<&WindowFrame>,
467 col_map: &HashMap<String, usize>,
468 ) -> Result<Value> {
469 let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
470 let mut min: Option<Value> = None;
471
472 for (_, row) in partition.iter().skip(start).take(end - start) {
473 let val = expr.eval(&row.values, col_map)?;
474 if !val.is_null() {
475 min = Some(match min {
476 Some(m) if val < m => val,
477 Some(m) => m,
478 None => val,
479 });
480 }
481 }
482
483 Ok(min.unwrap_or(Value::Null))
484 }
485
486 fn compute_aggregate_max(
488 partition: &[(usize, &Row)],
489 pos: usize,
490 expr: &Expr,
491 frame: Option<&WindowFrame>,
492 col_map: &HashMap<String, usize>,
493 ) -> Result<Value> {
494 let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
495 let mut max: Option<Value> = None;
496
497 for (_, row) in partition.iter().skip(start).take(end - start) {
498 let val = expr.eval(&row.values, col_map)?;
499 if !val.is_null() {
500 max = Some(match max {
501 Some(m) if val > m => val,
502 Some(m) => m,
503 None => val,
504 });
505 }
506 }
507
508 Ok(max.unwrap_or(Value::Null))
509 }
510}
511
512fn build_col_map(columns: &[String]) -> HashMap<String, usize> {
514 let mut map = HashMap::new();
515 for (i, col) in columns.iter().enumerate() {
516 map.insert(col.clone(), i);
517 if let Some(name) = col.split('.').next_back() {
519 map.insert(name.to_string(), i);
520 }
521 }
522 map
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528 use crate::{FrameBound, FrameUnit};
529
530 fn make_row(values: Vec<Value>, columns: Vec<&str>) -> Row {
531 Row::new(values, columns.into_iter().map(|s| s.to_string()).collect())
532 }
533
534 #[test]
535 fn test_row_number() {
536 let rows = vec![
537 make_row(
538 vec![Value::Integer(1), Value::Text("Alice".into())],
539 vec!["id", "name"],
540 ),
541 make_row(
542 vec![Value::Integer(2), Value::Text("Bob".into())],
543 vec!["id", "name"],
544 ),
545 make_row(
546 vec![Value::Integer(3), Value::Text("Carol".into())],
547 vec!["id", "name"],
548 ),
549 ];
550
551 let window_func = WindowFunction {
552 function: WindowFunctionType::RowNumber,
553 partition_by: vec![],
554 order_by: vec![WindowOrderByExpr {
555 expr: Expr::Column {
556 table: None,
557 name: "id".into(),
558 index: None,
559 },
560 order: WindowSortOrder::Asc,
561 nulls_first: None,
562 }],
563 frame: None,
564 };
565
566 let window_exprs = vec![(Expr::Window(window_func), "row_num".to_string())];
567 let result = WindowExecutor::execute(rows, &window_exprs).unwrap();
568
569 assert_eq!(result.len(), 3);
570 assert_eq!(result[0].values.last(), Some(&Value::Integer(1)));
571 assert_eq!(result[1].values.last(), Some(&Value::Integer(2)));
572 assert_eq!(result[2].values.last(), Some(&Value::Integer(3)));
573 }
574
575 #[test]
576 fn test_partition_by() {
577 let rows = vec![
578 make_row(
579 vec![Value::Text("A".into()), Value::Integer(10)],
580 vec!["dept", "salary"],
581 ),
582 make_row(
583 vec![Value::Text("A".into()), Value::Integer(20)],
584 vec!["dept", "salary"],
585 ),
586 make_row(
587 vec![Value::Text("B".into()), Value::Integer(15)],
588 vec!["dept", "salary"],
589 ),
590 make_row(
591 vec![Value::Text("B".into()), Value::Integer(25)],
592 vec!["dept", "salary"],
593 ),
594 ];
595
596 let window_func = WindowFunction {
597 function: WindowFunctionType::RowNumber,
598 partition_by: vec![Expr::Column {
599 table: None,
600 name: "dept".into(),
601 index: None,
602 }],
603 order_by: vec![WindowOrderByExpr {
604 expr: Expr::Column {
605 table: None,
606 name: "salary".into(),
607 index: None,
608 },
609 order: WindowSortOrder::Asc,
610 nulls_first: None,
611 }],
612 frame: None,
613 };
614
615 let window_exprs = vec![(Expr::Window(window_func), "row_num".to_string())];
616 let result = WindowExecutor::execute(rows, &window_exprs).unwrap();
617
618 assert_eq!(result.len(), 4);
622 }
623
624 #[test]
625 fn test_running_sum() {
626 let rows = vec![
627 make_row(
628 vec![Value::Integer(1), Value::Integer(100)],
629 vec!["id", "amount"],
630 ),
631 make_row(
632 vec![Value::Integer(2), Value::Integer(200)],
633 vec!["id", "amount"],
634 ),
635 make_row(
636 vec![Value::Integer(3), Value::Integer(150)],
637 vec!["id", "amount"],
638 ),
639 ];
640
641 let window_func = WindowFunction {
642 function: WindowFunctionType::Sum(Box::new(Expr::Column {
643 table: None,
644 name: "amount".into(),
645 index: None,
646 })),
647 partition_by: vec![],
648 order_by: vec![WindowOrderByExpr {
649 expr: Expr::Column {
650 table: None,
651 name: "id".into(),
652 index: None,
653 },
654 order: WindowSortOrder::Asc,
655 nulls_first: None,
656 }],
657 frame: Some(WindowFrame {
658 unit: FrameUnit::Rows,
659 start: FrameBound::UnboundedPreceding,
660 end: FrameBound::CurrentRow,
661 }),
662 };
663
664 let window_exprs = vec![(Expr::Window(window_func), "running_total".to_string())];
665 let result = WindowExecutor::execute(rows, &window_exprs).unwrap();
666
667 assert_eq!(result.len(), 3);
668 assert_eq!(result[0].values.last(), Some(&Value::Real(100.0)));
669 assert_eq!(result[1].values.last(), Some(&Value::Real(300.0)));
670 assert_eq!(result[2].values.last(), Some(&Value::Real(450.0)));
671 }
672
673 #[test]
674 fn test_lag_lead() {
675 let rows = vec![
676 make_row(
677 vec![Value::Integer(1), Value::Integer(100)],
678 vec!["id", "value"],
679 ),
680 make_row(
681 vec![Value::Integer(2), Value::Integer(200)],
682 vec!["id", "value"],
683 ),
684 make_row(
685 vec![Value::Integer(3), Value::Integer(300)],
686 vec!["id", "value"],
687 ),
688 ];
689
690 let lag_func = WindowFunction {
691 function: WindowFunctionType::Lag {
692 expr: Box::new(Expr::Column {
693 table: None,
694 name: "value".into(),
695 index: None,
696 }),
697 offset: 1,
698 default: Some(Box::new(Expr::Literal(Value::Integer(0)))),
699 },
700 partition_by: vec![],
701 order_by: vec![WindowOrderByExpr {
702 expr: Expr::Column {
703 table: None,
704 name: "id".into(),
705 index: None,
706 },
707 order: WindowSortOrder::Asc,
708 nulls_first: None,
709 }],
710 frame: None,
711 };
712
713 let window_exprs = vec![(Expr::Window(lag_func), "prev_value".to_string())];
714 let result = WindowExecutor::execute(rows, &window_exprs).unwrap();
715
716 assert_eq!(result.len(), 3);
717 assert_eq!(result[0].values.last(), Some(&Value::Integer(0))); assert_eq!(result[1].values.last(), Some(&Value::Integer(100)));
719 assert_eq!(result[2].values.last(), Some(&Value::Integer(200)));
720 }
721}