Skip to main content

laminar_sql/parser/
join_parser.rs

1//! Join query analysis and extraction
2//!
3//! This module analyzes JOIN clauses to extract:
4//! - Join type (INNER, LEFT, RIGHT, FULL)
5//! - Key columns for join condition
6//! - Time bounds for stream-stream joins
7//! - Detection of lookup joins vs stream-stream joins
8
9use std::time::Duration;
10
11use sqlparser::ast::{
12    BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, JoinConstraint,
13    JoinOperator, Select, TableFactor, TableVersion,
14};
15
16use super::window_rewriter::WindowRewriter;
17use super::ParseError;
18
19/// Join type classification
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum JoinType {
22    /// INNER JOIN
23    Inner,
24    /// LEFT \[OUTER\] JOIN
25    Left,
26    /// RIGHT \[OUTER\] JOIN
27    Right,
28    /// FULL \[OUTER\] JOIN
29    Full,
30    /// LEFT SEMI JOIN — emit left rows with at least one match
31    LeftSemi,
32    /// LEFT ANTI JOIN — emit left rows with no match
33    LeftAnti,
34    /// RIGHT SEMI JOIN — emit right rows with at least one match
35    RightSemi,
36    /// RIGHT ANTI JOIN — emit right rows with no match
37    RightAnti,
38    /// ASOF JOIN
39    AsOf,
40}
41
42/// Direction for ASOF JOIN time matching.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum AsofSqlDirection {
45    /// `left.ts >= right.ts` — find most recent right row
46    Backward,
47    /// `left.ts <= right.ts` — find next right row
48    Forward,
49    /// Match by minimum absolute time difference
50    Nearest,
51}
52
53impl std::fmt::Display for AsofSqlDirection {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            AsofSqlDirection::Backward => write!(f, "BACKWARD"),
57            AsofSqlDirection::Forward => write!(f, "FORWARD"),
58            AsofSqlDirection::Nearest => write!(f, "NEAREST"),
59        }
60    }
61}
62
63/// Analysis result for a JOIN clause
64#[derive(Debug, Clone)]
65pub struct JoinAnalysis {
66    /// Type of join (inner, left, right, full)
67    pub join_type: JoinType,
68    /// Left side table name
69    pub left_table: String,
70    /// Right side table name
71    pub right_table: String,
72    /// Left side key column
73    pub left_key_column: String,
74    /// Right side key column
75    pub right_key_column: String,
76    /// Time bound for stream-stream joins (None for lookup joins)
77    pub time_bound: Option<Duration>,
78    /// Whether this is a lookup join (no time bound)
79    pub is_lookup_join: bool,
80    /// Left side alias (if any)
81    pub left_alias: Option<String>,
82    /// Right side alias (if any)
83    pub right_alias: Option<String>,
84    /// Whether this is an ASOF join
85    pub is_asof_join: bool,
86    /// ASOF join direction (Backward or Forward)
87    pub asof_direction: Option<AsofSqlDirection>,
88    /// Left side time column for ASOF join
89    pub left_time_column: Option<String>,
90    /// Right side time column for ASOF join
91    pub right_time_column: Option<String>,
92    /// ASOF join tolerance (max time difference)
93    pub asof_tolerance: Option<Duration>,
94    /// Whether this is a temporal join (FOR SYSTEM_TIME AS OF)
95    pub is_temporal_join: bool,
96    /// The version column from FOR SYSTEM_TIME AS OF (e.g., `order_time`)
97    pub temporal_version_column: Option<String>,
98    /// Additional key columns for composite join keys (beyond the primary key pair)
99    pub additional_key_columns: Vec<(String, String)>,
100}
101
102impl JoinAnalysis {
103    /// Create a stream-stream join analysis
104    #[must_use]
105    pub fn stream_stream(
106        left_table: String,
107        right_table: String,
108        left_key: String,
109        right_key: String,
110        time_bound: Duration,
111        join_type: JoinType,
112    ) -> Self {
113        Self {
114            join_type,
115            left_table,
116            right_table,
117            left_key_column: left_key,
118            right_key_column: right_key,
119            time_bound: Some(time_bound),
120            is_lookup_join: false,
121            left_alias: None,
122            right_alias: None,
123            is_asof_join: false,
124            asof_direction: None,
125            left_time_column: None,
126            right_time_column: None,
127            asof_tolerance: None,
128            is_temporal_join: false,
129            temporal_version_column: None,
130            additional_key_columns: vec![],
131        }
132    }
133
134    /// Create a lookup join analysis
135    #[must_use]
136    pub fn lookup(
137        left_table: String,
138        right_table: String,
139        left_key: String,
140        right_key: String,
141        join_type: JoinType,
142    ) -> Self {
143        Self {
144            join_type,
145            left_table,
146            right_table,
147            left_key_column: left_key,
148            right_key_column: right_key,
149            time_bound: None,
150            is_lookup_join: true,
151            left_alias: None,
152            right_alias: None,
153            is_asof_join: false,
154            asof_direction: None,
155            left_time_column: None,
156            right_time_column: None,
157            asof_tolerance: None,
158            is_temporal_join: false,
159            temporal_version_column: None,
160            additional_key_columns: vec![],
161        }
162    }
163
164    /// Create an ASOF join analysis
165    #[must_use]
166    #[allow(clippy::too_many_arguments)]
167    pub fn asof(
168        left_table: String,
169        right_table: String,
170        left_key: String,
171        right_key: String,
172        direction: AsofSqlDirection,
173        left_time_col: String,
174        right_time_col: String,
175        tolerance: Option<Duration>,
176    ) -> Self {
177        Self {
178            join_type: JoinType::AsOf,
179            left_table,
180            right_table,
181            left_key_column: left_key,
182            right_key_column: right_key,
183            time_bound: None,
184            is_lookup_join: false,
185            left_alias: None,
186            right_alias: None,
187            is_asof_join: true,
188            asof_direction: Some(direction),
189            left_time_column: Some(left_time_col),
190            right_time_column: Some(right_time_col),
191            asof_tolerance: tolerance,
192            is_temporal_join: false,
193            temporal_version_column: None,
194            additional_key_columns: vec![],
195        }
196    }
197
198    /// Create a temporal join analysis (FOR SYSTEM_TIME AS OF).
199    #[must_use]
200    pub fn temporal(
201        left_table: String,
202        right_table: String,
203        left_key: String,
204        right_key: String,
205        version_column: String,
206        join_type: JoinType,
207    ) -> Self {
208        Self {
209            join_type,
210            left_table,
211            right_table,
212            left_key_column: left_key,
213            right_key_column: right_key,
214            time_bound: None,
215            is_lookup_join: false,
216            left_alias: None,
217            right_alias: None,
218            is_asof_join: false,
219            asof_direction: None,
220            left_time_column: None,
221            right_time_column: None,
222            asof_tolerance: None,
223            is_temporal_join: true,
224            temporal_version_column: Some(version_column),
225            additional_key_columns: vec![],
226        }
227    }
228}
229
230/// Analyze a SELECT statement for join information.
231///
232/// # Errors
233///
234/// Returns `ParseError::StreamingError` if:
235/// - Join constraint is not supported
236/// - Cannot extract key columns
237pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
238    let from = &select.from;
239    if from.is_empty() {
240        return Ok(None);
241    }
242
243    let first_table = &from[0];
244    if first_table.joins.is_empty() {
245        return Ok(None);
246    }
247
248    // Extract left table information
249    let left_table = extract_table_name(&first_table.relation)?;
250    let left_alias = extract_table_alias(&first_table.relation);
251
252    // Analyze the first join
253    let join = &first_table.joins[0];
254    let right_table = extract_table_name(&join.relation)?;
255    let right_alias = extract_table_alias(&join.relation);
256
257    let join_type = map_join_operator(&join.join_operator);
258
259    // Handle ASOF JOIN specially
260    if let JoinOperator::AsOf {
261        match_condition,
262        constraint,
263    } = &join.join_operator
264    {
265        let (direction, left_time, right_time, tolerance) =
266            analyze_asof_match_condition(match_condition)?;
267
268        // Extract key columns from the ON constraint
269        let (left_key, right_key) = analyze_asof_constraint(constraint)?;
270
271        let mut analysis = JoinAnalysis::asof(
272            left_table,
273            right_table,
274            left_key,
275            right_key,
276            direction,
277            left_time,
278            right_time,
279            tolerance,
280        );
281        analysis.left_alias = left_alias;
282        analysis.right_alias = right_alias;
283        return Ok(Some(analysis));
284    }
285
286    // Check for temporal join (FOR SYSTEM_TIME AS OF)
287    if let Some(version_col) = extract_temporal_version(&join.relation) {
288        let (left_key, right_key, additional, _) = analyze_join_constraint(&join.join_operator)?;
289        let mut analysis = JoinAnalysis::temporal(
290            left_table,
291            right_table,
292            left_key,
293            right_key,
294            version_col,
295            join_type,
296        );
297        analysis.left_alias = left_alias;
298        analysis.right_alias = right_alias;
299        analysis.additional_key_columns = additional;
300        return Ok(Some(analysis));
301    }
302
303    // Analyze the join constraint
304    let (left_key, right_key, additional, time_bound) =
305        analyze_join_constraint(&join.join_operator)?;
306
307    let mut analysis = if let Some(tb) = time_bound {
308        JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
309    } else {
310        JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
311    };
312
313    analysis.left_alias = left_alias;
314    analysis.right_alias = right_alias;
315    analysis.additional_key_columns = additional;
316
317    Ok(Some(analysis))
318}
319
320/// Extract table name from a TableFactor.
321fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
322    match factor {
323        TableFactor::Table { name, .. } => Ok(name.to_string()),
324        TableFactor::Derived { alias, .. } => {
325            if let Some(alias) = alias {
326                Ok(alias.name.value.clone())
327            } else {
328                Err(ParseError::StreamingError(
329                    "Derived table without alias not supported".to_string(),
330                ))
331            }
332        }
333        _ => Err(ParseError::StreamingError(
334            "Unsupported table factor type".to_string(),
335        )),
336    }
337}
338
339/// Extract the version column from a temporal join's `FOR SYSTEM_TIME AS OF` clause.
340///
341/// Returns `Some(column_name)` if the table factor has a temporal version qualifier,
342/// `None` otherwise.
343fn extract_temporal_version(factor: &TableFactor) -> Option<String> {
344    if let TableFactor::Table {
345        version: Some(TableVersion::ForSystemTimeAsOf(expr)),
346        ..
347    } = factor
348    {
349        Some(extract_column_name_from_expr(expr))
350    } else {
351        None
352    }
353}
354
355/// Extract a column name from an expression (e.g., `o.order_time` → `order_time`).
356///
357/// Falls back to the full expression string for complex expressions.
358fn extract_column_name_from_expr(expr: &Expr) -> String {
359    match expr {
360        Expr::Identifier(ident) => ident.value.clone(),
361        Expr::CompoundIdentifier(parts) => parts
362            .last()
363            .map_or_else(|| expr.to_string(), |p| p.value.clone()),
364        _ => expr.to_string(),
365    }
366}
367
368/// Extract table alias from a TableFactor.
369fn extract_table_alias(factor: &TableFactor) -> Option<String> {
370    match factor {
371        TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
372        TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
373        _ => None,
374    }
375}
376
377/// Map sqlparser `JoinOperator` to our `JoinType`.
378fn map_join_operator(op: &JoinOperator) -> JoinType {
379    match op {
380        JoinOperator::Inner(_) | JoinOperator::Join(_) | JoinOperator::StraightJoin(_) => {
381            JoinType::Inner
382        }
383        JoinOperator::Left(_) | JoinOperator::LeftOuter(_) => JoinType::Left,
384        JoinOperator::LeftSemi(_) | JoinOperator::Semi(_) => JoinType::LeftSemi,
385        JoinOperator::LeftAnti(_) | JoinOperator::Anti(_) => JoinType::LeftAnti,
386        JoinOperator::AsOf { .. } => JoinType::AsOf,
387        JoinOperator::Right(_) | JoinOperator::RightOuter(_) => JoinType::Right,
388        JoinOperator::RightSemi(_) => JoinType::RightSemi,
389        JoinOperator::RightAnti(_) => JoinType::RightAnti,
390        JoinOperator::FullOuter(_) => JoinType::Full,
391        // CrossJoin, CrossApply, OuterApply are rejected by get_join_constraint()
392        _ => JoinType::Inner,
393    }
394}
395
396/// Analyze join constraint to extract key columns, additional key columns, and time bound.
397#[allow(clippy::type_complexity)]
398fn analyze_join_constraint(
399    op: &JoinOperator,
400) -> Result<(String, String, Vec<(String, String)>, Option<Duration>), ParseError> {
401    let constraint = get_join_constraint(op)?;
402
403    match constraint {
404        JoinConstraint::On(expr) => {
405            let (key_pairs, time_bound) = analyze_on_expression(expr)?;
406            if key_pairs.is_empty() {
407                return Ok((String::new(), String::new(), vec![], time_bound));
408            }
409            let (first_left, first_right) = key_pairs[0].clone();
410            let additional = key_pairs[1..].to_vec();
411            Ok((first_left, first_right, additional, time_bound))
412        }
413        JoinConstraint::Using(cols) => {
414            if cols.is_empty() {
415                return Err(ParseError::StreamingError(
416                    "USING clause requires at least one column".to_string(),
417                ));
418            }
419            // First column is the primary key pair
420            let first_col = cols[0].to_string();
421            // Remaining columns are additional key pairs
422            let additional: Vec<(String, String)> = cols[1..]
423                .iter()
424                .map(|c| {
425                    let col = c.to_string();
426                    (col.clone(), col)
427                })
428                .collect();
429            Ok((first_col.clone(), first_col, additional, None))
430        }
431        JoinConstraint::Natural => Err(ParseError::StreamingError(
432            "NATURAL JOIN not supported for streaming".to_string(),
433        )),
434        JoinConstraint::None => Err(ParseError::StreamingError(
435            "JOIN without condition not supported for streaming".to_string(),
436        )),
437    }
438}
439
440/// Get the JoinConstraint from a JoinOperator.
441fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
442    match op {
443        JoinOperator::Inner(constraint)
444        | JoinOperator::Join(constraint)
445        | JoinOperator::Left(constraint)
446        | JoinOperator::LeftOuter(constraint)
447        | JoinOperator::Right(constraint)
448        | JoinOperator::RightOuter(constraint)
449        | JoinOperator::FullOuter(constraint)
450        | JoinOperator::LeftSemi(constraint)
451        | JoinOperator::RightSemi(constraint)
452        | JoinOperator::LeftAnti(constraint)
453        | JoinOperator::RightAnti(constraint)
454        | JoinOperator::Semi(constraint)
455        | JoinOperator::Anti(constraint)
456        | JoinOperator::StraightJoin(constraint)
457        | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
458        JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
459            ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
460        ),
461    }
462}
463
464/// Analyze ON expression to extract all key column pairs and time bound.
465#[allow(clippy::type_complexity)]
466fn analyze_on_expression(
467    expr: &Expr,
468) -> Result<(Vec<(String, String)>, Option<Duration>), ParseError> {
469    // Handle compound expressions (AND)
470    match expr {
471        Expr::BinaryOp {
472            left,
473            op: BinaryOperator::And,
474            right,
475        } => {
476            // Recursively analyze both sides
477            let left_result = analyze_on_expression(left);
478            let right_result = analyze_on_expression(right);
479
480            // Combine results - collect all key pairs and time bounds
481            match (left_result, right_result) {
482                (Ok((mut lk, lt)), Ok((rk, rt))) => {
483                    lk.extend(rk);
484                    Ok((lk, lt.or(rt)))
485                }
486                (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
487                (Err(e), Err(_)) => Err(e),
488            }
489        }
490        // Equality condition: a.col = b.col
491        Expr::BinaryOp {
492            left,
493            op: BinaryOperator::Eq,
494            right,
495        } => {
496            let left_col = extract_column_ref(left);
497            let right_col = extract_column_ref(right);
498
499            match (left_col, right_col) {
500                (Some(l), Some(r)) => Ok((vec![(l, r)], None)),
501                _ => Err(ParseError::StreamingError(
502                    "Cannot extract column references from equality condition".to_string(),
503                )),
504            }
505        }
506        // BETWEEN clause for time bound: p.ts BETWEEN o.ts AND o.ts + INTERVAL
507        Expr::Between {
508            expr: _,
509            low: _,
510            high,
511            ..
512        } => {
513            // Try to extract time bound from high expression
514            let time_bound = extract_time_bound_from_expr(high).ok();
515            Ok((vec![], time_bound))
516        }
517        // Comparison operators for time bounds
518        Expr::BinaryOp {
519            left: _,
520            op:
521                BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
522            right,
523        } => {
524            // Try to extract time bound from right side
525            let time_bound = extract_time_bound_from_expr(right).ok();
526            Ok((vec![], time_bound))
527        }
528        _ => Err(ParseError::StreamingError(format!(
529            "Unsupported join condition expression: {expr:?}"
530        ))),
531    }
532}
533
534/// Extract column reference from a function argument.
535fn extract_column_from_func_arg(arg: &FunctionArg) -> Option<String> {
536    let (FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))
537    | FunctionArg::Named {
538        arg: FunctionArgExpr::Expr(expr),
539        ..
540    }
541    | FunctionArg::ExprNamed {
542        arg: FunctionArgExpr::Expr(expr),
543        ..
544    }) = arg
545    else {
546        return None;
547    };
548    extract_column_ref(expr)
549}
550
551/// Extract column reference from expression (e.g., "a.id" -> "id")
552fn extract_column_ref(expr: &Expr) -> Option<String> {
553    match expr {
554        Expr::Identifier(ident) => Some(ident.value.clone()),
555        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
556        _ => None,
557    }
558}
559
560/// Extract time bound from an expression like "o.ts + INTERVAL '1' HOUR"
561fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
562    match expr {
563        // Direct interval
564        Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
565        // Addition or subtraction: col +/- INTERVAL
566        Expr::BinaryOp {
567            left: _,
568            op: BinaryOperator::Plus | BinaryOperator::Minus,
569            right,
570        } => extract_time_bound_from_expr(right),
571        // Nested expression
572        Expr::Nested(inner) => extract_time_bound_from_expr(inner),
573        _ => Err(ParseError::StreamingError(format!(
574            "Cannot extract time bound from: {expr:?}"
575        ))),
576    }
577}
578
579/// Analyze ASOF JOIN MATCH_CONDITION expression.
580///
581/// Extracts direction, time column names, and optional tolerance.
582fn analyze_asof_match_condition(
583    expr: &Expr,
584) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
585    if let Expr::BinaryOp {
586        left,
587        op: BinaryOperator::And,
588        right,
589    } = expr
590    {
591        // Try to get direction from left, tolerance from right
592        let dir_result = analyze_asof_direction(left);
593        let tol_result = extract_asof_tolerance(right);
594
595        match (dir_result, tol_result) {
596            (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
597            (Ok((dir, lt, rt)), Err(_)) => {
598                // Maybe tolerance is on left and direction on right
599                let dir2 = analyze_asof_direction(right);
600                let tol2 = extract_asof_tolerance(left);
601                match (dir2, tol2) {
602                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
603                    _ => Ok((dir, lt, rt, None)),
604                }
605            }
606            (Err(_), _) => {
607                // Try reversed
608                let dir2 = analyze_asof_direction(right);
609                let tol2 = extract_asof_tolerance(left);
610                match (dir2, tol2) {
611                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
612                    (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
613                    _ => Err(ParseError::StreamingError(
614                        "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
615                    )),
616                }
617            }
618        }
619    } else {
620        let (dir, lt, rt) = analyze_asof_direction(expr)?;
621        Ok((dir, lt, rt, None))
622    }
623}
624
625/// Extract ASOF direction and time columns from a comparison expression.
626fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
627    match expr {
628        Expr::BinaryOp {
629            left,
630            op: BinaryOperator::GtEq,
631            right,
632        } => {
633            let left_col = extract_column_ref(left).ok_or_else(|| {
634                ParseError::StreamingError(
635                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
636                )
637            })?;
638            let right_col = extract_column_ref(right).ok_or_else(|| {
639                ParseError::StreamingError(
640                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
641                )
642            })?;
643            Ok((AsofSqlDirection::Backward, left_col, right_col))
644        }
645        Expr::BinaryOp {
646            left,
647            op: BinaryOperator::LtEq,
648            right,
649        } => {
650            let left_col = extract_column_ref(left).ok_or_else(|| {
651                ParseError::StreamingError(
652                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
653                )
654            })?;
655            let right_col = extract_column_ref(right).ok_or_else(|| {
656                ParseError::StreamingError(
657                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
658                )
659            })?;
660            Ok((AsofSqlDirection::Forward, left_col, right_col))
661        }
662        // NEAREST(left_col, right_col) — function-style syntax
663        Expr::Function(func) => {
664            let name = func.name.to_string().to_uppercase();
665            if name != "NEAREST" {
666                return Err(ParseError::StreamingError(format!(
667                    "Unknown ASOF MATCH_CONDITION function: {name}"
668                )));
669            }
670            let args = match &func.args {
671                FunctionArguments::List(arg_list) => &arg_list.args,
672                _ => {
673                    return Err(ParseError::StreamingError(
674                        "NEAREST() requires exactly 2 column arguments".to_string(),
675                    ))
676                }
677            };
678            if args.len() != 2 {
679                return Err(ParseError::StreamingError(format!(
680                    "NEAREST() requires exactly 2 arguments, got {}",
681                    args.len()
682                )));
683            }
684            let left_col = extract_column_from_func_arg(&args[0]).ok_or_else(|| {
685                ParseError::StreamingError(
686                    "Cannot extract left time column from NEAREST()".to_string(),
687                )
688            })?;
689            let right_col = extract_column_from_func_arg(&args[1]).ok_or_else(|| {
690                ParseError::StreamingError(
691                    "Cannot extract right time column from NEAREST()".to_string(),
692                )
693            })?;
694            Ok((AsofSqlDirection::Nearest, left_col, right_col))
695        }
696        _ => Err(ParseError::StreamingError(
697            "ASOF MATCH_CONDITION must be >= or <= comparison, or NEAREST()".to_string(),
698        )),
699    }
700}
701
702/// Extract tolerance duration from an ASOF tolerance expression.
703///
704/// Handles: `left - right <= value` or `left - right <= INTERVAL '...'`
705fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
706    match expr {
707        Expr::BinaryOp {
708            left: _,
709            op: BinaryOperator::LtEq,
710            right,
711        } => {
712            // right side is either a literal number or INTERVAL
713            match right.as_ref() {
714                Expr::Value(v) => {
715                    if let sqlparser::ast::Value::Number(n, _) = &v.value {
716                        let ms: u64 = n.parse().map_err(|_| {
717                            ParseError::StreamingError(format!(
718                                "Cannot parse tolerance as number: {n}"
719                            ))
720                        })?;
721                        Ok(Duration::from_millis(ms))
722                    } else {
723                        Err(ParseError::StreamingError(
724                            "ASOF tolerance must be a number or INTERVAL".to_string(),
725                        ))
726                    }
727                }
728                Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
729                _ => Err(ParseError::StreamingError(
730                    "ASOF tolerance must be a number or INTERVAL".to_string(),
731                )),
732            }
733        }
734        _ => Err(ParseError::StreamingError(
735            "ASOF tolerance expression must be <= comparison".to_string(),
736        )),
737    }
738}
739
740/// Extract key columns from an ASOF JOIN constraint (ON clause).
741fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
742    match constraint {
743        JoinConstraint::On(expr) => extract_equality_columns(expr),
744        JoinConstraint::Using(cols) => {
745            if cols.is_empty() {
746                return Err(ParseError::StreamingError(
747                    "USING clause requires at least one column".to_string(),
748                ));
749            }
750            let col = cols[0].to_string();
751            Ok((col.clone(), col))
752        }
753        _ => Err(ParseError::StreamingError(
754            "ASOF JOIN requires ON or USING constraint".to_string(),
755        )),
756    }
757}
758
759/// Extract left and right column names from an equality expression.
760fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
761    match expr {
762        Expr::BinaryOp {
763            left,
764            op: BinaryOperator::Eq,
765            right,
766        } => {
767            let left_col = extract_column_ref(left).ok_or_else(|| {
768                ParseError::StreamingError("Cannot extract left key column".to_string())
769            })?;
770            let right_col = extract_column_ref(right).ok_or_else(|| {
771                ParseError::StreamingError("Cannot extract right key column".to_string())
772            })?;
773            Ok((left_col, right_col))
774        }
775        // If there's an AND, find the equality part
776        Expr::BinaryOp {
777            left,
778            op: BinaryOperator::And,
779            right,
780        } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
781        _ => Err(ParseError::StreamingError(
782            "ASOF JOIN ON clause must contain an equality condition".to_string(),
783        )),
784    }
785}
786
787/// Check if a SELECT contains a join.
788#[must_use]
789pub fn has_join(select: &Select) -> bool {
790    !select.from.is_empty() && !select.from[0].joins.is_empty()
791}
792
793/// Count the number of joins in a SELECT.
794#[must_use]
795pub fn count_joins(select: &Select) -> usize {
796    select
797        .from
798        .iter()
799        .map(|table_with_joins| table_with_joins.joins.len())
800        .sum()
801}
802
803/// Analysis result for multi-way JOINs (e.g., `A JOIN B ... JOIN C ...`).
804///
805/// Each step represents one left-deep join: step 0 joins the base table with
806/// the first right table, step 1 joins the result with the next right table, etc.
807#[derive(Debug, Clone)]
808pub struct MultiJoinAnalysis {
809    /// Ordered join steps (left-to-right)
810    pub joins: Vec<JoinAnalysis>,
811    /// All referenced tables in order (base table first, then each right table)
812    pub tables: Vec<String>,
813}
814
815impl MultiJoinAnalysis {
816    /// Number of join steps.
817    #[must_use]
818    pub fn len(&self) -> usize {
819        self.joins.len()
820    }
821
822    /// Whether there are no join steps.
823    #[must_use]
824    pub fn is_empty(&self) -> bool {
825        self.joins.is_empty()
826    }
827
828    /// Whether this is a single join (backward-compatible case).
829    #[must_use]
830    pub fn is_single(&self) -> bool {
831        self.joins.len() == 1
832    }
833
834    /// The first join step (convenience for single-join queries).
835    #[must_use]
836    pub fn first(&self) -> Option<&JoinAnalysis> {
837        self.joins.first()
838    }
839}
840
841/// Analyze a SELECT statement for all join steps (multi-way).
842///
843/// Returns `None` if the query has no joins. For a single join this
844/// returns a `MultiJoinAnalysis` with one step, making it backward
845/// compatible with `analyze_join()`.
846///
847/// # Errors
848///
849/// Returns `ParseError::StreamingError` if any join constraint is
850/// not supported or key columns cannot be extracted.
851pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
852    let from = &select.from;
853    if from.is_empty() {
854        return Ok(None);
855    }
856
857    let first_table = &from[0];
858    if first_table.joins.is_empty() {
859        return Ok(None);
860    }
861
862    // Extract base table
863    let base_table = extract_table_name(&first_table.relation)?;
864    let base_alias = extract_table_alias(&first_table.relation);
865
866    let mut join_steps = Vec::with_capacity(first_table.joins.len());
867    let mut tables = vec![base_table.clone()];
868
869    // Track the left table name for left-deep chaining
870    let mut prev_left_table = base_table;
871    let mut prev_left_alias = base_alias;
872
873    for join in &first_table.joins {
874        let right_table = extract_table_name(&join.relation)?;
875        let right_alias = extract_table_alias(&join.relation);
876        tables.push(right_table.clone());
877
878        let join_type = map_join_operator(&join.join_operator);
879
880        // Handle ASOF JOIN
881        if let JoinOperator::AsOf {
882            match_condition,
883            constraint,
884        } = &join.join_operator
885        {
886            let (direction, left_time, right_time, tolerance) =
887                analyze_asof_match_condition(match_condition)?;
888            let (left_key, right_key) = analyze_asof_constraint(constraint)?;
889
890            let mut analysis = JoinAnalysis::asof(
891                prev_left_table.clone(),
892                right_table.clone(),
893                left_key,
894                right_key,
895                direction,
896                left_time,
897                right_time,
898                tolerance,
899            );
900            analysis.left_alias.clone_from(&prev_left_alias);
901            analysis.right_alias = right_alias;
902            join_steps.push(analysis);
903        } else if let Some(version_col) = extract_temporal_version(&join.relation) {
904            // Temporal join: right side has FOR SYSTEM_TIME AS OF
905            let (left_key, right_key, additional, _) =
906                analyze_join_constraint(&join.join_operator)?;
907
908            let mut analysis = JoinAnalysis::temporal(
909                prev_left_table.clone(),
910                right_table.clone(),
911                left_key,
912                right_key,
913                version_col,
914                join_type,
915            );
916            analysis.left_alias.clone_from(&prev_left_alias);
917            analysis.right_alias = right_alias;
918            analysis.additional_key_columns = additional;
919            join_steps.push(analysis);
920        } else {
921            // Regular join (inner, left, right, full)
922            let (left_key, right_key, additional, time_bound) =
923                analyze_join_constraint(&join.join_operator)?;
924
925            let mut analysis = if let Some(tb) = time_bound {
926                JoinAnalysis::stream_stream(
927                    prev_left_table.clone(),
928                    right_table.clone(),
929                    left_key,
930                    right_key,
931                    tb,
932                    join_type,
933                )
934            } else {
935                JoinAnalysis::lookup(
936                    prev_left_table.clone(),
937                    right_table.clone(),
938                    left_key,
939                    right_key,
940                    join_type,
941                )
942            };
943            analysis.left_alias.clone_from(&prev_left_alias);
944            analysis.right_alias = right_alias;
945            analysis.additional_key_columns = additional;
946            join_steps.push(analysis);
947        }
948
949        // Next step's left table is this step's right table (left-deep)
950        prev_left_table = right_table;
951        prev_left_alias = extract_table_alias(&join.relation);
952    }
953
954    Ok(Some(MultiJoinAnalysis {
955        joins: join_steps,
956        tables,
957    }))
958}
959
960#[cfg(test)]
961mod tests {
962    use super::*;
963    use sqlparser::ast::{SetExpr, Statement};
964    use sqlparser::dialect::GenericDialect;
965    use sqlparser::parser::Parser;
966
967    fn parse_select(sql: &str) -> Select {
968        let dialect = GenericDialect {};
969        let statements = Parser::parse_sql(&dialect, sql).unwrap();
970        if let Statement::Query(query) = &statements[0] {
971            if let SetExpr::Select(select) = query.body.as_ref() {
972                return *select.clone();
973            }
974        }
975        panic!("Expected SELECT query");
976    }
977
978    #[test]
979    fn test_analyze_inner_join() {
980        let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
981        let select = parse_select(sql);
982
983        let analysis = analyze_join(&select).unwrap().unwrap();
984
985        assert_eq!(analysis.join_type, JoinType::Inner);
986        assert_eq!(analysis.left_table, "orders");
987        assert_eq!(analysis.right_table, "payments");
988        assert_eq!(analysis.left_key_column, "order_id");
989        assert_eq!(analysis.right_key_column, "order_id");
990        assert!(analysis.is_lookup_join); // No time bound = lookup join
991    }
992
993    #[test]
994    fn test_analyze_left_join() {
995        let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
996        let select = parse_select(sql);
997
998        let analysis = analyze_join(&select).unwrap().unwrap();
999
1000        assert_eq!(analysis.join_type, JoinType::Left);
1001        assert_eq!(analysis.left_key_column, "customer_id");
1002        assert_eq!(analysis.right_key_column, "id");
1003    }
1004
1005    #[test]
1006    fn test_analyze_join_using() {
1007        let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
1008        let select = parse_select(sql);
1009
1010        let analysis = analyze_join(&select).unwrap().unwrap();
1011
1012        assert_eq!(analysis.left_key_column, "order_id");
1013        assert_eq!(analysis.right_key_column, "order_id");
1014    }
1015
1016    #[test]
1017    fn test_analyze_stream_stream_join_with_time_bound() {
1018        let sql = "SELECT * FROM orders o
1019                   JOIN payments p ON o.order_id = p.order_id
1020                   AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
1021        let select = parse_select(sql);
1022
1023        let analysis = analyze_join(&select).unwrap().unwrap();
1024
1025        assert!(!analysis.is_lookup_join);
1026        assert!(analysis.time_bound.is_some());
1027        assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
1028    }
1029
1030    #[test]
1031    fn test_no_join() {
1032        let sql = "SELECT * FROM orders";
1033        let select = parse_select(sql);
1034
1035        let analysis = analyze_join(&select).unwrap();
1036        assert!(analysis.is_none());
1037    }
1038
1039    #[test]
1040    fn test_has_join() {
1041        let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1042        let sql_without_join = "SELECT * FROM orders";
1043
1044        let select_with = parse_select(sql_with_join);
1045        let select_without = parse_select(sql_without_join);
1046
1047        assert!(has_join(&select_with));
1048        assert!(!has_join(&select_without));
1049    }
1050
1051    #[test]
1052    fn test_count_joins() {
1053        let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
1054        let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
1055        let sql_zero = "SELECT * FROM a";
1056
1057        assert_eq!(count_joins(&parse_select(sql_one)), 1);
1058        assert_eq!(count_joins(&parse_select(sql_two)), 2);
1059        assert_eq!(count_joins(&parse_select(sql_zero)), 0);
1060    }
1061
1062    #[test]
1063    fn test_aliases() {
1064        let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
1065        let select = parse_select(sql);
1066
1067        let analysis = analyze_join(&select).unwrap().unwrap();
1068
1069        assert_eq!(analysis.left_alias, Some("o".to_string()));
1070        assert_eq!(analysis.right_alias, Some("p".to_string()));
1071    }
1072
1073    // -- ASOF JOIN tests --
1074
1075    fn parse_select_snowflake(sql: &str) -> Select {
1076        let dialect = sqlparser::dialect::SnowflakeDialect {};
1077        let statements = Parser::parse_sql(&dialect, sql).unwrap();
1078        if let Statement::Query(query) = &statements[0] {
1079            if let SetExpr::Select(select) = query.body.as_ref() {
1080                return *select.clone();
1081            }
1082        }
1083        panic!("Expected SELECT query");
1084    }
1085
1086    fn parse_select_laminar(sql: &str) -> Select {
1087        let dialect = crate::parser::dialect::LaminarDialect::default();
1088        let statements = Parser::parse_sql(&dialect, sql).unwrap();
1089        if let Statement::Query(query) = &statements[0] {
1090            if let SetExpr::Select(select) = query.body.as_ref() {
1091                return *select.clone();
1092            }
1093        }
1094        panic!("Expected SELECT query");
1095    }
1096
1097    #[test]
1098    fn test_asof_join_backward() {
1099        let sql = "SELECT * FROM trades t \
1100                    ASOF JOIN quotes q \
1101                    MATCH_CONDITION(t.ts >= q.ts) \
1102                    ON t.symbol = q.symbol";
1103        let select = parse_select_snowflake(sql);
1104        let analysis = analyze_join(&select).unwrap().unwrap();
1105
1106        assert!(analysis.is_asof_join);
1107        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1108        assert_eq!(analysis.join_type, JoinType::AsOf);
1109        assert!(analysis.asof_tolerance.is_none());
1110    }
1111
1112    #[test]
1113    fn test_asof_join_forward() {
1114        let sql = "SELECT * FROM trades t \
1115                    ASOF JOIN quotes q \
1116                    MATCH_CONDITION(t.ts <= q.ts) \
1117                    ON t.symbol = q.symbol";
1118        let select = parse_select_snowflake(sql);
1119        let analysis = analyze_join(&select).unwrap().unwrap();
1120
1121        assert!(analysis.is_asof_join);
1122        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
1123    }
1124
1125    #[test]
1126    fn test_asof_join_nearest() {
1127        let sql = "SELECT * FROM trades t \
1128                    ASOF JOIN quotes q \
1129                    MATCH_CONDITION(NEAREST(t.ts, q.ts)) \
1130                    ON t.symbol = q.symbol";
1131        let select = parse_select_snowflake(sql);
1132        let analysis = analyze_join(&select).unwrap().unwrap();
1133
1134        assert!(analysis.is_asof_join);
1135        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Nearest));
1136        assert_eq!(analysis.join_type, JoinType::AsOf);
1137        assert!(analysis.asof_tolerance.is_none());
1138    }
1139
1140    #[test]
1141    fn test_asof_join_with_tolerance() {
1142        let sql = "SELECT * FROM trades t \
1143                    ASOF JOIN quotes q \
1144                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
1145                    ON t.symbol = q.symbol";
1146        let select = parse_select_snowflake(sql);
1147        let analysis = analyze_join(&select).unwrap().unwrap();
1148
1149        assert!(analysis.is_asof_join);
1150        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1151        assert_eq!(analysis.asof_tolerance, Some(Duration::from_millis(5000)));
1152    }
1153
1154    #[test]
1155    fn test_asof_join_with_interval_tolerance() {
1156        let sql = "SELECT * FROM trades t \
1157                    ASOF JOIN quotes q \
1158                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
1159                    ON t.symbol = q.symbol";
1160        let select = parse_select_snowflake(sql);
1161        let analysis = analyze_join(&select).unwrap().unwrap();
1162
1163        assert!(analysis.is_asof_join);
1164        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1165        assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1166    }
1167
1168    #[test]
1169    fn test_asof_join_type_mapping() {
1170        let sql = "SELECT * FROM trades t \
1171                    ASOF JOIN quotes q \
1172                    MATCH_CONDITION(t.ts >= q.ts) \
1173                    ON t.symbol = q.symbol";
1174        let select = parse_select_snowflake(sql);
1175        let analysis = analyze_join(&select).unwrap().unwrap();
1176
1177        assert_eq!(analysis.join_type, JoinType::AsOf);
1178        assert!(!analysis.is_lookup_join);
1179    }
1180
1181    #[test]
1182    fn test_asof_join_extracts_time_columns() {
1183        let sql = "SELECT * FROM trades t \
1184                    ASOF JOIN quotes q \
1185                    MATCH_CONDITION(t.ts >= q.ts) \
1186                    ON t.symbol = q.symbol";
1187        let select = parse_select_snowflake(sql);
1188        let analysis = analyze_join(&select).unwrap().unwrap();
1189
1190        assert_eq!(analysis.left_time_column, Some("ts".to_string()));
1191        assert_eq!(analysis.right_time_column, Some("ts".to_string()));
1192    }
1193
1194    #[test]
1195    fn test_asof_join_extracts_key_columns() {
1196        let sql = "SELECT * FROM trades t \
1197                    ASOF JOIN quotes q \
1198                    MATCH_CONDITION(t.ts >= q.ts) \
1199                    ON t.symbol = q.symbol";
1200        let select = parse_select_snowflake(sql);
1201        let analysis = analyze_join(&select).unwrap().unwrap();
1202
1203        assert_eq!(analysis.left_key_column, "symbol");
1204        assert_eq!(analysis.right_key_column, "symbol");
1205    }
1206
1207    #[test]
1208    fn test_asof_join_aliases() {
1209        let sql = "SELECT * FROM trades AS t \
1210                    ASOF JOIN quotes AS q \
1211                    MATCH_CONDITION(t.ts >= q.ts) \
1212                    ON t.symbol = q.symbol";
1213        let select = parse_select_snowflake(sql);
1214        let analysis = analyze_join(&select).unwrap().unwrap();
1215
1216        assert_eq!(analysis.left_alias, Some("t".to_string()));
1217        assert_eq!(analysis.right_alias, Some("q".to_string()));
1218        assert_eq!(analysis.left_table, "trades");
1219        assert_eq!(analysis.right_table, "quotes");
1220    }
1221
1222    // -- Multi-way JOIN tests --
1223
1224    #[test]
1225    fn test_multi_join_single_backward_compat() {
1226        let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1227        let select = parse_select(sql);
1228        let multi = analyze_joins(&select).unwrap().unwrap();
1229
1230        assert!(multi.is_single());
1231        assert_eq!(multi.len(), 1);
1232        assert!(!multi.is_empty());
1233        let first = multi.first().unwrap();
1234        assert_eq!(first.left_table, "orders");
1235        assert_eq!(first.right_table, "payments");
1236    }
1237
1238    #[test]
1239    fn test_multi_join_two_way() {
1240        let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1241        let select = parse_select(sql);
1242        let multi = analyze_joins(&select).unwrap().unwrap();
1243
1244        assert_eq!(multi.len(), 2);
1245        assert!(!multi.is_single());
1246
1247        assert_eq!(multi.joins[0].left_table, "a");
1248        assert_eq!(multi.joins[0].right_table, "b");
1249        assert_eq!(multi.joins[0].left_key_column, "id");
1250        assert_eq!(multi.joins[0].right_key_column, "a_id");
1251
1252        assert_eq!(multi.joins[1].left_table, "b");
1253        assert_eq!(multi.joins[1].right_table, "c");
1254        assert_eq!(multi.joins[1].left_key_column, "id");
1255        assert_eq!(multi.joins[1].right_key_column, "b_id");
1256    }
1257
1258    #[test]
1259    fn test_multi_join_three_way() {
1260        let sql = "SELECT * FROM a \
1261                    JOIN b ON a.id = b.a_id \
1262                    JOIN c ON b.id = c.b_id \
1263                    JOIN d ON c.id = d.c_id";
1264        let select = parse_select(sql);
1265        let multi = analyze_joins(&select).unwrap().unwrap();
1266
1267        assert_eq!(multi.len(), 3);
1268        assert_eq!(multi.tables.len(), 4);
1269        assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1270    }
1271
1272    #[test]
1273    fn test_multi_join_mixed_asof_and_lookup() {
1274        // ASOF first, then lookup (use Snowflake dialect for ASOF)
1275        let sql = "SELECT * FROM trades t \
1276                    ASOF JOIN quotes q \
1277                    MATCH_CONDITION(t.ts >= q.ts) \
1278                    ON t.symbol = q.symbol \
1279                    JOIN products p ON q.product_id = p.id";
1280        let select = parse_select_snowflake(sql);
1281        let multi = analyze_joins(&select).unwrap().unwrap();
1282
1283        assert_eq!(multi.len(), 2);
1284        assert!(multi.joins[0].is_asof_join);
1285        assert!(multi.joins[1].is_lookup_join);
1286    }
1287
1288    #[test]
1289    fn test_multi_join_stream_stream_and_lookup() {
1290        let sql = "SELECT * FROM orders o \
1291                    JOIN payments p ON o.id = p.order_id \
1292                        AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1293                    JOIN customers c ON o.customer_id = c.id";
1294        let select = parse_select(sql);
1295        let multi = analyze_joins(&select).unwrap().unwrap();
1296
1297        assert_eq!(multi.len(), 2);
1298        assert!(!multi.joins[0].is_lookup_join); // stream-stream
1299        assert!(multi.joins[0].time_bound.is_some());
1300        assert!(multi.joins[1].is_lookup_join); // lookup
1301    }
1302
1303    #[test]
1304    fn test_multi_join_tables_list() {
1305        let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1306        let select = parse_select(sql);
1307        let multi = analyze_joins(&select).unwrap().unwrap();
1308
1309        assert_eq!(multi.tables, vec!["a", "b", "c"]);
1310    }
1311
1312    #[test]
1313    fn test_multi_join_aliases() {
1314        let sql = "SELECT * FROM orders AS o \
1315                    JOIN payments AS p ON o.id = p.order_id \
1316                    JOIN refunds AS r ON p.id = r.payment_id";
1317        let select = parse_select(sql);
1318        let multi = analyze_joins(&select).unwrap().unwrap();
1319
1320        assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1321        assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1322        assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1323        assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1324    }
1325
1326    #[test]
1327    fn test_multi_join_no_join_returns_none() {
1328        let sql = "SELECT * FROM orders";
1329        let select = parse_select(sql);
1330        let multi = analyze_joins(&select).unwrap();
1331        assert!(multi.is_none());
1332    }
1333
1334    // -- Temporal JOIN tests (FOR SYSTEM_TIME AS OF) --
1335
1336    #[test]
1337    fn test_temporal_join_detected() {
1338        let sql = "SELECT o.*, p.price \
1339                    FROM orders o \
1340                    JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1341                    ON o.product_id = p.id";
1342        let select = parse_select_laminar(sql);
1343        let analysis = analyze_join(&select).unwrap().unwrap();
1344
1345        assert!(analysis.is_temporal_join);
1346        assert_eq!(
1347            analysis.temporal_version_column,
1348            Some("order_time".to_string())
1349        );
1350        assert_eq!(analysis.left_table, "orders");
1351        assert_eq!(analysis.right_table, "products");
1352        assert_eq!(analysis.left_key_column, "product_id");
1353        assert_eq!(analysis.right_key_column, "id");
1354        assert!(!analysis.is_lookup_join);
1355        assert!(!analysis.is_asof_join);
1356    }
1357
1358    #[test]
1359    fn test_temporal_join_via_analyze_joins() {
1360        let sql = "SELECT o.*, p.price \
1361                    FROM orders o \
1362                    JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1363                    ON o.product_id = p.id";
1364        let select = parse_select_laminar(sql);
1365        let multi = analyze_joins(&select).unwrap().unwrap();
1366
1367        assert_eq!(multi.len(), 1);
1368        let first = multi.first().unwrap();
1369        assert!(first.is_temporal_join);
1370        assert_eq!(
1371            first.temporal_version_column,
1372            Some("order_time".to_string())
1373        );
1374    }
1375
1376    #[test]
1377    fn test_non_temporal_join_not_flagged() {
1378        let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1379        let select = parse_select(sql);
1380        let analysis = analyze_join(&select).unwrap().unwrap();
1381
1382        assert!(!analysis.is_temporal_join);
1383        assert!(analysis.temporal_version_column.is_none());
1384    }
1385
1386    #[test]
1387    fn test_unqualified_anti_maps_to_left_anti() {
1388        let sql = "SELECT * FROM orders o ANTI JOIN returns r ON o.id = r.order_id";
1389        let select = parse_select(sql);
1390        let analysis = analyze_join(&select).unwrap().unwrap();
1391        assert_eq!(analysis.join_type, JoinType::LeftAnti);
1392    }
1393
1394    #[test]
1395    fn test_unqualified_semi_maps_to_left_semi() {
1396        let sql = "SELECT * FROM orders o SEMI JOIN payments p ON o.id = p.order_id";
1397        let select = parse_select(sql);
1398        let analysis = analyze_join(&select).unwrap().unwrap();
1399        assert_eq!(analysis.join_type, JoinType::LeftSemi);
1400    }
1401
1402    #[test]
1403    fn test_composite_join_keys() {
1404        let sql = "SELECT * FROM orders o \
1405                    JOIN shipments s \
1406                    ON o.order_id = s.order_id AND o.region = s.region";
1407        let select = parse_select(sql);
1408        let analysis = analyze_join(&select).unwrap().unwrap();
1409
1410        // First key pair is the primary key
1411        assert_eq!(analysis.left_key_column, "order_id");
1412        assert_eq!(analysis.right_key_column, "order_id");
1413
1414        // Second key pair should be in additional_key_columns
1415        assert_eq!(
1416            analysis.additional_key_columns.len(),
1417            1,
1418            "Should have 1 additional key pair"
1419        );
1420        assert_eq!(analysis.additional_key_columns[0].0, "region");
1421        assert_eq!(analysis.additional_key_columns[0].1, "region");
1422    }
1423
1424    #[test]
1425    fn test_composite_using_clause() {
1426        let sql = "SELECT * FROM orders o JOIN shipments s USING (order_id, region)";
1427        let select = parse_select(sql);
1428        let analysis = analyze_join(&select).unwrap().unwrap();
1429
1430        // First column becomes primary key
1431        assert_eq!(analysis.left_key_column, "order_id");
1432        assert_eq!(analysis.right_key_column, "order_id");
1433
1434        // Additional columns
1435        assert_eq!(
1436            analysis.additional_key_columns.len(),
1437            1,
1438            "USING(order_id, region) should have 1 additional key"
1439        );
1440        assert_eq!(analysis.additional_key_columns[0].0, "region");
1441        assert_eq!(analysis.additional_key_columns[0].1, "region");
1442    }
1443}