1use std::time::Duration;
10
11use sqlparser::ast::{BinaryOperator, Expr, JoinConstraint, JoinOperator, Select, TableFactor};
12
13use super::window_rewriter::WindowRewriter;
14use super::ParseError;
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum JoinType {
19 Inner,
21 Left,
23 Right,
25 Full,
27 AsOf,
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum AsofSqlDirection {
34 Backward,
36 Forward,
38}
39
40impl std::fmt::Display for AsofSqlDirection {
41 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42 match self {
43 AsofSqlDirection::Backward => write!(f, "BACKWARD"),
44 AsofSqlDirection::Forward => write!(f, "FORWARD"),
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct JoinAnalysis {
52 pub join_type: JoinType,
54 pub left_table: String,
56 pub right_table: String,
58 pub left_key_column: String,
60 pub right_key_column: String,
62 pub time_bound: Option<Duration>,
64 pub is_lookup_join: bool,
66 pub left_alias: Option<String>,
68 pub right_alias: Option<String>,
70 pub is_asof_join: bool,
72 pub asof_direction: Option<AsofSqlDirection>,
74 pub left_time_column: Option<String>,
76 pub right_time_column: Option<String>,
78 pub asof_tolerance: Option<Duration>,
80}
81
82impl JoinAnalysis {
83 #[must_use]
85 pub fn stream_stream(
86 left_table: String,
87 right_table: String,
88 left_key: String,
89 right_key: String,
90 time_bound: Duration,
91 join_type: JoinType,
92 ) -> Self {
93 Self {
94 join_type,
95 left_table,
96 right_table,
97 left_key_column: left_key,
98 right_key_column: right_key,
99 time_bound: Some(time_bound),
100 is_lookup_join: false,
101 left_alias: None,
102 right_alias: None,
103 is_asof_join: false,
104 asof_direction: None,
105 left_time_column: None,
106 right_time_column: None,
107 asof_tolerance: None,
108 }
109 }
110
111 #[must_use]
113 pub fn lookup(
114 left_table: String,
115 right_table: String,
116 left_key: String,
117 right_key: String,
118 join_type: JoinType,
119 ) -> Self {
120 Self {
121 join_type,
122 left_table,
123 right_table,
124 left_key_column: left_key,
125 right_key_column: right_key,
126 time_bound: None,
127 is_lookup_join: true,
128 left_alias: None,
129 right_alias: None,
130 is_asof_join: false,
131 asof_direction: None,
132 left_time_column: None,
133 right_time_column: None,
134 asof_tolerance: None,
135 }
136 }
137
138 #[must_use]
140 #[allow(clippy::too_many_arguments)]
141 pub fn asof(
142 left_table: String,
143 right_table: String,
144 left_key: String,
145 right_key: String,
146 direction: AsofSqlDirection,
147 left_time_col: String,
148 right_time_col: String,
149 tolerance: Option<Duration>,
150 ) -> Self {
151 Self {
152 join_type: JoinType::AsOf,
153 left_table,
154 right_table,
155 left_key_column: left_key,
156 right_key_column: right_key,
157 time_bound: None,
158 is_lookup_join: false,
159 left_alias: None,
160 right_alias: None,
161 is_asof_join: true,
162 asof_direction: Some(direction),
163 left_time_column: Some(left_time_col),
164 right_time_column: Some(right_time_col),
165 asof_tolerance: tolerance,
166 }
167 }
168}
169
170pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
178 let from = &select.from;
179 if from.is_empty() {
180 return Ok(None);
181 }
182
183 let first_table = &from[0];
184 if first_table.joins.is_empty() {
185 return Ok(None);
186 }
187
188 let left_table = extract_table_name(&first_table.relation)?;
190 let left_alias = extract_table_alias(&first_table.relation);
191
192 let join = &first_table.joins[0];
194 let right_table = extract_table_name(&join.relation)?;
195 let right_alias = extract_table_alias(&join.relation);
196
197 let join_type = map_join_operator(&join.join_operator);
198
199 if let JoinOperator::AsOf {
201 match_condition,
202 constraint,
203 } = &join.join_operator
204 {
205 let (direction, left_time, right_time, tolerance) =
206 analyze_asof_match_condition(match_condition)?;
207
208 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
210
211 let mut analysis = JoinAnalysis::asof(
212 left_table,
213 right_table,
214 left_key,
215 right_key,
216 direction,
217 left_time,
218 right_time,
219 tolerance,
220 );
221 analysis.left_alias = left_alias;
222 analysis.right_alias = right_alias;
223 return Ok(Some(analysis));
224 }
225
226 let (left_key, right_key, time_bound) = analyze_join_constraint(&join.join_operator)?;
228
229 let mut analysis = if let Some(tb) = time_bound {
230 JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
231 } else {
232 JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
233 };
234
235 analysis.left_alias = left_alias;
236 analysis.right_alias = right_alias;
237
238 Ok(Some(analysis))
239}
240
241fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
243 match factor {
244 TableFactor::Table { name, .. } => Ok(name.to_string()),
245 TableFactor::Derived { alias, .. } => {
246 if let Some(alias) = alias {
247 Ok(alias.name.value.clone())
248 } else {
249 Err(ParseError::StreamingError(
250 "Derived table without alias not supported".to_string(),
251 ))
252 }
253 }
254 _ => Err(ParseError::StreamingError(
255 "Unsupported table factor type".to_string(),
256 )),
257 }
258}
259
260fn extract_table_alias(factor: &TableFactor) -> Option<String> {
262 match factor {
263 TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
264 TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
265 _ => None,
266 }
267}
268
269fn map_join_operator(op: &JoinOperator) -> JoinType {
271 match op {
272 JoinOperator::Inner(_)
273 | JoinOperator::Join(_)
274 | JoinOperator::CrossJoin(_)
275 | JoinOperator::CrossApply
276 | JoinOperator::OuterApply
277 | JoinOperator::StraightJoin(_) => JoinType::Inner,
278 JoinOperator::Left(_)
279 | JoinOperator::LeftOuter(_)
280 | JoinOperator::LeftSemi(_)
281 | JoinOperator::LeftAnti(_)
282 | JoinOperator::Semi(_) => JoinType::Left,
283 JoinOperator::AsOf { .. } => JoinType::AsOf,
284 JoinOperator::Right(_)
285 | JoinOperator::RightOuter(_)
286 | JoinOperator::RightSemi(_)
287 | JoinOperator::RightAnti(_)
288 | JoinOperator::Anti(_) => JoinType::Right,
289 JoinOperator::FullOuter(_) => JoinType::Full,
290 }
291}
292
293fn analyze_join_constraint(
295 op: &JoinOperator,
296) -> Result<(String, String, Option<Duration>), ParseError> {
297 let constraint = get_join_constraint(op)?;
298
299 match constraint {
300 JoinConstraint::On(expr) => analyze_on_expression(expr),
301 JoinConstraint::Using(cols) => {
302 if cols.is_empty() {
303 return Err(ParseError::StreamingError(
304 "USING clause requires at least one column".to_string(),
305 ));
306 }
307 let col = cols[0].to_string();
310 Ok((col.clone(), col, None))
311 }
312 JoinConstraint::Natural => Err(ParseError::StreamingError(
313 "NATURAL JOIN not supported for streaming".to_string(),
314 )),
315 JoinConstraint::None => Err(ParseError::StreamingError(
316 "JOIN without condition not supported for streaming".to_string(),
317 )),
318 }
319}
320
321fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
323 match op {
324 JoinOperator::Inner(constraint)
325 | JoinOperator::Join(constraint)
326 | JoinOperator::Left(constraint)
327 | JoinOperator::LeftOuter(constraint)
328 | JoinOperator::Right(constraint)
329 | JoinOperator::RightOuter(constraint)
330 | JoinOperator::FullOuter(constraint)
331 | JoinOperator::LeftSemi(constraint)
332 | JoinOperator::RightSemi(constraint)
333 | JoinOperator::LeftAnti(constraint)
334 | JoinOperator::RightAnti(constraint)
335 | JoinOperator::Semi(constraint)
336 | JoinOperator::Anti(constraint)
337 | JoinOperator::StraightJoin(constraint)
338 | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
339 JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
340 ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
341 ),
342 }
343}
344
345fn analyze_on_expression(expr: &Expr) -> Result<(String, String, Option<Duration>), ParseError> {
347 match expr {
349 Expr::BinaryOp {
350 left,
351 op: BinaryOperator::And,
352 right,
353 } => {
354 let left_result = analyze_on_expression(left);
356 let right_result = analyze_on_expression(right);
357
358 match (left_result, right_result) {
360 (Ok((lk, rk, None)), Ok((_, _, time))) if !lk.is_empty() => Ok((lk, rk, time)),
361 (Ok((_, _, time)), Ok((lk, rk, None))) if !lk.is_empty() => Ok((lk, rk, time)),
362 (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
363 (Ok((lk, rk, t1)), Ok((_, _, t2))) => {
364 Ok((lk, rk, t1.or(t2)))
366 }
367 (Err(e), Err(_)) => Err(e),
368 }
369 }
370 Expr::BinaryOp {
372 left,
373 op: BinaryOperator::Eq,
374 right,
375 } => {
376 let left_col = extract_column_ref(left);
377 let right_col = extract_column_ref(right);
378
379 match (left_col, right_col) {
380 (Some(l), Some(r)) => Ok((l, r, None)),
381 _ => Err(ParseError::StreamingError(
382 "Cannot extract column references from equality condition".to_string(),
383 )),
384 }
385 }
386 Expr::Between {
388 expr: _,
389 low: _,
390 high,
391 ..
392 } => {
393 let time_bound = extract_time_bound_from_expr(high).ok();
395 Ok((String::new(), String::new(), time_bound))
396 }
397 Expr::BinaryOp {
399 left: _,
400 op:
401 BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
402 right,
403 } => {
404 let time_bound = extract_time_bound_from_expr(right).ok();
406 Ok((String::new(), String::new(), time_bound))
407 }
408 _ => Err(ParseError::StreamingError(format!(
409 "Unsupported join condition expression: {expr:?}"
410 ))),
411 }
412}
413
414fn extract_column_ref(expr: &Expr) -> Option<String> {
416 match expr {
417 Expr::Identifier(ident) => Some(ident.value.clone()),
418 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
419 _ => None,
420 }
421}
422
423fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
425 match expr {
426 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
428 Expr::BinaryOp {
430 left: _,
431 op: BinaryOperator::Plus | BinaryOperator::Minus,
432 right,
433 } => extract_time_bound_from_expr(right),
434 Expr::Nested(inner) => extract_time_bound_from_expr(inner),
436 _ => Err(ParseError::StreamingError(format!(
437 "Cannot extract time bound from: {expr:?}"
438 ))),
439 }
440}
441
442fn analyze_asof_match_condition(
446 expr: &Expr,
447) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
448 if let Expr::BinaryOp {
449 left,
450 op: BinaryOperator::And,
451 right,
452 } = expr
453 {
454 let dir_result = analyze_asof_direction(left);
456 let tol_result = extract_asof_tolerance(right);
457
458 match (dir_result, tol_result) {
459 (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
460 (Ok((dir, lt, rt)), Err(_)) => {
461 let dir2 = analyze_asof_direction(right);
463 let tol2 = extract_asof_tolerance(left);
464 match (dir2, tol2) {
465 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
466 _ => Ok((dir, lt, rt, None)),
467 }
468 }
469 (Err(_), _) => {
470 let dir2 = analyze_asof_direction(right);
472 let tol2 = extract_asof_tolerance(left);
473 match (dir2, tol2) {
474 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
475 (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
476 _ => Err(ParseError::StreamingError(
477 "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
478 )),
479 }
480 }
481 }
482 } else {
483 let (dir, lt, rt) = analyze_asof_direction(expr)?;
484 Ok((dir, lt, rt, None))
485 }
486}
487
488fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
490 match expr {
491 Expr::BinaryOp {
492 left,
493 op: BinaryOperator::GtEq,
494 right,
495 } => {
496 let left_col = extract_column_ref(left).ok_or_else(|| {
497 ParseError::StreamingError(
498 "Cannot extract left time column from MATCH_CONDITION".to_string(),
499 )
500 })?;
501 let right_col = extract_column_ref(right).ok_or_else(|| {
502 ParseError::StreamingError(
503 "Cannot extract right time column from MATCH_CONDITION".to_string(),
504 )
505 })?;
506 Ok((AsofSqlDirection::Backward, left_col, right_col))
507 }
508 Expr::BinaryOp {
509 left,
510 op: BinaryOperator::LtEq,
511 right,
512 } => {
513 let left_col = extract_column_ref(left).ok_or_else(|| {
514 ParseError::StreamingError(
515 "Cannot extract left time column from MATCH_CONDITION".to_string(),
516 )
517 })?;
518 let right_col = extract_column_ref(right).ok_or_else(|| {
519 ParseError::StreamingError(
520 "Cannot extract right time column from MATCH_CONDITION".to_string(),
521 )
522 })?;
523 Ok((AsofSqlDirection::Forward, left_col, right_col))
524 }
525 _ => Err(ParseError::StreamingError(
526 "ASOF MATCH_CONDITION must be >= or <= comparison".to_string(),
527 )),
528 }
529}
530
531fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
535 match expr {
536 Expr::BinaryOp {
537 left: _,
538 op: BinaryOperator::LtEq,
539 right,
540 } => {
541 match right.as_ref() {
543 Expr::Value(v) => {
544 if let sqlparser::ast::Value::Number(n, _) = &v.value {
545 let ms: u64 = n.parse().map_err(|_| {
546 ParseError::StreamingError(format!(
547 "Cannot parse tolerance as number: {n}"
548 ))
549 })?;
550 Ok(Duration::from_millis(ms))
551 } else {
552 Err(ParseError::StreamingError(
553 "ASOF tolerance must be a number or INTERVAL".to_string(),
554 ))
555 }
556 }
557 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
558 _ => Err(ParseError::StreamingError(
559 "ASOF tolerance must be a number or INTERVAL".to_string(),
560 )),
561 }
562 }
563 _ => Err(ParseError::StreamingError(
564 "ASOF tolerance expression must be <= comparison".to_string(),
565 )),
566 }
567}
568
569fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
571 match constraint {
572 JoinConstraint::On(expr) => extract_equality_columns(expr),
573 JoinConstraint::Using(cols) => {
574 if cols.is_empty() {
575 return Err(ParseError::StreamingError(
576 "USING clause requires at least one column".to_string(),
577 ));
578 }
579 let col = cols[0].to_string();
580 Ok((col.clone(), col))
581 }
582 _ => Err(ParseError::StreamingError(
583 "ASOF JOIN requires ON or USING constraint".to_string(),
584 )),
585 }
586}
587
588fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
590 match expr {
591 Expr::BinaryOp {
592 left,
593 op: BinaryOperator::Eq,
594 right,
595 } => {
596 let left_col = extract_column_ref(left).ok_or_else(|| {
597 ParseError::StreamingError("Cannot extract left key column".to_string())
598 })?;
599 let right_col = extract_column_ref(right).ok_or_else(|| {
600 ParseError::StreamingError("Cannot extract right key column".to_string())
601 })?;
602 Ok((left_col, right_col))
603 }
604 Expr::BinaryOp {
606 left,
607 op: BinaryOperator::And,
608 right,
609 } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
610 _ => Err(ParseError::StreamingError(
611 "ASOF JOIN ON clause must contain an equality condition".to_string(),
612 )),
613 }
614}
615
616#[must_use]
618pub fn has_join(select: &Select) -> bool {
619 !select.from.is_empty() && !select.from[0].joins.is_empty()
620}
621
622#[must_use]
624pub fn count_joins(select: &Select) -> usize {
625 select
626 .from
627 .iter()
628 .map(|table_with_joins| table_with_joins.joins.len())
629 .sum()
630}
631
632#[derive(Debug, Clone)]
637pub struct MultiJoinAnalysis {
638 pub joins: Vec<JoinAnalysis>,
640 pub tables: Vec<String>,
642}
643
644impl MultiJoinAnalysis {
645 #[must_use]
647 pub fn len(&self) -> usize {
648 self.joins.len()
649 }
650
651 #[must_use]
653 pub fn is_empty(&self) -> bool {
654 self.joins.is_empty()
655 }
656
657 #[must_use]
659 pub fn is_single(&self) -> bool {
660 self.joins.len() == 1
661 }
662
663 #[must_use]
665 pub fn first(&self) -> Option<&JoinAnalysis> {
666 self.joins.first()
667 }
668}
669
670pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
681 let from = &select.from;
682 if from.is_empty() {
683 return Ok(None);
684 }
685
686 let first_table = &from[0];
687 if first_table.joins.is_empty() {
688 return Ok(None);
689 }
690
691 let base_table = extract_table_name(&first_table.relation)?;
693 let base_alias = extract_table_alias(&first_table.relation);
694
695 let mut join_steps = Vec::with_capacity(first_table.joins.len());
696 let mut tables = vec![base_table.clone()];
697
698 let mut prev_left_table = base_table;
700 let mut prev_left_alias = base_alias;
701
702 for join in &first_table.joins {
703 let right_table = extract_table_name(&join.relation)?;
704 let right_alias = extract_table_alias(&join.relation);
705 tables.push(right_table.clone());
706
707 let join_type = map_join_operator(&join.join_operator);
708
709 if let JoinOperator::AsOf {
711 match_condition,
712 constraint,
713 } = &join.join_operator
714 {
715 let (direction, left_time, right_time, tolerance) =
716 analyze_asof_match_condition(match_condition)?;
717 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
718
719 let mut analysis = JoinAnalysis::asof(
720 prev_left_table.clone(),
721 right_table.clone(),
722 left_key,
723 right_key,
724 direction,
725 left_time,
726 right_time,
727 tolerance,
728 );
729 analysis.left_alias.clone_from(&prev_left_alias);
730 analysis.right_alias = right_alias;
731 join_steps.push(analysis);
732 } else {
733 let (left_key, right_key, time_bound) = analyze_join_constraint(&join.join_operator)?;
735
736 let mut analysis = if let Some(tb) = time_bound {
737 JoinAnalysis::stream_stream(
738 prev_left_table.clone(),
739 right_table.clone(),
740 left_key,
741 right_key,
742 tb,
743 join_type,
744 )
745 } else {
746 JoinAnalysis::lookup(
747 prev_left_table.clone(),
748 right_table.clone(),
749 left_key,
750 right_key,
751 join_type,
752 )
753 };
754 analysis.left_alias.clone_from(&prev_left_alias);
755 analysis.right_alias = right_alias;
756 join_steps.push(analysis);
757 }
758
759 prev_left_table = right_table;
761 prev_left_alias = extract_table_alias(&join.relation);
762 }
763
764 Ok(Some(MultiJoinAnalysis {
765 joins: join_steps,
766 tables,
767 }))
768}
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773 use sqlparser::ast::{SetExpr, Statement};
774 use sqlparser::dialect::GenericDialect;
775 use sqlparser::parser::Parser;
776
777 fn parse_select(sql: &str) -> Select {
778 let dialect = GenericDialect {};
779 let statements = Parser::parse_sql(&dialect, sql).unwrap();
780 if let Statement::Query(query) = &statements[0] {
781 if let SetExpr::Select(select) = query.body.as_ref() {
782 return *select.clone();
783 }
784 }
785 panic!("Expected SELECT query");
786 }
787
788 #[test]
789 fn test_analyze_inner_join() {
790 let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
791 let select = parse_select(sql);
792
793 let analysis = analyze_join(&select).unwrap().unwrap();
794
795 assert_eq!(analysis.join_type, JoinType::Inner);
796 assert_eq!(analysis.left_table, "orders");
797 assert_eq!(analysis.right_table, "payments");
798 assert_eq!(analysis.left_key_column, "order_id");
799 assert_eq!(analysis.right_key_column, "order_id");
800 assert!(analysis.is_lookup_join); }
802
803 #[test]
804 fn test_analyze_left_join() {
805 let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
806 let select = parse_select(sql);
807
808 let analysis = analyze_join(&select).unwrap().unwrap();
809
810 assert_eq!(analysis.join_type, JoinType::Left);
811 assert_eq!(analysis.left_key_column, "customer_id");
812 assert_eq!(analysis.right_key_column, "id");
813 }
814
815 #[test]
816 fn test_analyze_join_using() {
817 let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
818 let select = parse_select(sql);
819
820 let analysis = analyze_join(&select).unwrap().unwrap();
821
822 assert_eq!(analysis.left_key_column, "order_id");
823 assert_eq!(analysis.right_key_column, "order_id");
824 }
825
826 #[test]
827 fn test_analyze_stream_stream_join_with_time_bound() {
828 let sql = "SELECT * FROM orders o
829 JOIN payments p ON o.order_id = p.order_id
830 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
831 let select = parse_select(sql);
832
833 let analysis = analyze_join(&select).unwrap().unwrap();
834
835 assert!(!analysis.is_lookup_join);
836 assert!(analysis.time_bound.is_some());
837 assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
838 }
839
840 #[test]
841 fn test_no_join() {
842 let sql = "SELECT * FROM orders";
843 let select = parse_select(sql);
844
845 let analysis = analyze_join(&select).unwrap();
846 assert!(analysis.is_none());
847 }
848
849 #[test]
850 fn test_has_join() {
851 let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
852 let sql_without_join = "SELECT * FROM orders";
853
854 let select_with = parse_select(sql_with_join);
855 let select_without = parse_select(sql_without_join);
856
857 assert!(has_join(&select_with));
858 assert!(!has_join(&select_without));
859 }
860
861 #[test]
862 fn test_count_joins() {
863 let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
864 let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
865 let sql_zero = "SELECT * FROM a";
866
867 assert_eq!(count_joins(&parse_select(sql_one)), 1);
868 assert_eq!(count_joins(&parse_select(sql_two)), 2);
869 assert_eq!(count_joins(&parse_select(sql_zero)), 0);
870 }
871
872 #[test]
873 fn test_aliases() {
874 let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
875 let select = parse_select(sql);
876
877 let analysis = analyze_join(&select).unwrap().unwrap();
878
879 assert_eq!(analysis.left_alias, Some("o".to_string()));
880 assert_eq!(analysis.right_alias, Some("p".to_string()));
881 }
882
883 fn parse_select_snowflake(sql: &str) -> Select {
886 let dialect = sqlparser::dialect::SnowflakeDialect {};
887 let statements = Parser::parse_sql(&dialect, sql).unwrap();
888 if let Statement::Query(query) = &statements[0] {
889 if let SetExpr::Select(select) = query.body.as_ref() {
890 return *select.clone();
891 }
892 }
893 panic!("Expected SELECT query");
894 }
895
896 #[test]
897 fn test_asof_join_backward() {
898 let sql = "SELECT * FROM trades t \
899 ASOF JOIN quotes q \
900 MATCH_CONDITION(t.ts >= q.ts) \
901 ON t.symbol = q.symbol";
902 let select = parse_select_snowflake(sql);
903 let analysis = analyze_join(&select).unwrap().unwrap();
904
905 assert!(analysis.is_asof_join);
906 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
907 assert_eq!(analysis.join_type, JoinType::AsOf);
908 assert!(analysis.asof_tolerance.is_none());
909 }
910
911 #[test]
912 fn test_asof_join_forward() {
913 let sql = "SELECT * FROM trades t \
914 ASOF JOIN quotes q \
915 MATCH_CONDITION(t.ts <= q.ts) \
916 ON t.symbol = q.symbol";
917 let select = parse_select_snowflake(sql);
918 let analysis = analyze_join(&select).unwrap().unwrap();
919
920 assert!(analysis.is_asof_join);
921 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
922 }
923
924 #[test]
925 fn test_asof_join_with_tolerance() {
926 let sql = "SELECT * FROM trades t \
927 ASOF JOIN quotes q \
928 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
929 ON t.symbol = q.symbol";
930 let select = parse_select_snowflake(sql);
931 let analysis = analyze_join(&select).unwrap().unwrap();
932
933 assert!(analysis.is_asof_join);
934 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
935 assert_eq!(analysis.asof_tolerance, Some(Duration::from_millis(5000)));
936 }
937
938 #[test]
939 fn test_asof_join_with_interval_tolerance() {
940 let sql = "SELECT * FROM trades t \
941 ASOF JOIN quotes q \
942 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
943 ON t.symbol = q.symbol";
944 let select = parse_select_snowflake(sql);
945 let analysis = analyze_join(&select).unwrap().unwrap();
946
947 assert!(analysis.is_asof_join);
948 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
949 assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
950 }
951
952 #[test]
953 fn test_asof_join_type_mapping() {
954 let sql = "SELECT * FROM trades t \
955 ASOF JOIN quotes q \
956 MATCH_CONDITION(t.ts >= q.ts) \
957 ON t.symbol = q.symbol";
958 let select = parse_select_snowflake(sql);
959 let analysis = analyze_join(&select).unwrap().unwrap();
960
961 assert_eq!(analysis.join_type, JoinType::AsOf);
962 assert!(!analysis.is_lookup_join);
963 }
964
965 #[test]
966 fn test_asof_join_extracts_time_columns() {
967 let sql = "SELECT * FROM trades t \
968 ASOF JOIN quotes q \
969 MATCH_CONDITION(t.ts >= q.ts) \
970 ON t.symbol = q.symbol";
971 let select = parse_select_snowflake(sql);
972 let analysis = analyze_join(&select).unwrap().unwrap();
973
974 assert_eq!(analysis.left_time_column, Some("ts".to_string()));
975 assert_eq!(analysis.right_time_column, Some("ts".to_string()));
976 }
977
978 #[test]
979 fn test_asof_join_extracts_key_columns() {
980 let sql = "SELECT * FROM trades t \
981 ASOF JOIN quotes q \
982 MATCH_CONDITION(t.ts >= q.ts) \
983 ON t.symbol = q.symbol";
984 let select = parse_select_snowflake(sql);
985 let analysis = analyze_join(&select).unwrap().unwrap();
986
987 assert_eq!(analysis.left_key_column, "symbol");
988 assert_eq!(analysis.right_key_column, "symbol");
989 }
990
991 #[test]
992 fn test_asof_join_aliases() {
993 let sql = "SELECT * FROM trades AS t \
994 ASOF JOIN quotes AS q \
995 MATCH_CONDITION(t.ts >= q.ts) \
996 ON t.symbol = q.symbol";
997 let select = parse_select_snowflake(sql);
998 let analysis = analyze_join(&select).unwrap().unwrap();
999
1000 assert_eq!(analysis.left_alias, Some("t".to_string()));
1001 assert_eq!(analysis.right_alias, Some("q".to_string()));
1002 assert_eq!(analysis.left_table, "trades");
1003 assert_eq!(analysis.right_table, "quotes");
1004 }
1005
1006 #[test]
1009 fn test_multi_join_single_backward_compat() {
1010 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1011 let select = parse_select(sql);
1012 let multi = analyze_joins(&select).unwrap().unwrap();
1013
1014 assert!(multi.is_single());
1015 assert_eq!(multi.len(), 1);
1016 assert!(!multi.is_empty());
1017 let first = multi.first().unwrap();
1018 assert_eq!(first.left_table, "orders");
1019 assert_eq!(first.right_table, "payments");
1020 }
1021
1022 #[test]
1023 fn test_multi_join_two_way() {
1024 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1025 let select = parse_select(sql);
1026 let multi = analyze_joins(&select).unwrap().unwrap();
1027
1028 assert_eq!(multi.len(), 2);
1029 assert!(!multi.is_single());
1030
1031 assert_eq!(multi.joins[0].left_table, "a");
1032 assert_eq!(multi.joins[0].right_table, "b");
1033 assert_eq!(multi.joins[0].left_key_column, "id");
1034 assert_eq!(multi.joins[0].right_key_column, "a_id");
1035
1036 assert_eq!(multi.joins[1].left_table, "b");
1037 assert_eq!(multi.joins[1].right_table, "c");
1038 assert_eq!(multi.joins[1].left_key_column, "id");
1039 assert_eq!(multi.joins[1].right_key_column, "b_id");
1040 }
1041
1042 #[test]
1043 fn test_multi_join_three_way() {
1044 let sql = "SELECT * FROM a \
1045 JOIN b ON a.id = b.a_id \
1046 JOIN c ON b.id = c.b_id \
1047 JOIN d ON c.id = d.c_id";
1048 let select = parse_select(sql);
1049 let multi = analyze_joins(&select).unwrap().unwrap();
1050
1051 assert_eq!(multi.len(), 3);
1052 assert_eq!(multi.tables.len(), 4);
1053 assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1054 }
1055
1056 #[test]
1057 fn test_multi_join_mixed_asof_and_lookup() {
1058 let sql = "SELECT * FROM trades t \
1060 ASOF JOIN quotes q \
1061 MATCH_CONDITION(t.ts >= q.ts) \
1062 ON t.symbol = q.symbol \
1063 JOIN products p ON q.product_id = p.id";
1064 let select = parse_select_snowflake(sql);
1065 let multi = analyze_joins(&select).unwrap().unwrap();
1066
1067 assert_eq!(multi.len(), 2);
1068 assert!(multi.joins[0].is_asof_join);
1069 assert!(multi.joins[1].is_lookup_join);
1070 }
1071
1072 #[test]
1073 fn test_multi_join_stream_stream_and_lookup() {
1074 let sql = "SELECT * FROM orders o \
1075 JOIN payments p ON o.id = p.order_id \
1076 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1077 JOIN customers c ON o.customer_id = c.id";
1078 let select = parse_select(sql);
1079 let multi = analyze_joins(&select).unwrap().unwrap();
1080
1081 assert_eq!(multi.len(), 2);
1082 assert!(!multi.joins[0].is_lookup_join); assert!(multi.joins[0].time_bound.is_some());
1084 assert!(multi.joins[1].is_lookup_join); }
1086
1087 #[test]
1088 fn test_multi_join_tables_list() {
1089 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1090 let select = parse_select(sql);
1091 let multi = analyze_joins(&select).unwrap().unwrap();
1092
1093 assert_eq!(multi.tables, vec!["a", "b", "c"]);
1094 }
1095
1096 #[test]
1097 fn test_multi_join_aliases() {
1098 let sql = "SELECT * FROM orders AS o \
1099 JOIN payments AS p ON o.id = p.order_id \
1100 JOIN refunds AS r ON p.id = r.payment_id";
1101 let select = parse_select(sql);
1102 let multi = analyze_joins(&select).unwrap().unwrap();
1103
1104 assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1105 assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1106 assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1107 assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1108 }
1109
1110 #[test]
1111 fn test_multi_join_no_join_returns_none() {
1112 let sql = "SELECT * FROM orders";
1113 let select = parse_select(sql);
1114 let multi = analyze_joins(&select).unwrap();
1115 assert!(multi.is_none());
1116 }
1117}