1use sqlparser::ast::{
9 Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, Query, Select, SelectItem,
10 SetExpr, Statement,
11};
12
13use super::{ParseError, WindowFunction};
14
15pub struct WindowRewriter;
17
18impl WindowRewriter {
19 pub fn rewrite_statement(stmt: &mut Statement) -> Result<(), ParseError> {
41 if let Statement::Query(query) = stmt {
42 Self::rewrite_query(query)?;
43 }
44 Ok(())
45 }
46
47 fn rewrite_query(query: &mut Query) -> Result<(), ParseError> {
49 if let SetExpr::Select(select) = &mut *query.body {
50 Self::rewrite_select(select)?;
51 }
52 Ok(())
53 }
54
55 fn rewrite_select(select: &mut Select) -> Result<(), ParseError> {
60 let window_func = Self::find_window_in_group_by(select)?;
62
63 if let Some(_window) = window_func {
64 Self::ensure_window_columns_in_projection(select);
66 }
67
68 Ok(())
69 }
70
71 fn find_window_in_group_by(select: &Select) -> Result<Option<WindowFunction>, ParseError> {
73 match &select.group_by {
75 sqlparser::ast::GroupByExpr::Expressions(exprs, _modifiers) => {
76 for expr in exprs {
77 if let Some(window) = Self::extract_window_function(expr)? {
78 return Ok(Some(window));
79 }
80 }
81 }
82 sqlparser::ast::GroupByExpr::All(_) => {}
83 }
84 Ok(None)
85 }
86
87 fn ensure_window_columns_in_projection(select: &mut Select) {
89 let has_window_start = Self::has_projection_column(select, "window_start");
90 let has_window_end = Self::has_projection_column(select, "window_end");
91
92 if !has_window_start {
94 select.projection.insert(
95 0,
96 SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("window_start"))),
97 );
98 }
99
100 if !has_window_end {
102 select.projection.insert(
103 1,
104 SelectItem::UnnamedExpr(Expr::Identifier(Ident::new("window_end"))),
105 );
106 }
107 }
108
109 fn has_projection_column(select: &Select, name: &str) -> bool {
111 select.projection.iter().any(|item| {
112 if let SelectItem::UnnamedExpr(Expr::Identifier(ident)) = item {
113 ident.value.eq_ignore_ascii_case(name)
114 } else if let SelectItem::ExprWithAlias { alias, .. } = item {
115 alias.value.eq_ignore_ascii_case(name)
116 } else {
117 false
118 }
119 })
120 }
121
122 #[must_use]
124 pub fn contains_window_function(expr: &Expr) -> bool {
125 match expr {
126 Expr::Function(func) => {
127 if let Some(name) = func.name.0.last() {
128 let func_name = name.to_string().to_uppercase();
129 matches!(func_name.as_str(), "TUMBLE" | "HOP" | "SLIDE" | "SESSION")
130 } else {
131 false
132 }
133 }
134 _ => false,
135 }
136 }
137
138 pub fn extract_window_function(expr: &Expr) -> Result<Option<WindowFunction>, ParseError> {
156 match expr {
157 Expr::Function(func) => {
158 let name =
159 func.name.0.last().ok_or_else(|| {
160 ParseError::WindowError("Empty function name".to_string())
161 })?;
162
163 let func_name = name.to_string().to_uppercase();
164
165 let args = Self::extract_function_args(&func.args)?;
167
168 match func_name.as_str() {
169 "TUMBLE" => Self::parse_tumble_args(&args),
170 "HOP" | "SLIDE" => Self::parse_hop_args(&args),
171 "SESSION" => Self::parse_session_args(&args),
172 _ => Ok(None),
173 }
174 }
175 _ => Ok(None),
176 }
177 }
178
179 fn extract_function_args(args: &FunctionArguments) -> Result<Vec<Expr>, ParseError> {
181 match args {
182 FunctionArguments::List(arg_list) => {
183 let mut result = Vec::new();
184 for arg in &arg_list.args {
185 if let Some(expr) = Self::extract_arg_expr(arg) {
186 result.push(expr);
187 }
188 }
189 Ok(result)
190 }
191 FunctionArguments::None => Ok(vec![]),
192 FunctionArguments::Subquery(_) => Err(ParseError::WindowError(
193 "Subquery arguments not supported for window functions".to_string(),
194 )),
195 }
196 }
197
198 fn extract_arg_expr(arg: &FunctionArg) -> Option<Expr> {
200 match arg {
201 FunctionArg::Unnamed(arg_expr) => match arg_expr {
202 FunctionArgExpr::Expr(expr) => Some(expr.clone()),
203 FunctionArgExpr::Wildcard | FunctionArgExpr::QualifiedWildcard(_) => None,
204 },
205 FunctionArg::Named { arg, .. } | FunctionArg::ExprNamed { arg, .. } => match arg {
206 FunctionArgExpr::Expr(expr) => Some(expr.clone()),
207 FunctionArgExpr::Wildcard | FunctionArgExpr::QualifiedWildcard(_) => None,
208 },
209 }
210 }
211
212 fn parse_tumble_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
214 if args.len() != 2 {
215 return Err(ParseError::WindowError(format!(
216 "TUMBLE requires 2 arguments (time_column, interval), got {}",
217 args.len()
218 )));
219 }
220
221 Ok(Some(WindowFunction::Tumble {
222 time_column: Box::new(args[0].clone()),
223 interval: Box::new(args[1].clone()),
224 }))
225 }
226
227 fn parse_hop_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
229 if args.len() != 3 {
230 return Err(ParseError::WindowError(format!(
231 "HOP/SLIDE requires 3 arguments (time_column, slide_interval, window_size), got {}",
232 args.len()
233 )));
234 }
235
236 Ok(Some(WindowFunction::Hop {
237 time_column: Box::new(args[0].clone()),
238 slide_interval: Box::new(args[1].clone()),
239 window_interval: Box::new(args[2].clone()),
240 }))
241 }
242
243 fn parse_session_args(args: &[Expr]) -> Result<Option<WindowFunction>, ParseError> {
245 if args.len() != 2 {
246 return Err(ParseError::WindowError(format!(
247 "SESSION requires 2 arguments (time_column, gap_interval), got {}",
248 args.len()
249 )));
250 }
251
252 Ok(Some(WindowFunction::Session {
253 time_column: Box::new(args[0].clone()),
254 gap_interval: Box::new(args[1].clone()),
255 }))
256 }
257
258 #[must_use]
262 pub fn get_time_column_name(window: &WindowFunction) -> Option<String> {
263 let expr = match window {
264 WindowFunction::Tumble { time_column, .. }
265 | WindowFunction::Hop { time_column, .. }
266 | WindowFunction::Session { time_column, .. } => time_column.as_ref(),
267 };
268
269 match expr {
270 Expr::Identifier(ident) => Some(ident.value.clone()),
271 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
272 _ => None,
273 }
274 }
275
276 pub fn parse_interval_to_duration(expr: &Expr) -> Result<std::time::Duration, ParseError> {
284 match expr {
285 Expr::Interval(interval) => {
286 let value = Self::extract_interval_value(&interval.value)?;
288
289 let unit = interval
291 .leading_field
292 .clone()
293 .unwrap_or(sqlparser::ast::DateTimeField::Second);
294
295 let seconds =
296 match unit {
297 sqlparser::ast::DateTimeField::Second
298 | sqlparser::ast::DateTimeField::Seconds => value,
299 sqlparser::ast::DateTimeField::Minute
300 | sqlparser::ast::DateTimeField::Minutes => value * 60,
301 sqlparser::ast::DateTimeField::Hour
302 | sqlparser::ast::DateTimeField::Hours => value * 3600,
303 sqlparser::ast::DateTimeField::Day
304 | sqlparser::ast::DateTimeField::Days => value * 86400,
305 _ => {
306 return Err(ParseError::WindowError(format!(
307 "Unsupported interval unit: {unit:?}"
308 )))
309 }
310 };
311
312 Ok(std::time::Duration::from_secs(seconds))
313 }
314 Expr::Value(value_with_span) => {
316 use sqlparser::ast::Value;
317 if let Value::SingleQuotedString(s) = &value_with_span.value {
318 Self::parse_interval_string(s)
319 } else {
320 Err(ParseError::WindowError(format!(
321 "Expected string value, got: {value_with_span:?}"
322 )))
323 }
324 }
325 Expr::Identifier(ident) => Self::parse_interval_string(&ident.value),
327 _ => Err(ParseError::WindowError(format!(
328 "Expected INTERVAL expression, got: {expr:?}"
329 ))),
330 }
331 }
332
333 fn extract_interval_value(expr: &Expr) -> Result<u64, ParseError> {
335 match expr {
336 Expr::Value(value_with_span) => {
337 use sqlparser::ast::Value;
338 match &value_with_span.value {
339 Value::Number(n, _) => n.parse::<u64>().map_err(|_| {
340 ParseError::WindowError(format!("Invalid interval value: {n}"))
341 }),
342 Value::SingleQuotedString(s) => {
343 let num_str = s.split_whitespace().next().unwrap_or(s);
345 num_str.parse::<u64>().map_err(|_| {
346 ParseError::WindowError(format!("Invalid interval value: {s}"))
347 })
348 }
349 _ => Err(ParseError::WindowError(format!(
350 "Unsupported value type in interval: {value_with_span:?}"
351 ))),
352 }
353 }
354 _ => Err(ParseError::WindowError(format!(
355 "Cannot extract interval value from: {expr:?}"
356 ))),
357 }
358 }
359
360 fn parse_interval_string(s: &str) -> Result<std::time::Duration, ParseError> {
362 let parts: Vec<&str> = s.split_whitespace().collect();
363 if parts.is_empty() {
364 return Err(ParseError::WindowError("Empty interval string".to_string()));
365 }
366
367 let value: u64 = parts[0].parse().map_err(|_| {
368 ParseError::WindowError(format!("Invalid interval value: {}", parts[0]))
369 })?;
370
371 let unit = if parts.len() > 1 {
372 parts[1].to_uppercase()
373 } else {
374 "SECOND".to_string()
375 };
376
377 let seconds = match unit.as_str() {
378 "SECOND" | "SECONDS" | "S" => value,
379 "MINUTE" | "MINUTES" | "M" => value * 60,
380 "HOUR" | "HOURS" | "H" => value * 3600,
381 "DAY" | "DAYS" | "D" => value * 86400,
382 _ => {
383 return Err(ParseError::WindowError(format!(
384 "Unsupported interval unit: {unit}"
385 )))
386 }
387 };
388
389 Ok(std::time::Duration::from_secs(seconds))
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use sqlparser::dialect::GenericDialect;
397 use sqlparser::parser::Parser;
398
399 #[test]
400 fn test_contains_window_function() {
401 let sql = "SELECT TUMBLE(event_time, INTERVAL '5' MINUTE) FROM events";
402 let dialect = GenericDialect {};
403 let statements = Parser::parse_sql(&dialect, sql).unwrap();
404
405 if let Statement::Query(query) = &statements[0] {
406 if let SetExpr::Select(select) = &*query.body {
407 if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
408 assert!(WindowRewriter::contains_window_function(expr));
409 }
410 }
411 }
412 }
413
414 #[test]
415 fn test_rewrite_statement() {
416 let sql = "SELECT COUNT(*) FROM events GROUP BY event_time";
417 let dialect = GenericDialect {};
418 let mut statements = Parser::parse_sql(&dialect, sql).unwrap();
419
420 assert!(WindowRewriter::rewrite_statement(&mut statements[0]).is_ok());
422 }
423
424 #[test]
425 fn test_extract_tumble_with_actual_args() {
426 let sql = "SELECT TUMBLE(order_time, INTERVAL '10' MINUTE) FROM orders";
427 let dialect = GenericDialect {};
428 let statements = Parser::parse_sql(&dialect, sql).unwrap();
429
430 if let Statement::Query(query) = &statements[0] {
431 if let SetExpr::Select(select) = &*query.body {
432 if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
433 let window = WindowRewriter::extract_window_function(expr)
434 .unwrap()
435 .unwrap();
436
437 match window {
438 WindowFunction::Tumble {
439 time_column,
440 interval,
441 } => {
442 assert_eq!(time_column.to_string(), "order_time");
444
445 assert!(interval.to_string().contains("10"));
447 }
448 _ => panic!("Expected Tumble window"),
449 }
450 }
451 }
452 }
453 }
454
455 #[test]
456 fn test_extract_hop_with_actual_args() {
457 let sql = "SELECT HOP(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM readings";
458 let dialect = GenericDialect {};
459 let statements = Parser::parse_sql(&dialect, sql).unwrap();
460
461 if let Statement::Query(query) = &statements[0] {
462 if let SetExpr::Select(select) = &*query.body {
463 if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
464 let window = WindowRewriter::extract_window_function(expr)
465 .unwrap()
466 .unwrap();
467
468 match window {
469 WindowFunction::Hop {
470 time_column,
471 slide_interval,
472 window_interval,
473 } => {
474 assert_eq!(time_column.to_string(), "ts");
475 assert!(slide_interval.to_string().contains('1'));
476 assert!(window_interval.to_string().contains('5'));
477 }
478 _ => panic!("Expected Hop window"),
479 }
480 }
481 }
482 }
483 }
484
485 #[test]
486 fn test_extract_session_with_actual_args() {
487 let sql = "SELECT SESSION(click_time, INTERVAL '30' MINUTE) FROM clicks";
488 let dialect = GenericDialect {};
489 let statements = Parser::parse_sql(&dialect, sql).unwrap();
490
491 if let Statement::Query(query) = &statements[0] {
492 if let SetExpr::Select(select) = &*query.body {
493 if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
494 let window = WindowRewriter::extract_window_function(expr)
495 .unwrap()
496 .unwrap();
497
498 match window {
499 WindowFunction::Session {
500 time_column,
501 gap_interval,
502 } => {
503 assert_eq!(time_column.to_string(), "click_time");
504 assert!(gap_interval.to_string().contains("30"));
505 }
506 _ => panic!("Expected Session window"),
507 }
508 }
509 }
510 }
511 }
512
513 #[test]
514 fn test_tumble_wrong_args_count() {
515 let sql = "SELECT TUMBLE(ts) FROM events";
516 let dialect = GenericDialect {};
517 let statements = Parser::parse_sql(&dialect, sql).unwrap();
518
519 if let Statement::Query(query) = &statements[0] {
520 if let SetExpr::Select(select) = &*query.body {
521 if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
522 let result = WindowRewriter::extract_window_function(expr);
523 assert!(result.is_err());
524 let err = result.unwrap_err();
525 assert!(err.to_string().contains("2 arguments"));
526 }
527 }
528 }
529 }
530
531 #[test]
532 fn test_hop_wrong_args_count() {
533 let sql = "SELECT HOP(ts, INTERVAL '1' MINUTE) FROM events";
534 let dialect = GenericDialect {};
535 let statements = Parser::parse_sql(&dialect, sql).unwrap();
536
537 if let Statement::Query(query) = &statements[0] {
538 if let SetExpr::Select(select) = &*query.body {
539 if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
540 let result = WindowRewriter::extract_window_function(expr);
541 assert!(result.is_err());
542 let err = result.unwrap_err();
543 assert!(err.to_string().contains("3 arguments"));
544 }
545 }
546 }
547 }
548
549 #[test]
550 fn test_slide_alias_for_hop() {
551 let sql = "SELECT SLIDE(ts, INTERVAL '1' MINUTE, INTERVAL '5' MINUTE) FROM events";
552 let dialect = GenericDialect {};
553 let statements = Parser::parse_sql(&dialect, sql).unwrap();
554
555 if let Statement::Query(query) = &statements[0] {
556 if let SetExpr::Select(select) = &*query.body {
557 if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
558 let window = WindowRewriter::extract_window_function(expr)
559 .unwrap()
560 .unwrap();
561
562 assert!(matches!(window, WindowFunction::Hop { .. }));
564 }
565 }
566 }
567 }
568
569 #[test]
570 fn test_get_time_column_name() {
571 let sql = "SELECT TUMBLE(my_timestamp, INTERVAL '5' MINUTE) FROM events";
572 let dialect = GenericDialect {};
573 let statements = Parser::parse_sql(&dialect, sql).unwrap();
574
575 if let Statement::Query(query) = &statements[0] {
576 if let SetExpr::Select(select) = &*query.body {
577 if let SelectItem::UnnamedExpr(expr) = &select.projection[0] {
578 let window = WindowRewriter::extract_window_function(expr)
579 .unwrap()
580 .unwrap();
581
582 let col_name = WindowRewriter::get_time_column_name(&window);
583 assert_eq!(col_name, Some("my_timestamp".to_string()));
584 }
585 }
586 }
587 }
588
589 #[test]
590 fn test_parse_interval_to_duration() {
591 let sql = "SELECT COUNT(*) FROM events GROUP BY TUMBLE(ts, INTERVAL '5' MINUTE)";
593 let dialect = GenericDialect {};
594 let statements = Parser::parse_sql(&dialect, sql).unwrap();
595
596 if let Statement::Query(query) = &statements[0] {
597 if let SetExpr::Select(select) = &*query.body {
598 if let sqlparser::ast::GroupByExpr::Expressions(exprs, _) = &select.group_by {
599 if let Some(expr) = exprs.first() {
600 let window = WindowRewriter::extract_window_function(expr)
601 .unwrap()
602 .unwrap();
603
604 if let WindowFunction::Tumble { interval, .. } = window {
605 let duration =
606 WindowRewriter::parse_interval_to_duration(&interval).unwrap();
607 assert_eq!(duration, std::time::Duration::from_secs(300));
608 }
609 }
610 }
611 }
612 }
613 }
614
615 #[test]
616 fn test_parse_interval_string_formats() {
617 let cases = [
619 ("5 MINUTE", 300),
620 ("5 MINUTES", 300),
621 ("1 HOUR", 3600),
622 ("2 HOURS", 7200),
623 ("10 SECOND", 10),
624 ("1 DAY", 86400),
625 ];
626
627 for (input, expected_secs) in cases {
628 let result = WindowRewriter::parse_interval_string(input).unwrap();
629 assert_eq!(
630 result,
631 std::time::Duration::from_secs(expected_secs),
632 "Failed for input: {input}"
633 );
634 }
635 }
636
637 #[test]
638 fn test_window_in_group_by() {
639 let sql = "SELECT user_id, COUNT(*) FROM events GROUP BY TUMBLE(event_time, INTERVAL '1' HOUR), user_id";
640 let dialect = GenericDialect {};
641 let statements = Parser::parse_sql(&dialect, sql).unwrap();
642
643 if let Statement::Query(query) = &statements[0] {
644 if let SetExpr::Select(select) = &*query.body {
645 let window = WindowRewriter::find_window_in_group_by(select)
646 .unwrap()
647 .unwrap();
648
649 assert!(matches!(window, WindowFunction::Tumble { .. }));
650
651 if let WindowFunction::Tumble { time_column, .. } = window {
652 assert_eq!(time_column.to_string(), "event_time");
653 }
654 }
655 }
656 }
657}