1use std::fmt::Write as _;
13
14use crate::{
15 compiler::{
16 aggregation::OrderDirection,
17 window_functions::{
18 FrameBoundary, FrameExclusion, FrameType, WindowExecutionPlan, WindowFrame,
19 WindowFunction, WindowFunctionType,
20 },
21 },
22 db::{GenericWhereGenerator, PostgresDialect, types::DatabaseType},
23 error::{FraiseQLError, Result},
24};
25
26#[derive(Debug, Clone)]
28pub struct WindowSql {
29 pub raw_sql: String,
34
35 pub parameters: Vec<serde_json::Value>,
38}
39
40pub struct WindowSqlGenerator {
42 database_type: DatabaseType,
43}
44
45impl WindowSqlGenerator {
46 #[must_use]
48 pub const fn new(database_type: DatabaseType) -> Self {
49 Self { database_type }
50 }
51
52 pub fn generate(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
61 match self.database_type {
62 DatabaseType::PostgreSQL => self.generate_postgres(plan),
63 DatabaseType::MySQL => self.generate_mysql(plan),
64 DatabaseType::SQLite => self.generate_sqlite(plan),
65 DatabaseType::SQLServer => self.generate_sqlserver(plan),
66 }
67 }
68
69 fn generate_postgres(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
71 let mut sql = String::from("SELECT ");
72 let mut parameters = Vec::new();
73
74 for (i, col) in plan.select.iter().enumerate() {
76 if i > 0 {
77 sql.push_str(", ");
78 }
79 let _ = write!(sql, "{} AS {}", col.expression, col.alias);
80 }
81
82 for window in &plan.windows {
84 if !plan.select.is_empty() || sql.len() > "SELECT ".len() {
85 sql.push_str(", ");
86 }
87 sql.push_str(&self.generate_window_function(window)?);
88 }
89
90 let _ = write!(sql, " FROM {}", plan.table);
92
93 if let Some(clause) = &plan.where_clause {
96 let gen = GenericWhereGenerator::new(PostgresDialect);
97 let (where_sql, where_params) = gen.generate(clause)?;
98 sql.push_str(" WHERE ");
99 sql.push_str(&where_sql);
100 parameters.extend(where_params);
101 }
102
103 if !plan.order_by.is_empty() {
105 sql.push_str(" ORDER BY ");
106 for (i, order) in plan.order_by.iter().enumerate() {
107 if i > 0 {
108 sql.push_str(", ");
109 }
110 #[allow(clippy::match_same_arms)]
111 let dir = match order.direction {
113 OrderDirection::Asc => "ASC",
114 OrderDirection::Desc => "DESC",
115 _ => "ASC",
116 };
117 let _ = write!(sql, "{} {}", order.field, dir);
121 }
122 }
123
124 if let Some(limit) = plan.limit {
126 let _ = write!(sql, " LIMIT {limit}");
127 }
128 if let Some(offset) = plan.offset {
129 let _ = write!(sql, " OFFSET {offset}");
130 }
131
132 Ok(WindowSql {
133 raw_sql: sql,
134 parameters,
135 })
136 }
137
138 fn generate_window_function(&self, window: &WindowFunction) -> Result<String> {
140 let func_sql = self.generate_function_call(&window.function)?;
141 let mut sql = format!("{func_sql} OVER (");
142
143 if !window.partition_by.is_empty() {
146 sql.push_str("PARTITION BY ");
147 sql.push_str(&window.partition_by.join(", "));
148 }
149
150 if !window.order_by.is_empty() {
152 if !window.partition_by.is_empty() {
153 sql.push(' ');
154 }
155 sql.push_str("ORDER BY ");
156 for (i, order) in window.order_by.iter().enumerate() {
157 if i > 0 {
158 sql.push_str(", ");
159 }
160 #[allow(clippy::match_same_arms)]
161 let dir = match order.direction {
163 OrderDirection::Asc => "ASC",
164 OrderDirection::Desc => "DESC",
165 _ => "ASC",
166 };
167 let _ = write!(sql, "{} {}", order.field, dir);
168 }
169 }
170
171 if let Some(frame) = &window.frame {
173 if !window.partition_by.is_empty() || !window.order_by.is_empty() {
174 sql.push(' ');
175 }
176 sql.push_str(&self.generate_frame_clause(frame)?);
177 }
178
179 sql.push(')');
180 let _ = write!(sql, " AS {}", window.alias);
181
182 Ok(sql)
183 }
184
185 fn generate_function_call(&self, function: &WindowFunctionType) -> Result<String> {
187 let sql = match function {
188 WindowFunctionType::RowNumber => "ROW_NUMBER()".to_string(),
189 WindowFunctionType::Rank => "RANK()".to_string(),
190 WindowFunctionType::DenseRank => "DENSE_RANK()".to_string(),
191 WindowFunctionType::Ntile { n } => format!("NTILE({n})"),
192 WindowFunctionType::PercentRank => "PERCENT_RANK()".to_string(),
193 WindowFunctionType::CumeDist => "CUME_DIST()".to_string(),
194
195 WindowFunctionType::Lag {
196 field,
197 offset,
198 default,
199 } => {
200 if let Some(default_val) = default {
201 format!("LAG({field}, {offset}, {default_val})")
202 } else {
203 format!("LAG({field}, {offset})")
204 }
205 },
206 WindowFunctionType::Lead {
207 field,
208 offset,
209 default,
210 } => {
211 if let Some(default_val) = default {
212 format!("LEAD({field}, {offset}, {default_val})")
213 } else {
214 format!("LEAD({field}, {offset})")
215 }
216 },
217 WindowFunctionType::FirstValue { field } => format!("FIRST_VALUE({field})"),
218 WindowFunctionType::LastValue { field } => format!("LAST_VALUE({field})"),
219 WindowFunctionType::NthValue { field, n } => format!("NTH_VALUE({field}, {n})"),
220
221 WindowFunctionType::Sum { field } => format!("SUM({field})"),
222 WindowFunctionType::Avg { field } => format!("AVG({field})"),
223 WindowFunctionType::Count { field: Some(field) } => format!("COUNT({field})"),
224 WindowFunctionType::Count { field: None } => "COUNT(*)".to_string(),
225 WindowFunctionType::Min { field } => format!("MIN({field})"),
226 WindowFunctionType::Max { field } => format!("MAX({field})"),
227 WindowFunctionType::Stddev { field } => {
228 match self.database_type {
230 DatabaseType::SQLServer => format!("STDEV({field})"),
231 _ => format!("STDDEV({field})"),
232 }
233 },
234 WindowFunctionType::Variance { field } => {
235 match self.database_type {
237 DatabaseType::SQLServer => format!("VAR({field})"),
238 _ => format!("VARIANCE({field})"),
239 }
240 },
241 };
242
243 Ok(sql)
244 }
245
246 fn generate_frame_clause(&self, frame: &WindowFrame) -> Result<String> {
248 let frame_type = match frame.frame_type {
249 FrameType::Rows => "ROWS",
250 FrameType::Range => "RANGE",
251 FrameType::Groups => {
252 if !matches!(self.database_type, DatabaseType::PostgreSQL) {
253 return Err(FraiseQLError::validation(
254 "GROUPS frame type only supported on PostgreSQL",
255 ));
256 }
257 "GROUPS"
258 },
259 };
260
261 let start = self.format_frame_boundary(&frame.start);
262 let end = self.format_frame_boundary(&frame.end);
263
264 let mut sql = format!("{frame_type} BETWEEN {start} AND {end}");
265
266 if let Some(exclusion) = &frame.exclusion {
268 if matches!(self.database_type, DatabaseType::PostgreSQL) {
269 let excl = match exclusion {
270 FrameExclusion::CurrentRow => "EXCLUDE CURRENT ROW",
271 FrameExclusion::Group => "EXCLUDE GROUP",
272 FrameExclusion::Ties => "EXCLUDE TIES",
273 FrameExclusion::NoOthers => "EXCLUDE NO OTHERS",
274 };
275 let _ = write!(sql, " {excl}");
276 }
277 }
278
279 Ok(sql)
280 }
281
282 #[must_use]
284 pub fn format_frame_boundary(&self, boundary: &FrameBoundary) -> String {
285 match boundary {
286 FrameBoundary::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
287 FrameBoundary::NPreceding { n } => format!("{n} PRECEDING"),
288 FrameBoundary::CurrentRow => "CURRENT ROW".to_string(),
289 FrameBoundary::NFollowing { n } => format!("{n} FOLLOWING"),
290 FrameBoundary::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
291 }
292 }
293
294 fn generate_mysql(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
296 self.generate_postgres(plan)
300 }
301
302 fn generate_sqlite(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
304 self.generate_postgres(plan)
307 }
308
309 fn generate_sqlserver(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
311 self.generate_postgres(plan)
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 #![allow(clippy::unwrap_used)] use super::*;
321 use crate::{
322 compiler::{
323 aggregation::{OrderByClause, OrderDirection},
324 window_functions::*,
325 },
326 db::{WhereClause, WhereOperator},
327 };
328
329 #[test]
330 fn test_generate_row_number() {
331 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
332
333 let plan = WindowExecutionPlan {
334 table: "tf_sales".to_string(),
335 select: vec![SelectColumn {
336 expression: "revenue".to_string(),
337 alias: "revenue".to_string(),
338 }],
339 windows: vec![WindowFunction {
340 function: WindowFunctionType::RowNumber,
341 alias: "rank".to_string(),
342 partition_by: vec!["data->>'category'".to_string()],
343 order_by: vec![OrderByClause {
344 field: "revenue".to_string(),
345 direction: OrderDirection::Desc,
346 }],
347 frame: None,
348 }],
349 where_clause: None,
350 order_by: vec![],
351 limit: None,
352 offset: None,
353 };
354
355 let sql = generator.generate(&plan).unwrap();
356
357 assert!(sql.raw_sql.contains("ROW_NUMBER()"));
358 assert!(sql.raw_sql.contains("PARTITION BY data->>'category'"));
359 assert!(sql.raw_sql.contains("ORDER BY revenue DESC"));
360 }
361
362 #[test]
363 fn test_generate_running_total() {
364 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
365
366 let plan = WindowExecutionPlan {
367 table: "tf_sales".to_string(),
368 select: vec![
369 SelectColumn {
370 expression: "occurred_at".to_string(),
371 alias: "date".to_string(),
372 },
373 SelectColumn {
374 expression: "revenue".to_string(),
375 alias: "revenue".to_string(),
376 },
377 ],
378 windows: vec![WindowFunction {
379 function: WindowFunctionType::Sum {
380 field: "revenue".to_string(),
381 },
382 alias: "running_total".to_string(),
383 partition_by: vec![],
384 order_by: vec![OrderByClause {
385 field: "occurred_at".to_string(),
386 direction: OrderDirection::Asc,
387 }],
388 frame: Some(WindowFrame {
389 frame_type: FrameType::Rows,
390 start: FrameBoundary::UnboundedPreceding,
391 end: FrameBoundary::CurrentRow,
392 exclusion: None,
393 }),
394 }],
395 where_clause: None,
396 order_by: vec![],
397 limit: None,
398 offset: None,
399 };
400
401 let sql = generator.generate(&plan).unwrap();
402
403 assert!(sql.raw_sql.contains("SUM(revenue) OVER"));
404 assert!(sql.raw_sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
405 }
406
407 #[test]
408 fn test_generate_lag_lead() {
409 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
410
411 let plan = WindowExecutionPlan {
412 table: "tf_sales".to_string(),
413 select: vec![],
414 windows: vec![
415 WindowFunction {
416 function: WindowFunctionType::Lag {
417 field: "revenue".to_string(),
418 offset: 1,
419 default: Some(serde_json::json!(0)),
420 },
421 alias: "prev_revenue".to_string(),
422 partition_by: vec![],
423 order_by: vec![OrderByClause {
424 field: "occurred_at".to_string(),
425 direction: OrderDirection::Asc,
426 }],
427 frame: None,
428 },
429 WindowFunction {
430 function: WindowFunctionType::Lead {
431 field: "revenue".to_string(),
432 offset: 1,
433 default: None,
434 },
435 alias: "next_revenue".to_string(),
436 partition_by: vec![],
437 order_by: vec![OrderByClause {
438 field: "occurred_at".to_string(),
439 direction: OrderDirection::Asc,
440 }],
441 frame: None,
442 },
443 ],
444 where_clause: None,
445 order_by: vec![],
446 limit: None,
447 offset: None,
448 };
449
450 let sql = generator.generate(&plan).unwrap();
451
452 assert!(sql.raw_sql.contains("LAG(revenue, 1, 0)"));
453 assert!(sql.raw_sql.contains("LEAD(revenue, 1)"));
454 }
455
456 #[test]
457 fn test_frame_boundary_formatting() {
458 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
459
460 assert_eq!(
461 generator.format_frame_boundary(&FrameBoundary::UnboundedPreceding),
462 "UNBOUNDED PRECEDING"
463 );
464 assert_eq!(
465 generator.format_frame_boundary(&FrameBoundary::NPreceding { n: 5 }),
466 "5 PRECEDING"
467 );
468 assert_eq!(generator.format_frame_boundary(&FrameBoundary::CurrentRow), "CURRENT ROW");
469 assert_eq!(
470 generator.format_frame_boundary(&FrameBoundary::NFollowing { n: 3 }),
471 "3 FOLLOWING"
472 );
473 assert_eq!(
474 generator.format_frame_boundary(&FrameBoundary::UnboundedFollowing),
475 "UNBOUNDED FOLLOWING"
476 );
477 }
478
479 #[test]
480 fn test_moving_average() {
481 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
482
483 let plan = WindowExecutionPlan {
484 table: "tf_sales".to_string(),
485 select: vec![],
486 windows: vec![WindowFunction {
487 function: WindowFunctionType::Avg {
488 field: "revenue".to_string(),
489 },
490 alias: "moving_avg_7d".to_string(),
491 partition_by: vec![],
492 order_by: vec![OrderByClause {
493 field: "occurred_at".to_string(),
494 direction: OrderDirection::Asc,
495 }],
496 frame: Some(WindowFrame {
497 frame_type: FrameType::Rows,
498 start: FrameBoundary::NPreceding { n: 6 },
499 end: FrameBoundary::CurrentRow,
500 exclusion: None,
501 }),
502 }],
503 where_clause: None,
504 order_by: vec![],
505 limit: None,
506 offset: None,
507 };
508
509 let sql = generator.generate(&plan).unwrap();
510
511 assert!(sql.raw_sql.contains("AVG(revenue) OVER"));
512 assert!(sql.raw_sql.contains("ROWS BETWEEN 6 PRECEDING AND CURRENT ROW"));
513 }
514
515 #[test]
516 fn test_sqlserver_stddev_variance() {
517 let generator = WindowSqlGenerator::new(DatabaseType::SQLServer);
518
519 let plan = WindowExecutionPlan {
520 table: "tf_sales".to_string(),
521 select: vec![],
522 windows: vec![
523 WindowFunction {
524 function: WindowFunctionType::Stddev {
525 field: "revenue".to_string(),
526 },
527 alias: "stddev".to_string(),
528 partition_by: vec![],
529 order_by: vec![],
530 frame: None,
531 },
532 WindowFunction {
533 function: WindowFunctionType::Variance {
534 field: "revenue".to_string(),
535 },
536 alias: "variance".to_string(),
537 partition_by: vec![],
538 order_by: vec![],
539 frame: None,
540 },
541 ],
542 where_clause: None,
543 order_by: vec![],
544 limit: None,
545 offset: None,
546 };
547
548 let sql = generator.generate(&plan).unwrap();
549
550 assert!(sql.raw_sql.contains("STDEV(revenue)"));
552 assert!(sql.raw_sql.contains("VAR(revenue)"));
553 }
554
555 #[test]
556 fn test_where_clause_uses_bind_parameters() {
557 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
561
562 let plan = WindowExecutionPlan {
563 table: "tf_sales".to_string(),
564 select: vec![SelectColumn {
565 expression: "revenue".to_string(),
566 alias: "revenue".to_string(),
567 }],
568 windows: vec![WindowFunction {
569 function: WindowFunctionType::RowNumber,
570 alias: "rank".to_string(),
571 partition_by: vec![],
572 order_by: vec![],
573 frame: None,
574 }],
575 where_clause: Some(WhereClause::Field {
576 path: vec!["status".to_string()],
577 operator: WhereOperator::Eq,
578 value: serde_json::json!("active"),
579 }),
580 order_by: vec![],
581 limit: None,
582 offset: None,
583 };
584
585 let sql = generator.generate(&plan).unwrap();
586
587 assert!(
589 sql.raw_sql.contains("WHERE data->>'status' = $1"),
590 "expected bind parameter $1, got: {}",
591 sql.raw_sql
592 );
593 assert!(!sql.raw_sql.contains("WHERE 1=1"));
594 assert_eq!(sql.parameters, vec![serde_json::json!("active")]);
595 }
596
597 #[test]
598 fn test_where_clause_applied() {
599 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
600
601 let plan = WindowExecutionPlan {
602 table: "tf_sales".to_string(),
603 select: vec![SelectColumn {
604 expression: "revenue".to_string(),
605 alias: "revenue".to_string(),
606 }],
607 windows: vec![WindowFunction {
608 function: WindowFunctionType::RowNumber,
609 alias: "rank".to_string(),
610 partition_by: vec![],
611 order_by: vec![],
612 frame: None,
613 }],
614 where_clause: Some(WhereClause::Field {
615 path: vec!["status".to_string()],
616 operator: WhereOperator::Eq,
617 value: serde_json::json!("active"),
618 }),
619 order_by: vec![],
620 limit: None,
621 offset: None,
622 };
623
624 let sql = generator.generate(&plan).unwrap();
625
626 assert!(sql.raw_sql.contains("WHERE"), "WHERE clause must appear in SQL");
628 assert!(!sql.raw_sql.contains("WHERE 1=1"));
629 }
630
631 #[test]
632 fn test_no_where_clause_omitted() {
633 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
634
635 let plan = WindowExecutionPlan {
636 table: "tf_sales".to_string(),
637 select: vec![],
638 windows: vec![WindowFunction {
639 function: WindowFunctionType::RowNumber,
640 alias: "rank".to_string(),
641 partition_by: vec![],
642 order_by: vec![],
643 frame: None,
644 }],
645 where_clause: None,
646 order_by: vec![],
647 limit: None,
648 offset: None,
649 };
650
651 let sql = generator.generate(&plan).unwrap();
652
653 assert!(!sql.raw_sql.contains("WHERE"));
655 }
656}