1use std::time::Duration;
10
11use sqlparser::ast::{
12 BinaryOperator, Expr, JoinConstraint, JoinOperator, Select, TableFactor, TableVersion,
13};
14
15use super::window_rewriter::WindowRewriter;
16use super::ParseError;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum JoinType {
21 Inner,
23 Left,
25 Right,
27 Full,
29 LeftSemi,
31 LeftAnti,
33 RightSemi,
35 RightAnti,
37 AsOf,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq)]
43pub enum AsofSqlDirection {
44 Backward,
46 Forward,
48}
49
50impl std::fmt::Display for AsofSqlDirection {
51 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52 match self {
53 AsofSqlDirection::Backward => write!(f, "BACKWARD"),
54 AsofSqlDirection::Forward => write!(f, "FORWARD"),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61pub struct JoinAnalysis {
62 pub join_type: JoinType,
64 pub left_table: String,
66 pub right_table: String,
68 pub left_key_column: String,
70 pub right_key_column: String,
72 pub time_bound: Option<Duration>,
74 pub is_lookup_join: bool,
76 pub left_alias: Option<String>,
78 pub right_alias: Option<String>,
80 pub is_asof_join: bool,
82 pub asof_direction: Option<AsofSqlDirection>,
84 pub left_time_column: Option<String>,
86 pub right_time_column: Option<String>,
88 pub asof_tolerance: Option<Duration>,
90 pub is_temporal_join: bool,
92 pub temporal_version_column: Option<String>,
94 pub additional_key_columns: Vec<(String, String)>,
96}
97
98impl JoinAnalysis {
99 #[must_use]
101 pub fn stream_stream(
102 left_table: String,
103 right_table: String,
104 left_key: String,
105 right_key: String,
106 time_bound: Duration,
107 join_type: JoinType,
108 ) -> Self {
109 Self {
110 join_type,
111 left_table,
112 right_table,
113 left_key_column: left_key,
114 right_key_column: right_key,
115 time_bound: Some(time_bound),
116 is_lookup_join: false,
117 left_alias: None,
118 right_alias: None,
119 is_asof_join: false,
120 asof_direction: None,
121 left_time_column: None,
122 right_time_column: None,
123 asof_tolerance: None,
124 is_temporal_join: false,
125 temporal_version_column: None,
126 additional_key_columns: vec![],
127 }
128 }
129
130 #[must_use]
132 pub fn lookup(
133 left_table: String,
134 right_table: String,
135 left_key: String,
136 right_key: String,
137 join_type: JoinType,
138 ) -> Self {
139 Self {
140 join_type,
141 left_table,
142 right_table,
143 left_key_column: left_key,
144 right_key_column: right_key,
145 time_bound: None,
146 is_lookup_join: true,
147 left_alias: None,
148 right_alias: None,
149 is_asof_join: false,
150 asof_direction: None,
151 left_time_column: None,
152 right_time_column: None,
153 asof_tolerance: None,
154 is_temporal_join: false,
155 temporal_version_column: None,
156 additional_key_columns: vec![],
157 }
158 }
159
160 #[must_use]
162 #[allow(clippy::too_many_arguments)]
163 pub fn asof(
164 left_table: String,
165 right_table: String,
166 left_key: String,
167 right_key: String,
168 direction: AsofSqlDirection,
169 left_time_col: String,
170 right_time_col: String,
171 tolerance: Option<Duration>,
172 ) -> Self {
173 Self {
174 join_type: JoinType::AsOf,
175 left_table,
176 right_table,
177 left_key_column: left_key,
178 right_key_column: right_key,
179 time_bound: None,
180 is_lookup_join: false,
181 left_alias: None,
182 right_alias: None,
183 is_asof_join: true,
184 asof_direction: Some(direction),
185 left_time_column: Some(left_time_col),
186 right_time_column: Some(right_time_col),
187 asof_tolerance: tolerance,
188 is_temporal_join: false,
189 temporal_version_column: None,
190 additional_key_columns: vec![],
191 }
192 }
193
194 #[must_use]
196 pub fn temporal(
197 left_table: String,
198 right_table: String,
199 left_key: String,
200 right_key: String,
201 version_column: String,
202 join_type: JoinType,
203 ) -> Self {
204 Self {
205 join_type,
206 left_table,
207 right_table,
208 left_key_column: left_key,
209 right_key_column: right_key,
210 time_bound: None,
211 is_lookup_join: false,
212 left_alias: None,
213 right_alias: None,
214 is_asof_join: false,
215 asof_direction: None,
216 left_time_column: None,
217 right_time_column: None,
218 asof_tolerance: None,
219 is_temporal_join: true,
220 temporal_version_column: Some(version_column),
221 additional_key_columns: vec![],
222 }
223 }
224}
225
226pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
234 let from = &select.from;
235 if from.is_empty() {
236 return Ok(None);
237 }
238
239 let first_table = &from[0];
240 if first_table.joins.is_empty() {
241 return Ok(None);
242 }
243
244 let left_table = extract_table_name(&first_table.relation)?;
246 let left_alias = extract_table_alias(&first_table.relation);
247
248 let join = &first_table.joins[0];
250 let right_table = extract_table_name(&join.relation)?;
251 let right_alias = extract_table_alias(&join.relation);
252
253 let join_type = map_join_operator(&join.join_operator);
254
255 if let JoinOperator::AsOf {
257 match_condition,
258 constraint,
259 } = &join.join_operator
260 {
261 let (direction, left_time, right_time, tolerance) =
262 analyze_asof_match_condition(match_condition)?;
263
264 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
266
267 let mut analysis = JoinAnalysis::asof(
268 left_table,
269 right_table,
270 left_key,
271 right_key,
272 direction,
273 left_time,
274 right_time,
275 tolerance,
276 );
277 analysis.left_alias = left_alias;
278 analysis.right_alias = right_alias;
279 return Ok(Some(analysis));
280 }
281
282 if let Some(version_col) = extract_temporal_version(&join.relation) {
284 let (left_key, right_key, additional, _) = analyze_join_constraint(&join.join_operator)?;
285 let mut analysis = JoinAnalysis::temporal(
286 left_table,
287 right_table,
288 left_key,
289 right_key,
290 version_col,
291 join_type,
292 );
293 analysis.left_alias = left_alias;
294 analysis.right_alias = right_alias;
295 analysis.additional_key_columns = additional;
296 return Ok(Some(analysis));
297 }
298
299 let (left_key, right_key, additional, time_bound) =
301 analyze_join_constraint(&join.join_operator)?;
302
303 let mut analysis = if let Some(tb) = time_bound {
304 JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
305 } else {
306 JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
307 };
308
309 analysis.left_alias = left_alias;
310 analysis.right_alias = right_alias;
311 analysis.additional_key_columns = additional;
312
313 Ok(Some(analysis))
314}
315
316fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
318 match factor {
319 TableFactor::Table { name, .. } => Ok(name.to_string()),
320 TableFactor::Derived { alias, .. } => {
321 if let Some(alias) = alias {
322 Ok(alias.name.value.clone())
323 } else {
324 Err(ParseError::StreamingError(
325 "Derived table without alias not supported".to_string(),
326 ))
327 }
328 }
329 _ => Err(ParseError::StreamingError(
330 "Unsupported table factor type".to_string(),
331 )),
332 }
333}
334
335fn extract_temporal_version(factor: &TableFactor) -> Option<String> {
340 if let TableFactor::Table {
341 version: Some(TableVersion::ForSystemTimeAsOf(expr)),
342 ..
343 } = factor
344 {
345 Some(extract_column_name_from_expr(expr))
346 } else {
347 None
348 }
349}
350
351fn extract_column_name_from_expr(expr: &Expr) -> String {
355 match expr {
356 Expr::Identifier(ident) => ident.value.clone(),
357 Expr::CompoundIdentifier(parts) => parts
358 .last()
359 .map_or_else(|| expr.to_string(), |p| p.value.clone()),
360 _ => expr.to_string(),
361 }
362}
363
364fn extract_table_alias(factor: &TableFactor) -> Option<String> {
366 match factor {
367 TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
368 TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
369 _ => None,
370 }
371}
372
373fn map_join_operator(op: &JoinOperator) -> JoinType {
375 match op {
376 JoinOperator::Inner(_) | JoinOperator::Join(_) | JoinOperator::StraightJoin(_) => {
377 JoinType::Inner
378 }
379 JoinOperator::Left(_) | JoinOperator::LeftOuter(_) => JoinType::Left,
380 JoinOperator::LeftSemi(_) | JoinOperator::Semi(_) => JoinType::LeftSemi,
381 JoinOperator::LeftAnti(_) => JoinType::LeftAnti,
382 JoinOperator::AsOf { .. } => JoinType::AsOf,
383 JoinOperator::Right(_) | JoinOperator::RightOuter(_) => JoinType::Right,
384 JoinOperator::RightSemi(_) => JoinType::RightSemi,
385 JoinOperator::RightAnti(_) | JoinOperator::Anti(_) => JoinType::RightAnti,
386 JoinOperator::FullOuter(_) => JoinType::Full,
387 _ => JoinType::Inner,
389 }
390}
391
392#[allow(clippy::type_complexity)]
394fn analyze_join_constraint(
395 op: &JoinOperator,
396) -> Result<(String, String, Vec<(String, String)>, Option<Duration>), ParseError> {
397 let constraint = get_join_constraint(op)?;
398
399 match constraint {
400 JoinConstraint::On(expr) => {
401 let (key_pairs, time_bound) = analyze_on_expression(expr)?;
402 if key_pairs.is_empty() {
403 return Ok((String::new(), String::new(), vec![], time_bound));
404 }
405 let (first_left, first_right) = key_pairs[0].clone();
406 let additional = key_pairs[1..].to_vec();
407 Ok((first_left, first_right, additional, time_bound))
408 }
409 JoinConstraint::Using(cols) => {
410 if cols.is_empty() {
411 return Err(ParseError::StreamingError(
412 "USING clause requires at least one column".to_string(),
413 ));
414 }
415 let first_col = cols[0].to_string();
417 let additional: Vec<(String, String)> = cols[1..]
419 .iter()
420 .map(|c| {
421 let col = c.to_string();
422 (col.clone(), col)
423 })
424 .collect();
425 Ok((first_col.clone(), first_col, additional, None))
426 }
427 JoinConstraint::Natural => Err(ParseError::StreamingError(
428 "NATURAL JOIN not supported for streaming".to_string(),
429 )),
430 JoinConstraint::None => Err(ParseError::StreamingError(
431 "JOIN without condition not supported for streaming".to_string(),
432 )),
433 }
434}
435
436fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
438 match op {
439 JoinOperator::Inner(constraint)
440 | JoinOperator::Join(constraint)
441 | JoinOperator::Left(constraint)
442 | JoinOperator::LeftOuter(constraint)
443 | JoinOperator::Right(constraint)
444 | JoinOperator::RightOuter(constraint)
445 | JoinOperator::FullOuter(constraint)
446 | JoinOperator::LeftSemi(constraint)
447 | JoinOperator::RightSemi(constraint)
448 | JoinOperator::LeftAnti(constraint)
449 | JoinOperator::RightAnti(constraint)
450 | JoinOperator::Semi(constraint)
451 | JoinOperator::Anti(constraint)
452 | JoinOperator::StraightJoin(constraint)
453 | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
454 JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
455 ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
456 ),
457 }
458}
459
460#[allow(clippy::type_complexity)]
462fn analyze_on_expression(
463 expr: &Expr,
464) -> Result<(Vec<(String, String)>, Option<Duration>), ParseError> {
465 match expr {
467 Expr::BinaryOp {
468 left,
469 op: BinaryOperator::And,
470 right,
471 } => {
472 let left_result = analyze_on_expression(left);
474 let right_result = analyze_on_expression(right);
475
476 match (left_result, right_result) {
478 (Ok((mut lk, lt)), Ok((rk, rt))) => {
479 lk.extend(rk);
480 Ok((lk, lt.or(rt)))
481 }
482 (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
483 (Err(e), Err(_)) => Err(e),
484 }
485 }
486 Expr::BinaryOp {
488 left,
489 op: BinaryOperator::Eq,
490 right,
491 } => {
492 let left_col = extract_column_ref(left);
493 let right_col = extract_column_ref(right);
494
495 match (left_col, right_col) {
496 (Some(l), Some(r)) => Ok((vec![(l, r)], None)),
497 _ => Err(ParseError::StreamingError(
498 "Cannot extract column references from equality condition".to_string(),
499 )),
500 }
501 }
502 Expr::Between {
504 expr: _,
505 low: _,
506 high,
507 ..
508 } => {
509 let time_bound = extract_time_bound_from_expr(high).ok();
511 Ok((vec![], time_bound))
512 }
513 Expr::BinaryOp {
515 left: _,
516 op:
517 BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
518 right,
519 } => {
520 let time_bound = extract_time_bound_from_expr(right).ok();
522 Ok((vec![], time_bound))
523 }
524 _ => Err(ParseError::StreamingError(format!(
525 "Unsupported join condition expression: {expr:?}"
526 ))),
527 }
528}
529
530fn extract_column_ref(expr: &Expr) -> Option<String> {
532 match expr {
533 Expr::Identifier(ident) => Some(ident.value.clone()),
534 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
535 _ => None,
536 }
537}
538
539fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
541 match expr {
542 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
544 Expr::BinaryOp {
546 left: _,
547 op: BinaryOperator::Plus | BinaryOperator::Minus,
548 right,
549 } => extract_time_bound_from_expr(right),
550 Expr::Nested(inner) => extract_time_bound_from_expr(inner),
552 _ => Err(ParseError::StreamingError(format!(
553 "Cannot extract time bound from: {expr:?}"
554 ))),
555 }
556}
557
558fn analyze_asof_match_condition(
562 expr: &Expr,
563) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
564 if let Expr::BinaryOp {
565 left,
566 op: BinaryOperator::And,
567 right,
568 } = expr
569 {
570 let dir_result = analyze_asof_direction(left);
572 let tol_result = extract_asof_tolerance(right);
573
574 match (dir_result, tol_result) {
575 (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
576 (Ok((dir, lt, rt)), Err(_)) => {
577 let dir2 = analyze_asof_direction(right);
579 let tol2 = extract_asof_tolerance(left);
580 match (dir2, tol2) {
581 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
582 _ => Ok((dir, lt, rt, None)),
583 }
584 }
585 (Err(_), _) => {
586 let dir2 = analyze_asof_direction(right);
588 let tol2 = extract_asof_tolerance(left);
589 match (dir2, tol2) {
590 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
591 (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
592 _ => Err(ParseError::StreamingError(
593 "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
594 )),
595 }
596 }
597 }
598 } else {
599 let (dir, lt, rt) = analyze_asof_direction(expr)?;
600 Ok((dir, lt, rt, None))
601 }
602}
603
604fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
606 match expr {
607 Expr::BinaryOp {
608 left,
609 op: BinaryOperator::GtEq,
610 right,
611 } => {
612 let left_col = extract_column_ref(left).ok_or_else(|| {
613 ParseError::StreamingError(
614 "Cannot extract left time column from MATCH_CONDITION".to_string(),
615 )
616 })?;
617 let right_col = extract_column_ref(right).ok_or_else(|| {
618 ParseError::StreamingError(
619 "Cannot extract right time column from MATCH_CONDITION".to_string(),
620 )
621 })?;
622 Ok((AsofSqlDirection::Backward, left_col, right_col))
623 }
624 Expr::BinaryOp {
625 left,
626 op: BinaryOperator::LtEq,
627 right,
628 } => {
629 let left_col = extract_column_ref(left).ok_or_else(|| {
630 ParseError::StreamingError(
631 "Cannot extract left time column from MATCH_CONDITION".to_string(),
632 )
633 })?;
634 let right_col = extract_column_ref(right).ok_or_else(|| {
635 ParseError::StreamingError(
636 "Cannot extract right time column from MATCH_CONDITION".to_string(),
637 )
638 })?;
639 Ok((AsofSqlDirection::Forward, left_col, right_col))
640 }
641 _ => Err(ParseError::StreamingError(
642 "ASOF MATCH_CONDITION must be >= or <= comparison".to_string(),
643 )),
644 }
645}
646
647fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
651 match expr {
652 Expr::BinaryOp {
653 left: _,
654 op: BinaryOperator::LtEq,
655 right,
656 } => {
657 match right.as_ref() {
659 Expr::Value(v) => {
660 if let sqlparser::ast::Value::Number(n, _) = &v.value {
661 let ms: u64 = n.parse().map_err(|_| {
662 ParseError::StreamingError(format!(
663 "Cannot parse tolerance as number: {n}"
664 ))
665 })?;
666 Ok(Duration::from_millis(ms))
667 } else {
668 Err(ParseError::StreamingError(
669 "ASOF tolerance must be a number or INTERVAL".to_string(),
670 ))
671 }
672 }
673 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
674 _ => Err(ParseError::StreamingError(
675 "ASOF tolerance must be a number or INTERVAL".to_string(),
676 )),
677 }
678 }
679 _ => Err(ParseError::StreamingError(
680 "ASOF tolerance expression must be <= comparison".to_string(),
681 )),
682 }
683}
684
685fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
687 match constraint {
688 JoinConstraint::On(expr) => extract_equality_columns(expr),
689 JoinConstraint::Using(cols) => {
690 if cols.is_empty() {
691 return Err(ParseError::StreamingError(
692 "USING clause requires at least one column".to_string(),
693 ));
694 }
695 let col = cols[0].to_string();
696 Ok((col.clone(), col))
697 }
698 _ => Err(ParseError::StreamingError(
699 "ASOF JOIN requires ON or USING constraint".to_string(),
700 )),
701 }
702}
703
704fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
706 match expr {
707 Expr::BinaryOp {
708 left,
709 op: BinaryOperator::Eq,
710 right,
711 } => {
712 let left_col = extract_column_ref(left).ok_or_else(|| {
713 ParseError::StreamingError("Cannot extract left key column".to_string())
714 })?;
715 let right_col = extract_column_ref(right).ok_or_else(|| {
716 ParseError::StreamingError("Cannot extract right key column".to_string())
717 })?;
718 Ok((left_col, right_col))
719 }
720 Expr::BinaryOp {
722 left,
723 op: BinaryOperator::And,
724 right,
725 } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
726 _ => Err(ParseError::StreamingError(
727 "ASOF JOIN ON clause must contain an equality condition".to_string(),
728 )),
729 }
730}
731
732#[must_use]
734pub fn has_join(select: &Select) -> bool {
735 !select.from.is_empty() && !select.from[0].joins.is_empty()
736}
737
738#[must_use]
740pub fn count_joins(select: &Select) -> usize {
741 select
742 .from
743 .iter()
744 .map(|table_with_joins| table_with_joins.joins.len())
745 .sum()
746}
747
748#[derive(Debug, Clone)]
753pub struct MultiJoinAnalysis {
754 pub joins: Vec<JoinAnalysis>,
756 pub tables: Vec<String>,
758}
759
760impl MultiJoinAnalysis {
761 #[must_use]
763 pub fn len(&self) -> usize {
764 self.joins.len()
765 }
766
767 #[must_use]
769 pub fn is_empty(&self) -> bool {
770 self.joins.is_empty()
771 }
772
773 #[must_use]
775 pub fn is_single(&self) -> bool {
776 self.joins.len() == 1
777 }
778
779 #[must_use]
781 pub fn first(&self) -> Option<&JoinAnalysis> {
782 self.joins.first()
783 }
784}
785
786pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
797 let from = &select.from;
798 if from.is_empty() {
799 return Ok(None);
800 }
801
802 let first_table = &from[0];
803 if first_table.joins.is_empty() {
804 return Ok(None);
805 }
806
807 let base_table = extract_table_name(&first_table.relation)?;
809 let base_alias = extract_table_alias(&first_table.relation);
810
811 let mut join_steps = Vec::with_capacity(first_table.joins.len());
812 let mut tables = vec![base_table.clone()];
813
814 let mut prev_left_table = base_table;
816 let mut prev_left_alias = base_alias;
817
818 for join in &first_table.joins {
819 let right_table = extract_table_name(&join.relation)?;
820 let right_alias = extract_table_alias(&join.relation);
821 tables.push(right_table.clone());
822
823 let join_type = map_join_operator(&join.join_operator);
824
825 if let JoinOperator::AsOf {
827 match_condition,
828 constraint,
829 } = &join.join_operator
830 {
831 let (direction, left_time, right_time, tolerance) =
832 analyze_asof_match_condition(match_condition)?;
833 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
834
835 let mut analysis = JoinAnalysis::asof(
836 prev_left_table.clone(),
837 right_table.clone(),
838 left_key,
839 right_key,
840 direction,
841 left_time,
842 right_time,
843 tolerance,
844 );
845 analysis.left_alias.clone_from(&prev_left_alias);
846 analysis.right_alias = right_alias;
847 join_steps.push(analysis);
848 } else if let Some(version_col) = extract_temporal_version(&join.relation) {
849 let (left_key, right_key, additional, _) =
851 analyze_join_constraint(&join.join_operator)?;
852
853 let mut analysis = JoinAnalysis::temporal(
854 prev_left_table.clone(),
855 right_table.clone(),
856 left_key,
857 right_key,
858 version_col,
859 join_type,
860 );
861 analysis.left_alias.clone_from(&prev_left_alias);
862 analysis.right_alias = right_alias;
863 analysis.additional_key_columns = additional;
864 join_steps.push(analysis);
865 } else {
866 let (left_key, right_key, additional, time_bound) =
868 analyze_join_constraint(&join.join_operator)?;
869
870 let mut analysis = if let Some(tb) = time_bound {
871 JoinAnalysis::stream_stream(
872 prev_left_table.clone(),
873 right_table.clone(),
874 left_key,
875 right_key,
876 tb,
877 join_type,
878 )
879 } else {
880 JoinAnalysis::lookup(
881 prev_left_table.clone(),
882 right_table.clone(),
883 left_key,
884 right_key,
885 join_type,
886 )
887 };
888 analysis.left_alias.clone_from(&prev_left_alias);
889 analysis.right_alias = right_alias;
890 analysis.additional_key_columns = additional;
891 join_steps.push(analysis);
892 }
893
894 prev_left_table = right_table;
896 prev_left_alias = extract_table_alias(&join.relation);
897 }
898
899 Ok(Some(MultiJoinAnalysis {
900 joins: join_steps,
901 tables,
902 }))
903}
904
905#[cfg(test)]
906mod tests {
907 use super::*;
908 use sqlparser::ast::{SetExpr, Statement};
909 use sqlparser::dialect::GenericDialect;
910 use sqlparser::parser::Parser;
911
912 fn parse_select(sql: &str) -> Select {
913 let dialect = GenericDialect {};
914 let statements = Parser::parse_sql(&dialect, sql).unwrap();
915 if let Statement::Query(query) = &statements[0] {
916 if let SetExpr::Select(select) = query.body.as_ref() {
917 return *select.clone();
918 }
919 }
920 panic!("Expected SELECT query");
921 }
922
923 #[test]
924 fn test_analyze_inner_join() {
925 let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
926 let select = parse_select(sql);
927
928 let analysis = analyze_join(&select).unwrap().unwrap();
929
930 assert_eq!(analysis.join_type, JoinType::Inner);
931 assert_eq!(analysis.left_table, "orders");
932 assert_eq!(analysis.right_table, "payments");
933 assert_eq!(analysis.left_key_column, "order_id");
934 assert_eq!(analysis.right_key_column, "order_id");
935 assert!(analysis.is_lookup_join); }
937
938 #[test]
939 fn test_analyze_left_join() {
940 let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
941 let select = parse_select(sql);
942
943 let analysis = analyze_join(&select).unwrap().unwrap();
944
945 assert_eq!(analysis.join_type, JoinType::Left);
946 assert_eq!(analysis.left_key_column, "customer_id");
947 assert_eq!(analysis.right_key_column, "id");
948 }
949
950 #[test]
951 fn test_analyze_join_using() {
952 let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
953 let select = parse_select(sql);
954
955 let analysis = analyze_join(&select).unwrap().unwrap();
956
957 assert_eq!(analysis.left_key_column, "order_id");
958 assert_eq!(analysis.right_key_column, "order_id");
959 }
960
961 #[test]
962 fn test_analyze_stream_stream_join_with_time_bound() {
963 let sql = "SELECT * FROM orders o
964 JOIN payments p ON o.order_id = p.order_id
965 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
966 let select = parse_select(sql);
967
968 let analysis = analyze_join(&select).unwrap().unwrap();
969
970 assert!(!analysis.is_lookup_join);
971 assert!(analysis.time_bound.is_some());
972 assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
973 }
974
975 #[test]
976 fn test_no_join() {
977 let sql = "SELECT * FROM orders";
978 let select = parse_select(sql);
979
980 let analysis = analyze_join(&select).unwrap();
981 assert!(analysis.is_none());
982 }
983
984 #[test]
985 fn test_has_join() {
986 let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
987 let sql_without_join = "SELECT * FROM orders";
988
989 let select_with = parse_select(sql_with_join);
990 let select_without = parse_select(sql_without_join);
991
992 assert!(has_join(&select_with));
993 assert!(!has_join(&select_without));
994 }
995
996 #[test]
997 fn test_count_joins() {
998 let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
999 let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
1000 let sql_zero = "SELECT * FROM a";
1001
1002 assert_eq!(count_joins(&parse_select(sql_one)), 1);
1003 assert_eq!(count_joins(&parse_select(sql_two)), 2);
1004 assert_eq!(count_joins(&parse_select(sql_zero)), 0);
1005 }
1006
1007 #[test]
1008 fn test_aliases() {
1009 let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
1010 let select = parse_select(sql);
1011
1012 let analysis = analyze_join(&select).unwrap().unwrap();
1013
1014 assert_eq!(analysis.left_alias, Some("o".to_string()));
1015 assert_eq!(analysis.right_alias, Some("p".to_string()));
1016 }
1017
1018 fn parse_select_snowflake(sql: &str) -> Select {
1021 let dialect = sqlparser::dialect::SnowflakeDialect {};
1022 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1023 if let Statement::Query(query) = &statements[0] {
1024 if let SetExpr::Select(select) = query.body.as_ref() {
1025 return *select.clone();
1026 }
1027 }
1028 panic!("Expected SELECT query");
1029 }
1030
1031 fn parse_select_laminar(sql: &str) -> Select {
1032 let dialect = crate::parser::dialect::LaminarDialect::default();
1033 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1034 if let Statement::Query(query) = &statements[0] {
1035 if let SetExpr::Select(select) = query.body.as_ref() {
1036 return *select.clone();
1037 }
1038 }
1039 panic!("Expected SELECT query");
1040 }
1041
1042 #[test]
1043 fn test_asof_join_backward() {
1044 let sql = "SELECT * FROM trades t \
1045 ASOF JOIN quotes q \
1046 MATCH_CONDITION(t.ts >= q.ts) \
1047 ON t.symbol = q.symbol";
1048 let select = parse_select_snowflake(sql);
1049 let analysis = analyze_join(&select).unwrap().unwrap();
1050
1051 assert!(analysis.is_asof_join);
1052 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1053 assert_eq!(analysis.join_type, JoinType::AsOf);
1054 assert!(analysis.asof_tolerance.is_none());
1055 }
1056
1057 #[test]
1058 fn test_asof_join_forward() {
1059 let sql = "SELECT * FROM trades t \
1060 ASOF JOIN quotes q \
1061 MATCH_CONDITION(t.ts <= q.ts) \
1062 ON t.symbol = q.symbol";
1063 let select = parse_select_snowflake(sql);
1064 let analysis = analyze_join(&select).unwrap().unwrap();
1065
1066 assert!(analysis.is_asof_join);
1067 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
1068 }
1069
1070 #[test]
1071 fn test_asof_join_with_tolerance() {
1072 let sql = "SELECT * FROM trades t \
1073 ASOF JOIN quotes q \
1074 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
1075 ON t.symbol = q.symbol";
1076 let select = parse_select_snowflake(sql);
1077 let analysis = analyze_join(&select).unwrap().unwrap();
1078
1079 assert!(analysis.is_asof_join);
1080 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1081 assert_eq!(analysis.asof_tolerance, Some(Duration::from_millis(5000)));
1082 }
1083
1084 #[test]
1085 fn test_asof_join_with_interval_tolerance() {
1086 let sql = "SELECT * FROM trades t \
1087 ASOF JOIN quotes q \
1088 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
1089 ON t.symbol = q.symbol";
1090 let select = parse_select_snowflake(sql);
1091 let analysis = analyze_join(&select).unwrap().unwrap();
1092
1093 assert!(analysis.is_asof_join);
1094 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1095 assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1096 }
1097
1098 #[test]
1099 fn test_asof_join_type_mapping() {
1100 let sql = "SELECT * FROM trades t \
1101 ASOF JOIN quotes q \
1102 MATCH_CONDITION(t.ts >= q.ts) \
1103 ON t.symbol = q.symbol";
1104 let select = parse_select_snowflake(sql);
1105 let analysis = analyze_join(&select).unwrap().unwrap();
1106
1107 assert_eq!(analysis.join_type, JoinType::AsOf);
1108 assert!(!analysis.is_lookup_join);
1109 }
1110
1111 #[test]
1112 fn test_asof_join_extracts_time_columns() {
1113 let sql = "SELECT * FROM trades t \
1114 ASOF JOIN quotes q \
1115 MATCH_CONDITION(t.ts >= q.ts) \
1116 ON t.symbol = q.symbol";
1117 let select = parse_select_snowflake(sql);
1118 let analysis = analyze_join(&select).unwrap().unwrap();
1119
1120 assert_eq!(analysis.left_time_column, Some("ts".to_string()));
1121 assert_eq!(analysis.right_time_column, Some("ts".to_string()));
1122 }
1123
1124 #[test]
1125 fn test_asof_join_extracts_key_columns() {
1126 let sql = "SELECT * FROM trades t \
1127 ASOF JOIN quotes q \
1128 MATCH_CONDITION(t.ts >= q.ts) \
1129 ON t.symbol = q.symbol";
1130 let select = parse_select_snowflake(sql);
1131 let analysis = analyze_join(&select).unwrap().unwrap();
1132
1133 assert_eq!(analysis.left_key_column, "symbol");
1134 assert_eq!(analysis.right_key_column, "symbol");
1135 }
1136
1137 #[test]
1138 fn test_asof_join_aliases() {
1139 let sql = "SELECT * FROM trades AS t \
1140 ASOF JOIN quotes AS q \
1141 MATCH_CONDITION(t.ts >= q.ts) \
1142 ON t.symbol = q.symbol";
1143 let select = parse_select_snowflake(sql);
1144 let analysis = analyze_join(&select).unwrap().unwrap();
1145
1146 assert_eq!(analysis.left_alias, Some("t".to_string()));
1147 assert_eq!(analysis.right_alias, Some("q".to_string()));
1148 assert_eq!(analysis.left_table, "trades");
1149 assert_eq!(analysis.right_table, "quotes");
1150 }
1151
1152 #[test]
1155 fn test_multi_join_single_backward_compat() {
1156 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1157 let select = parse_select(sql);
1158 let multi = analyze_joins(&select).unwrap().unwrap();
1159
1160 assert!(multi.is_single());
1161 assert_eq!(multi.len(), 1);
1162 assert!(!multi.is_empty());
1163 let first = multi.first().unwrap();
1164 assert_eq!(first.left_table, "orders");
1165 assert_eq!(first.right_table, "payments");
1166 }
1167
1168 #[test]
1169 fn test_multi_join_two_way() {
1170 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1171 let select = parse_select(sql);
1172 let multi = analyze_joins(&select).unwrap().unwrap();
1173
1174 assert_eq!(multi.len(), 2);
1175 assert!(!multi.is_single());
1176
1177 assert_eq!(multi.joins[0].left_table, "a");
1178 assert_eq!(multi.joins[0].right_table, "b");
1179 assert_eq!(multi.joins[0].left_key_column, "id");
1180 assert_eq!(multi.joins[0].right_key_column, "a_id");
1181
1182 assert_eq!(multi.joins[1].left_table, "b");
1183 assert_eq!(multi.joins[1].right_table, "c");
1184 assert_eq!(multi.joins[1].left_key_column, "id");
1185 assert_eq!(multi.joins[1].right_key_column, "b_id");
1186 }
1187
1188 #[test]
1189 fn test_multi_join_three_way() {
1190 let sql = "SELECT * FROM a \
1191 JOIN b ON a.id = b.a_id \
1192 JOIN c ON b.id = c.b_id \
1193 JOIN d ON c.id = d.c_id";
1194 let select = parse_select(sql);
1195 let multi = analyze_joins(&select).unwrap().unwrap();
1196
1197 assert_eq!(multi.len(), 3);
1198 assert_eq!(multi.tables.len(), 4);
1199 assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1200 }
1201
1202 #[test]
1203 fn test_multi_join_mixed_asof_and_lookup() {
1204 let sql = "SELECT * FROM trades t \
1206 ASOF JOIN quotes q \
1207 MATCH_CONDITION(t.ts >= q.ts) \
1208 ON t.symbol = q.symbol \
1209 JOIN products p ON q.product_id = p.id";
1210 let select = parse_select_snowflake(sql);
1211 let multi = analyze_joins(&select).unwrap().unwrap();
1212
1213 assert_eq!(multi.len(), 2);
1214 assert!(multi.joins[0].is_asof_join);
1215 assert!(multi.joins[1].is_lookup_join);
1216 }
1217
1218 #[test]
1219 fn test_multi_join_stream_stream_and_lookup() {
1220 let sql = "SELECT * FROM orders o \
1221 JOIN payments p ON o.id = p.order_id \
1222 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1223 JOIN customers c ON o.customer_id = c.id";
1224 let select = parse_select(sql);
1225 let multi = analyze_joins(&select).unwrap().unwrap();
1226
1227 assert_eq!(multi.len(), 2);
1228 assert!(!multi.joins[0].is_lookup_join); assert!(multi.joins[0].time_bound.is_some());
1230 assert!(multi.joins[1].is_lookup_join); }
1232
1233 #[test]
1234 fn test_multi_join_tables_list() {
1235 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1236 let select = parse_select(sql);
1237 let multi = analyze_joins(&select).unwrap().unwrap();
1238
1239 assert_eq!(multi.tables, vec!["a", "b", "c"]);
1240 }
1241
1242 #[test]
1243 fn test_multi_join_aliases() {
1244 let sql = "SELECT * FROM orders AS o \
1245 JOIN payments AS p ON o.id = p.order_id \
1246 JOIN refunds AS r ON p.id = r.payment_id";
1247 let select = parse_select(sql);
1248 let multi = analyze_joins(&select).unwrap().unwrap();
1249
1250 assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1251 assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1252 assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1253 assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1254 }
1255
1256 #[test]
1257 fn test_multi_join_no_join_returns_none() {
1258 let sql = "SELECT * FROM orders";
1259 let select = parse_select(sql);
1260 let multi = analyze_joins(&select).unwrap();
1261 assert!(multi.is_none());
1262 }
1263
1264 #[test]
1267 fn test_temporal_join_detected() {
1268 let sql = "SELECT o.*, p.price \
1269 FROM orders o \
1270 JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1271 ON o.product_id = p.id";
1272 let select = parse_select_laminar(sql);
1273 let analysis = analyze_join(&select).unwrap().unwrap();
1274
1275 assert!(analysis.is_temporal_join);
1276 assert_eq!(
1277 analysis.temporal_version_column,
1278 Some("order_time".to_string())
1279 );
1280 assert_eq!(analysis.left_table, "orders");
1281 assert_eq!(analysis.right_table, "products");
1282 assert_eq!(analysis.left_key_column, "product_id");
1283 assert_eq!(analysis.right_key_column, "id");
1284 assert!(!analysis.is_lookup_join);
1285 assert!(!analysis.is_asof_join);
1286 }
1287
1288 #[test]
1289 fn test_temporal_join_via_analyze_joins() {
1290 let sql = "SELECT o.*, p.price \
1291 FROM orders o \
1292 JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1293 ON o.product_id = p.id";
1294 let select = parse_select_laminar(sql);
1295 let multi = analyze_joins(&select).unwrap().unwrap();
1296
1297 assert_eq!(multi.len(), 1);
1298 let first = multi.first().unwrap();
1299 assert!(first.is_temporal_join);
1300 assert_eq!(
1301 first.temporal_version_column,
1302 Some("order_time".to_string())
1303 );
1304 }
1305
1306 #[test]
1307 fn test_non_temporal_join_not_flagged() {
1308 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1309 let select = parse_select(sql);
1310 let analysis = analyze_join(&select).unwrap().unwrap();
1311
1312 assert!(!analysis.is_temporal_join);
1313 assert!(analysis.temporal_version_column.is_none());
1314 }
1315
1316 #[test]
1317 fn test_composite_join_keys() {
1318 let sql = "SELECT * FROM orders o \
1319 JOIN shipments s \
1320 ON o.order_id = s.order_id AND o.region = s.region";
1321 let select = parse_select(sql);
1322 let analysis = analyze_join(&select).unwrap().unwrap();
1323
1324 assert_eq!(analysis.left_key_column, "order_id");
1326 assert_eq!(analysis.right_key_column, "order_id");
1327
1328 assert_eq!(
1330 analysis.additional_key_columns.len(),
1331 1,
1332 "Should have 1 additional key pair"
1333 );
1334 assert_eq!(analysis.additional_key_columns[0].0, "region");
1335 assert_eq!(analysis.additional_key_columns[0].1, "region");
1336 }
1337
1338 #[test]
1339 fn test_composite_using_clause() {
1340 let sql = "SELECT * FROM orders o JOIN shipments s USING (order_id, region)";
1341 let select = parse_select(sql);
1342 let analysis = analyze_join(&select).unwrap().unwrap();
1343
1344 assert_eq!(analysis.left_key_column, "order_id");
1346 assert_eq!(analysis.right_key_column, "order_id");
1347
1348 assert_eq!(
1350 analysis.additional_key_columns.len(),
1351 1,
1352 "USING(order_id, region) should have 1 additional key"
1353 );
1354 assert_eq!(analysis.additional_key_columns[0].0, "region");
1355 assert_eq!(analysis.additional_key_columns[0].1, "region");
1356 }
1357}