1use std::time::Duration;
10
11use sqlparser::ast::{
12 BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, JoinConstraint,
13 JoinOperator, Select, TableFactor, TableVersion,
14};
15
16use super::window_rewriter::WindowRewriter;
17use super::ParseError;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum JoinType {
22 Inner,
24 Left,
26 Right,
28 Full,
30 LeftSemi,
32 LeftAnti,
34 RightSemi,
36 RightAnti,
38 AsOf,
40}
41
42#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum AsofSqlDirection {
45 Backward,
47 Forward,
49 Nearest,
51}
52
53impl std::fmt::Display for AsofSqlDirection {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 match self {
56 AsofSqlDirection::Backward => write!(f, "BACKWARD"),
57 AsofSqlDirection::Forward => write!(f, "FORWARD"),
58 AsofSqlDirection::Nearest => write!(f, "NEAREST"),
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
65struct RawTimeCols {
66 expr_qualifier: Option<String>,
67 expr_col: String,
68 low_qualifier: Option<String>,
69 low_col: String,
70}
71
72fn resolve_time_cols(
75 raw: &RawTimeCols,
76 left_table: &str,
77 right_table: &str,
78 left_alias: Option<&str>,
79 right_alias: Option<&str>,
80) -> (String, String) {
81 let matches_left = |q: &Option<String>| -> bool {
82 q.as_ref()
83 .is_some_and(|t| t == left_table || left_alias.is_some_and(|a| a == t))
84 };
85 let matches_right = |q: &Option<String>| -> bool {
86 q.as_ref()
87 .is_some_and(|t| t == right_table || right_alias.is_some_and(|a| a == t))
88 };
89
90 if matches_right(&raw.expr_qualifier) && matches_left(&raw.low_qualifier) {
91 (raw.low_col.clone(), raw.expr_col.clone())
92 } else if matches_left(&raw.expr_qualifier) && matches_right(&raw.low_qualifier) {
93 (raw.expr_col.clone(), raw.low_col.clone())
94 } else {
95 (raw.low_col.clone(), raw.expr_col.clone())
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct JoinAnalysis {
102 pub join_type: JoinType,
104 pub left_table: String,
106 pub right_table: String,
108 pub left_key_column: String,
110 pub right_key_column: String,
112 pub time_bound: Option<Duration>,
114 pub is_lookup_join: bool,
116 pub left_alias: Option<String>,
118 pub right_alias: Option<String>,
120 pub is_asof_join: bool,
122 pub asof_direction: Option<AsofSqlDirection>,
124 pub left_time_column: Option<String>,
126 pub right_time_column: Option<String>,
128 pub asof_tolerance: Option<Duration>,
130 pub is_temporal_join: bool,
132 pub temporal_version_column: Option<String>,
134 pub additional_key_columns: Vec<(String, String)>,
136}
137
138impl JoinAnalysis {
139 #[must_use]
141 pub fn stream_stream(
142 left_table: String,
143 right_table: String,
144 left_key: String,
145 right_key: String,
146 time_bound: Duration,
147 join_type: JoinType,
148 ) -> Self {
149 Self {
150 join_type,
151 left_table,
152 right_table,
153 left_key_column: left_key,
154 right_key_column: right_key,
155 time_bound: Some(time_bound),
156 is_lookup_join: false,
157 left_alias: None,
158 right_alias: None,
159 is_asof_join: false,
160 asof_direction: None,
161 left_time_column: None,
162 right_time_column: None,
163 asof_tolerance: None,
164 is_temporal_join: false,
165 temporal_version_column: None,
166 additional_key_columns: vec![],
167 }
168 }
169
170 #[must_use]
172 pub fn lookup(
173 left_table: String,
174 right_table: String,
175 left_key: String,
176 right_key: String,
177 join_type: JoinType,
178 ) -> Self {
179 Self {
180 join_type,
181 left_table,
182 right_table,
183 left_key_column: left_key,
184 right_key_column: right_key,
185 time_bound: None,
186 is_lookup_join: true,
187 left_alias: None,
188 right_alias: None,
189 is_asof_join: false,
190 asof_direction: None,
191 left_time_column: None,
192 right_time_column: None,
193 asof_tolerance: None,
194 is_temporal_join: false,
195 temporal_version_column: None,
196 additional_key_columns: vec![],
197 }
198 }
199
200 #[must_use]
202 #[allow(clippy::too_many_arguments)]
203 pub fn asof(
204 left_table: String,
205 right_table: String,
206 left_key: String,
207 right_key: String,
208 direction: AsofSqlDirection,
209 left_time_col: String,
210 right_time_col: String,
211 tolerance: Option<Duration>,
212 ) -> Self {
213 Self {
214 join_type: JoinType::AsOf,
215 left_table,
216 right_table,
217 left_key_column: left_key,
218 right_key_column: right_key,
219 time_bound: None,
220 is_lookup_join: false,
221 left_alias: None,
222 right_alias: None,
223 is_asof_join: true,
224 asof_direction: Some(direction),
225 left_time_column: Some(left_time_col),
226 right_time_column: Some(right_time_col),
227 asof_tolerance: tolerance,
228 is_temporal_join: false,
229 temporal_version_column: None,
230 additional_key_columns: vec![],
231 }
232 }
233
234 #[must_use]
236 pub fn temporal(
237 left_table: String,
238 right_table: String,
239 left_key: String,
240 right_key: String,
241 version_column: String,
242 join_type: JoinType,
243 ) -> Self {
244 Self {
245 join_type,
246 left_table,
247 right_table,
248 left_key_column: left_key,
249 right_key_column: right_key,
250 time_bound: None,
251 is_lookup_join: false,
252 left_alias: None,
253 right_alias: None,
254 is_asof_join: false,
255 asof_direction: None,
256 left_time_column: None,
257 right_time_column: None,
258 asof_tolerance: None,
259 is_temporal_join: true,
260 temporal_version_column: Some(version_column),
261 additional_key_columns: vec![],
262 }
263 }
264}
265
266pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
274 let from = &select.from;
275 if from.is_empty() {
276 return Ok(None);
277 }
278
279 let first_table = &from[0];
280 if first_table.joins.is_empty() {
281 return Ok(None);
282 }
283
284 let left_table = extract_table_name(&first_table.relation)?;
286 let left_alias = extract_table_alias(&first_table.relation);
287
288 let join = &first_table.joins[0];
290 let right_table = extract_table_name(&join.relation)?;
291 let right_alias = extract_table_alias(&join.relation);
292
293 let join_type = map_join_operator(&join.join_operator);
294
295 if let JoinOperator::AsOf {
297 match_condition,
298 constraint,
299 } = &join.join_operator
300 {
301 let (direction, left_time, right_time, tolerance) =
302 analyze_asof_match_condition(match_condition)?;
303
304 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
306
307 let mut analysis = JoinAnalysis::asof(
308 left_table,
309 right_table,
310 left_key,
311 right_key,
312 direction,
313 left_time,
314 right_time,
315 tolerance,
316 );
317 analysis.left_alias = left_alias;
318 analysis.right_alias = right_alias;
319 return Ok(Some(analysis));
320 }
321
322 if let Some(version_col) = extract_temporal_version(&join.relation) {
324 let (left_key, right_key, additional, _, _) = analyze_join_constraint(&join.join_operator)?;
325 let mut analysis = JoinAnalysis::temporal(
326 left_table,
327 right_table,
328 left_key,
329 right_key,
330 version_col,
331 join_type,
332 );
333 analysis.left_alias = left_alias;
334 analysis.right_alias = right_alias;
335 analysis.additional_key_columns = additional;
336 return Ok(Some(analysis));
337 }
338
339 let (left_key, right_key, additional, time_bound, time_cols) =
341 analyze_join_constraint(&join.join_operator)?;
342
343 let mut analysis = if let Some(tb) = time_bound {
344 JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
345 } else {
346 JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
347 };
348
349 analysis.left_alias.clone_from(&left_alias);
350 analysis.right_alias.clone_from(&right_alias);
351 analysis.additional_key_columns = additional;
352
353 if let Some(ref raw) = time_cols {
354 let (lt, rt) = resolve_time_cols(
355 raw,
356 &analysis.left_table,
357 &analysis.right_table,
358 left_alias.as_deref(),
359 right_alias.as_deref(),
360 );
361 analysis.left_time_column = Some(lt);
362 analysis.right_time_column = Some(rt);
363 }
364
365 Ok(Some(analysis))
366}
367
368fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
370 match factor {
371 TableFactor::Table { name, .. } => Ok(name.to_string()),
372 TableFactor::Derived { alias, .. } => {
373 if let Some(alias) = alias {
374 Ok(alias.name.value.clone())
375 } else {
376 Err(ParseError::StreamingError(
377 "Derived table without alias not supported".to_string(),
378 ))
379 }
380 }
381 _ => Err(ParseError::StreamingError(
382 "Unsupported table factor type".to_string(),
383 )),
384 }
385}
386
387fn extract_temporal_version(factor: &TableFactor) -> Option<String> {
392 if let TableFactor::Table {
393 version: Some(TableVersion::ForSystemTimeAsOf(expr)),
394 ..
395 } = factor
396 {
397 Some(extract_column_name_from_expr(expr))
398 } else {
399 None
400 }
401}
402
403fn extract_column_name_from_expr(expr: &Expr) -> String {
407 match expr {
408 Expr::Identifier(ident) => ident.value.clone(),
409 Expr::CompoundIdentifier(parts) => parts
410 .last()
411 .map_or_else(|| expr.to_string(), |p| p.value.clone()),
412 _ => expr.to_string(),
413 }
414}
415
416fn extract_table_alias(factor: &TableFactor) -> Option<String> {
418 match factor {
419 TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
420 TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
421 _ => None,
422 }
423}
424
425fn map_join_operator(op: &JoinOperator) -> JoinType {
427 match op {
428 JoinOperator::Inner(_) | JoinOperator::Join(_) | JoinOperator::StraightJoin(_) => {
429 JoinType::Inner
430 }
431 JoinOperator::Left(_) | JoinOperator::LeftOuter(_) => JoinType::Left,
432 JoinOperator::LeftSemi(_) | JoinOperator::Semi(_) => JoinType::LeftSemi,
433 JoinOperator::LeftAnti(_) | JoinOperator::Anti(_) => JoinType::LeftAnti,
434 JoinOperator::AsOf { .. } => JoinType::AsOf,
435 JoinOperator::Right(_) | JoinOperator::RightOuter(_) => JoinType::Right,
436 JoinOperator::RightSemi(_) => JoinType::RightSemi,
437 JoinOperator::RightAnti(_) => JoinType::RightAnti,
438 JoinOperator::FullOuter(_) => JoinType::Full,
439 _ => JoinType::Inner,
441 }
442}
443
444#[allow(clippy::type_complexity)]
447fn analyze_join_constraint(
448 op: &JoinOperator,
449) -> Result<
450 (
451 String,
452 String,
453 Vec<(String, String)>,
454 Option<Duration>,
455 Option<RawTimeCols>,
456 ),
457 ParseError,
458> {
459 let constraint = get_join_constraint(op)?;
460
461 match constraint {
462 JoinConstraint::On(expr) => {
463 let (key_pairs, time_bound, time_cols) = analyze_on_expression(expr)?;
464 if key_pairs.is_empty() {
465 return Ok((String::new(), String::new(), vec![], time_bound, time_cols));
466 }
467 let (first_left, first_right) = key_pairs[0].clone();
468 let additional = key_pairs[1..].to_vec();
469 Ok((first_left, first_right, additional, time_bound, time_cols))
470 }
471 JoinConstraint::Using(cols) => {
472 if cols.is_empty() {
473 return Err(ParseError::StreamingError(
474 "USING clause requires at least one column".to_string(),
475 ));
476 }
477 let first_col = cols[0].to_string();
479 let additional: Vec<(String, String)> = cols[1..]
481 .iter()
482 .map(|c| {
483 let col = c.to_string();
484 (col.clone(), col)
485 })
486 .collect();
487 Ok((first_col.clone(), first_col, additional, None, None))
488 }
489 JoinConstraint::Natural => Err(ParseError::StreamingError(
490 "NATURAL JOIN not supported for streaming".to_string(),
491 )),
492 JoinConstraint::None => Err(ParseError::StreamingError(
493 "JOIN without condition not supported for streaming".to_string(),
494 )),
495 }
496}
497
498fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
500 match op {
501 JoinOperator::Inner(constraint)
502 | JoinOperator::Join(constraint)
503 | JoinOperator::Left(constraint)
504 | JoinOperator::LeftOuter(constraint)
505 | JoinOperator::Right(constraint)
506 | JoinOperator::RightOuter(constraint)
507 | JoinOperator::FullOuter(constraint)
508 | JoinOperator::LeftSemi(constraint)
509 | JoinOperator::RightSemi(constraint)
510 | JoinOperator::LeftAnti(constraint)
511 | JoinOperator::RightAnti(constraint)
512 | JoinOperator::Semi(constraint)
513 | JoinOperator::Anti(constraint)
514 | JoinOperator::StraightJoin(constraint)
515 | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
516 JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
517 ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
518 ),
519 }
520}
521
522#[allow(clippy::type_complexity)]
525fn analyze_on_expression(
526 expr: &Expr,
527) -> Result<(Vec<(String, String)>, Option<Duration>, Option<RawTimeCols>), ParseError> {
528 match expr {
530 Expr::BinaryOp {
531 left,
532 op: BinaryOperator::And,
533 right,
534 } => {
535 let left_result = analyze_on_expression(left);
537 let right_result = analyze_on_expression(right);
538
539 match (left_result, right_result) {
541 (Ok((mut lk, lt, ltc)), Ok((rk, rt, rtc))) => {
542 lk.extend(rk);
543 Ok((lk, lt.or(rt), ltc.or(rtc)))
544 }
545 (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
546 (Err(e), Err(_)) => Err(e),
547 }
548 }
549 Expr::BinaryOp {
551 left,
552 op: BinaryOperator::Eq,
553 right,
554 } => {
555 let left_col = extract_column_ref(left);
556 let right_col = extract_column_ref(right);
557
558 match (left_col, right_col) {
559 (Some(l), Some(r)) => Ok((vec![(l, r)], None, None)),
560 _ => Err(ParseError::StreamingError(
561 "Cannot extract column references from equality condition".to_string(),
562 )),
563 }
564 }
565 Expr::Between {
567 expr: between_expr,
568 low,
569 high,
570 ..
571 } => {
572 let time_bound = extract_time_bound_from_expr(high).ok();
574 let between_col = extract_qualified_column_ref(between_expr);
575 let low_col = extract_qualified_column_ref(low);
576 let time_cols = if let (Some((bt, bc)), Some((lt, lc))) = (between_col, low_col) {
577 Some(RawTimeCols {
578 expr_qualifier: bt,
579 expr_col: bc,
580 low_qualifier: lt,
581 low_col: lc,
582 })
583 } else {
584 if time_bound.is_some() {
585 tracing::warn!(
586 "BETWEEN clause has time bound but time column references \
587 could not be extracted (expressions must be simple column refs)"
588 );
589 }
590 None
591 };
592 Ok((vec![], time_bound, time_cols))
593 }
594 Expr::BinaryOp {
596 left: _,
597 op:
598 BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
599 right,
600 } => {
601 let time_bound = extract_time_bound_from_expr(right).ok();
603 Ok((vec![], time_bound, None))
604 }
605 _ => Err(ParseError::StreamingError(format!(
606 "Unsupported join condition expression: {expr:?}"
607 ))),
608 }
609}
610
611fn extract_column_from_func_arg(arg: &FunctionArg) -> Option<String> {
613 let (FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))
614 | FunctionArg::Named {
615 arg: FunctionArgExpr::Expr(expr),
616 ..
617 }
618 | FunctionArg::ExprNamed {
619 arg: FunctionArgExpr::Expr(expr),
620 ..
621 }) = arg
622 else {
623 return None;
624 };
625 extract_column_ref(expr)
626}
627
628fn extract_column_ref(expr: &Expr) -> Option<String> {
630 match expr {
631 Expr::Identifier(ident) => Some(ident.value.clone()),
632 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
633 _ => None,
634 }
635}
636
637fn extract_qualified_column_ref(expr: &Expr) -> Option<(Option<String>, String)> {
638 match expr {
639 Expr::Identifier(ident) => Some((None, ident.value.clone())),
640 Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
641 Some((Some(parts[0].value.clone()), parts[1].value.clone()))
642 }
643 Expr::CompoundIdentifier(parts) => parts.last().map(|p| (None, p.value.clone())),
644 _ => None,
645 }
646}
647
648fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
650 match expr {
651 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
653 Expr::BinaryOp {
655 left: _,
656 op: BinaryOperator::Plus | BinaryOperator::Minus,
657 right,
658 } => extract_time_bound_from_expr(right),
659 Expr::Nested(inner) => extract_time_bound_from_expr(inner),
661 _ => Err(ParseError::StreamingError(format!(
662 "Cannot extract time bound from: {expr:?}"
663 ))),
664 }
665}
666
667fn analyze_asof_match_condition(
671 expr: &Expr,
672) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
673 if let Expr::BinaryOp {
674 left,
675 op: BinaryOperator::And,
676 right,
677 } = expr
678 {
679 let dir_result = analyze_asof_direction(left);
681 let tol_result = extract_asof_tolerance(right);
682
683 match (dir_result, tol_result) {
684 (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
685 (Ok((dir, lt, rt)), Err(_)) => {
686 let dir2 = analyze_asof_direction(right);
688 let tol2 = extract_asof_tolerance(left);
689 match (dir2, tol2) {
690 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
691 _ => Ok((dir, lt, rt, None)),
692 }
693 }
694 (Err(_), _) => {
695 let dir2 = analyze_asof_direction(right);
697 let tol2 = extract_asof_tolerance(left);
698 match (dir2, tol2) {
699 (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
700 (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
701 _ => Err(ParseError::StreamingError(
702 "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
703 )),
704 }
705 }
706 }
707 } else {
708 let (dir, lt, rt) = analyze_asof_direction(expr)?;
709 Ok((dir, lt, rt, None))
710 }
711}
712
713fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
715 match expr {
716 Expr::BinaryOp {
717 left,
718 op: BinaryOperator::GtEq,
719 right,
720 } => {
721 let left_col = extract_column_ref(left).ok_or_else(|| {
722 ParseError::StreamingError(
723 "Cannot extract left time column from MATCH_CONDITION".to_string(),
724 )
725 })?;
726 let right_col = extract_column_ref(right).ok_or_else(|| {
727 ParseError::StreamingError(
728 "Cannot extract right time column from MATCH_CONDITION".to_string(),
729 )
730 })?;
731 Ok((AsofSqlDirection::Backward, left_col, right_col))
732 }
733 Expr::BinaryOp {
734 left,
735 op: BinaryOperator::LtEq,
736 right,
737 } => {
738 let left_col = extract_column_ref(left).ok_or_else(|| {
739 ParseError::StreamingError(
740 "Cannot extract left time column from MATCH_CONDITION".to_string(),
741 )
742 })?;
743 let right_col = extract_column_ref(right).ok_or_else(|| {
744 ParseError::StreamingError(
745 "Cannot extract right time column from MATCH_CONDITION".to_string(),
746 )
747 })?;
748 Ok((AsofSqlDirection::Forward, left_col, right_col))
749 }
750 Expr::Function(func) => {
752 let name = func.name.to_string().to_uppercase();
753 if name != "NEAREST" {
754 return Err(ParseError::StreamingError(format!(
755 "Unknown ASOF MATCH_CONDITION function: {name}"
756 )));
757 }
758 let args = match &func.args {
759 FunctionArguments::List(arg_list) => &arg_list.args,
760 _ => {
761 return Err(ParseError::StreamingError(
762 "NEAREST() requires exactly 2 column arguments".to_string(),
763 ))
764 }
765 };
766 if args.len() != 2 {
767 return Err(ParseError::StreamingError(format!(
768 "NEAREST() requires exactly 2 arguments, got {}",
769 args.len()
770 )));
771 }
772 let left_col = extract_column_from_func_arg(&args[0]).ok_or_else(|| {
773 ParseError::StreamingError(
774 "Cannot extract left time column from NEAREST()".to_string(),
775 )
776 })?;
777 let right_col = extract_column_from_func_arg(&args[1]).ok_or_else(|| {
778 ParseError::StreamingError(
779 "Cannot extract right time column from NEAREST()".to_string(),
780 )
781 })?;
782 Ok((AsofSqlDirection::Nearest, left_col, right_col))
783 }
784 _ => Err(ParseError::StreamingError(
785 "ASOF MATCH_CONDITION must be >= or <= comparison, or NEAREST()".to_string(),
786 )),
787 }
788}
789
790fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
794 match expr {
795 Expr::BinaryOp {
796 left: _,
797 op: BinaryOperator::LtEq,
798 right,
799 } => {
800 match right.as_ref() {
802 Expr::Value(v) => {
803 if let sqlparser::ast::Value::Number(n, _) = &v.value {
804 let ms: u64 = n.parse().map_err(|_| {
805 ParseError::StreamingError(format!(
806 "Cannot parse tolerance as number: {n}"
807 ))
808 })?;
809 Ok(Duration::from_millis(ms))
810 } else {
811 Err(ParseError::StreamingError(
812 "ASOF tolerance must be a number or INTERVAL".to_string(),
813 ))
814 }
815 }
816 Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
817 _ => Err(ParseError::StreamingError(
818 "ASOF tolerance must be a number or INTERVAL".to_string(),
819 )),
820 }
821 }
822 _ => Err(ParseError::StreamingError(
823 "ASOF tolerance expression must be <= comparison".to_string(),
824 )),
825 }
826}
827
828fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
830 match constraint {
831 JoinConstraint::On(expr) => extract_equality_columns(expr),
832 JoinConstraint::Using(cols) => {
833 if cols.is_empty() {
834 return Err(ParseError::StreamingError(
835 "USING clause requires at least one column".to_string(),
836 ));
837 }
838 let col = cols[0].to_string();
839 Ok((col.clone(), col))
840 }
841 _ => Err(ParseError::StreamingError(
842 "ASOF JOIN requires ON or USING constraint".to_string(),
843 )),
844 }
845}
846
847fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
849 match expr {
850 Expr::BinaryOp {
851 left,
852 op: BinaryOperator::Eq,
853 right,
854 } => {
855 let left_col = extract_column_ref(left).ok_or_else(|| {
856 ParseError::StreamingError("Cannot extract left key column".to_string())
857 })?;
858 let right_col = extract_column_ref(right).ok_or_else(|| {
859 ParseError::StreamingError("Cannot extract right key column".to_string())
860 })?;
861 Ok((left_col, right_col))
862 }
863 Expr::BinaryOp {
865 left,
866 op: BinaryOperator::And,
867 right,
868 } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
869 _ => Err(ParseError::StreamingError(
870 "ASOF JOIN ON clause must contain an equality condition".to_string(),
871 )),
872 }
873}
874
875#[must_use]
877pub fn has_join(select: &Select) -> bool {
878 !select.from.is_empty() && !select.from[0].joins.is_empty()
879}
880
881#[must_use]
883pub fn count_joins(select: &Select) -> usize {
884 select
885 .from
886 .iter()
887 .map(|table_with_joins| table_with_joins.joins.len())
888 .sum()
889}
890
891#[derive(Debug, Clone)]
896pub struct MultiJoinAnalysis {
897 pub joins: Vec<JoinAnalysis>,
899 pub tables: Vec<String>,
901}
902
903impl MultiJoinAnalysis {
904 #[must_use]
906 pub fn len(&self) -> usize {
907 self.joins.len()
908 }
909
910 #[must_use]
912 pub fn is_empty(&self) -> bool {
913 self.joins.is_empty()
914 }
915
916 #[must_use]
918 pub fn is_single(&self) -> bool {
919 self.joins.len() == 1
920 }
921
922 #[must_use]
924 pub fn first(&self) -> Option<&JoinAnalysis> {
925 self.joins.first()
926 }
927}
928
929pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
940 let from = &select.from;
941 if from.is_empty() {
942 return Ok(None);
943 }
944
945 let first_table = &from[0];
946 if first_table.joins.is_empty() {
947 return Ok(None);
948 }
949
950 let base_table = extract_table_name(&first_table.relation)?;
952 let base_alias = extract_table_alias(&first_table.relation);
953
954 let mut join_steps = Vec::with_capacity(first_table.joins.len());
955 let mut tables = vec![base_table.clone()];
956
957 let mut prev_left_table = base_table;
959 let mut prev_left_alias = base_alias;
960
961 for join in &first_table.joins {
962 let right_table = extract_table_name(&join.relation)?;
963 let right_alias = extract_table_alias(&join.relation);
964 tables.push(right_table.clone());
965
966 let join_type = map_join_operator(&join.join_operator);
967
968 if let JoinOperator::AsOf {
970 match_condition,
971 constraint,
972 } = &join.join_operator
973 {
974 let (direction, left_time, right_time, tolerance) =
975 analyze_asof_match_condition(match_condition)?;
976 let (left_key, right_key) = analyze_asof_constraint(constraint)?;
977
978 let mut analysis = JoinAnalysis::asof(
979 prev_left_table.clone(),
980 right_table.clone(),
981 left_key,
982 right_key,
983 direction,
984 left_time,
985 right_time,
986 tolerance,
987 );
988 analysis.left_alias.clone_from(&prev_left_alias);
989 analysis.right_alias = right_alias;
990 join_steps.push(analysis);
991 } else if let Some(version_col) = extract_temporal_version(&join.relation) {
992 let (left_key, right_key, additional, _, _) =
994 analyze_join_constraint(&join.join_operator)?;
995
996 let mut analysis = JoinAnalysis::temporal(
997 prev_left_table.clone(),
998 right_table.clone(),
999 left_key,
1000 right_key,
1001 version_col,
1002 join_type,
1003 );
1004 analysis.left_alias.clone_from(&prev_left_alias);
1005 analysis.right_alias = right_alias;
1006 analysis.additional_key_columns = additional;
1007 join_steps.push(analysis);
1008 } else {
1009 let (left_key, right_key, additional, time_bound, time_cols) =
1011 analyze_join_constraint(&join.join_operator)?;
1012
1013 let mut analysis = if let Some(tb) = time_bound {
1014 JoinAnalysis::stream_stream(
1015 prev_left_table.clone(),
1016 right_table.clone(),
1017 left_key,
1018 right_key,
1019 tb,
1020 join_type,
1021 )
1022 } else {
1023 JoinAnalysis::lookup(
1024 prev_left_table.clone(),
1025 right_table.clone(),
1026 left_key,
1027 right_key,
1028 join_type,
1029 )
1030 };
1031 analysis.left_alias.clone_from(&prev_left_alias);
1032 analysis.right_alias.clone_from(&right_alias);
1033 analysis.additional_key_columns = additional;
1034
1035 if let Some(ref raw) = time_cols {
1036 let (lt, rt) = resolve_time_cols(
1037 raw,
1038 &analysis.left_table,
1039 &analysis.right_table,
1040 prev_left_alias.as_deref(),
1041 right_alias.as_deref(),
1042 );
1043 analysis.left_time_column = Some(lt);
1044 analysis.right_time_column = Some(rt);
1045 }
1046 join_steps.push(analysis);
1047 }
1048
1049 prev_left_table = right_table;
1051 prev_left_alias = extract_table_alias(&join.relation);
1052 }
1053
1054 Ok(Some(MultiJoinAnalysis {
1055 joins: join_steps,
1056 tables,
1057 }))
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062 use super::*;
1063 use sqlparser::ast::{SetExpr, Statement};
1064 use sqlparser::dialect::GenericDialect;
1065 use sqlparser::parser::Parser;
1066
1067 fn parse_select(sql: &str) -> Select {
1068 let dialect = GenericDialect {};
1069 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1070 if let Statement::Query(query) = &statements[0] {
1071 if let SetExpr::Select(select) = query.body.as_ref() {
1072 return *select.clone();
1073 }
1074 }
1075 panic!("Expected SELECT query");
1076 }
1077
1078 #[test]
1079 fn test_analyze_inner_join() {
1080 let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
1081 let select = parse_select(sql);
1082
1083 let analysis = analyze_join(&select).unwrap().unwrap();
1084
1085 assert_eq!(analysis.join_type, JoinType::Inner);
1086 assert_eq!(analysis.left_table, "orders");
1087 assert_eq!(analysis.right_table, "payments");
1088 assert_eq!(analysis.left_key_column, "order_id");
1089 assert_eq!(analysis.right_key_column, "order_id");
1090 assert!(analysis.is_lookup_join); }
1092
1093 #[test]
1094 fn test_analyze_left_join() {
1095 let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
1096 let select = parse_select(sql);
1097
1098 let analysis = analyze_join(&select).unwrap().unwrap();
1099
1100 assert_eq!(analysis.join_type, JoinType::Left);
1101 assert_eq!(analysis.left_key_column, "customer_id");
1102 assert_eq!(analysis.right_key_column, "id");
1103 }
1104
1105 #[test]
1106 fn test_analyze_join_using() {
1107 let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
1108 let select = parse_select(sql);
1109
1110 let analysis = analyze_join(&select).unwrap().unwrap();
1111
1112 assert_eq!(analysis.left_key_column, "order_id");
1113 assert_eq!(analysis.right_key_column, "order_id");
1114 }
1115
1116 #[test]
1117 fn test_analyze_stream_stream_join_with_time_bound() {
1118 let sql = "SELECT * FROM orders o
1119 JOIN payments p ON o.order_id = p.order_id
1120 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
1121 let select = parse_select(sql);
1122
1123 let analysis = analyze_join(&select).unwrap().unwrap();
1124
1125 assert!(!analysis.is_lookup_join);
1126 assert!(analysis.time_bound.is_some());
1127 assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
1128 }
1129
1130 #[test]
1131 fn test_no_join() {
1132 let sql = "SELECT * FROM orders";
1133 let select = parse_select(sql);
1134
1135 let analysis = analyze_join(&select).unwrap();
1136 assert!(analysis.is_none());
1137 }
1138
1139 #[test]
1140 fn test_has_join() {
1141 let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1142 let sql_without_join = "SELECT * FROM orders";
1143
1144 let select_with = parse_select(sql_with_join);
1145 let select_without = parse_select(sql_without_join);
1146
1147 assert!(has_join(&select_with));
1148 assert!(!has_join(&select_without));
1149 }
1150
1151 #[test]
1152 fn test_count_joins() {
1153 let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
1154 let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
1155 let sql_zero = "SELECT * FROM a";
1156
1157 assert_eq!(count_joins(&parse_select(sql_one)), 1);
1158 assert_eq!(count_joins(&parse_select(sql_two)), 2);
1159 assert_eq!(count_joins(&parse_select(sql_zero)), 0);
1160 }
1161
1162 #[test]
1163 fn test_aliases() {
1164 let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
1165 let select = parse_select(sql);
1166
1167 let analysis = analyze_join(&select).unwrap().unwrap();
1168
1169 assert_eq!(analysis.left_alias, Some("o".to_string()));
1170 assert_eq!(analysis.right_alias, Some("p".to_string()));
1171 }
1172
1173 fn parse_select_snowflake(sql: &str) -> Select {
1176 let dialect = sqlparser::dialect::SnowflakeDialect {};
1177 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1178 if let Statement::Query(query) = &statements[0] {
1179 if let SetExpr::Select(select) = query.body.as_ref() {
1180 return *select.clone();
1181 }
1182 }
1183 panic!("Expected SELECT query");
1184 }
1185
1186 fn parse_select_laminar(sql: &str) -> Select {
1187 let dialect = crate::parser::dialect::LaminarDialect::default();
1188 let statements = Parser::parse_sql(&dialect, sql).unwrap();
1189 if let Statement::Query(query) = &statements[0] {
1190 if let SetExpr::Select(select) = query.body.as_ref() {
1191 return *select.clone();
1192 }
1193 }
1194 panic!("Expected SELECT query");
1195 }
1196
1197 #[test]
1198 fn test_asof_join_backward() {
1199 let sql = "SELECT * FROM trades t \
1200 ASOF JOIN quotes q \
1201 MATCH_CONDITION(t.ts >= q.ts) \
1202 ON t.symbol = q.symbol";
1203 let select = parse_select_snowflake(sql);
1204 let analysis = analyze_join(&select).unwrap().unwrap();
1205
1206 assert!(analysis.is_asof_join);
1207 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1208 assert_eq!(analysis.join_type, JoinType::AsOf);
1209 assert!(analysis.asof_tolerance.is_none());
1210 }
1211
1212 #[test]
1213 fn test_asof_join_forward() {
1214 let sql = "SELECT * FROM trades t \
1215 ASOF JOIN quotes q \
1216 MATCH_CONDITION(t.ts <= q.ts) \
1217 ON t.symbol = q.symbol";
1218 let select = parse_select_snowflake(sql);
1219 let analysis = analyze_join(&select).unwrap().unwrap();
1220
1221 assert!(analysis.is_asof_join);
1222 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
1223 }
1224
1225 #[test]
1226 fn test_asof_join_nearest() {
1227 let sql = "SELECT * FROM trades t \
1228 ASOF JOIN quotes q \
1229 MATCH_CONDITION(NEAREST(t.ts, q.ts)) \
1230 ON t.symbol = q.symbol";
1231 let select = parse_select_snowflake(sql);
1232 let analysis = analyze_join(&select).unwrap().unwrap();
1233
1234 assert!(analysis.is_asof_join);
1235 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Nearest));
1236 assert_eq!(analysis.join_type, JoinType::AsOf);
1237 assert!(analysis.asof_tolerance.is_none());
1238 }
1239
1240 #[test]
1241 fn test_asof_join_with_tolerance() {
1242 let sql = "SELECT * FROM trades t \
1243 ASOF JOIN quotes q \
1244 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
1245 ON t.symbol = q.symbol";
1246 let select = parse_select_snowflake(sql);
1247 let analysis = analyze_join(&select).unwrap().unwrap();
1248
1249 assert!(analysis.is_asof_join);
1250 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1251 assert_eq!(analysis.asof_tolerance, Some(Duration::from_millis(5000)));
1252 }
1253
1254 #[test]
1255 fn test_asof_join_with_interval_tolerance() {
1256 let sql = "SELECT * FROM trades t \
1257 ASOF JOIN quotes q \
1258 MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
1259 ON t.symbol = q.symbol";
1260 let select = parse_select_snowflake(sql);
1261 let analysis = analyze_join(&select).unwrap().unwrap();
1262
1263 assert!(analysis.is_asof_join);
1264 assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1265 assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1266 }
1267
1268 #[test]
1269 fn test_asof_join_type_mapping() {
1270 let sql = "SELECT * FROM trades t \
1271 ASOF JOIN quotes q \
1272 MATCH_CONDITION(t.ts >= q.ts) \
1273 ON t.symbol = q.symbol";
1274 let select = parse_select_snowflake(sql);
1275 let analysis = analyze_join(&select).unwrap().unwrap();
1276
1277 assert_eq!(analysis.join_type, JoinType::AsOf);
1278 assert!(!analysis.is_lookup_join);
1279 }
1280
1281 #[test]
1282 fn test_asof_join_extracts_time_columns() {
1283 let sql = "SELECT * FROM trades t \
1284 ASOF JOIN quotes q \
1285 MATCH_CONDITION(t.ts >= q.ts) \
1286 ON t.symbol = q.symbol";
1287 let select = parse_select_snowflake(sql);
1288 let analysis = analyze_join(&select).unwrap().unwrap();
1289
1290 assert_eq!(analysis.left_time_column, Some("ts".to_string()));
1291 assert_eq!(analysis.right_time_column, Some("ts".to_string()));
1292 }
1293
1294 #[test]
1295 fn test_asof_join_extracts_key_columns() {
1296 let sql = "SELECT * FROM trades t \
1297 ASOF JOIN quotes q \
1298 MATCH_CONDITION(t.ts >= q.ts) \
1299 ON t.symbol = q.symbol";
1300 let select = parse_select_snowflake(sql);
1301 let analysis = analyze_join(&select).unwrap().unwrap();
1302
1303 assert_eq!(analysis.left_key_column, "symbol");
1304 assert_eq!(analysis.right_key_column, "symbol");
1305 }
1306
1307 #[test]
1308 fn test_asof_join_aliases() {
1309 let sql = "SELECT * FROM trades AS t \
1310 ASOF JOIN quotes AS q \
1311 MATCH_CONDITION(t.ts >= q.ts) \
1312 ON t.symbol = q.symbol";
1313 let select = parse_select_snowflake(sql);
1314 let analysis = analyze_join(&select).unwrap().unwrap();
1315
1316 assert_eq!(analysis.left_alias, Some("t".to_string()));
1317 assert_eq!(analysis.right_alias, Some("q".to_string()));
1318 assert_eq!(analysis.left_table, "trades");
1319 assert_eq!(analysis.right_table, "quotes");
1320 }
1321
1322 #[test]
1325 fn test_multi_join_single_backward_compat() {
1326 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1327 let select = parse_select(sql);
1328 let multi = analyze_joins(&select).unwrap().unwrap();
1329
1330 assert!(multi.is_single());
1331 assert_eq!(multi.len(), 1);
1332 assert!(!multi.is_empty());
1333 let first = multi.first().unwrap();
1334 assert_eq!(first.left_table, "orders");
1335 assert_eq!(first.right_table, "payments");
1336 }
1337
1338 #[test]
1339 fn test_multi_join_two_way() {
1340 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1341 let select = parse_select(sql);
1342 let multi = analyze_joins(&select).unwrap().unwrap();
1343
1344 assert_eq!(multi.len(), 2);
1345 assert!(!multi.is_single());
1346
1347 assert_eq!(multi.joins[0].left_table, "a");
1348 assert_eq!(multi.joins[0].right_table, "b");
1349 assert_eq!(multi.joins[0].left_key_column, "id");
1350 assert_eq!(multi.joins[0].right_key_column, "a_id");
1351
1352 assert_eq!(multi.joins[1].left_table, "b");
1353 assert_eq!(multi.joins[1].right_table, "c");
1354 assert_eq!(multi.joins[1].left_key_column, "id");
1355 assert_eq!(multi.joins[1].right_key_column, "b_id");
1356 }
1357
1358 #[test]
1359 fn test_multi_join_three_way() {
1360 let sql = "SELECT * FROM a \
1361 JOIN b ON a.id = b.a_id \
1362 JOIN c ON b.id = c.b_id \
1363 JOIN d ON c.id = d.c_id";
1364 let select = parse_select(sql);
1365 let multi = analyze_joins(&select).unwrap().unwrap();
1366
1367 assert_eq!(multi.len(), 3);
1368 assert_eq!(multi.tables.len(), 4);
1369 assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1370 }
1371
1372 #[test]
1373 fn test_multi_join_mixed_asof_and_lookup() {
1374 let sql = "SELECT * FROM trades t \
1376 ASOF JOIN quotes q \
1377 MATCH_CONDITION(t.ts >= q.ts) \
1378 ON t.symbol = q.symbol \
1379 JOIN products p ON q.product_id = p.id";
1380 let select = parse_select_snowflake(sql);
1381 let multi = analyze_joins(&select).unwrap().unwrap();
1382
1383 assert_eq!(multi.len(), 2);
1384 assert!(multi.joins[0].is_asof_join);
1385 assert!(multi.joins[1].is_lookup_join);
1386 }
1387
1388 #[test]
1389 fn test_multi_join_stream_stream_and_lookup() {
1390 let sql = "SELECT * FROM orders o \
1391 JOIN payments p ON o.id = p.order_id \
1392 AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1393 JOIN customers c ON o.customer_id = c.id";
1394 let select = parse_select(sql);
1395 let multi = analyze_joins(&select).unwrap().unwrap();
1396
1397 assert_eq!(multi.len(), 2);
1398 assert!(!multi.joins[0].is_lookup_join); assert!(multi.joins[0].time_bound.is_some());
1400 assert!(multi.joins[1].is_lookup_join); }
1402
1403 #[test]
1404 fn test_multi_join_tables_list() {
1405 let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1406 let select = parse_select(sql);
1407 let multi = analyze_joins(&select).unwrap().unwrap();
1408
1409 assert_eq!(multi.tables, vec!["a", "b", "c"]);
1410 }
1411
1412 #[test]
1413 fn test_multi_join_aliases() {
1414 let sql = "SELECT * FROM orders AS o \
1415 JOIN payments AS p ON o.id = p.order_id \
1416 JOIN refunds AS r ON p.id = r.payment_id";
1417 let select = parse_select(sql);
1418 let multi = analyze_joins(&select).unwrap().unwrap();
1419
1420 assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1421 assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1422 assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1423 assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1424 }
1425
1426 #[test]
1427 fn test_multi_join_no_join_returns_none() {
1428 let sql = "SELECT * FROM orders";
1429 let select = parse_select(sql);
1430 let multi = analyze_joins(&select).unwrap();
1431 assert!(multi.is_none());
1432 }
1433
1434 #[test]
1437 fn test_temporal_join_detected() {
1438 let sql = "SELECT o.*, p.price \
1439 FROM orders o \
1440 JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1441 ON o.product_id = p.id";
1442 let select = parse_select_laminar(sql);
1443 let analysis = analyze_join(&select).unwrap().unwrap();
1444
1445 assert!(analysis.is_temporal_join);
1446 assert_eq!(
1447 analysis.temporal_version_column,
1448 Some("order_time".to_string())
1449 );
1450 assert_eq!(analysis.left_table, "orders");
1451 assert_eq!(analysis.right_table, "products");
1452 assert_eq!(analysis.left_key_column, "product_id");
1453 assert_eq!(analysis.right_key_column, "id");
1454 assert!(!analysis.is_lookup_join);
1455 assert!(!analysis.is_asof_join);
1456 }
1457
1458 #[test]
1459 fn test_temporal_join_via_analyze_joins() {
1460 let sql = "SELECT o.*, p.price \
1461 FROM orders o \
1462 JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1463 ON o.product_id = p.id";
1464 let select = parse_select_laminar(sql);
1465 let multi = analyze_joins(&select).unwrap().unwrap();
1466
1467 assert_eq!(multi.len(), 1);
1468 let first = multi.first().unwrap();
1469 assert!(first.is_temporal_join);
1470 assert_eq!(
1471 first.temporal_version_column,
1472 Some("order_time".to_string())
1473 );
1474 }
1475
1476 #[test]
1477 fn test_non_temporal_join_not_flagged() {
1478 let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1479 let select = parse_select(sql);
1480 let analysis = analyze_join(&select).unwrap().unwrap();
1481
1482 assert!(!analysis.is_temporal_join);
1483 assert!(analysis.temporal_version_column.is_none());
1484 }
1485
1486 #[test]
1487 fn test_unqualified_anti_maps_to_left_anti() {
1488 let sql = "SELECT * FROM orders o ANTI JOIN returns r ON o.id = r.order_id";
1489 let select = parse_select(sql);
1490 let analysis = analyze_join(&select).unwrap().unwrap();
1491 assert_eq!(analysis.join_type, JoinType::LeftAnti);
1492 }
1493
1494 #[test]
1495 fn test_unqualified_semi_maps_to_left_semi() {
1496 let sql = "SELECT * FROM orders o SEMI JOIN payments p ON o.id = p.order_id";
1497 let select = parse_select(sql);
1498 let analysis = analyze_join(&select).unwrap().unwrap();
1499 assert_eq!(analysis.join_type, JoinType::LeftSemi);
1500 }
1501
1502 #[test]
1503 fn test_composite_join_keys() {
1504 let sql = "SELECT * FROM orders o \
1505 JOIN shipments s \
1506 ON o.order_id = s.order_id AND o.region = s.region";
1507 let select = parse_select(sql);
1508 let analysis = analyze_join(&select).unwrap().unwrap();
1509
1510 assert_eq!(analysis.left_key_column, "order_id");
1512 assert_eq!(analysis.right_key_column, "order_id");
1513
1514 assert_eq!(
1516 analysis.additional_key_columns.len(),
1517 1,
1518 "Should have 1 additional key pair"
1519 );
1520 assert_eq!(analysis.additional_key_columns[0].0, "region");
1521 assert_eq!(analysis.additional_key_columns[0].1, "region");
1522 }
1523
1524 #[test]
1525 fn test_composite_using_clause() {
1526 let sql = "SELECT * FROM orders o JOIN shipments s USING (order_id, region)";
1527 let select = parse_select(sql);
1528 let analysis = analyze_join(&select).unwrap().unwrap();
1529
1530 assert_eq!(analysis.left_key_column, "order_id");
1532 assert_eq!(analysis.right_key_column, "order_id");
1533
1534 assert_eq!(
1536 analysis.additional_key_columns.len(),
1537 1,
1538 "USING(order_id, region) should have 1 additional key"
1539 );
1540 assert_eq!(analysis.additional_key_columns[0].0, "region");
1541 assert_eq!(analysis.additional_key_columns[0].1, "region");
1542 }
1543}