1use std::time::Duration;
10
11use sqlparser::ast::{
12 BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, JoinConstraint,
13 JoinOperator, Select, TableFactor, TableVersion,
14};
15
16use super::window_rewriter::WindowRewriter;
17use super::ParseError;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum JoinType {
22 Inner,
24 Left,
26 Right,
28 Full,
30 LeftSemi,
32 LeftAnti,
34 RightSemi,
36 RightAnti,
38 AsOf,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum AsofSqlDirection {
45 Backward,
47 Forward,
49 Nearest,
51}
52
53impl std::fmt::Display for AsofSqlDirection {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 AsofSqlDirection::Backward => write!(f, "BACKWARD"),
57 AsofSqlDirection::Forward => write!(f, "FORWARD"),
58 AsofSqlDirection::Nearest => write!(f, "NEAREST"),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct JoinAnalysis {
66 pub join_type: JoinType,
68 pub left_table: String,
70 pub right_table: String,
72 pub left_key_column: String,
74 pub right_key_column: String,
76 pub time_bound: Option<Duration>,
78 pub is_lookup_join: bool,
80 pub left_alias: Option<String>,
82 pub right_alias: Option<String>,
84 pub is_asof_join: bool,
86 pub asof_direction: Option<AsofSqlDirection>,
88 pub left_time_column: Option<String>,
90 pub right_time_column: Option<String>,
92 pub asof_tolerance: Option<Duration>,
94 pub is_temporal_join: bool,
96 pub temporal_version_column: Option<String>,
98 pub additional_key_columns: Vec<(String, String)>,
100}
101
102impl JoinAnalysis {
103 #[must_use]
105 pub fn stream_stream(
106 left_table: String,
107 right_table: String,
108 left_key: String,
109 right_key: String,
110 time_bound: Duration,
111 join_type: JoinType,
112 ) -> Self {
113 Self {
114 join_type,
115 left_table,
116 right_table,
117 left_key_column: left_key,
118 right_key_column: right_key,
119 time_bound: Some(time_bound),
120 is_lookup_join: false,
121 left_alias: None,
122 right_alias: None,
123 is_asof_join: false,
124 asof_direction: None,
125 left_time_column: None,
126 right_time_column: None,
127 asof_tolerance: None,
128 is_temporal_join: false,
129 temporal_version_column: None,
130 additional_key_columns: vec![],
131 }
132 }
133
134 #[must_use]
136 pub fn lookup(
137 left_table: String,
138 right_table: String,
139 left_key: String,
140 right_key: String,
141 join_type: JoinType,
142 ) -> Self {
143 Self {
144 join_type,
145 left_table,
146 right_table,
147 left_key_column: left_key,
148 right_key_column: right_key,
149 time_bound: None,
150 is_lookup_join: true,
151 left_alias: None,
152 right_alias: None,
153 is_asof_join: false,
154 asof_direction: None,
155 left_time_column: None,
156 right_time_column: None,
157 asof_tolerance: None,
158 is_temporal_join: false,
159 temporal_version_column: None,
160 additional_key_columns: vec![],
161 }
162 }
163
164 #[must_use]
166 #[allow(clippy::too_many_arguments)]
167 pub fn asof(
168 left_table: String,
169 right_table: String,
170 left_key: String,
171 right_key: String,
172 direction: AsofSqlDirection,
173 left_time_col: String,
174 right_time_col: String,
175 tolerance: Option<Duration>,
176 ) -> Self {
177 Self {
178 join_type: JoinType::AsOf,
179 left_table,
180 right_table,
181 left_key_column: left_key,
182 right_key_column: right_key,
183 time_bound: None,
184 is_lookup_join: false,
185 left_alias: None,
186 right_alias: None,
187 is_asof_join: true,
188 asof_direction: Some(direction),
189 left_time_column: Some(left_time_col),
190 right_time_column: Some(right_time_col),
191 asof_tolerance: tolerance,
192 is_temporal_join: false,
193 temporal_version_column: None,
194 additional_key_columns: vec![],
195 }
196 }
197
198 #[must_use]
200 pub fn temporal(
201 left_table: String,
202 right_table: String,
203 left_key: String,
204 right_key: String,
205 version_column: String,
206 join_type: JoinType,
207 ) -> Self {
208 Self {
209 join_type,
210 left_table,
211 right_table,
212 left_key_column: left_key,
213 right_key_column: right_key,
214 time_bound: None,
215 is_lookup_join: false,
216 left_alias: None,
217 right_alias: None,
218 is_asof_join: false,
219 asof_direction: None,
220 left_time_column: None,
221 right_time_column: None,
222 asof_tolerance: None,
223 is_temporal_join: true,
224 temporal_version_column: Some(version_column),
225 additional_key_columns: vec![],
226 }
227 }
228}
229
230pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
238 let from = &select.from;
239 if from.is_empty() {
240 return Ok(None);
241 }
242
243 let first_table = &from[0];
244 if first_table.joins.is_empty() {
245 return Ok(None);
246 }
247
248 let left_table = extract_table_name(&first_table.relation)?;
250 let left_alias = extract_table_alias(&first_table.relation);
251
252 let join = &first_table.joins[0];
254 let right_table = extract_table_name(&join.relation)?;
255 let right_alias = extract_table_alias(&join.relation);
256
257 let join_type = map_join_operator(&join.join_operator);
258
259 if let JoinOperator::AsOf {
261 match_condition,
262 constraint,
263 } = &join.join_operator
264 {
265 let (direction, left_time, right_time, tolerance) =
266 analyze_asof_match_condition(match_condition)?;
267
268 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
270
271 let mut analysis = JoinAnalysis::asof(
272 left_table,
273 right_table,
274 left_key,
275 right_key,
276 direction,
277 left_time,
278 right_time,
279 tolerance,
280 );
281 analysis.left_alias = left_alias;
282 analysis.right_alias = right_alias;
283 return Ok(Some(analysis));
284 }
285
286 if let Some(version_col) = extract_temporal_version(&join.relation) {
288 let (left_key, right_key, additional, _) = analyze_join_constraint(&join.join_operator)?;
289 let mut analysis = JoinAnalysis::temporal(
290 left_table,
291 right_table,
292 left_key,
293 right_key,
294 version_col,
295 join_type,
296 );
297 analysis.left_alias = left_alias;
298 analysis.right_alias = right_alias;
299 analysis.additional_key_columns = additional;
300 return Ok(Some(analysis));
301 }
302
303 let (left_key, right_key, additional, time_bound) =
305 analyze_join_constraint(&join.join_operator)?;
306
307 let mut analysis = if let Some(tb) = time_bound {
308 JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
309 } else {
310 JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
311 };
312
313 analysis.left_alias = left_alias;
314 analysis.right_alias = right_alias;
315 analysis.additional_key_columns = additional;
316
317 Ok(Some(analysis))
318}
319
320fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
322 match factor {
323 TableFactor::Table { name, .. } => Ok(name.to_string()),
324 TableFactor::Derived { alias, .. } => {
325 if let Some(alias) = alias {
326 Ok(alias.name.value.clone())
327 } else {
328 Err(ParseError::StreamingError(
329 "Derived table without alias not supported".to_string(),
330 ))
331 }
332 }
333 _ => Err(ParseError::StreamingError(
334 "Unsupported table factor type".to_string(),
335 )),
336 }
337}
338
339fn extract_temporal_version(factor: &TableFactor) -> Option<String> {
344 if let TableFactor::Table {
345 version: Some(TableVersion::ForSystemTimeAsOf(expr)),
346 ..
347 } = factor
348 {
349 Some(extract_column_name_from_expr(expr))
350 } else {
351 None
352 }
353}
354
355fn extract_column_name_from_expr(expr: &Expr) -> String {
359 match expr {
360 Expr::Identifier(ident) => ident.value.clone(),
361 Expr::CompoundIdentifier(parts) => parts
362 .last()
363 .map_or_else(|| expr.to_string(), |p| p.value.clone()),
364 _ => expr.to_string(),
365 }
366}
367
368fn extract_table_alias(factor: &TableFactor) -> Option<String> {
370 match factor {
371 TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
372 TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
373 _ => None,
374 }
375}
376
377fn map_join_operator(op: &JoinOperator) -> JoinType {
379 match op {
380 JoinOperator::Inner(_) | JoinOperator::Join(_) | JoinOperator::StraightJoin(_) => {
381 JoinType::Inner
382 }
383 JoinOperator::Left(_) | JoinOperator::LeftOuter(_) => JoinType::Left,
384 JoinOperator::LeftSemi(_) | JoinOperator::Semi(_) => JoinType::LeftSemi,
385 JoinOperator::LeftAnti(_) | JoinOperator::Anti(_) => JoinType::LeftAnti,
386 JoinOperator::AsOf { .. } => JoinType::AsOf,
387 JoinOperator::Right(_) | JoinOperator::RightOuter(_) => JoinType::Right,
388 JoinOperator::RightSemi(_) => JoinType::RightSemi,
389 JoinOperator::RightAnti(_) => JoinType::RightAnti,
390 JoinOperator::FullOuter(_) => JoinType::Full,
391 _ => JoinType::Inner,
393 }
394}
395
396#[allow(clippy::type_complexity)]
398fn analyze_join_constraint(
399 op: &JoinOperator,
400) -> Result<(String, String, Vec<(String, String)>, Option<Duration>), ParseError> {
401 let constraint = get_join_constraint(op)?;
402
403 match constraint {
404 JoinConstraint::On(expr) => {
405 let (key_pairs, time_bound) = analyze_on_expression(expr)?;
406 if key_pairs.is_empty() {
407 return Ok((String::new(), String::new(), vec![], time_bound));
408 }
409 let (first_left, first_right) = key_pairs[0].clone();
410 let additional = key_pairs[1..].to_vec();
411 Ok((first_left, first_right, additional, time_bound))
412 }
413 JoinConstraint::Using(cols) => {
414 if cols.is_empty() {
415 return Err(ParseError::StreamingError(
416 "USING clause requires at least one column".to_string(),
417 ));
418 }
419 let first_col = cols[0].to_string();
421 let additional: Vec<(String, String)> = cols[1..]
423 .iter()
424 .map(|c| {
425 let col = c.to_string();
426 (col.clone(), col)
427 })
428 .collect();
429 Ok((first_col.clone(), first_col, additional, None))
430 }
431 JoinConstraint::Natural => Err(ParseError::StreamingError(
432 "NATURAL JOIN not supported for streaming".to_string(),
433 )),
434 JoinConstraint::None => Err(ParseError::StreamingError(
435 "JOIN without condition not supported for streaming".to_string(),
436 )),
437 }
438}
439
440fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
442 match op {
443 JoinOperator::Inner(constraint)
444 | JoinOperator::Join(constraint)
445 | JoinOperator::Left(constraint)
446 | JoinOperator::LeftOuter(constraint)
447 | JoinOperator::Right(constraint)
448 | JoinOperator::RightOuter(constraint)
449 | JoinOperator::FullOuter(constraint)
450 | JoinOperator::LeftSemi(constraint)
451 | JoinOperator::RightSemi(constraint)
452 | JoinOperator::LeftAnti(constraint)
453 | JoinOperator::RightAnti(constraint)
454 | JoinOperator::Semi(constraint)
455 | JoinOperator::Anti(constraint)
456 | JoinOperator::StraightJoin(constraint)
457 | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
458 JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
459 ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
460 ),
461 }
462}
463
464#[allow(clippy::type_complexity)]
466fn analyze_on_expression(
467 expr: &Expr,
468) -> Result<(Vec<(String, String)>, Option<Duration>), ParseError> {
469 match expr {
471 Expr::BinaryOp {
472 left,
473 op: BinaryOperator::And,
474 right,
475 } => {
476 let left_result = analyze_on_expression(left);
478 let right_result = analyze_on_expression(right);
479
480 match (left_result, right_result) {
482 (Ok((mut lk, lt)), Ok((rk, rt))) => {
483 lk.extend(rk);
484 Ok((lk, lt.or(rt)))
485 }
486 (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
487 (Err(e), Err(_)) => Err(e),
488 }
489 }
490 Expr::BinaryOp {
492 left,
493 op: BinaryOperator::Eq,
494 right,
495 } => {
496 let left_col = extract_column_ref(left);
497 let right_col = extract_column_ref(right);
498
499 match (left_col, right_col) {
500 (Some(l), Some(r)) => Ok((vec![(l, r)], None)),
501 _ => Err(ParseError::StreamingError(
502 "Cannot extract column references from equality condition".to_string(),
503 )),
504 }
505 }
506 Expr::Between {
508 expr: _,
509 low: _,
510 high,
511 ..
512 } => {
513 let time_bound = extract_time_bound_from_expr(high).ok();
515 Ok((vec![], time_bound))
516 }
517 Expr::BinaryOp {
519 left: _,
520 op:
521 BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
522 right,
523 } => {
524 let time_bound = extract_time_bound_from_expr(right).ok();
526 Ok((vec![], time_bound))
527 }
528 _ => Err(ParseError::StreamingError(format!(
529 "Unsupported join condition expression: {expr:?}"
530 ))),
531 }
532}
533
534fn extract_column_from_func_arg(arg: &FunctionArg) -> Option<String> {
536 let (FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))
537 | FunctionArg::Named {
538 arg: FunctionArgExpr::Expr(expr),
539 ..
540 }
541 | FunctionArg::ExprNamed {
542 arg: FunctionArgExpr::Expr(expr),
543 ..
544 }) = arg
545 else {
546 return None;
547 };
548 extract_column_ref(expr)
549}
550
551fn extract_column_ref(expr: &Expr) -> Option<String> {
553 match expr {
554 Expr::Identifier(ident) => Some(ident.value.clone()),
555 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
556 _ => None,
557 }
558}
559
560fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
562 match expr {
563 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
565 Expr::BinaryOp {
567 left: _,
568 op: BinaryOperator::Plus | BinaryOperator::Minus,
569 right,
570 } => extract_time_bound_from_expr(right),
571 Expr::Nested(inner) => extract_time_bound_from_expr(inner),
573 _ => Err(ParseError::StreamingError(format!(
574 "Cannot extract time bound from: {expr:?}"
575 ))),
576 }
577}
578
579fn analyze_asof_match_condition(
583 expr: &Expr,
584) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
585 if let Expr::BinaryOp {
586 left,
587 op: BinaryOperator::And,
588 right,
589 } = expr
590 {
591 let dir_result = analyze_asof_direction(left);
593 let tol_result = extract_asof_tolerance(right);
594
595 match (dir_result, tol_result) {
596 (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
597 (Ok((dir, lt, rt)), Err(_)) => {
598 let dir2 = analyze_asof_direction(right);
600 let tol2 = extract_asof_tolerance(left);
601 match (dir2, tol2) {
602 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
603 _ => Ok((dir, lt, rt, None)),
604 }
605 }
606 (Err(_), _) => {
607 let dir2 = analyze_asof_direction(right);
609 let tol2 = extract_asof_tolerance(left);
610 match (dir2, tol2) {
611 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
612 (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
613 _ => Err(ParseError::StreamingError(
614 "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
615 )),
616 }
617 }
618 }
619 } else {
620 let (dir, lt, rt) = analyze_asof_direction(expr)?;
621 Ok((dir, lt, rt, None))
622 }
623}
624
625fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
627 match expr {
628 Expr::BinaryOp {
629 left,
630 op: BinaryOperator::GtEq,
631 right,
632 } => {
633 let left_col = extract_column_ref(left).ok_or_else(|| {
634 ParseError::StreamingError(
635 "Cannot extract left time column from MATCH_CONDITION".to_string(),
636 )
637 })?;
638 let right_col = extract_column_ref(right).ok_or_else(|| {
639 ParseError::StreamingError(
640 "Cannot extract right time column from MATCH_CONDITION".to_string(),
641 )
642 })?;
643 Ok((AsofSqlDirection::Backward, left_col, right_col))
644 }
645 Expr::BinaryOp {
646 left,
647 op: BinaryOperator::LtEq,
648 right,
649 } => {
650 let left_col = extract_column_ref(left).ok_or_else(|| {
651 ParseError::StreamingError(
652 "Cannot extract left time column from MATCH_CONDITION".to_string(),
653 )
654 })?;
655 let right_col = extract_column_ref(right).ok_or_else(|| {
656 ParseError::StreamingError(
657 "Cannot extract right time column from MATCH_CONDITION".to_string(),
658 )
659 })?;
660 Ok((AsofSqlDirection::Forward, left_col, right_col))
661 }
662 Expr::Function(func) => {
664 let name = func.name.to_string().to_uppercase();
665 if name != "NEAREST" {
666 return Err(ParseError::StreamingError(format!(
667 "Unknown ASOF MATCH_CONDITION function: {name}"
668 )));
669 }
670 let args = match &func.args {
671 FunctionArguments::List(arg_list) => &arg_list.args,
672 _ => {
673 return Err(ParseError::StreamingError(
674 "NEAREST() requires exactly 2 column arguments".to_string(),
675 ))
676 }
677 };
678 if args.len() != 2 {
679 return Err(ParseError::StreamingError(format!(
680 "NEAREST() requires exactly 2 arguments, got {}",
681 args.len()
682 )));
683 }
684 let left_col = extract_column_from_func_arg(&args[0]).ok_or_else(|| {
685 ParseError::StreamingError(
686 "Cannot extract left time column from NEAREST()".to_string(),
687 )
688 })?;
689 let right_col = extract_column_from_func_arg(&args[1]).ok_or_else(|| {
690 ParseError::StreamingError(
691 "Cannot extract right time column from NEAREST()".to_string(),
692 )
693 })?;
694 Ok((AsofSqlDirection::Nearest, left_col, right_col))
695 }
696 _ => Err(ParseError::StreamingError(
697 "ASOF MATCH_CONDITION must be >= or <= comparison, or NEAREST()".to_string(),
698 )),
699 }
700}
701
702fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
706 match expr {
707 Expr::BinaryOp {
708 left: _,
709 op: BinaryOperator::LtEq,
710 right,
711 } => {
712 match right.as_ref() {
714 Expr::Value(v) => {
715 if let sqlparser::ast::Value::Number(n, _) = &v.value {
716 let ms: u64 = n.parse().map_err(|_| {
717 ParseError::StreamingError(format!(
718 "Cannot parse tolerance as number: {n}"
719 ))
720 })?;
721 Ok(Duration::from_millis(ms))
722 } else {
723 Err(ParseError::StreamingError(
724 "ASOF tolerance must be a number or INTERVAL".to_string(),
725 ))
726 }
727 }
728 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
729 _ => Err(ParseError::StreamingError(
730 "ASOF tolerance must be a number or INTERVAL".to_string(),
731 )),
732 }
733 }
734 _ => Err(ParseError::StreamingError(
735 "ASOF tolerance expression must be <= comparison".to_string(),
736 )),
737 }
738}
739
740fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
742 match constraint {
743 JoinConstraint::On(expr) => extract_equality_columns(expr),
744 JoinConstraint::Using(cols) => {
745 if cols.is_empty() {
746 return Err(ParseError::StreamingError(
747 "USING clause requires at least one column".to_string(),
748 ));
749 }
750 let col = cols[0].to_string();
751 Ok((col.clone(), col))
752 }
753 _ => Err(ParseError::StreamingError(
754 "ASOF JOIN requires ON or USING constraint".to_string(),
755 )),
756 }
757}
758
759fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
761 match expr {
762 Expr::BinaryOp {
763 left,
764 op: BinaryOperator::Eq,
765 right,
766 } => {
767 let left_col = extract_column_ref(left).ok_or_else(|| {
768 ParseError::StreamingError("Cannot extract left key column".to_string())
769 })?;
770 let right_col = extract_column_ref(right).ok_or_else(|| {
771 ParseError::StreamingError("Cannot extract right key column".to_string())
772 })?;
773 Ok((left_col, right_col))
774 }
775 Expr::BinaryOp {
777 left,
778 op: BinaryOperator::And,
779 right,
780 } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
781 _ => Err(ParseError::StreamingError(
782 "ASOF JOIN ON clause must contain an equality condition".to_string(),
783 )),
784 }
785}
786
787#[must_use]
789pub fn has_join(select: &Select) -> bool {
790 !select.from.is_empty() && !select.from[0].joins.is_empty()
791}
792
793#[must_use]
795pub fn count_joins(select: &Select) -> usize {
796 select
797 .from
798 .iter()
799 .map(|table_with_joins| table_with_joins.joins.len())
800 .sum()
801}
802
803#[derive(Debug, Clone)]
808pub struct MultiJoinAnalysis {
809 pub joins: Vec<JoinAnalysis>,
811 pub tables: Vec<String>,
813}
814
815impl MultiJoinAnalysis {
816 #[must_use]
818 pub fn len(&self) -> usize {
819 self.joins.len()
820 }
821
822 #[must_use]
824 pub fn is_empty(&self) -> bool {
825 self.joins.is_empty()
826 }
827
828 #[must_use]
830 pub fn is_single(&self) -> bool {
831 self.joins.len() == 1
832 }
833
834 #[must_use]
836 pub fn first(&self) -> Option<&JoinAnalysis> {
837 self.joins.first()
838 }
839}
840
841pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
852 let from = &select.from;
853 if from.is_empty() {
854 return Ok(None);
855 }
856
857 let first_table = &from[0];
858 if first_table.joins.is_empty() {
859 return Ok(None);
860 }
861
862 let base_table = extract_table_name(&first_table.relation)?;
864 let base_alias = extract_table_alias(&first_table.relation);
865
866 let mut join_steps = Vec::with_capacity(first_table.joins.len());
867 let mut tables = vec![base_table.clone()];
868
869 let mut prev_left_table = base_table;
871 let mut prev_left_alias = base_alias;
872
873 for join in &first_table.joins {
874 let right_table = extract_table_name(&join.relation)?;
875 let right_alias = extract_table_alias(&join.relation);
876 tables.push(right_table.clone());
877
878 let join_type = map_join_operator(&join.join_operator);
879
880 if let JoinOperator::AsOf {
882 match_condition,
883 constraint,
884 } = &join.join_operator
885 {
886 let (direction, left_time, right_time, tolerance) =
887 analyze_asof_match_condition(match_condition)?;
888 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
889
890 let mut analysis = JoinAnalysis::asof(
891 prev_left_table.clone(),
892 right_table.clone(),
893 left_key,
894 right_key,
895 direction,
896 left_time,
897 right_time,
898 tolerance,
899 );
900 analysis.left_alias.clone_from(&prev_left_alias);
901 analysis.right_alias = right_alias;
902 join_steps.push(analysis);
903 } else if let Some(version_col) = extract_temporal_version(&join.relation) {
904 let (left_key, right_key, additional, _) =
906 analyze_join_constraint(&join.join_operator)?;
907
908 let mut analysis = JoinAnalysis::temporal(
909 prev_left_table.clone(),
910 right_table.clone(),
911 left_key,
912 right_key,
913 version_col,
914 join_type,
915 );
916 analysis.left_alias.clone_from(&prev_left_alias);
917 analysis.right_alias = right_alias;
918 analysis.additional_key_columns = additional;
919 join_steps.push(analysis);
920 } else {
921 let (left_key, right_key, additional, time_bound) =
923 analyze_join_constraint(&join.join_operator)?;
924
925 let mut analysis = if let Some(tb) = time_bound {
926 JoinAnalysis::stream_stream(
927 prev_left_table.clone(),
928 right_table.clone(),
929 left_key,
930 right_key,
931 tb,
932 join_type,
933 )
934 } else {
935 JoinAnalysis::lookup(
936 prev_left_table.clone(),
937 right_table.clone(),
938 left_key,
939 right_key,
940 join_type,
941 )
942 };
943 analysis.left_alias.clone_from(&prev_left_alias);
944 analysis.right_alias = right_alias;
945 analysis.additional_key_columns = additional;
946 join_steps.push(analysis);
947 }
948
949 prev_left_table = right_table;
951 prev_left_alias = extract_table_alias(&join.relation);
952 }
953
954 Ok(Some(MultiJoinAnalysis {
955 joins: join_steps,
956 tables,
957 }))
958}
959
960#[cfg(test)]
961mod tests {
962 use super::*;
963 use sqlparser::ast::{SetExpr, Statement};
964 use sqlparser::dialect::GenericDialect;
965 use sqlparser::parser::Parser;
966
967 fn parse_select(sql: &str) -> Select {
968 let dialect = GenericDialect {};
969 let statements = Parser::parse_sql(&dialect, sql).unwrap();
970 if let Statement::Query(query) = &statements[0] {
971 if let SetExpr::Select(select) = query.body.as_ref() {
972 return *select.clone();
973 }
974 }
975 panic!("Expected SELECT query");
976 }
977
978 #[test]
979 fn test_analyze_inner_join() {
980 let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
981 let select = parse_select(sql);
982
983 let analysis = analyze_join(&select).unwrap().unwrap();
984
985 assert_eq!(analysis.join_type, JoinType::Inner);
986 assert_eq!(analysis.left_table, "orders");
987 assert_eq!(analysis.right_table, "payments");
988 assert_eq!(analysis.left_key_column, "order_id");
989 assert_eq!(analysis.right_key_column, "order_id");
990 assert!(analysis.is_lookup_join); }
992
993 #[test]
994 fn test_analyze_left_join() {
995 let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
996 let select = parse_select(sql);
997
998 let analysis = analyze_join(&select).unwrap().unwrap();
999
1000 assert_eq!(analysis.join_type, JoinType::Left);
1001 assert_eq!(analysis.left_key_column, "customer_id");
1002 assert_eq!(analysis.right_key_column, "id");
1003 }
1004
1005 #[test]
1006 fn test_analyze_join_using() {
1007 let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
1008 let select = parse_select(sql);
1009
1010 let analysis = analyze_join(&select).unwrap().unwrap();
1011
1012 assert_eq!(analysis.left_key_column, "order_id");
1013 assert_eq!(analysis.right_key_column, "order_id");
1014 }
1015
1016 #[test]
1017 fn test_analyze_stream_stream_join_with_time_bound() {
1018 let sql = "SELECT * FROM orders o
1019 JOIN payments p ON o.order_id = p.order_id
1020 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
1021 let select = parse_select(sql);
1022
1023 let analysis = analyze_join(&select).unwrap().unwrap();
1024
1025 assert!(!analysis.is_lookup_join);
1026 assert!(analysis.time_bound.is_some());
1027 assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
1028 }
1029
1030 #[test]
1031 fn test_no_join() {
1032 let sql = "SELECT * FROM orders";
1033 let select = parse_select(sql);
1034
1035 let analysis = analyze_join(&select).unwrap();
1036 assert!(analysis.is_none());
1037 }
1038
1039 #[test]
1040 fn test_has_join() {
1041 let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1042 let sql_without_join = "SELECT * FROM orders";
1043
1044 let select_with = parse_select(sql_with_join);
1045 let select_without = parse_select(sql_without_join);
1046
1047 assert!(has_join(&select_with));
1048 assert!(!has_join(&select_without));
1049 }
1050
1051 #[test]
1052 fn test_count_joins() {
1053 let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
1054 let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
1055 let sql_zero = "SELECT * FROM a";
1056
1057 assert_eq!(count_joins(&parse_select(sql_one)), 1);
1058 assert_eq!(count_joins(&parse_select(sql_two)), 2);
1059 assert_eq!(count_joins(&parse_select(sql_zero)), 0);
1060 }
1061
1062 #[test]
1063 fn test_aliases() {
1064 let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
1065 let select = parse_select(sql);
1066
1067 let analysis = analyze_join(&select).unwrap().unwrap();
1068
1069 assert_eq!(analysis.left_alias, Some("o".to_string()));
1070 assert_eq!(analysis.right_alias, Some("p".to_string()));
1071 }
1072
1073 fn parse_select_snowflake(sql: &str) -> Select {
1076 let dialect = sqlparser::dialect::SnowflakeDialect {};
1077 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1078 if let Statement::Query(query) = &statements[0] {
1079 if let SetExpr::Select(select) = query.body.as_ref() {
1080 return *select.clone();
1081 }
1082 }
1083 panic!("Expected SELECT query");
1084 }
1085
1086 fn parse_select_laminar(sql: &str) -> Select {
1087 let dialect = crate::parser::dialect::LaminarDialect::default();
1088 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1089 if let Statement::Query(query) = &statements[0] {
1090 if let SetExpr::Select(select) = query.body.as_ref() {
1091 return *select.clone();
1092 }
1093 }
1094 panic!("Expected SELECT query");
1095 }
1096
1097 #[test]
1098 fn test_asof_join_backward() {
1099 let sql = "SELECT * FROM trades t \
1100 ASOF JOIN quotes q \
1101 MATCH_CONDITION(t.ts >= q.ts) \
1102 ON t.symbol = q.symbol";
1103 let select = parse_select_snowflake(sql);
1104 let analysis = analyze_join(&select).unwrap().unwrap();
1105
1106 assert!(analysis.is_asof_join);
1107 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1108 assert_eq!(analysis.join_type, JoinType::AsOf);
1109 assert!(analysis.asof_tolerance.is_none());
1110 }
1111
1112 #[test]
1113 fn test_asof_join_forward() {
1114 let sql = "SELECT * FROM trades t \
1115 ASOF JOIN quotes q \
1116 MATCH_CONDITION(t.ts <= q.ts) \
1117 ON t.symbol = q.symbol";
1118 let select = parse_select_snowflake(sql);
1119 let analysis = analyze_join(&select).unwrap().unwrap();
1120
1121 assert!(analysis.is_asof_join);
1122 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
1123 }
1124
1125 #[test]
1126 fn test_asof_join_nearest() {
1127 let sql = "SELECT * FROM trades t \
1128 ASOF JOIN quotes q \
1129 MATCH_CONDITION(NEAREST(t.ts, q.ts)) \
1130 ON t.symbol = q.symbol";
1131 let select = parse_select_snowflake(sql);
1132 let analysis = analyze_join(&select).unwrap().unwrap();
1133
1134 assert!(analysis.is_asof_join);
1135 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Nearest));
1136 assert_eq!(analysis.join_type, JoinType::AsOf);
1137 assert!(analysis.asof_tolerance.is_none());
1138 }
1139
1140 #[test]
1141 fn test_asof_join_with_tolerance() {
1142 let sql = "SELECT * FROM trades t \
1143 ASOF JOIN quotes q \
1144 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
1145 ON t.symbol = q.symbol";
1146 let select = parse_select_snowflake(sql);
1147 let analysis = analyze_join(&select).unwrap().unwrap();
1148
1149 assert!(analysis.is_asof_join);
1150 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1151 assert_eq!(analysis.asof_tolerance, Some(Duration::from_millis(5000)));
1152 }
1153
1154 #[test]
1155 fn test_asof_join_with_interval_tolerance() {
1156 let sql = "SELECT * FROM trades t \
1157 ASOF JOIN quotes q \
1158 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
1159 ON t.symbol = q.symbol";
1160 let select = parse_select_snowflake(sql);
1161 let analysis = analyze_join(&select).unwrap().unwrap();
1162
1163 assert!(analysis.is_asof_join);
1164 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1165 assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1166 }
1167
1168 #[test]
1169 fn test_asof_join_type_mapping() {
1170 let sql = "SELECT * FROM trades t \
1171 ASOF JOIN quotes q \
1172 MATCH_CONDITION(t.ts >= q.ts) \
1173 ON t.symbol = q.symbol";
1174 let select = parse_select_snowflake(sql);
1175 let analysis = analyze_join(&select).unwrap().unwrap();
1176
1177 assert_eq!(analysis.join_type, JoinType::AsOf);
1178 assert!(!analysis.is_lookup_join);
1179 }
1180
1181 #[test]
1182 fn test_asof_join_extracts_time_columns() {
1183 let sql = "SELECT * FROM trades t \
1184 ASOF JOIN quotes q \
1185 MATCH_CONDITION(t.ts >= q.ts) \
1186 ON t.symbol = q.symbol";
1187 let select = parse_select_snowflake(sql);
1188 let analysis = analyze_join(&select).unwrap().unwrap();
1189
1190 assert_eq!(analysis.left_time_column, Some("ts".to_string()));
1191 assert_eq!(analysis.right_time_column, Some("ts".to_string()));
1192 }
1193
1194 #[test]
1195 fn test_asof_join_extracts_key_columns() {
1196 let sql = "SELECT * FROM trades t \
1197 ASOF JOIN quotes q \
1198 MATCH_CONDITION(t.ts >= q.ts) \
1199 ON t.symbol = q.symbol";
1200 let select = parse_select_snowflake(sql);
1201 let analysis = analyze_join(&select).unwrap().unwrap();
1202
1203 assert_eq!(analysis.left_key_column, "symbol");
1204 assert_eq!(analysis.right_key_column, "symbol");
1205 }
1206
1207 #[test]
1208 fn test_asof_join_aliases() {
1209 let sql = "SELECT * FROM trades AS t \
1210 ASOF JOIN quotes AS q \
1211 MATCH_CONDITION(t.ts >= q.ts) \
1212 ON t.symbol = q.symbol";
1213 let select = parse_select_snowflake(sql);
1214 let analysis = analyze_join(&select).unwrap().unwrap();
1215
1216 assert_eq!(analysis.left_alias, Some("t".to_string()));
1217 assert_eq!(analysis.right_alias, Some("q".to_string()));
1218 assert_eq!(analysis.left_table, "trades");
1219 assert_eq!(analysis.right_table, "quotes");
1220 }
1221
1222 #[test]
1225 fn test_multi_join_single_backward_compat() {
1226 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1227 let select = parse_select(sql);
1228 let multi = analyze_joins(&select).unwrap().unwrap();
1229
1230 assert!(multi.is_single());
1231 assert_eq!(multi.len(), 1);
1232 assert!(!multi.is_empty());
1233 let first = multi.first().unwrap();
1234 assert_eq!(first.left_table, "orders");
1235 assert_eq!(first.right_table, "payments");
1236 }
1237
1238 #[test]
1239 fn test_multi_join_two_way() {
1240 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1241 let select = parse_select(sql);
1242 let multi = analyze_joins(&select).unwrap().unwrap();
1243
1244 assert_eq!(multi.len(), 2);
1245 assert!(!multi.is_single());
1246
1247 assert_eq!(multi.joins[0].left_table, "a");
1248 assert_eq!(multi.joins[0].right_table, "b");
1249 assert_eq!(multi.joins[0].left_key_column, "id");
1250 assert_eq!(multi.joins[0].right_key_column, "a_id");
1251
1252 assert_eq!(multi.joins[1].left_table, "b");
1253 assert_eq!(multi.joins[1].right_table, "c");
1254 assert_eq!(multi.joins[1].left_key_column, "id");
1255 assert_eq!(multi.joins[1].right_key_column, "b_id");
1256 }
1257
1258 #[test]
1259 fn test_multi_join_three_way() {
1260 let sql = "SELECT * FROM a \
1261 JOIN b ON a.id = b.a_id \
1262 JOIN c ON b.id = c.b_id \
1263 JOIN d ON c.id = d.c_id";
1264 let select = parse_select(sql);
1265 let multi = analyze_joins(&select).unwrap().unwrap();
1266
1267 assert_eq!(multi.len(), 3);
1268 assert_eq!(multi.tables.len(), 4);
1269 assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1270 }
1271
1272 #[test]
1273 fn test_multi_join_mixed_asof_and_lookup() {
1274 let sql = "SELECT * FROM trades t \
1276 ASOF JOIN quotes q \
1277 MATCH_CONDITION(t.ts >= q.ts) \
1278 ON t.symbol = q.symbol \
1279 JOIN products p ON q.product_id = p.id";
1280 let select = parse_select_snowflake(sql);
1281 let multi = analyze_joins(&select).unwrap().unwrap();
1282
1283 assert_eq!(multi.len(), 2);
1284 assert!(multi.joins[0].is_asof_join);
1285 assert!(multi.joins[1].is_lookup_join);
1286 }
1287
1288 #[test]
1289 fn test_multi_join_stream_stream_and_lookup() {
1290 let sql = "SELECT * FROM orders o \
1291 JOIN payments p ON o.id = p.order_id \
1292 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1293 JOIN customers c ON o.customer_id = c.id";
1294 let select = parse_select(sql);
1295 let multi = analyze_joins(&select).unwrap().unwrap();
1296
1297 assert_eq!(multi.len(), 2);
1298 assert!(!multi.joins[0].is_lookup_join); assert!(multi.joins[0].time_bound.is_some());
1300 assert!(multi.joins[1].is_lookup_join); }
1302
1303 #[test]
1304 fn test_multi_join_tables_list() {
1305 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1306 let select = parse_select(sql);
1307 let multi = analyze_joins(&select).unwrap().unwrap();
1308
1309 assert_eq!(multi.tables, vec!["a", "b", "c"]);
1310 }
1311
1312 #[test]
1313 fn test_multi_join_aliases() {
1314 let sql = "SELECT * FROM orders AS o \
1315 JOIN payments AS p ON o.id = p.order_id \
1316 JOIN refunds AS r ON p.id = r.payment_id";
1317 let select = parse_select(sql);
1318 let multi = analyze_joins(&select).unwrap().unwrap();
1319
1320 assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1321 assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1322 assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1323 assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1324 }
1325
1326 #[test]
1327 fn test_multi_join_no_join_returns_none() {
1328 let sql = "SELECT * FROM orders";
1329 let select = parse_select(sql);
1330 let multi = analyze_joins(&select).unwrap();
1331 assert!(multi.is_none());
1332 }
1333
1334 #[test]
1337 fn test_temporal_join_detected() {
1338 let sql = "SELECT o.*, p.price \
1339 FROM orders o \
1340 JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1341 ON o.product_id = p.id";
1342 let select = parse_select_laminar(sql);
1343 let analysis = analyze_join(&select).unwrap().unwrap();
1344
1345 assert!(analysis.is_temporal_join);
1346 assert_eq!(
1347 analysis.temporal_version_column,
1348 Some("order_time".to_string())
1349 );
1350 assert_eq!(analysis.left_table, "orders");
1351 assert_eq!(analysis.right_table, "products");
1352 assert_eq!(analysis.left_key_column, "product_id");
1353 assert_eq!(analysis.right_key_column, "id");
1354 assert!(!analysis.is_lookup_join);
1355 assert!(!analysis.is_asof_join);
1356 }
1357
1358 #[test]
1359 fn test_temporal_join_via_analyze_joins() {
1360 let sql = "SELECT o.*, p.price \
1361 FROM orders o \
1362 JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1363 ON o.product_id = p.id";
1364 let select = parse_select_laminar(sql);
1365 let multi = analyze_joins(&select).unwrap().unwrap();
1366
1367 assert_eq!(multi.len(), 1);
1368 let first = multi.first().unwrap();
1369 assert!(first.is_temporal_join);
1370 assert_eq!(
1371 first.temporal_version_column,
1372 Some("order_time".to_string())
1373 );
1374 }
1375
1376 #[test]
1377 fn test_non_temporal_join_not_flagged() {
1378 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1379 let select = parse_select(sql);
1380 let analysis = analyze_join(&select).unwrap().unwrap();
1381
1382 assert!(!analysis.is_temporal_join);
1383 assert!(analysis.temporal_version_column.is_none());
1384 }
1385
1386 #[test]
1387 fn test_unqualified_anti_maps_to_left_anti() {
1388 let sql = "SELECT * FROM orders o ANTI JOIN returns r ON o.id = r.order_id";
1389 let select = parse_select(sql);
1390 let analysis = analyze_join(&select).unwrap().unwrap();
1391 assert_eq!(analysis.join_type, JoinType::LeftAnti);
1392 }
1393
1394 #[test]
1395 fn test_unqualified_semi_maps_to_left_semi() {
1396 let sql = "SELECT * FROM orders o SEMI JOIN payments p ON o.id = p.order_id";
1397 let select = parse_select(sql);
1398 let analysis = analyze_join(&select).unwrap().unwrap();
1399 assert_eq!(analysis.join_type, JoinType::LeftSemi);
1400 }
1401
1402 #[test]
1403 fn test_composite_join_keys() {
1404 let sql = "SELECT * FROM orders o \
1405 JOIN shipments s \
1406 ON o.order_id = s.order_id AND o.region = s.region";
1407 let select = parse_select(sql);
1408 let analysis = analyze_join(&select).unwrap().unwrap();
1409
1410 assert_eq!(analysis.left_key_column, "order_id");
1412 assert_eq!(analysis.right_key_column, "order_id");
1413
1414 assert_eq!(
1416 analysis.additional_key_columns.len(),
1417 1,
1418 "Should have 1 additional key pair"
1419 );
1420 assert_eq!(analysis.additional_key_columns[0].0, "region");
1421 assert_eq!(analysis.additional_key_columns[0].1, "region");
1422 }
1423
1424 #[test]
1425 fn test_composite_using_clause() {
1426 let sql = "SELECT * FROM orders o JOIN shipments s USING (order_id, region)";
1427 let select = parse_select(sql);
1428 let analysis = analyze_join(&select).unwrap().unwrap();
1429
1430 assert_eq!(analysis.left_key_column, "order_id");
1432 assert_eq!(analysis.right_key_column, "order_id");
1433
1434 assert_eq!(
1436 analysis.additional_key_columns.len(),
1437 1,
1438 "USING(order_id, region) should have 1 additional key"
1439 );
1440 assert_eq!(analysis.additional_key_columns[0].0, "region");
1441 assert_eq!(analysis.additional_key_columns[0].1, "region");
1442 }
1443}