Skip to main content

featherdb_query/
window.rs

1//! Window function execution
2//!
3//! Implements SQL window functions (OVER clause) including:
4//! - Ranking functions: ROW_NUMBER, RANK, DENSE_RANK, NTILE
5//! - Navigation functions: LAG, LEAD, FIRST_VALUE, LAST_VALUE, NTH_VALUE
6//! - Aggregate functions over windows: SUM, AVG, COUNT, MIN, MAX
7
8use crate::executor::Row;
9use crate::expr::{
10    Expr, FrameBound, WindowFrame, WindowFunction, WindowFunctionType, WindowOrderByExpr,
11    WindowSortOrder,
12};
13use featherdb_core::{Result, Value};
14use std::collections::HashMap;
15use std::sync::Arc;
16
17/// Window executor that computes window function results
18pub struct WindowExecutor;
19
20impl WindowExecutor {
21    /// Execute window functions on a set of rows
22    ///
23    /// For each window expression, computes the window function result for each row
24    /// and appends it to the row's values.
25    pub fn execute(rows: Vec<Row>, window_exprs: &[(Expr, String)]) -> Result<Vec<Row>> {
26        if rows.is_empty() {
27            return Ok(rows);
28        }
29
30        // Build column map from first row
31        let col_map = build_col_map(&rows[0].columns);
32
33        // Process each window expression
34        let mut result_rows = rows;
35
36        for (window_expr, alias) in window_exprs {
37            if let Expr::Window(window_func) = window_expr {
38                result_rows =
39                    Self::execute_window_function(result_rows, window_func, alias, &col_map)?;
40            }
41        }
42
43        Ok(result_rows)
44    }
45
46    /// Execute a single window function across all rows
47    fn execute_window_function(
48        mut rows: Vec<Row>,
49        window_func: &WindowFunction,
50        alias: &str,
51        col_map: &HashMap<String, usize>,
52    ) -> Result<Vec<Row>> {
53        // Partition the rows
54        let partitions = Self::partition_rows(&rows, &window_func.partition_by, col_map)?;
55
56        // For each partition, sort and compute window values
57        let mut window_values: Vec<(usize, Value)> = Vec::new();
58
59        for partition_indices in partitions {
60            // Get a view of partition rows with their original indices
61            let mut partition_rows: Vec<(usize, &Row)> =
62                partition_indices.iter().map(|&i| (i, &rows[i])).collect();
63
64            // Sort the partition by ORDER BY clause
65            if !window_func.order_by.is_empty() {
66                Self::sort_partition(&mut partition_rows, &window_func.order_by, col_map)?;
67            }
68
69            // Compute window function values for this partition
70            let values = Self::compute_window_values(
71                &partition_rows,
72                &window_func.function,
73                window_func.frame.as_ref(),
74                &window_func.order_by,
75                col_map,
76            )?;
77
78            // Map values back to original row indices
79            for ((orig_idx, _), value) in partition_rows.into_iter().zip(values) {
80                window_values.push((orig_idx, value));
81            }
82        }
83
84        // Sort window values by original index
85        window_values.sort_by_key(|(idx, _)| *idx);
86
87        // Add window values to rows
88        for (row, (_, value)) in rows.iter_mut().zip(window_values) {
89            row.values.push(value);
90            // Need to make columns mutable via Arc::make_mut
91            let cols = Arc::make_mut(&mut row.columns);
92            cols.push(alias.to_string());
93        }
94
95        Ok(rows)
96    }
97
98    /// Partition rows based on PARTITION BY expressions
99    fn partition_rows(
100        rows: &[Row],
101        partition_by: &[Expr],
102        col_map: &HashMap<String, usize>,
103    ) -> Result<Vec<Vec<usize>>> {
104        if partition_by.is_empty() {
105            // No partitioning - all rows in one partition
106            return Ok(vec![(0..rows.len()).collect()]);
107        }
108
109        let mut partitions: HashMap<Vec<Value>, Vec<usize>> = HashMap::new();
110
111        for (idx, row) in rows.iter().enumerate() {
112            let key: Vec<Value> = partition_by
113                .iter()
114                .map(|expr| expr.eval(&row.values, col_map))
115                .collect::<Result<_>>()?;
116
117            partitions.entry(key).or_default().push(idx);
118        }
119
120        Ok(partitions.into_values().collect())
121    }
122
123    /// Sort partition rows by ORDER BY expressions
124    fn sort_partition(
125        partition: &mut [(usize, &Row)],
126        order_by: &[WindowOrderByExpr],
127        col_map: &HashMap<String, usize>,
128    ) -> Result<()> {
129        partition.sort_by(|(_, row_a), (_, row_b)| {
130            for order_expr in order_by {
131                let val_a = order_expr
132                    .expr
133                    .eval(&row_a.values, col_map)
134                    .unwrap_or(Value::Null);
135                let val_b = order_expr
136                    .expr
137                    .eval(&row_b.values, col_map)
138                    .unwrap_or(Value::Null);
139
140                let cmp = match order_expr.order {
141                    WindowSortOrder::Asc => val_a.cmp(&val_b),
142                    WindowSortOrder::Desc => val_b.cmp(&val_a),
143                };
144
145                if cmp != std::cmp::Ordering::Equal {
146                    return cmp;
147                }
148            }
149            std::cmp::Ordering::Equal
150        });
151
152        Ok(())
153    }
154
155    /// Compute window function values for a sorted partition
156    fn compute_window_values(
157        partition: &[(usize, &Row)],
158        function: &WindowFunctionType,
159        frame: Option<&WindowFrame>,
160        order_by: &[WindowOrderByExpr],
161        col_map: &HashMap<String, usize>,
162    ) -> Result<Vec<Value>> {
163        let n = partition.len();
164        let mut values = Vec::with_capacity(n);
165
166        for (pos, (_, row)) in partition.iter().enumerate() {
167            let value = match function {
168                WindowFunctionType::RowNumber => Value::Integer((pos + 1) as i64),
169
170                WindowFunctionType::Rank => {
171                    Self::compute_rank(partition, pos, order_by, col_map, false)?
172                }
173
174                WindowFunctionType::DenseRank => {
175                    Self::compute_rank(partition, pos, order_by, col_map, true)?
176                }
177
178                WindowFunctionType::NTile(num_buckets) => {
179                    let bucket = ((pos as u32 * *num_buckets) / n as u32) + 1;
180                    Value::Integer(bucket as i64)
181                }
182
183                WindowFunctionType::Lag {
184                    expr,
185                    offset,
186                    default,
187                } => {
188                    let target_pos = pos as i64 - *offset;
189                    if target_pos >= 0 && (target_pos as usize) < n {
190                        let target_row = partition[target_pos as usize].1;
191                        expr.eval(&target_row.values, col_map)?
192                    } else {
193                        default
194                            .as_ref()
195                            .map(|d| d.eval(&row.values, col_map))
196                            .transpose()?
197                            .unwrap_or(Value::Null)
198                    }
199                }
200
201                WindowFunctionType::Lead {
202                    expr,
203                    offset,
204                    default,
205                } => {
206                    let target_pos = pos as i64 + *offset;
207                    if target_pos >= 0 && (target_pos as usize) < n {
208                        let target_row = partition[target_pos as usize].1;
209                        expr.eval(&target_row.values, col_map)?
210                    } else {
211                        default
212                            .as_ref()
213                            .map(|d| d.eval(&row.values, col_map))
214                            .transpose()?
215                            .unwrap_or(Value::Null)
216                    }
217                }
218
219                WindowFunctionType::FirstValue(expr) => {
220                    let (start, _) = Self::compute_frame_bounds(pos, n, frame);
221                    if start < n {
222                        let first_row = partition[start].1;
223                        expr.eval(&first_row.values, col_map)?
224                    } else {
225                        Value::Null
226                    }
227                }
228
229                WindowFunctionType::LastValue(expr) => {
230                    let (_, end) = Self::compute_frame_bounds(pos, n, frame);
231                    if end > 0 && end <= n {
232                        let last_row = partition[end - 1].1;
233                        expr.eval(&last_row.values, col_map)?
234                    } else {
235                        Value::Null
236                    }
237                }
238
239                WindowFunctionType::NthValue(expr, nth) => {
240                    let (start, end) = Self::compute_frame_bounds(pos, n, frame);
241                    let target = start + (*nth as usize) - 1;
242                    if target < end && target < n {
243                        let target_row = partition[target].1;
244                        expr.eval(&target_row.values, col_map)?
245                    } else {
246                        Value::Null
247                    }
248                }
249
250                // Aggregate functions over window
251                WindowFunctionType::Sum(expr) => {
252                    Self::compute_aggregate_sum(partition, pos, expr, frame, col_map)?
253                }
254
255                WindowFunctionType::Avg(expr) => {
256                    Self::compute_aggregate_avg(partition, pos, expr, frame, col_map)?
257                }
258
259                WindowFunctionType::Count(expr) => {
260                    Self::compute_aggregate_count(partition, pos, expr.as_deref(), frame, col_map)?
261                }
262
263                WindowFunctionType::Min(expr) => {
264                    Self::compute_aggregate_min(partition, pos, expr, frame, col_map)?
265                }
266
267                WindowFunctionType::Max(expr) => {
268                    Self::compute_aggregate_max(partition, pos, expr, frame, col_map)?
269                }
270            };
271
272            values.push(value);
273        }
274
275        Ok(values)
276    }
277
278    /// Compare ORDER BY values of two rows for tie detection
279    fn order_by_values_equal(
280        row_a: &Row,
281        row_b: &Row,
282        order_by: &[WindowOrderByExpr],
283        col_map: &HashMap<String, usize>,
284    ) -> bool {
285        for order_expr in order_by {
286            let val_a = order_expr
287                .expr
288                .eval(&row_a.values, col_map)
289                .unwrap_or(Value::Null);
290            let val_b = order_expr
291                .expr
292                .eval(&row_b.values, col_map)
293                .unwrap_or(Value::Null);
294            if val_a != val_b {
295                return false;
296            }
297        }
298        true
299    }
300
301    /// Compute RANK or DENSE_RANK with proper tie detection
302    fn compute_rank(
303        partition: &[(usize, &Row)],
304        pos: usize,
305        order_by: &[WindowOrderByExpr],
306        col_map: &HashMap<String, usize>,
307        dense: bool,
308    ) -> Result<Value> {
309        if pos == 0 || order_by.is_empty() {
310            return Ok(Value::Integer(1));
311        }
312
313        // Walk backwards to find the rank
314        if dense {
315            // DENSE_RANK: count distinct ORDER BY value groups before this position
316            let mut dense_rank = 1i64;
317            for i in 1..=pos {
318                if !Self::order_by_values_equal(
319                    partition[i].1,
320                    partition[i - 1].1,
321                    order_by,
322                    col_map,
323                ) {
324                    dense_rank += 1;
325                }
326            }
327            Ok(Value::Integer(dense_rank))
328        } else {
329            // RANK: position of first row with same ORDER BY values (1-based)
330            let mut rank_start = pos;
331            while rank_start > 0
332                && Self::order_by_values_equal(
333                    partition[rank_start].1,
334                    partition[rank_start - 1].1,
335                    order_by,
336                    col_map,
337                )
338            {
339                rank_start -= 1;
340            }
341            Ok(Value::Integer((rank_start + 1) as i64))
342        }
343    }
344
345    /// Compute frame bounds for the current position
346    fn compute_frame_bounds(
347        pos: usize,
348        partition_size: usize,
349        frame: Option<&WindowFrame>,
350    ) -> (usize, usize) {
351        let frame = match frame {
352            Some(f) => f,
353            None => {
354                // Default frame: RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW
355                return (0, pos + 1);
356            }
357        };
358
359        let start = match &frame.start {
360            FrameBound::UnboundedPreceding => 0,
361            FrameBound::Preceding(n) => pos.saturating_sub(*n as usize),
362            FrameBound::CurrentRow => pos,
363            FrameBound::Following(n) => (pos + *n as usize).min(partition_size),
364            FrameBound::UnboundedFollowing => partition_size,
365        };
366
367        let end = match &frame.end {
368            FrameBound::UnboundedPreceding => 0,
369            FrameBound::Preceding(n) => pos.saturating_sub(*n as usize),
370            FrameBound::CurrentRow => pos + 1,
371            FrameBound::Following(n) => (pos + *n as usize + 1).min(partition_size),
372            FrameBound::UnboundedFollowing => partition_size,
373        };
374
375        (start, end)
376    }
377
378    /// Compute SUM aggregate over window frame
379    fn compute_aggregate_sum(
380        partition: &[(usize, &Row)],
381        pos: usize,
382        expr: &Expr,
383        frame: Option<&WindowFrame>,
384        col_map: &HashMap<String, usize>,
385    ) -> Result<Value> {
386        let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
387        let mut sum = 0.0;
388        let mut has_value = false;
389
390        for (_, row) in partition.iter().skip(start).take(end - start) {
391            let val = expr.eval(&row.values, col_map)?;
392            if let Some(n) = val.as_f64() {
393                sum += n;
394                has_value = true;
395            }
396        }
397
398        if has_value {
399            Ok(Value::Real(sum))
400        } else {
401            Ok(Value::Null)
402        }
403    }
404
405    /// Compute AVG aggregate over window frame
406    fn compute_aggregate_avg(
407        partition: &[(usize, &Row)],
408        pos: usize,
409        expr: &Expr,
410        frame: Option<&WindowFrame>,
411        col_map: &HashMap<String, usize>,
412    ) -> Result<Value> {
413        let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
414        let mut sum = 0.0;
415        let mut count = 0;
416
417        for (_, row) in partition.iter().skip(start).take(end - start) {
418            let val = expr.eval(&row.values, col_map)?;
419            if let Some(n) = val.as_f64() {
420                sum += n;
421                count += 1;
422            }
423        }
424
425        if count > 0 {
426            Ok(Value::Real(sum / count as f64))
427        } else {
428            Ok(Value::Null)
429        }
430    }
431
432    /// Compute COUNT aggregate over window frame
433    fn compute_aggregate_count(
434        partition: &[(usize, &Row)],
435        pos: usize,
436        expr: Option<&Expr>,
437        frame: Option<&WindowFrame>,
438        col_map: &HashMap<String, usize>,
439    ) -> Result<Value> {
440        let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
441        let mut count = 0i64;
442
443        for (_, row) in partition.iter().skip(start).take(end - start) {
444            match expr {
445                Some(e) => {
446                    let val = e.eval(&row.values, col_map)?;
447                    if !val.is_null() {
448                        count += 1;
449                    }
450                }
451                None => {
452                    // COUNT(*) - count all rows
453                    count += 1;
454                }
455            }
456        }
457
458        Ok(Value::Integer(count))
459    }
460
461    /// Compute MIN aggregate over window frame
462    fn compute_aggregate_min(
463        partition: &[(usize, &Row)],
464        pos: usize,
465        expr: &Expr,
466        frame: Option<&WindowFrame>,
467        col_map: &HashMap<String, usize>,
468    ) -> Result<Value> {
469        let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
470        let mut min: Option<Value> = None;
471
472        for (_, row) in partition.iter().skip(start).take(end - start) {
473            let val = expr.eval(&row.values, col_map)?;
474            if !val.is_null() {
475                min = Some(match min {
476                    Some(m) if val < m => val,
477                    Some(m) => m,
478                    None => val,
479                });
480            }
481        }
482
483        Ok(min.unwrap_or(Value::Null))
484    }
485
486    /// Compute MAX aggregate over window frame
487    fn compute_aggregate_max(
488        partition: &[(usize, &Row)],
489        pos: usize,
490        expr: &Expr,
491        frame: Option<&WindowFrame>,
492        col_map: &HashMap<String, usize>,
493    ) -> Result<Value> {
494        let (start, end) = Self::compute_frame_bounds(pos, partition.len(), frame);
495        let mut max: Option<Value> = None;
496
497        for (_, row) in partition.iter().skip(start).take(end - start) {
498            let val = expr.eval(&row.values, col_map)?;
499            if !val.is_null() {
500                max = Some(match max {
501                    Some(m) if val > m => val,
502                    Some(m) => m,
503                    None => val,
504                });
505            }
506        }
507
508        Ok(max.unwrap_or(Value::Null))
509    }
510}
511
512/// Build a column name to index map
513fn build_col_map(columns: &[String]) -> HashMap<String, usize> {
514    let mut map = HashMap::new();
515    for (i, col) in columns.iter().enumerate() {
516        map.insert(col.clone(), i);
517        // Also add unqualified name
518        if let Some(name) = col.split('.').next_back() {
519            map.insert(name.to_string(), i);
520        }
521    }
522    map
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528    use crate::{FrameBound, FrameUnit};
529
530    fn make_row(values: Vec<Value>, columns: Vec<&str>) -> Row {
531        Row::new(values, columns.into_iter().map(|s| s.to_string()).collect())
532    }
533
534    #[test]
535    fn test_row_number() {
536        let rows = vec![
537            make_row(
538                vec![Value::Integer(1), Value::Text("Alice".into())],
539                vec!["id", "name"],
540            ),
541            make_row(
542                vec![Value::Integer(2), Value::Text("Bob".into())],
543                vec!["id", "name"],
544            ),
545            make_row(
546                vec![Value::Integer(3), Value::Text("Carol".into())],
547                vec!["id", "name"],
548            ),
549        ];
550
551        let window_func = WindowFunction {
552            function: WindowFunctionType::RowNumber,
553            partition_by: vec![],
554            order_by: vec![WindowOrderByExpr {
555                expr: Expr::Column {
556                    table: None,
557                    name: "id".into(),
558                    index: None,
559                },
560                order: WindowSortOrder::Asc,
561                nulls_first: None,
562            }],
563            frame: None,
564        };
565
566        let window_exprs = vec![(Expr::Window(window_func), "row_num".to_string())];
567        let result = WindowExecutor::execute(rows, &window_exprs).unwrap();
568
569        assert_eq!(result.len(), 3);
570        assert_eq!(result[0].values.last(), Some(&Value::Integer(1)));
571        assert_eq!(result[1].values.last(), Some(&Value::Integer(2)));
572        assert_eq!(result[2].values.last(), Some(&Value::Integer(3)));
573    }
574
575    #[test]
576    fn test_partition_by() {
577        let rows = vec![
578            make_row(
579                vec![Value::Text("A".into()), Value::Integer(10)],
580                vec!["dept", "salary"],
581            ),
582            make_row(
583                vec![Value::Text("A".into()), Value::Integer(20)],
584                vec!["dept", "salary"],
585            ),
586            make_row(
587                vec![Value::Text("B".into()), Value::Integer(15)],
588                vec!["dept", "salary"],
589            ),
590            make_row(
591                vec![Value::Text("B".into()), Value::Integer(25)],
592                vec!["dept", "salary"],
593            ),
594        ];
595
596        let window_func = WindowFunction {
597            function: WindowFunctionType::RowNumber,
598            partition_by: vec![Expr::Column {
599                table: None,
600                name: "dept".into(),
601                index: None,
602            }],
603            order_by: vec![WindowOrderByExpr {
604                expr: Expr::Column {
605                    table: None,
606                    name: "salary".into(),
607                    index: None,
608                },
609                order: WindowSortOrder::Asc,
610                nulls_first: None,
611            }],
612            frame: None,
613        };
614
615        let window_exprs = vec![(Expr::Window(window_func), "row_num".to_string())];
616        let result = WindowExecutor::execute(rows, &window_exprs).unwrap();
617
618        // Rows should be numbered within each partition
619        // Department A: row 1 (salary 10), row 2 (salary 20)
620        // Department B: row 1 (salary 15), row 2 (salary 25)
621        assert_eq!(result.len(), 4);
622    }
623
624    #[test]
625    fn test_running_sum() {
626        let rows = vec![
627            make_row(
628                vec![Value::Integer(1), Value::Integer(100)],
629                vec!["id", "amount"],
630            ),
631            make_row(
632                vec![Value::Integer(2), Value::Integer(200)],
633                vec!["id", "amount"],
634            ),
635            make_row(
636                vec![Value::Integer(3), Value::Integer(150)],
637                vec!["id", "amount"],
638            ),
639        ];
640
641        let window_func = WindowFunction {
642            function: WindowFunctionType::Sum(Box::new(Expr::Column {
643                table: None,
644                name: "amount".into(),
645                index: None,
646            })),
647            partition_by: vec![],
648            order_by: vec![WindowOrderByExpr {
649                expr: Expr::Column {
650                    table: None,
651                    name: "id".into(),
652                    index: None,
653                },
654                order: WindowSortOrder::Asc,
655                nulls_first: None,
656            }],
657            frame: Some(WindowFrame {
658                unit: FrameUnit::Rows,
659                start: FrameBound::UnboundedPreceding,
660                end: FrameBound::CurrentRow,
661            }),
662        };
663
664        let window_exprs = vec![(Expr::Window(window_func), "running_total".to_string())];
665        let result = WindowExecutor::execute(rows, &window_exprs).unwrap();
666
667        assert_eq!(result.len(), 3);
668        assert_eq!(result[0].values.last(), Some(&Value::Real(100.0)));
669        assert_eq!(result[1].values.last(), Some(&Value::Real(300.0)));
670        assert_eq!(result[2].values.last(), Some(&Value::Real(450.0)));
671    }
672
673    #[test]
674    fn test_lag_lead() {
675        let rows = vec![
676            make_row(
677                vec![Value::Integer(1), Value::Integer(100)],
678                vec!["id", "value"],
679            ),
680            make_row(
681                vec![Value::Integer(2), Value::Integer(200)],
682                vec!["id", "value"],
683            ),
684            make_row(
685                vec![Value::Integer(3), Value::Integer(300)],
686                vec!["id", "value"],
687            ),
688        ];
689
690        let lag_func = WindowFunction {
691            function: WindowFunctionType::Lag {
692                expr: Box::new(Expr::Column {
693                    table: None,
694                    name: "value".into(),
695                    index: None,
696                }),
697                offset: 1,
698                default: Some(Box::new(Expr::Literal(Value::Integer(0)))),
699            },
700            partition_by: vec![],
701            order_by: vec![WindowOrderByExpr {
702                expr: Expr::Column {
703                    table: None,
704                    name: "id".into(),
705                    index: None,
706                },
707                order: WindowSortOrder::Asc,
708                nulls_first: None,
709            }],
710            frame: None,
711        };
712
713        let window_exprs = vec![(Expr::Window(lag_func), "prev_value".to_string())];
714        let result = WindowExecutor::execute(rows, &window_exprs).unwrap();
715
716        assert_eq!(result.len(), 3);
717        assert_eq!(result[0].values.last(), Some(&Value::Integer(0))); // Default for first row
718        assert_eq!(result[1].values.last(), Some(&Value::Integer(100)));
719        assert_eq!(result[2].values.last(), Some(&Value::Integer(200)));
720    }
721}