1use crate::{
13 compiler::{
14 aggregation::OrderDirection,
15 window_functions::{
16 FrameBoundary, FrameExclusion, FrameType, WindowExecutionPlan, WindowFrame,
17 WindowFunction, WindowFunctionType,
18 },
19 },
20 db::types::DatabaseType,
21 error::{FraiseQLError, Result},
22};
23
24#[derive(Debug, Clone)]
26pub struct WindowSql {
27 pub complete_sql: String,
29
30 pub parameters: Vec<serde_json::Value>,
32}
33
34pub struct WindowSqlGenerator {
36 database_type: DatabaseType,
37}
38
39impl WindowSqlGenerator {
40 #[must_use]
42 pub const fn new(database_type: DatabaseType) -> Self {
43 Self { database_type }
44 }
45
46 pub fn generate(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
55 match self.database_type {
56 DatabaseType::PostgreSQL => self.generate_postgres(plan),
57 DatabaseType::MySQL => self.generate_mysql(plan),
58 DatabaseType::SQLite => self.generate_sqlite(plan),
59 DatabaseType::SQLServer => self.generate_sqlserver(plan),
60 }
61 }
62
63 fn generate_postgres(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
65 let mut sql = String::from("SELECT ");
66 let parameters = Vec::new();
67
68 for (i, col) in plan.select.iter().enumerate() {
70 if i > 0 {
71 sql.push_str(", ");
72 }
73 sql.push_str(&format!("{} AS {}", col.expression, col.alias));
74 }
75
76 for window in &plan.windows {
78 if !plan.select.is_empty() || sql.len() > "SELECT ".len() {
79 sql.push_str(", ");
80 }
81 sql.push_str(&self.generate_window_function(window)?);
82 }
83
84 sql.push_str(&format!(" FROM {}", plan.table));
86
87 if plan.where_clause.is_some() {
89 sql.push_str(" WHERE 1=1"); }
91
92 if !plan.order_by.is_empty() {
94 sql.push_str(" ORDER BY ");
95 for (i, order) in plan.order_by.iter().enumerate() {
96 if i > 0 {
97 sql.push_str(", ");
98 }
99 let dir = match order.direction {
100 OrderDirection::Asc => "ASC",
101 OrderDirection::Desc => "DESC",
102 };
103 sql.push_str(&format!("{} {}", order.field, dir));
104 }
105 }
106
107 if let Some(limit) = plan.limit {
109 sql.push_str(&format!(" LIMIT {limit}"));
110 }
111 if let Some(offset) = plan.offset {
112 sql.push_str(&format!(" OFFSET {offset}"));
113 }
114
115 Ok(WindowSql {
116 complete_sql: sql,
117 parameters,
118 })
119 }
120
121 fn generate_window_function(&self, window: &WindowFunction) -> Result<String> {
123 let func_sql = self.generate_function_call(&window.function)?;
124 let mut sql = format!("{func_sql} OVER (");
125
126 if !window.partition_by.is_empty() {
128 sql.push_str("PARTITION BY ");
129 sql.push_str(&window.partition_by.join(", "));
130 }
131
132 if !window.order_by.is_empty() {
134 if !window.partition_by.is_empty() {
135 sql.push(' ');
136 }
137 sql.push_str("ORDER BY ");
138 for (i, order) in window.order_by.iter().enumerate() {
139 if i > 0 {
140 sql.push_str(", ");
141 }
142 let dir = match order.direction {
143 OrderDirection::Asc => "ASC",
144 OrderDirection::Desc => "DESC",
145 };
146 sql.push_str(&format!("{} {}", order.field, dir));
147 }
148 }
149
150 if let Some(frame) = &window.frame {
152 if !window.partition_by.is_empty() || !window.order_by.is_empty() {
153 sql.push(' ');
154 }
155 sql.push_str(&self.generate_frame_clause(frame)?);
156 }
157
158 sql.push(')');
159 sql.push_str(&format!(" AS {}", window.alias));
160
161 Ok(sql)
162 }
163
164 fn generate_function_call(&self, function: &WindowFunctionType) -> Result<String> {
166 let sql = match function {
167 WindowFunctionType::RowNumber => "ROW_NUMBER()".to_string(),
168 WindowFunctionType::Rank => "RANK()".to_string(),
169 WindowFunctionType::DenseRank => "DENSE_RANK()".to_string(),
170 WindowFunctionType::Ntile { n } => format!("NTILE({n})"),
171 WindowFunctionType::PercentRank => "PERCENT_RANK()".to_string(),
172 WindowFunctionType::CumeDist => "CUME_DIST()".to_string(),
173
174 WindowFunctionType::Lag {
175 field,
176 offset,
177 default,
178 } => {
179 if let Some(default_val) = default {
180 format!("LAG({field}, {offset}, {default_val})")
181 } else {
182 format!("LAG({field}, {offset})")
183 }
184 },
185 WindowFunctionType::Lead {
186 field,
187 offset,
188 default,
189 } => {
190 if let Some(default_val) = default {
191 format!("LEAD({field}, {offset}, {default_val})")
192 } else {
193 format!("LEAD({field}, {offset})")
194 }
195 },
196 WindowFunctionType::FirstValue { field } => format!("FIRST_VALUE({field})"),
197 WindowFunctionType::LastValue { field } => format!("LAST_VALUE({field})"),
198 WindowFunctionType::NthValue { field, n } => format!("NTH_VALUE({field}, {n})"),
199
200 WindowFunctionType::Sum { field } => format!("SUM({field})"),
201 WindowFunctionType::Avg { field } => format!("AVG({field})"),
202 WindowFunctionType::Count { field: Some(field) } => format!("COUNT({field})"),
203 WindowFunctionType::Count { field: None } => "COUNT(*)".to_string(),
204 WindowFunctionType::Min { field } => format!("MIN({field})"),
205 WindowFunctionType::Max { field } => format!("MAX({field})"),
206 WindowFunctionType::Stddev { field } => {
207 match self.database_type {
209 DatabaseType::SQLServer => format!("STDEV({field})"),
210 _ => format!("STDDEV({field})"),
211 }
212 },
213 WindowFunctionType::Variance { field } => {
214 match self.database_type {
216 DatabaseType::SQLServer => format!("VAR({field})"),
217 _ => format!("VARIANCE({field})"),
218 }
219 },
220 };
221
222 Ok(sql)
223 }
224
225 fn generate_frame_clause(&self, frame: &WindowFrame) -> Result<String> {
227 let frame_type = match frame.frame_type {
228 FrameType::Rows => "ROWS",
229 FrameType::Range => "RANGE",
230 FrameType::Groups => {
231 if !matches!(self.database_type, DatabaseType::PostgreSQL) {
232 return Err(FraiseQLError::validation(
233 "GROUPS frame type only supported on PostgreSQL",
234 ));
235 }
236 "GROUPS"
237 },
238 };
239
240 let start = self.format_frame_boundary(&frame.start);
241 let end = self.format_frame_boundary(&frame.end);
242
243 let mut sql = format!("{frame_type} BETWEEN {start} AND {end}");
244
245 if let Some(exclusion) = &frame.exclusion {
247 if matches!(self.database_type, DatabaseType::PostgreSQL) {
248 let excl = match exclusion {
249 FrameExclusion::CurrentRow => "EXCLUDE CURRENT ROW",
250 FrameExclusion::Group => "EXCLUDE GROUP",
251 FrameExclusion::Ties => "EXCLUDE TIES",
252 FrameExclusion::NoOthers => "EXCLUDE NO OTHERS",
253 };
254 sql.push_str(&format!(" {excl}"));
255 }
256 }
257
258 Ok(sql)
259 }
260
261 #[must_use]
263 pub fn format_frame_boundary(&self, boundary: &FrameBoundary) -> String {
264 match boundary {
265 FrameBoundary::UnboundedPreceding => "UNBOUNDED PRECEDING".to_string(),
266 FrameBoundary::NPreceding { n } => format!("{n} PRECEDING"),
267 FrameBoundary::CurrentRow => "CURRENT ROW".to_string(),
268 FrameBoundary::NFollowing { n } => format!("{n} FOLLOWING"),
269 FrameBoundary::UnboundedFollowing => "UNBOUNDED FOLLOWING".to_string(),
270 }
271 }
272
273 fn generate_mysql(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
275 self.generate_postgres(plan)
279 }
280
281 fn generate_sqlite(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
283 self.generate_postgres(plan)
286 }
287
288 fn generate_sqlserver(&self, plan: &WindowExecutionPlan) -> Result<WindowSql> {
290 self.generate_postgres(plan)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298 use crate::compiler::{
299 aggregation::{OrderByClause, OrderDirection},
300 window_functions::*,
301 };
302
303 #[test]
304 fn test_generate_row_number() {
305 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
306
307 let plan = WindowExecutionPlan {
308 table: "tf_sales".to_string(),
309 select: vec![SelectColumn {
310 expression: "revenue".to_string(),
311 alias: "revenue".to_string(),
312 }],
313 windows: vec![WindowFunction {
314 function: WindowFunctionType::RowNumber,
315 alias: "rank".to_string(),
316 partition_by: vec!["data->>'category'".to_string()],
317 order_by: vec![OrderByClause {
318 field: "revenue".to_string(),
319 direction: OrderDirection::Desc,
320 }],
321 frame: None,
322 }],
323 where_clause: None,
324 order_by: vec![],
325 limit: None,
326 offset: None,
327 };
328
329 let sql = generator.generate(&plan).unwrap();
330
331 assert!(sql.complete_sql.contains("ROW_NUMBER()"));
332 assert!(sql.complete_sql.contains("PARTITION BY data->>'category'"));
333 assert!(sql.complete_sql.contains("ORDER BY revenue DESC"));
334 }
335
336 #[test]
337 fn test_generate_running_total() {
338 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
339
340 let plan = WindowExecutionPlan {
341 table: "tf_sales".to_string(),
342 select: vec![
343 SelectColumn {
344 expression: "occurred_at".to_string(),
345 alias: "date".to_string(),
346 },
347 SelectColumn {
348 expression: "revenue".to_string(),
349 alias: "revenue".to_string(),
350 },
351 ],
352 windows: vec![WindowFunction {
353 function: WindowFunctionType::Sum {
354 field: "revenue".to_string(),
355 },
356 alias: "running_total".to_string(),
357 partition_by: vec![],
358 order_by: vec![OrderByClause {
359 field: "occurred_at".to_string(),
360 direction: OrderDirection::Asc,
361 }],
362 frame: Some(WindowFrame {
363 frame_type: FrameType::Rows,
364 start: FrameBoundary::UnboundedPreceding,
365 end: FrameBoundary::CurrentRow,
366 exclusion: None,
367 }),
368 }],
369 where_clause: None,
370 order_by: vec![],
371 limit: None,
372 offset: None,
373 };
374
375 let sql = generator.generate(&plan).unwrap();
376
377 assert!(sql.complete_sql.contains("SUM(revenue) OVER"));
378 assert!(sql.complete_sql.contains("ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW"));
379 }
380
381 #[test]
382 fn test_generate_lag_lead() {
383 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
384
385 let plan = WindowExecutionPlan {
386 table: "tf_sales".to_string(),
387 select: vec![],
388 windows: vec![
389 WindowFunction {
390 function: WindowFunctionType::Lag {
391 field: "revenue".to_string(),
392 offset: 1,
393 default: Some(serde_json::json!(0)),
394 },
395 alias: "prev_revenue".to_string(),
396 partition_by: vec![],
397 order_by: vec![OrderByClause {
398 field: "occurred_at".to_string(),
399 direction: OrderDirection::Asc,
400 }],
401 frame: None,
402 },
403 WindowFunction {
404 function: WindowFunctionType::Lead {
405 field: "revenue".to_string(),
406 offset: 1,
407 default: None,
408 },
409 alias: "next_revenue".to_string(),
410 partition_by: vec![],
411 order_by: vec![OrderByClause {
412 field: "occurred_at".to_string(),
413 direction: OrderDirection::Asc,
414 }],
415 frame: None,
416 },
417 ],
418 where_clause: None,
419 order_by: vec![],
420 limit: None,
421 offset: None,
422 };
423
424 let sql = generator.generate(&plan).unwrap();
425
426 assert!(sql.complete_sql.contains("LAG(revenue, 1, 0)"));
427 assert!(sql.complete_sql.contains("LEAD(revenue, 1)"));
428 }
429
430 #[test]
431 fn test_frame_boundary_formatting() {
432 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
433
434 assert_eq!(
435 generator.format_frame_boundary(&FrameBoundary::UnboundedPreceding),
436 "UNBOUNDED PRECEDING"
437 );
438 assert_eq!(
439 generator.format_frame_boundary(&FrameBoundary::NPreceding { n: 5 }),
440 "5 PRECEDING"
441 );
442 assert_eq!(generator.format_frame_boundary(&FrameBoundary::CurrentRow), "CURRENT ROW");
443 assert_eq!(
444 generator.format_frame_boundary(&FrameBoundary::NFollowing { n: 3 }),
445 "3 FOLLOWING"
446 );
447 assert_eq!(
448 generator.format_frame_boundary(&FrameBoundary::UnboundedFollowing),
449 "UNBOUNDED FOLLOWING"
450 );
451 }
452
453 #[test]
454 fn test_moving_average() {
455 let generator = WindowSqlGenerator::new(DatabaseType::PostgreSQL);
456
457 let plan = WindowExecutionPlan {
458 table: "tf_sales".to_string(),
459 select: vec![],
460 windows: vec![WindowFunction {
461 function: WindowFunctionType::Avg {
462 field: "revenue".to_string(),
463 },
464 alias: "moving_avg_7d".to_string(),
465 partition_by: vec![],
466 order_by: vec![OrderByClause {
467 field: "occurred_at".to_string(),
468 direction: OrderDirection::Asc,
469 }],
470 frame: Some(WindowFrame {
471 frame_type: FrameType::Rows,
472 start: FrameBoundary::NPreceding { n: 6 },
473 end: FrameBoundary::CurrentRow,
474 exclusion: None,
475 }),
476 }],
477 where_clause: None,
478 order_by: vec![],
479 limit: None,
480 offset: None,
481 };
482
483 let sql = generator.generate(&plan).unwrap();
484
485 assert!(sql.complete_sql.contains("AVG(revenue) OVER"));
486 assert!(sql.complete_sql.contains("ROWS BETWEEN 6 PRECEDING AND CURRENT ROW"));
487 }
488
489 #[test]
490 fn test_sqlserver_stddev_variance() {
491 let generator = WindowSqlGenerator::new(DatabaseType::SQLServer);
492
493 let plan = WindowExecutionPlan {
494 table: "tf_sales".to_string(),
495 select: vec![],
496 windows: vec![
497 WindowFunction {
498 function: WindowFunctionType::Stddev {
499 field: "revenue".to_string(),
500 },
501 alias: "stddev".to_string(),
502 partition_by: vec![],
503 order_by: vec![],
504 frame: None,
505 },
506 WindowFunction {
507 function: WindowFunctionType::Variance {
508 field: "revenue".to_string(),
509 },
510 alias: "variance".to_string(),
511 partition_by: vec![],
512 order_by: vec![],
513 frame: None,
514 },
515 ],
516 where_clause: None,
517 order_by: vec![],
518 limit: None,
519 offset: None,
520 };
521
522 let sql = generator.generate(&plan).unwrap();
523
524 assert!(sql.complete_sql.contains("STDEV(revenue)"));
526 assert!(sql.complete_sql.contains("VAR(revenue)"));
527 }
528}