Skip to main content

fraiseql_core/runtime/
window.rs

1//! Window Function SQL Generation
2//!
3//! Generates database-specific SQL for window functions.
4//!
5//! # Supported Databases
6//!
7//! - **PostgreSQL**: Full support (all functions + GROUPS frames + frame exclusion)
8//! - **MySQL 8.0+**: Full support (no GROUPS, no frame exclusion)
9//! - **SQLite 3.25+**: Basic support (no GROUPS, no PERCENT_RANK/CUME_DIST)
10//! - **SQL Server**: Full support (STDEV/VAR naming difference)
11
12use crate::{
13    compiler::{
14        aggregation::OrderDirection,
15        window_functions::{
16            FrameBoundary, FrameExclusion, FrameType, WindowExecutionPlan, WindowFrame,
17            WindowFunction, WindowFunctionType,
18        },
19    },
20    db::types::DatabaseType,
21    error::{FraiseQLError, Result},
22};
23
24/// Generated SQL for window function query
25#[derive(Debug, Clone)]
26pub struct WindowSql {
27    /// Complete SQL query
28    pub complete_sql: String,
29
30    /// Parameterized values (for WHERE clause)
31    pub parameters: Vec<serde_json::Value>,
32}
33
34/// Window function SQL generator
35pub struct WindowSqlGenerator {
36    database_type: DatabaseType,
37}
38
39impl WindowSqlGenerator {
40    /// Create new generator for database type
41    #[must_use]
42    pub const fn new(database_type: DatabaseType) -> Self {
43        Self { database_type }
44    }
45
46    /// Generate SQL from window execution plan
47    ///
48    /// # Errors
49    ///
50    /// Returns error if:
51    /// - Unsupported function for database
52    /// - Invalid frame specification
53    /// - WHERE clause generation fails
54    pub fn generate(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
55        match self.database_type {
56            DatabaseType::PostgreSQL => self.generate_postgres(plan),
57            DatabaseType::MySQL => self.generate_mysql(plan),
58            DatabaseType::SQLite => self.generate_sqlite(plan),
59            DatabaseType::SQLServer => self.generate_sqlserver(plan),
60        }
61    }
62
63    /// Generate PostgreSQL window function SQL
64    fn generate_postgres(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
65        let mut sql = String::from("SELECT ");
66        let parameters = Vec::new();
67
68        // Add regular SELECT columns
69        for (i, col) in plan.select.iter().enumerate() {
70            if i > 0 {
71                sql.push_str(", ");
72            }
73            sql.push_str(&format!("{} AS {}", col.expression, col.alias));
74        }
75
76        // Add window functions
77        for window in &plan.windows {
78            if !plan.select.is_empty() || sql.len() > "SELECT ".len() {
79                sql.push_str(", ");
80            }
81            sql.push_str(&self.generate_window_function(window)?);
82        }
83
84        // FROM clause
85        sql.push_str(&format!(" FROM {}", plan.table));
86
87        // WHERE clause (if any)
88        if plan.where_clause.is_some() {
89            sql.push_str(" WHERE 1=1"); // Placeholder
90        }
91
92        // ORDER BY clause
93        if !plan.order_by.is_empty() {
94            sql.push_str(" ORDER BY ");
95            for (i, order) in plan.order_by.iter().enumerate() {
96                if i > 0 {
97                    sql.push_str(", ");
98                }
99                let dir = match order.direction {
100                    OrderDirection::Asc => "ASC",
101                    OrderDirection::Desc => "DESC",
102                };
103                sql.push_str(&format!("{} {}", order.field, dir));
104            }
105        }
106
107        // LIMIT / OFFSET
108        if let Some(limit) = plan.limit {
109            sql.push_str(&format!(" LIMIT {limit}"));
110        }
111        if let Some(offset) = plan.offset {
112            sql.push_str(&format!(" OFFSET {offset}"));
113        }
114
115        Ok(WindowSql {
116            complete_sql: sql,
117            parameters,
118        })
119    }
120
121    /// Generate window function expression
122    fn generate_window_function(&self, window: &WindowFunction) -> Result<String> {
123        let func_sql = self.generate_function_call(&window.function)?;
124        let mut sql = format!("{func_sql} OVER (");
125
126        // PARTITION BY
127        if !window.partition_by.is_empty() {
128            sql.push_str("PARTITION BY ");
129            sql.push_str(&window.partition_by.join(", "));
130        }
131
132        // ORDER BY
133        if !window.order_by.is_empty() {
134            if !window.partition_by.is_empty() {
135                sql.push(' ');
136            }
137            sql.push_str("ORDER BY ");
138            for (i, order) in window.order_by.iter().enumerate() {
139                if i > 0 {
140                    sql.push_str(", ");
141                }
142                let dir = match order.direction {
143                    OrderDirection::Asc => "ASC",
144                    OrderDirection::Desc => "DESC",
145                };
146                sql.push_str(&format!("{} {}", order.field, dir));
147            }
148        }
149
150        // Frame clause
151        if let Some(frame) = &window.frame {
152            if !window.partition_by.is_empty() || !window.order_by.is_empty() {
153                sql.push(' ');
154            }
155            sql.push_str(&self.generate_frame_clause(frame)?);
156        }
157
158        sql.push(')');
159        sql.push_str(&format!(" AS {}", window.alias));
160
161        Ok(sql)
162    }
163
164    /// Generate function call SQL
165    fn generate_function_call(&self, function: &WindowFunctionType) -> Result<String> {
166        let sql = match function {
167            WindowFunctionType::RowNumber => "ROW_NUMBER()".to_string(),
168            WindowFunctionType::Rank => "RANK()".to_string(),
169            WindowFunctionType::DenseRank => "DENSE_RANK()".to_string(),
170            WindowFunctionType::Ntile { n } => format!("NTILE({n})"),
171            WindowFunctionType::PercentRank => "PERCENT_RANK()".to_string(),
172            WindowFunctionType::CumeDist => "CUME_DIST()".to_string(),
173
174            WindowFunctionType::Lag {
175                field,
176                offset,
177                default,
178            } => {
179                if let Some(default_val) = default {
180                    format!("LAG({field}, {offset}, {default_val})")
181                } else {
182                    format!("LAG({field}, {offset})")
183                }
184            },
185            WindowFunctionType::Lead {
186                field,
187                offset,
188                default,
189            } => {
190                if let Some(default_val) = default {
191                    format!("LEAD({field}, {offset}, {default_val})")
192                } else {
193                    format!("LEAD({field}, {offset})")
194                }
195            },
196            WindowFunctionType::FirstValue { field } => format!("FIRST_VALUE({field})"),
197            WindowFunctionType::LastValue { field } => format!("LAST_VALUE({field})"),
198            WindowFunctionType::NthValue { field, n } => format!("NTH_VALUE({field}, {n})"),
199
200            WindowFunctionType::Sum { field } => format!("SUM({field})"),
201            WindowFunctionType::Avg { field } => format!("AVG({field})"),
202            WindowFunctionType::Count { field: Some(field) } => format!("COUNT({field})"),
203            WindowFunctionType::Count { field: None } => "COUNT(*)".to_string(),
204            WindowFunctionType::Min { field } => format!("MIN({field})"),
205            WindowFunctionType::Max { field } => format!("MAX({field})"),
206            WindowFunctionType::Stddev { field } => {
207                // PostgreSQL/MySQL use STDDEV, SQL Server uses STDEV
208                match self.database_type {
209                    DatabaseType::SQLServer => format!("STDEV({field})"),
210                    _ => format!("STDDEV({field})"),
211                }
212            },
213            WindowFunctionType::Variance { field } => {
214                // PostgreSQL/MySQL use VARIANCE, SQL Server uses VAR
215                match self.database_type {
216                    DatabaseType::SQLServer => format!("VAR({field})"),
217                    _ => format!("VARIANCE({field})"),
218                }
219            },
220        };
221
222        Ok(sql)
223    }
224
225    /// Generate window frame clause
226    fn generate_frame_clause(&self, frame: &WindowFrame) -> Result<String> {
227        let frame_type = match frame.frame_type {
228            FrameType::Rows => "ROWS",
229            FrameType::Range => "RANGE",
230            FrameType::Groups => {
231                if !matches!(self.database_type, DatabaseType::PostgreSQL) {
232                    return Err(FraiseQLError::validation(
233                        "GROUPS frame type only supported on PostgreSQL",
234                    ));
235                }
236                "GROUPS"
237            },
238        };
239
240        let start = self.format_frame_boundary(&frame.start);
241        let end = self.format_frame_boundary(&frame.end);
242
243        let mut sql = format!("{frame_type} BETWEEN {start} AND {end}");
244
245        // Frame exclusion (PostgreSQL only)
246        if let Some(exclusion) = &frame.exclusion {
247            if matches!(self.database_type, DatabaseType::PostgreSQL) {
248                let excl = match exclusion {
249                    FrameExclusion::CurrentRow => "EXCLUDE CURRENT ROW",
250                    FrameExclusion::Group => "EXCLUDE GROUP",
251                    FrameExclusion::Ties => "EXCLUDE TIES",
252                    FrameExclusion::NoOthers => "EXCLUDE NO OTHERS",
253                };
254                sql.push_str(&format!(" {excl}"));
255            }
256        }
257
258        Ok(sql)
259    }
260
261    /// Format frame boundary
262    #[must_use]
263    pub fn format_frame_boundary(&self, boundary: &FrameBoundary) -> String {
264        match boundary {
265            FrameBoundary::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
266            FrameBoundary::NPreceding { n } => format!("{n} PRECEDING"),
267            FrameBoundary::CurrentRow => "CURRENT ROW".to_string(),
268            FrameBoundary::NFollowing { n } => format!("{n} FOLLOWING"),
269            FrameBoundary::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
270        }
271    }
272
273    /// Generate MySQL window function SQL
274    fn generate_mysql(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
275        // MySQL 8.0+ supports window functions similar to PostgreSQL
276        // Main differences handled in generate_function_call (no STDEV/VAR differences for window
277        // functions)
278        self.generate_postgres(plan)
279    }
280
281    /// Generate SQLite window function SQL
282    fn generate_sqlite(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
283        // SQLite 3.25+ supports window functions
284        // Similar to PostgreSQL but no PERCENT_RANK, CUME_DIST validation done in planner
285        self.generate_postgres(plan)
286    }
287
288    /// Generate SQL Server window function SQL
289    fn generate_sqlserver(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
290        // SQL Server supports window functions with minor differences (STDEV/VAR naming)
291        self.generate_postgres(plan)
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298    use crate::compiler::{
299        aggregation::{OrderByClause, OrderDirection},
300        window_functions::*,
301    };
302
303    #[test]
304    fn test_generate_row_number() {
305        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
306
307        let plan = WindowExecutionPlan {
308            table:        "tf_sales".to_string(),
309            select:       vec![SelectColumn {
310                expression: "revenue".to_string(),
311                alias:      "revenue".to_string(),
312            }],
313            windows:      vec![WindowFunction {
314                function:     WindowFunctionType::RowNumber,
315                alias:        "rank".to_string(),
316                partition_by: vec!["data->>'category'".to_string()],
317                order_by:     vec![OrderByClause {
318                    field:     "revenue".to_string(),
319                    direction: OrderDirection::Desc,
320                }],
321                frame:        None,
322            }],
323            where_clause: None,
324            order_by:     vec![],
325            limit:        None,
326            offset:       None,
327        };
328
329        let sql = generator.generate(&plan).unwrap();
330
331        assert!(sql.complete_sql.contains("ROW_NUMBER()"));
332        assert!(sql.complete_sql.contains("PARTITION BY data->>'category'"));
333        assert!(sql.complete_sql.contains("ORDER BY revenue DESC"));
334    }
335
336    #[test]
337    fn test_generate_running_total() {
338        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
339
340        let plan = WindowExecutionPlan {
341            table:        "tf_sales".to_string(),
342            select:       vec![
343                SelectColumn {
344                    expression: "occurred_at".to_string(),
345                    alias:      "date".to_string(),
346                },
347                SelectColumn {
348                    expression: "revenue".to_string(),
349                    alias:      "revenue".to_string(),
350                },
351            ],
352            windows:      vec![WindowFunction {
353                function:     WindowFunctionType::Sum {
354                    field: "revenue".to_string(),
355                },
356                alias:        "running_total".to_string(),
357                partition_by: vec![],
358                order_by:     vec![OrderByClause {
359                    field:     "occurred_at".to_string(),
360                    direction: OrderDirection::Asc,
361                }],
362                frame:        Some(WindowFrame {
363                    frame_type: FrameType::Rows,
364                    start:      FrameBoundary::UnboundedPreceding,
365                    end:        FrameBoundary::CurrentRow,
366                    exclusion:  None,
367                }),
368            }],
369            where_clause: None,
370            order_by:     vec![],
371            limit:        None,
372            offset:       None,
373        };
374
375        let sql = generator.generate(&plan).unwrap();
376
377        assert!(sql.complete_sql.contains("SUM(revenue) OVER"));
378        assert!(sql.complete_sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
379    }
380
381    #[test]
382    fn test_generate_lag_lead() {
383        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
384
385        let plan = WindowExecutionPlan {
386            table:        "tf_sales".to_string(),
387            select:       vec![],
388            windows:      vec![
389                WindowFunction {
390                    function:     WindowFunctionType::Lag {
391                        field:   "revenue".to_string(),
392                        offset:  1,
393                        default: Some(serde_json::json!(0)),
394                    },
395                    alias:        "prev_revenue".to_string(),
396                    partition_by: vec![],
397                    order_by:     vec![OrderByClause {
398                        field:     "occurred_at".to_string(),
399                        direction: OrderDirection::Asc,
400                    }],
401                    frame:        None,
402                },
403                WindowFunction {
404                    function:     WindowFunctionType::Lead {
405                        field:   "revenue".to_string(),
406                        offset:  1,
407                        default: None,
408                    },
409                    alias:        "next_revenue".to_string(),
410                    partition_by: vec![],
411                    order_by:     vec![OrderByClause {
412                        field:     "occurred_at".to_string(),
413                        direction: OrderDirection::Asc,
414                    }],
415                    frame:        None,
416                },
417            ],
418            where_clause: None,
419            order_by:     vec![],
420            limit:        None,
421            offset:       None,
422        };
423
424        let sql = generator.generate(&plan).unwrap();
425
426        assert!(sql.complete_sql.contains("LAG(revenue, 1, 0)"));
427        assert!(sql.complete_sql.contains("LEAD(revenue, 1)"));
428    }
429
430    #[test]
431    fn test_frame_boundary_formatting() {
432        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
433
434        assert_eq!(
435            generator.format_frame_boundary(&FrameBoundary::UnboundedPreceding),
436            "UNBOUNDED PRECEDING"
437        );
438        assert_eq!(
439            generator.format_frame_boundary(&FrameBoundary::NPreceding { n: 5 }),
440            "5 PRECEDING"
441        );
442        assert_eq!(generator.format_frame_boundary(&FrameBoundary::CurrentRow), "CURRENT ROW");
443        assert_eq!(
444            generator.format_frame_boundary(&FrameBoundary::NFollowing { n: 3 }),
445            "3 FOLLOWING"
446        );
447        assert_eq!(
448            generator.format_frame_boundary(&FrameBoundary::UnboundedFollowing),
449            "UNBOUNDED FOLLOWING"
450        );
451    }
452
453    #[test]
454    fn test_moving_average() {
455        let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
456
457        let plan = WindowExecutionPlan {
458            table:        "tf_sales".to_string(),
459            select:       vec![],
460            windows:      vec![WindowFunction {
461                function:     WindowFunctionType::Avg {
462                    field: "revenue".to_string(),
463                },
464                alias:        "moving_avg_7d".to_string(),
465                partition_by: vec![],
466                order_by:     vec![OrderByClause {
467                    field:     "occurred_at".to_string(),
468                    direction: OrderDirection::Asc,
469                }],
470                frame:        Some(WindowFrame {
471                    frame_type: FrameType::Rows,
472                    start:      FrameBoundary::NPreceding { n: 6 },
473                    end:        FrameBoundary::CurrentRow,
474                    exclusion:  None,
475                }),
476            }],
477            where_clause: None,
478            order_by:     vec![],
479            limit:        None,
480            offset:       None,
481        };
482
483        let sql = generator.generate(&plan).unwrap();
484
485        assert!(sql.complete_sql.contains("AVG(revenue) OVER"));
486        assert!(sql.complete_sql.contains("ROWS BETWEEN 6 PRECEDING AND CURRENT ROW"));
487    }
488
489    #[test]
490    fn test_sqlserver_stddev_variance() {
491        let generator = WindowSqlGenerator::new(DatabaseType::SQLServer);
492
493        let plan = WindowExecutionPlan {
494            table:        "tf_sales".to_string(),
495            select:       vec![],
496            windows:      vec![
497                WindowFunction {
498                    function:     WindowFunctionType::Stddev {
499                        field: "revenue".to_string(),
500                    },
501                    alias:        "stddev".to_string(),
502                    partition_by: vec![],
503                    order_by:     vec![],
504                    frame:        None,
505                },
506                WindowFunction {
507                    function:     WindowFunctionType::Variance {
508                        field: "revenue".to_string(),
509                    },
510                    alias:        "variance".to_string(),
511                    partition_by: vec![],
512                    order_by:     vec![],
513                    frame:        None,
514                },
515            ],
516            where_clause: None,
517            order_by:     vec![],
518            limit:        None,
519            offset:       None,
520        };
521
522        let sql = generator.generate(&plan).unwrap();
523
524        // SQL Server uses STDEV/VAR instead of STDDEV/VARIANCE
525        assert!(sql.complete_sql.contains("STDEV(revenue)"));
526        assert!(sql.complete_sql.contains("VAR(revenue)"));
527    }
528}