Skip to main content

fraiseql_core/runtime/
window.rs

1//! Window Function SQL Generation
2//!
3//! Generates database-specific SQL for window functions.
4//!
5//! # Supported Databases
6//!
7//! - **PostgreSQL**: Full support (all functions + GROUPS frames + frame exclusion)
8//! - **MySQL 8.0+**: Full support (no GROUPS, no frame exclusion)
9//! - **SQLite 3.25+**: Basic support (no GROUPS, no `PERCENT_RANK/CUME_DIST`)
10//! - **SQL Server**: Full support (STDEV/VAR naming difference)
11
12use std::fmt::Write as _;
13
14use crate::{
15    compiler::{
16        aggregation::OrderDirection,
17        window_functions::{
18            FrameBoundary, FrameExclusion, FrameType, WindowExecutionPlan, WindowFrame,
19            WindowFunction, WindowFunctionType,
20        },
21    },
22    db::{GenericWhereGenerator, PostgresDialect, types::DatabaseType},
23    error::{FraiseQLError, Result},
24};
25
26/// Generated SQL for window function query
27#[derive(Debug, Clone)]
28pub struct WindowSql {
29    /// Parameterized SQL template. WHERE clause values use dialect-specific
30    /// placeholders (`$1`, `?`, `@P1`); column names are schema-derived and
31    /// allowlist-validated via [`crate::compiler::window_allowlist::WindowAllowlist`]
32    /// and are not user-controlled at runtime.
33    pub raw_sql: String,
34
35    /// Bind parameters in placeholder order, passed to
36    /// `execute_parameterized_aggregate`.
37    pub parameters: Vec<serde_json::Value>,
38}
39
40/// Window function SQL generator
41pub struct WindowSqlGenerator {
42    database_type: DatabaseType,
43}
44
45impl WindowSqlGenerator {
46    /// Create new generator for database type
47    #[must_use]
48    pub const fn new(database_type: DatabaseType) -> Self {
49        Self { database_type }
50    }
51
52    /// Generate SQL from window execution plan
53    ///
54    /// # Errors
55    ///
56    /// Returns error if:
57    /// - Unsupported function for database
58    /// - Invalid frame specification
59    /// - WHERE clause generation fails
60    pub fn generate(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
61        match self.database_type {
62            DatabaseType::PostgreSQL => self.generate_postgres(plan),
63            DatabaseType::MySQL => self.generate_mysql(plan),
64            DatabaseType::SQLite => self.generate_sqlite(plan),
65            DatabaseType::SQLServer => self.generate_sqlserver(plan),
66        }
67    }
68
69    /// Generate PostgreSQL window function SQL
70    fn generate_postgres(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
71        let mut sql = String::from("SELECT ");
72        let mut parameters = Vec::new();
73
74        // Add regular SELECT columns
75        for (i, col) in plan.select.iter().enumerate() {
76            if i > 0 {
77                sql.push_str(", ");
78            }
79            let _ = write!(sql, "{} AS {}", col.expression, col.alias);
80        }
81
82        // Add window functions
83        for window in &plan.windows {
84            if !plan.select.is_empty() || sql.len() > "SELECT ".len() {
85                sql.push_str(", ");
86            }
87            sql.push_str(&self.generate_window_function(window)?);
88        }
89
90        // FROM clause
91        let _ = write!(sql, " FROM {}", plan.table);
92
93        // WHERE clause (if any) — use parameterized generation to avoid literal
94        // value escaping and enable the database to cache execution plans.
95        if let Some(clause) = &plan.where_clause {
96            let gen = GenericWhereGenerator::new(PostgresDialect);
97            let (where_sql, where_params) = gen.generate(clause)?;
98            sql.push_str(" WHERE ");
99            sql.push_str(&where_sql);
100            parameters.extend(where_params);
101        }
102
103        // ORDER BY clause
104        if !plan.order_by.is_empty() {
105            sql.push_str(" ORDER BY ");
106            for (i, order) in plan.order_by.iter().enumerate() {
107                if i > 0 {
108                    sql.push_str(", ");
109                }
110                #[allow(clippy::match_same_arms)]
111                // Reason: non_exhaustive enum requires catch-all; explicit Asc arm documents intent
112                let dir = match order.direction {
113                    OrderDirection::Asc => "ASC",
114                    OrderDirection::Desc => "DESC",
115                    _ => "ASC",
116                };
117                // Fields in the outer ORDER BY may be JSONB path expressions
118                // (e.g. `data->>'category'`) or window aliases (e.g. `rank`); they
119                // are validated at planner parse time and must not be identifier-quoted.
120                let _ = write!(sql, "{} {}", order.field, dir);
121            }
122        }
123
124        // LIMIT / OFFSET
125        if let Some(limit) = plan.limit {
126            let _ = write!(sql, " LIMIT {limit}");
127        }
128        if let Some(offset) = plan.offset {
129            let _ = write!(sql, " OFFSET {offset}");
130        }
131
132        Ok(WindowSql {
133            raw_sql: sql,
134            parameters,
135        })
136    }
137
138    /// Generate window function expression
139    fn generate_window_function(&self, window: &WindowFunction) -> Result<String> {
140        let func_sql = self.generate_function_call(&window.function)?;
141        let mut sql = format!("{func_sql} OVER (");
142
143        // PARTITION BY — values are pre-validated SQL expressions (may be JSONB paths
144        // like `data->>'col'`); rejection of unsafe chars happens at planner parse time.
145        if !window.partition_by.is_empty() {
146            sql.push_str("PARTITION BY ");
147            sql.push_str(&window.partition_by.join(", "));
148        }
149
150        // ORDER BY — same: values validated at parse time, may be JSONB expressions.
151        if !window.order_by.is_empty() {
152            if !window.partition_by.is_empty() {
153                sql.push(' ');
154            }
155            sql.push_str("ORDER BY ");
156            for (i, order) in window.order_by.iter().enumerate() {
157                if i > 0 {
158                    sql.push_str(", ");
159                }
160                #[allow(clippy::match_same_arms)]
161                // Reason: non_exhaustive enum requires catch-all; explicit Asc arm documents intent
162                let dir = match order.direction {
163                    OrderDirection::Asc => "ASC",
164                    OrderDirection::Desc => "DESC",
165                    _ => "ASC",
166                };
167                let _ = write!(sql, "{} {}", order.field, dir);
168            }
169        }
170
171        // Frame clause
172        if let Some(frame) = &window.frame {
173            if !window.partition_by.is_empty() || !window.order_by.is_empty() {
174                sql.push(' ');
175            }
176            sql.push_str(&self.generate_frame_clause(frame)?);
177        }
178
179        sql.push(')');
180        let _ = write!(sql, " AS {}", window.alias);
181
182        Ok(sql)
183    }
184
185    /// Generate function call SQL
186    fn generate_function_call(&self, function: &WindowFunctionType) -> Result<String> {
187        let sql = match function {
188            WindowFunctionType::RowNumber => "ROW_NUMBER()".to_string(),
189            WindowFunctionType::Rank => "RANK()".to_string(),
190            WindowFunctionType::DenseRank => "DENSE_RANK()".to_string(),
191            WindowFunctionType::Ntile { n } => format!("NTILE({n})"),
192            WindowFunctionType::PercentRank => "PERCENT_RANK()".to_string(),
193            WindowFunctionType::CumeDist => "CUME_DIST()".to_string(),
194
195            WindowFunctionType::Lag {
196                field,
197                offset,
198                default,
199            } => {
200                if let Some(default_val) = default {
201                    format!("LAG({field}, {offset}, {default_val})")
202                } else {
203                    format!("LAG({field}, {offset})")
204                }
205            },
206            WindowFunctionType::Lead {
207                field,
208                offset,
209                default,
210            } => {
211                if let Some(default_val) = default {
212                    format!("LEAD({field}, {offset}, {default_val})")
213                } else {
214                    format!("LEAD({field}, {offset})")
215                }
216            },
217            WindowFunctionType::FirstValue { field } => format!("FIRST_VALUE({field})"),
218            WindowFunctionType::LastValue { field } => format!("LAST_VALUE({field})"),
219            WindowFunctionType::NthValue { field, n } => format!("NTH_VALUE({field}, {n})"),
220
221            WindowFunctionType::Sum { field } => format!("SUM({field})"),
222            WindowFunctionType::Avg { field } => format!("AVG({field})"),
223            WindowFunctionType::Count { field: Some(field) } => format!("COUNT({field})"),
224            WindowFunctionType::Count { field: None } => "COUNT(*)".to_string(),
225            WindowFunctionType::Min { field } => format!("MIN({field})"),
226            WindowFunctionType::Max { field } => format!("MAX({field})"),
227            WindowFunctionType::Stddev { field } => {
228                // PostgreSQL/MySQL use STDDEV, SQL Server uses STDEV
229                match self.database_type {
230                    DatabaseType::SQLServer => format!("STDEV({field})"),
231                    _ => format!("STDDEV({field})"),
232                }
233            },
234            WindowFunctionType::Variance { field } => {
235                // PostgreSQL/MySQL use VARIANCE, SQL Server uses VAR
236                match self.database_type {
237                    DatabaseType::SQLServer => format!("VAR({field})"),
238                    _ => format!("VARIANCE({field})"),
239                }
240            },
241        };
242
243        Ok(sql)
244    }
245
246    /// Generate window frame clause
247    fn generate_frame_clause(&self, frame: &WindowFrame) -> Result<String> {
248        let frame_type = match frame.frame_type {
249            FrameType::Rows => "ROWS",
250            FrameType::Range => "RANGE",
251            FrameType::Groups => {
252                if !matches!(self.database_type, DatabaseType::PostgreSQL) {
253                    return Err(FraiseQLError::validation(
254                        "GROUPS frame type only supported on PostgreSQL",
255                    ));
256                }
257                "GROUPS"
258            },
259        };
260
261        let start = self.format_frame_boundary(&frame.start);
262        let end = self.format_frame_boundary(&frame.end);
263
264        let mut sql = format!("{frame_type} BETWEEN {start} AND {end}");
265
266        // Frame exclusion (PostgreSQL only)
267        if let Some(exclusion) = &frame.exclusion {
268            if matches!(self.database_type, DatabaseType::PostgreSQL) {
269                let excl = match exclusion {
270                    FrameExclusion::CurrentRow => "EXCLUDE CURRENT ROW",
271                    FrameExclusion::Group => "EXCLUDE GROUP",
272                    FrameExclusion::Ties => "EXCLUDE TIES",
273                    FrameExclusion::NoOthers => "EXCLUDE NO OTHERS",
274                };
275                let _ = write!(sql, " {excl}");
276            }
277        }
278
279        Ok(sql)
280    }
281
282    /// Format frame boundary
283    #[must_use]
284    pub fn format_frame_boundary(&self, boundary: &FrameBoundary) -> String {
285        match boundary {
286            FrameBoundary::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
287            FrameBoundary::NPreceding { n } => format!("{n} PRECEDING"),
288            FrameBoundary::CurrentRow => "CURRENT ROW".to_string(),
289            FrameBoundary::NFollowing { n } => format!("{n} FOLLOWING"),
290            FrameBoundary::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
291        }
292    }
293
294    /// Generate MySQL window function SQL
295    fn generate_mysql(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
296        // MySQL 8.0+ supports window functions similar to PostgreSQL
297        // Main differences handled in generate_function_call (no STDEV/VAR differences for window
298        // functions)
299        self.generate_postgres(plan)
300    }
301
302    /// Generate SQLite window function SQL
303    fn generate_sqlite(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
304        // SQLite 3.25+ supports window functions
305        // Similar to PostgreSQL but no PERCENT_RANK, CUME_DIST validation done in planner
306        self.generate_postgres(plan)
307    }
308
309    /// Generate SQL Server window function SQL
310    fn generate_sqlserver(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
311        // SQL Server supports window functions with minor differences (STDEV/VAR naming)
312        self.generate_postgres(plan)
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    #![allow(clippy::unwrap_used)] // Reason: test code, panics are acceptable
319
320    use super::*;
321    use crate::{
322        compiler::{
323            aggregation::{OrderByClause, OrderDirection},
324            window_functions::*,
325        },
326        db::{WhereClause, WhereOperator},
327    };
328
329    #[test]
330    fn test_generate_row_number() {
331        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
332
333        let plan = WindowExecutionPlan {
334            table:        "tf_sales".to_string(),
335            select:       vec![SelectColumn {
336                expression: "revenue".to_string(),
337                alias:      "revenue".to_string(),
338            }],
339            windows:      vec![WindowFunction {
340                function:     WindowFunctionType::RowNumber,
341                alias:        "rank".to_string(),
342                partition_by: vec!["data->>'category'".to_string()],
343                order_by:     vec![OrderByClause {
344                    field:     "revenue".to_string(),
345                    direction: OrderDirection::Desc,
346                }],
347                frame:        None,
348            }],
349            where_clause: None,
350            order_by:     vec![],
351            limit:        None,
352            offset:       None,
353        };
354
355        let sql = generator.generate(&plan).unwrap();
356
357        assert!(sql.raw_sql.contains("ROW_NUMBER()"));
358        assert!(sql.raw_sql.contains("PARTITION BY data->>'category'"));
359        assert!(sql.raw_sql.contains("ORDER BY revenue DESC"));
360    }
361
362    #[test]
363    fn test_generate_running_total() {
364        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
365
366        let plan = WindowExecutionPlan {
367            table:        "tf_sales".to_string(),
368            select:       vec![
369                SelectColumn {
370                    expression: "occurred_at".to_string(),
371                    alias:      "date".to_string(),
372                },
373                SelectColumn {
374                    expression: "revenue".to_string(),
375                    alias:      "revenue".to_string(),
376                },
377            ],
378            windows:      vec![WindowFunction {
379                function:     WindowFunctionType::Sum {
380                    field: "revenue".to_string(),
381                },
382                alias:        "running_total".to_string(),
383                partition_by: vec![],
384                order_by:     vec![OrderByClause {
385                    field:     "occurred_at".to_string(),
386                    direction: OrderDirection::Asc,
387                }],
388                frame:        Some(WindowFrame {
389                    frame_type: FrameType::Rows,
390                    start:      FrameBoundary::UnboundedPreceding,
391                    end:        FrameBoundary::CurrentRow,
392                    exclusion:  None,
393                }),
394            }],
395            where_clause: None,
396            order_by:     vec![],
397            limit:        None,
398            offset:       None,
399        };
400
401        let sql = generator.generate(&plan).unwrap();
402
403        assert!(sql.raw_sql.contains("SUM(revenue) OVER"));
404        assert!(sql.raw_sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
405    }
406
407    #[test]
408    fn test_generate_lag_lead() {
409        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
410
411        let plan = WindowExecutionPlan {
412            table:        "tf_sales".to_string(),
413            select:       vec![],
414            windows:      vec![
415                WindowFunction {
416                    function:     WindowFunctionType::Lag {
417                        field:   "revenue".to_string(),
418                        offset:  1,
419                        default: Some(serde_json::json!(0)),
420                    },
421                    alias:        "prev_revenue".to_string(),
422                    partition_by: vec![],
423                    order_by:     vec![OrderByClause {
424                        field:     "occurred_at".to_string(),
425                        direction: OrderDirection::Asc,
426                    }],
427                    frame:        None,
428                },
429                WindowFunction {
430                    function:     WindowFunctionType::Lead {
431                        field:   "revenue".to_string(),
432                        offset:  1,
433                        default: None,
434                    },
435                    alias:        "next_revenue".to_string(),
436                    partition_by: vec![],
437                    order_by:     vec![OrderByClause {
438                        field:     "occurred_at".to_string(),
439                        direction: OrderDirection::Asc,
440                    }],
441                    frame:        None,
442                },
443            ],
444            where_clause: None,
445            order_by:     vec![],
446            limit:        None,
447            offset:       None,
448        };
449
450        let sql = generator.generate(&plan).unwrap();
451
452        assert!(sql.raw_sql.contains("LAG(revenue, 1, 0)"));
453        assert!(sql.raw_sql.contains("LEAD(revenue, 1)"));
454    }
455
456    #[test]
457    fn test_frame_boundary_formatting() {
458        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
459
460        assert_eq!(
461            generator.format_frame_boundary(&FrameBoundary::UnboundedPreceding),
462            "UNBOUNDED PRECEDING"
463        );
464        assert_eq!(
465            generator.format_frame_boundary(&FrameBoundary::NPreceding { n: 5 }),
466            "5 PRECEDING"
467        );
468        assert_eq!(generator.format_frame_boundary(&FrameBoundary::CurrentRow), "CURRENT ROW");
469        assert_eq!(
470            generator.format_frame_boundary(&FrameBoundary::NFollowing { n: 3 }),
471            "3 FOLLOWING"
472        );
473        assert_eq!(
474            generator.format_frame_boundary(&FrameBoundary::UnboundedFollowing),
475            "UNBOUNDED FOLLOWING"
476        );
477    }
478
479    #[test]
480    fn test_moving_average() {
481        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
482
483        let plan = WindowExecutionPlan {
484            table:        "tf_sales".to_string(),
485            select:       vec![],
486            windows:      vec![WindowFunction {
487                function:     WindowFunctionType::Avg {
488                    field: "revenue".to_string(),
489                },
490                alias:        "moving_avg_7d".to_string(),
491                partition_by: vec![],
492                order_by:     vec![OrderByClause {
493                    field:     "occurred_at".to_string(),
494                    direction: OrderDirection::Asc,
495                }],
496                frame:        Some(WindowFrame {
497                    frame_type: FrameType::Rows,
498                    start:      FrameBoundary::NPreceding { n: 6 },
499                    end:        FrameBoundary::CurrentRow,
500                    exclusion:  None,
501                }),
502            }],
503            where_clause: None,
504            order_by:     vec![],
505            limit:        None,
506            offset:       None,
507        };
508
509        let sql = generator.generate(&plan).unwrap();
510
511        assert!(sql.raw_sql.contains("AVG(revenue) OVER"));
512        assert!(sql.raw_sql.contains("ROWS BETWEEN 6 PRECEDING AND CURRENT ROW"));
513    }
514
515    #[test]
516    fn test_sqlserver_stddev_variance() {
517        let generator = WindowSqlGenerator::new(DatabaseType::SQLServer);
518
519        let plan = WindowExecutionPlan {
520            table:        "tf_sales".to_string(),
521            select:       vec![],
522            windows:      vec![
523                WindowFunction {
524                    function:     WindowFunctionType::Stddev {
525                        field: "revenue".to_string(),
526                    },
527                    alias:        "stddev".to_string(),
528                    partition_by: vec![],
529                    order_by:     vec![],
530                    frame:        None,
531                },
532                WindowFunction {
533                    function:     WindowFunctionType::Variance {
534                        field: "revenue".to_string(),
535                    },
536                    alias:        "variance".to_string(),
537                    partition_by: vec![],
538                    order_by:     vec![],
539                    frame:        None,
540                },
541            ],
542            where_clause: None,
543            order_by:     vec![],
544            limit:        None,
545            offset:       None,
546        };
547
548        let sql = generator.generate(&plan).unwrap();
549
550        // SQL Server uses STDEV/VAR instead of STDDEV/VARIANCE
551        assert!(sql.raw_sql.contains("STDEV(revenue)"));
552        assert!(sql.raw_sql.contains("VAR(revenue)"));
553    }
554
555    #[test]
556    fn test_where_clause_uses_bind_parameters() {
557        // Ensures WHERE clause is rendered with $N bind parameters (not literal values).
558        // Literals would require escaping and are vulnerable to injection edge-cases;
559        // bind parameters are always safe.
560        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
561
562        let plan = WindowExecutionPlan {
563            table:        "tf_sales".to_string(),
564            select:       vec![SelectColumn {
565                expression: "revenue".to_string(),
566                alias:      "revenue".to_string(),
567            }],
568            windows:      vec![WindowFunction {
569                function:     WindowFunctionType::RowNumber,
570                alias:        "rank".to_string(),
571                partition_by: vec![],
572                order_by:     vec![],
573                frame:        None,
574            }],
575            where_clause: Some(WhereClause::Field {
576                path:     vec!["status".to_string()],
577                operator: WhereOperator::Eq,
578                value:    serde_json::json!("active"),
579            }),
580            order_by:     vec![],
581            limit:        None,
582            offset:       None,
583        };
584
585        let sql = generator.generate(&plan).unwrap();
586
587        // WHERE clause must use bind parameter ($1), not a literal string value.
588        assert!(
589            sql.raw_sql.contains("WHERE data->>'status' = $1"),
590            "expected bind parameter $1, got: {}",
591            sql.raw_sql
592        );
593        assert!(!sql.raw_sql.contains("WHERE 1=1"));
594        assert_eq!(sql.parameters, vec![serde_json::json!("active")]);
595    }
596
597    #[test]
598    fn test_where_clause_applied() {
599        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
600
601        let plan = WindowExecutionPlan {
602            table:        "tf_sales".to_string(),
603            select:       vec![SelectColumn {
604                expression: "revenue".to_string(),
605                alias:      "revenue".to_string(),
606            }],
607            windows:      vec![WindowFunction {
608                function:     WindowFunctionType::RowNumber,
609                alias:        "rank".to_string(),
610                partition_by: vec![],
611                order_by:     vec![],
612                frame:        None,
613            }],
614            where_clause: Some(WhereClause::Field {
615                path:     vec!["status".to_string()],
616                operator: WhereOperator::Eq,
617                value:    serde_json::json!("active"),
618            }),
619            order_by:     vec![],
620            limit:        None,
621            offset:       None,
622        };
623
624        let sql = generator.generate(&plan).unwrap();
625
626        // WHERE clause is rendered (not 1=1), value is a bind parameter.
627        assert!(sql.raw_sql.contains("WHERE"), "WHERE clause must appear in SQL");
628        assert!(!sql.raw_sql.contains("WHERE 1=1"));
629    }
630
631    #[test]
632    fn test_no_where_clause_omitted() {
633        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
634
635        let plan = WindowExecutionPlan {
636            table:        "tf_sales".to_string(),
637            select:       vec![],
638            windows:      vec![WindowFunction {
639                function:     WindowFunctionType::RowNumber,
640                alias:        "rank".to_string(),
641                partition_by: vec![],
642                order_by:     vec![],
643                frame:        None,
644            }],
645            where_clause: None,
646            order_by:     vec![],
647            limit:        None,
648            offset:       None,
649        };
650
651        let sql = generator.generate(&plan).unwrap();
652
653        // No WHERE clause in output
654        assert!(!sql.raw_sql.contains("WHERE"));
655    }
656}