Skip to main content

laminar_sql/parser/
join_parser.rs

1//! Join query analysis and extraction
2//!
3//! This module analyzes JOIN clauses to extract:
4//! - Join type (INNER, LEFT, RIGHT, FULL)
5//! - Key columns for join condition
6//! - Time bounds for stream-stream joins
7//! - Detection of lookup joins vs stream-stream joins
8
9use std::time::Duration;
10
11use sqlparser::ast::{
12    BinaryOperator, Expr, FunctionArg, FunctionArgExpr, FunctionArguments, JoinConstraint,
13    JoinOperator, Select, TableFactor, TableVersion,
14};
15
16use super::window_rewriter::WindowRewriter;
17use super::ParseError;
18
19/// Join type classification
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum JoinType {
22    /// INNER JOIN
23    Inner,
24    /// LEFT \[OUTER\] JOIN
25    Left,
26    /// RIGHT \[OUTER\] JOIN
27    Right,
28    /// FULL \[OUTER\] JOIN
29    Full,
30    /// LEFT SEMI JOIN — emit left rows with at least one match
31    LeftSemi,
32    /// LEFT ANTI JOIN — emit left rows with no match
33    LeftAnti,
34    /// RIGHT SEMI JOIN — emit right rows with at least one match
35    RightSemi,
36    /// RIGHT ANTI JOIN — emit right rows with no match
37    RightAnti,
38    /// ASOF JOIN
39    AsOf,
40}
41
42/// Direction for ASOF JOIN time matching.
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
44pub enum AsofSqlDirection {
45    /// `left.ts >= right.ts` — find most recent right row
46    Backward,
47    /// `left.ts <= right.ts` — find next right row
48    Forward,
49    /// Match by minimum absolute time difference
50    Nearest,
51}
52
53impl std::fmt::Display for AsofSqlDirection {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        match self {
56            AsofSqlDirection::Backward => write!(f, "BACKWARD"),
57            AsofSqlDirection::Forward => write!(f, "FORWARD"),
58            AsofSqlDirection::Nearest => write!(f, "NEAREST"),
59        }
60    }
61}
62
63/// Unresolved time column refs from a BETWEEN clause.
64#[derive(Debug, Clone)]
65struct RawTimeCols {
66    expr_qualifier: Option<String>,
67    expr_col: String,
68    low_qualifier: Option<String>,
69    low_col: String,
70}
71
72/// Resolve BETWEEN time columns to `(left_time_col, right_time_col)` using
73/// table qualifiers. Falls back to positional (low=left, expr=right).
74fn resolve_time_cols(
75    raw: &RawTimeCols,
76    left_table: &str,
77    right_table: &str,
78    left_alias: Option<&str>,
79    right_alias: Option<&str>,
80) -> (String, String) {
81    let matches_left = |q: &Option<String>| -> bool {
82        q.as_ref()
83            .is_some_and(|t| t == left_table || left_alias.is_some_and(|a| a == t))
84    };
85    let matches_right = |q: &Option<String>| -> bool {
86        q.as_ref()
87            .is_some_and(|t| t == right_table || right_alias.is_some_and(|a| a == t))
88    };
89
90    if matches_right(&raw.expr_qualifier) && matches_left(&raw.low_qualifier) {
91        (raw.low_col.clone(), raw.expr_col.clone())
92    } else if matches_left(&raw.expr_qualifier) && matches_right(&raw.low_qualifier) {
93        (raw.expr_col.clone(), raw.low_col.clone())
94    } else {
95        (raw.low_col.clone(), raw.expr_col.clone())
96    }
97}
98
99/// Analysis result for a JOIN clause
100#[derive(Debug, Clone)]
101pub struct JoinAnalysis {
102    /// Type of join (inner, left, right, full)
103    pub join_type: JoinType,
104    /// Left side table name
105    pub left_table: String,
106    /// Right side table name
107    pub right_table: String,
108    /// Left side key column
109    pub left_key_column: String,
110    /// Right side key column
111    pub right_key_column: String,
112    /// Time bound for stream-stream joins (None for lookup joins)
113    pub time_bound: Option<Duration>,
114    /// Whether this is a lookup join (no time bound)
115    pub is_lookup_join: bool,
116    /// Left side alias (if any)
117    pub left_alias: Option<String>,
118    /// Right side alias (if any)
119    pub right_alias: Option<String>,
120    /// Whether this is an ASOF join
121    pub is_asof_join: bool,
122    /// ASOF join direction (Backward or Forward)
123    pub asof_direction: Option<AsofSqlDirection>,
124    /// Left side time column for ASOF join
125    pub left_time_column: Option<String>,
126    /// Right side time column for ASOF join
127    pub right_time_column: Option<String>,
128    /// ASOF join tolerance (max time difference)
129    pub asof_tolerance: Option<Duration>,
130    /// Whether this is a temporal join (FOR SYSTEM_TIME AS OF)
131    pub is_temporal_join: bool,
132    /// The version column from FOR SYSTEM_TIME AS OF (e.g., `order_time`)
133    pub temporal_version_column: Option<String>,
134    /// Additional key columns for composite join keys (beyond the primary key pair)
135    pub additional_key_columns: Vec<(String, String)>,
136}
137
138impl JoinAnalysis {
139    /// Create a stream-stream join analysis
140    #[must_use]
141    pub fn stream_stream(
142        left_table: String,
143        right_table: String,
144        left_key: String,
145        right_key: String,
146        time_bound: Duration,
147        join_type: JoinType,
148    ) -> Self {
149        Self {
150            join_type,
151            left_table,
152            right_table,
153            left_key_column: left_key,
154            right_key_column: right_key,
155            time_bound: Some(time_bound),
156            is_lookup_join: false,
157            left_alias: None,
158            right_alias: None,
159            is_asof_join: false,
160            asof_direction: None,
161            left_time_column: None,
162            right_time_column: None,
163            asof_tolerance: None,
164            is_temporal_join: false,
165            temporal_version_column: None,
166            additional_key_columns: vec![],
167        }
168    }
169
170    /// Create a lookup join analysis
171    #[must_use]
172    pub fn lookup(
173        left_table: String,
174        right_table: String,
175        left_key: String,
176        right_key: String,
177        join_type: JoinType,
178    ) -> Self {
179        Self {
180            join_type,
181            left_table,
182            right_table,
183            left_key_column: left_key,
184            right_key_column: right_key,
185            time_bound: None,
186            is_lookup_join: true,
187            left_alias: None,
188            right_alias: None,
189            is_asof_join: false,
190            asof_direction: None,
191            left_time_column: None,
192            right_time_column: None,
193            asof_tolerance: None,
194            is_temporal_join: false,
195            temporal_version_column: None,
196            additional_key_columns: vec![],
197        }
198    }
199
200    /// Create an ASOF join analysis
201    #[must_use]
202    #[allow(clippy::too_many_arguments)]
203    pub fn asof(
204        left_table: String,
205        right_table: String,
206        left_key: String,
207        right_key: String,
208        direction: AsofSqlDirection,
209        left_time_col: String,
210        right_time_col: String,
211        tolerance: Option<Duration>,
212    ) -> Self {
213        Self {
214            join_type: JoinType::AsOf,
215            left_table,
216            right_table,
217            left_key_column: left_key,
218            right_key_column: right_key,
219            time_bound: None,
220            is_lookup_join: false,
221            left_alias: None,
222            right_alias: None,
223            is_asof_join: true,
224            asof_direction: Some(direction),
225            left_time_column: Some(left_time_col),
226            right_time_column: Some(right_time_col),
227            asof_tolerance: tolerance,
228            is_temporal_join: false,
229            temporal_version_column: None,
230            additional_key_columns: vec![],
231        }
232    }
233
234    /// Create a temporal join analysis (FOR SYSTEM_TIME AS OF).
235    #[must_use]
236    pub fn temporal(
237        left_table: String,
238        right_table: String,
239        left_key: String,
240        right_key: String,
241        version_column: String,
242        join_type: JoinType,
243    ) -> Self {
244        Self {
245            join_type,
246            left_table,
247            right_table,
248            left_key_column: left_key,
249            right_key_column: right_key,
250            time_bound: None,
251            is_lookup_join: false,
252            left_alias: None,
253            right_alias: None,
254            is_asof_join: false,
255            asof_direction: None,
256            left_time_column: None,
257            right_time_column: None,
258            asof_tolerance: None,
259            is_temporal_join: true,
260            temporal_version_column: Some(version_column),
261            additional_key_columns: vec![],
262        }
263    }
264}
265
266/// Analyze a SELECT statement for join information.
267///
268/// # Errors
269///
270/// Returns `ParseError::StreamingError` if:
271/// - Join constraint is not supported
272/// - Cannot extract key columns
273pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
274    let from = &select.from;
275    if from.is_empty() {
276        return Ok(None);
277    }
278
279    let first_table = &from[0];
280    if first_table.joins.is_empty() {
281        return Ok(None);
282    }
283
284    // Extract left table information
285    let left_table = extract_table_name(&first_table.relation)?;
286    let left_alias = extract_table_alias(&first_table.relation);
287
288    // Analyze the first join
289    let join = &first_table.joins[0];
290    let right_table = extract_table_name(&join.relation)?;
291    let right_alias = extract_table_alias(&join.relation);
292
293    let join_type = map_join_operator(&join.join_operator);
294
295    // Handle ASOF JOIN specially
296    if let JoinOperator::AsOf {
297        match_condition,
298        constraint,
299    } = &join.join_operator
300    {
301        let (direction, left_time, right_time, tolerance) =
302            analyze_asof_match_condition(match_condition)?;
303
304        // Extract key columns from the ON constraint
305        let (left_key, right_key) = analyze_asof_constraint(constraint)?;
306
307        let mut analysis = JoinAnalysis::asof(
308            left_table,
309            right_table,
310            left_key,
311            right_key,
312            direction,
313            left_time,
314            right_time,
315            tolerance,
316        );
317        analysis.left_alias = left_alias;
318        analysis.right_alias = right_alias;
319        return Ok(Some(analysis));
320    }
321
322    // Check for temporal join (FOR SYSTEM_TIME AS OF)
323    if let Some(version_col) = extract_temporal_version(&join.relation) {
324        let (left_key, right_key, additional, _, _) = analyze_join_constraint(&join.join_operator)?;
325        let mut analysis = JoinAnalysis::temporal(
326            left_table,
327            right_table,
328            left_key,
329            right_key,
330            version_col,
331            join_type,
332        );
333        analysis.left_alias = left_alias;
334        analysis.right_alias = right_alias;
335        analysis.additional_key_columns = additional;
336        return Ok(Some(analysis));
337    }
338
339    // Analyze the join constraint
340    let (left_key, right_key, additional, time_bound, time_cols) =
341        analyze_join_constraint(&join.join_operator)?;
342
343    let mut analysis = if let Some(tb) = time_bound {
344        JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
345    } else {
346        JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
347    };
348
349    analysis.left_alias.clone_from(&left_alias);
350    analysis.right_alias.clone_from(&right_alias);
351    analysis.additional_key_columns = additional;
352
353    if let Some(ref raw) = time_cols {
354        let (lt, rt) = resolve_time_cols(
355            raw,
356            &analysis.left_table,
357            &analysis.right_table,
358            left_alias.as_deref(),
359            right_alias.as_deref(),
360        );
361        analysis.left_time_column = Some(lt);
362        analysis.right_time_column = Some(rt);
363    }
364
365    Ok(Some(analysis))
366}
367
368/// Extract table name from a TableFactor.
369fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
370    match factor {
371        TableFactor::Table { name, .. } => Ok(name.to_string()),
372        TableFactor::Derived { alias, .. } => {
373            if let Some(alias) = alias {
374                Ok(alias.name.value.clone())
375            } else {
376                Err(ParseError::StreamingError(
377                    "Derived table without alias not supported".to_string(),
378                ))
379            }
380        }
381        _ => Err(ParseError::StreamingError(
382            "Unsupported table factor type".to_string(),
383        )),
384    }
385}
386
387/// Extract the version column from a temporal join's `FOR SYSTEM_TIME AS OF` clause.
388///
389/// Returns `Some(column_name)` if the table factor has a temporal version qualifier,
390/// `None` otherwise.
391fn extract_temporal_version(factor: &TableFactor) -> Option<String> {
392    if let TableFactor::Table {
393        version: Some(TableVersion::ForSystemTimeAsOf(expr)),
394        ..
395    } = factor
396    {
397        Some(extract_column_name_from_expr(expr))
398    } else {
399        None
400    }
401}
402
403/// Extract a column name from an expression (e.g., `o.order_time` → `order_time`).
404///
405/// Falls back to the full expression string for complex expressions.
406fn extract_column_name_from_expr(expr: &Expr) -> String {
407    match expr {
408        Expr::Identifier(ident) => ident.value.clone(),
409        Expr::CompoundIdentifier(parts) => parts
410            .last()
411            .map_or_else(|| expr.to_string(), |p| p.value.clone()),
412        _ => expr.to_string(),
413    }
414}
415
416/// Extract table alias from a TableFactor.
417fn extract_table_alias(factor: &TableFactor) -> Option<String> {
418    match factor {
419        TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
420        TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
421        _ => None,
422    }
423}
424
425/// Map sqlparser `JoinOperator` to our `JoinType`.
426fn map_join_operator(op: &JoinOperator) -> JoinType {
427    match op {
428        JoinOperator::Inner(_) | JoinOperator::Join(_) | JoinOperator::StraightJoin(_) => {
429            JoinType::Inner
430        }
431        JoinOperator::Left(_) | JoinOperator::LeftOuter(_) => JoinType::Left,
432        JoinOperator::LeftSemi(_) | JoinOperator::Semi(_) => JoinType::LeftSemi,
433        JoinOperator::LeftAnti(_) | JoinOperator::Anti(_) => JoinType::LeftAnti,
434        JoinOperator::AsOf { .. } => JoinType::AsOf,
435        JoinOperator::Right(_) | JoinOperator::RightOuter(_) => JoinType::Right,
436        JoinOperator::RightSemi(_) => JoinType::RightSemi,
437        JoinOperator::RightAnti(_) => JoinType::RightAnti,
438        JoinOperator::FullOuter(_) => JoinType::Full,
439        // CrossJoin, CrossApply, OuterApply are rejected by get_join_constraint()
440        _ => JoinType::Inner,
441    }
442}
443
444/// Analyze join constraint to extract key columns, additional key columns,
445/// time bound, and optional time column pair.
446#[allow(clippy::type_complexity)]
447fn analyze_join_constraint(
448    op: &JoinOperator,
449) -> Result<
450    (
451        String,
452        String,
453        Vec<(String, String)>,
454        Option<Duration>,
455        Option<RawTimeCols>,
456    ),
457    ParseError,
458> {
459    let constraint = get_join_constraint(op)?;
460
461    match constraint {
462        JoinConstraint::On(expr) => {
463            let (key_pairs, time_bound, time_cols) = analyze_on_expression(expr)?;
464            if key_pairs.is_empty() {
465                return Ok((String::new(), String::new(), vec![], time_bound, time_cols));
466            }
467            let (first_left, first_right) = key_pairs[0].clone();
468            let additional = key_pairs[1..].to_vec();
469            Ok((first_left, first_right, additional, time_bound, time_cols))
470        }
471        JoinConstraint::Using(cols) => {
472            if cols.is_empty() {
473                return Err(ParseError::StreamingError(
474                    "USING clause requires at least one column".to_string(),
475                ));
476            }
477            // First column is the primary key pair
478            let first_col = cols[0].to_string();
479            // Remaining columns are additional key pairs
480            let additional: Vec<(String, String)> = cols[1..]
481                .iter()
482                .map(|c| {
483                    let col = c.to_string();
484                    (col.clone(), col)
485                })
486                .collect();
487            Ok((first_col.clone(), first_col, additional, None, None))
488        }
489        JoinConstraint::Natural => Err(ParseError::StreamingError(
490            "NATURAL JOIN not supported for streaming".to_string(),
491        )),
492        JoinConstraint::None => Err(ParseError::StreamingError(
493            "JOIN without condition not supported for streaming".to_string(),
494        )),
495    }
496}
497
498/// Get the JoinConstraint from a JoinOperator.
499fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
500    match op {
501        JoinOperator::Inner(constraint)
502        | JoinOperator::Join(constraint)
503        | JoinOperator::Left(constraint)
504        | JoinOperator::LeftOuter(constraint)
505        | JoinOperator::Right(constraint)
506        | JoinOperator::RightOuter(constraint)
507        | JoinOperator::FullOuter(constraint)
508        | JoinOperator::LeftSemi(constraint)
509        | JoinOperator::RightSemi(constraint)
510        | JoinOperator::LeftAnti(constraint)
511        | JoinOperator::RightAnti(constraint)
512        | JoinOperator::Semi(constraint)
513        | JoinOperator::Anti(constraint)
514        | JoinOperator::StraightJoin(constraint)
515        | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
516        JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
517            ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
518        ),
519    }
520}
521
522/// Analyze ON expression to extract all key column pairs, time bound,
523/// and optional time column pair for stream-stream joins.
524#[allow(clippy::type_complexity)]
525fn analyze_on_expression(
526    expr: &Expr,
527) -> Result<(Vec<(String, String)>, Option<Duration>, Option<RawTimeCols>), ParseError> {
528    // Handle compound expressions (AND)
529    match expr {
530        Expr::BinaryOp {
531            left,
532            op: BinaryOperator::And,
533            right,
534        } => {
535            // Recursively analyze both sides
536            let left_result = analyze_on_expression(left);
537            let right_result = analyze_on_expression(right);
538
539            // Combine results - collect all key pairs and time bounds
540            match (left_result, right_result) {
541                (Ok((mut lk, lt, ltc)), Ok((rk, rt, rtc))) => {
542                    lk.extend(rk);
543                    Ok((lk, lt.or(rt), ltc.or(rtc)))
544                }
545                (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
546                (Err(e), Err(_)) => Err(e),
547            }
548        }
549        // Equality condition: a.col = b.col
550        Expr::BinaryOp {
551            left,
552            op: BinaryOperator::Eq,
553            right,
554        } => {
555            let left_col = extract_column_ref(left);
556            let right_col = extract_column_ref(right);
557
558            match (left_col, right_col) {
559                (Some(l), Some(r)) => Ok((vec![(l, r)], None, None)),
560                _ => Err(ParseError::StreamingError(
561                    "Cannot extract column references from equality condition".to_string(),
562                )),
563            }
564        }
565        // BETWEEN clause for time bound: p.ts BETWEEN o.ts AND o.ts + INTERVAL
566        Expr::Between {
567            expr: between_expr,
568            low,
569            high,
570            ..
571        } => {
572            // Try to extract time bound from high expression
573            let time_bound = extract_time_bound_from_expr(high).ok();
574            let between_col = extract_qualified_column_ref(between_expr);
575            let low_col = extract_qualified_column_ref(low);
576            let time_cols = if let (Some((bt, bc)), Some((lt, lc))) = (between_col, low_col) {
577                Some(RawTimeCols {
578                    expr_qualifier: bt,
579                    expr_col: bc,
580                    low_qualifier: lt,
581                    low_col: lc,
582                })
583            } else {
584                if time_bound.is_some() {
585                    tracing::warn!(
586                        "BETWEEN clause has time bound but time column references \
587                         could not be extracted (expressions must be simple column refs)"
588                    );
589                }
590                None
591            };
592            Ok((vec![], time_bound, time_cols))
593        }
594        // Comparison operators for time bounds
595        Expr::BinaryOp {
596            left: _,
597            op:
598                BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
599            right,
600        } => {
601            // Try to extract time bound from right side
602            let time_bound = extract_time_bound_from_expr(right).ok();
603            Ok((vec![], time_bound, None))
604        }
605        _ => Err(ParseError::StreamingError(format!(
606            "Unsupported join condition expression: {expr:?}"
607        ))),
608    }
609}
610
611/// Extract column reference from a function argument.
612fn extract_column_from_func_arg(arg: &FunctionArg) -> Option<String> {
613    let (FunctionArg::Unnamed(FunctionArgExpr::Expr(expr))
614    | FunctionArg::Named {
615        arg: FunctionArgExpr::Expr(expr),
616        ..
617    }
618    | FunctionArg::ExprNamed {
619        arg: FunctionArgExpr::Expr(expr),
620        ..
621    }) = arg
622    else {
623        return None;
624    };
625    extract_column_ref(expr)
626}
627
628/// Extract column reference from expression (e.g., "a.id" -> "id")
629fn extract_column_ref(expr: &Expr) -> Option<String> {
630    match expr {
631        Expr::Identifier(ident) => Some(ident.value.clone()),
632        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
633        _ => None,
634    }
635}
636
637fn extract_qualified_column_ref(expr: &Expr) -> Option<(Option<String>, String)> {
638    match expr {
639        Expr::Identifier(ident) => Some((None, ident.value.clone())),
640        Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
641            Some((Some(parts[0].value.clone()), parts[1].value.clone()))
642        }
643        Expr::CompoundIdentifier(parts) => parts.last().map(|p| (None, p.value.clone())),
644        _ => None,
645    }
646}
647
648/// Extract time bound from an expression like "o.ts + INTERVAL '1' HOUR"
649fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
650    match expr {
651        // Direct interval
652        Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
653        // Addition or subtraction: col +/- INTERVAL
654        Expr::BinaryOp {
655            left: _,
656            op: BinaryOperator::Plus | BinaryOperator::Minus,
657            right,
658        } => extract_time_bound_from_expr(right),
659        // Nested expression
660        Expr::Nested(inner) => extract_time_bound_from_expr(inner),
661        _ => Err(ParseError::StreamingError(format!(
662            "Cannot extract time bound from: {expr:?}"
663        ))),
664    }
665}
666
667/// Analyze ASOF JOIN MATCH_CONDITION expression.
668///
669/// Extracts direction, time column names, and optional tolerance.
670fn analyze_asof_match_condition(
671    expr: &Expr,
672) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
673    if let Expr::BinaryOp {
674        left,
675        op: BinaryOperator::And,
676        right,
677    } = expr
678    {
679        // Try to get direction from left, tolerance from right
680        let dir_result = analyze_asof_direction(left);
681        let tol_result = extract_asof_tolerance(right);
682
683        match (dir_result, tol_result) {
684            (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
685            (Ok((dir, lt, rt)), Err(_)) => {
686                // Maybe tolerance is on left and direction on right
687                let dir2 = analyze_asof_direction(right);
688                let tol2 = extract_asof_tolerance(left);
689                match (dir2, tol2) {
690                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
691                    _ => Ok((dir, lt, rt, None)),
692                }
693            }
694            (Err(_), _) => {
695                // Try reversed
696                let dir2 = analyze_asof_direction(right);
697                let tol2 = extract_asof_tolerance(left);
698                match (dir2, tol2) {
699                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
700                    (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
701                    _ => Err(ParseError::StreamingError(
702                        "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
703                    )),
704                }
705            }
706        }
707    } else {
708        let (dir, lt, rt) = analyze_asof_direction(expr)?;
709        Ok((dir, lt, rt, None))
710    }
711}
712
713/// Extract ASOF direction and time columns from a comparison expression.
714fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
715    match expr {
716        Expr::BinaryOp {
717            left,
718            op: BinaryOperator::GtEq,
719            right,
720        } => {
721            let left_col = extract_column_ref(left).ok_or_else(|| {
722                ParseError::StreamingError(
723                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
724                )
725            })?;
726            let right_col = extract_column_ref(right).ok_or_else(|| {
727                ParseError::StreamingError(
728                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
729                )
730            })?;
731            Ok((AsofSqlDirection::Backward, left_col, right_col))
732        }
733        Expr::BinaryOp {
734            left,
735            op: BinaryOperator::LtEq,
736            right,
737        } => {
738            let left_col = extract_column_ref(left).ok_or_else(|| {
739                ParseError::StreamingError(
740                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
741                )
742            })?;
743            let right_col = extract_column_ref(right).ok_or_else(|| {
744                ParseError::StreamingError(
745                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
746                )
747            })?;
748            Ok((AsofSqlDirection::Forward, left_col, right_col))
749        }
750        // NEAREST(left_col, right_col) — function-style syntax
751        Expr::Function(func) => {
752            let name = func.name.to_string().to_uppercase();
753            if name != "NEAREST" {
754                return Err(ParseError::StreamingError(format!(
755                    "Unknown ASOF MATCH_CONDITION function: {name}"
756                )));
757            }
758            let args = match &func.args {
759                FunctionArguments::List(arg_list) => &arg_list.args,
760                _ => {
761                    return Err(ParseError::StreamingError(
762                        "NEAREST() requires exactly 2 column arguments".to_string(),
763                    ))
764                }
765            };
766            if args.len() != 2 {
767                return Err(ParseError::StreamingError(format!(
768                    "NEAREST() requires exactly 2 arguments, got {}",
769                    args.len()
770                )));
771            }
772            let left_col = extract_column_from_func_arg(&args[0]).ok_or_else(|| {
773                ParseError::StreamingError(
774                    "Cannot extract left time column from NEAREST()".to_string(),
775                )
776            })?;
777            let right_col = extract_column_from_func_arg(&args[1]).ok_or_else(|| {
778                ParseError::StreamingError(
779                    "Cannot extract right time column from NEAREST()".to_string(),
780                )
781            })?;
782            Ok((AsofSqlDirection::Nearest, left_col, right_col))
783        }
784        _ => Err(ParseError::StreamingError(
785            "ASOF MATCH_CONDITION must be >= or <= comparison, or NEAREST()".to_string(),
786        )),
787    }
788}
789
790/// Extract tolerance duration from an ASOF tolerance expression.
791///
792/// Handles: `left - right <= value` or `left - right <= INTERVAL '...'`
793fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
794    match expr {
795        Expr::BinaryOp {
796            left: _,
797            op: BinaryOperator::LtEq,
798            right,
799        } => {
800            // right side is either a literal number or INTERVAL
801            match right.as_ref() {
802                Expr::Value(v) => {
803                    if let sqlparser::ast::Value::Number(n, _) = &v.value {
804                        let ms: u64 = n.parse().map_err(|_| {
805                            ParseError::StreamingError(format!(
806                                "Cannot parse tolerance as number: {n}"
807                            ))
808                        })?;
809                        Ok(Duration::from_millis(ms))
810                    } else {
811                        Err(ParseError::StreamingError(
812                            "ASOF tolerance must be a number or INTERVAL".to_string(),
813                        ))
814                    }
815                }
816                Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
817                _ => Err(ParseError::StreamingError(
818                    "ASOF tolerance must be a number or INTERVAL".to_string(),
819                )),
820            }
821        }
822        _ => Err(ParseError::StreamingError(
823            "ASOF tolerance expression must be <= comparison".to_string(),
824        )),
825    }
826}
827
828/// Extract key columns from an ASOF JOIN constraint (ON clause).
829fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
830    match constraint {
831        JoinConstraint::On(expr) => extract_equality_columns(expr),
832        JoinConstraint::Using(cols) => {
833            if cols.is_empty() {
834                return Err(ParseError::StreamingError(
835                    "USING clause requires at least one column".to_string(),
836                ));
837            }
838            let col = cols[0].to_string();
839            Ok((col.clone(), col))
840        }
841        _ => Err(ParseError::StreamingError(
842            "ASOF JOIN requires ON or USING constraint".to_string(),
843        )),
844    }
845}
846
847/// Extract left and right column names from an equality expression.
848fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
849    match expr {
850        Expr::BinaryOp {
851            left,
852            op: BinaryOperator::Eq,
853            right,
854        } => {
855            let left_col = extract_column_ref(left).ok_or_else(|| {
856                ParseError::StreamingError("Cannot extract left key column".to_string())
857            })?;
858            let right_col = extract_column_ref(right).ok_or_else(|| {
859                ParseError::StreamingError("Cannot extract right key column".to_string())
860            })?;
861            Ok((left_col, right_col))
862        }
863        // If there's an AND, find the equality part
864        Expr::BinaryOp {
865            left,
866            op: BinaryOperator::And,
867            right,
868        } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
869        _ => Err(ParseError::StreamingError(
870            "ASOF JOIN ON clause must contain an equality condition".to_string(),
871        )),
872    }
873}
874
875/// Check if a SELECT contains a join.
876#[must_use]
877pub fn has_join(select: &Select) -> bool {
878    !select.from.is_empty() && !select.from[0].joins.is_empty()
879}
880
881/// Count the number of joins in a SELECT.
882#[must_use]
883pub fn count_joins(select: &Select) -> usize {
884    select
885        .from
886        .iter()
887        .map(|table_with_joins| table_with_joins.joins.len())
888        .sum()
889}
890
891/// Analysis result for multi-way JOINs (e.g., `A JOIN B ... JOIN C ...`).
892///
893/// Each step represents one left-deep join: step 0 joins the base table with
894/// the first right table, step 1 joins the result with the next right table, etc.
895#[derive(Debug, Clone)]
896pub struct MultiJoinAnalysis {
897    /// Ordered join steps (left-to-right)
898    pub joins: Vec<JoinAnalysis>,
899    /// All referenced tables in order (base table first, then each right table)
900    pub tables: Vec<String>,
901}
902
903impl MultiJoinAnalysis {
904    /// Number of join steps.
905    #[must_use]
906    pub fn len(&self) -> usize {
907        self.joins.len()
908    }
909
910    /// Whether there are no join steps.
911    #[must_use]
912    pub fn is_empty(&self) -> bool {
913        self.joins.is_empty()
914    }
915
916    /// Whether this is a single join (backward-compatible case).
917    #[must_use]
918    pub fn is_single(&self) -> bool {
919        self.joins.len() == 1
920    }
921
922    /// The first join step (convenience for single-join queries).
923    #[must_use]
924    pub fn first(&self) -> Option<&JoinAnalysis> {
925        self.joins.first()
926    }
927}
928
929/// Analyze a SELECT statement for all join steps (multi-way).
930///
931/// Returns `None` if the query has no joins. For a single join this
932/// returns a `MultiJoinAnalysis` with one step, making it backward
933/// compatible with `analyze_join()`.
934///
935/// # Errors
936///
937/// Returns `ParseError::StreamingError` if any join constraint is
938/// not supported or key columns cannot be extracted.
939pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
940    let from = &select.from;
941    if from.is_empty() {
942        return Ok(None);
943    }
944
945    let first_table = &from[0];
946    if first_table.joins.is_empty() {
947        return Ok(None);
948    }
949
950    // Extract base table
951    let base_table = extract_table_name(&first_table.relation)?;
952    let base_alias = extract_table_alias(&first_table.relation);
953
954    let mut join_steps = Vec::with_capacity(first_table.joins.len());
955    let mut tables = vec![base_table.clone()];
956
957    // Track the left table name for left-deep chaining
958    let mut prev_left_table = base_table;
959    let mut prev_left_alias = base_alias;
960
961    for join in &first_table.joins {
962        let right_table = extract_table_name(&join.relation)?;
963        let right_alias = extract_table_alias(&join.relation);
964        tables.push(right_table.clone());
965
966        let join_type = map_join_operator(&join.join_operator);
967
968        // Handle ASOF JOIN
969        if let JoinOperator::AsOf {
970            match_condition,
971            constraint,
972        } = &join.join_operator
973        {
974            let (direction, left_time, right_time, tolerance) =
975                analyze_asof_match_condition(match_condition)?;
976            let (left_key, right_key) = analyze_asof_constraint(constraint)?;
977
978            let mut analysis = JoinAnalysis::asof(
979                prev_left_table.clone(),
980                right_table.clone(),
981                left_key,
982                right_key,
983                direction,
984                left_time,
985                right_time,
986                tolerance,
987            );
988            analysis.left_alias.clone_from(&prev_left_alias);
989            analysis.right_alias = right_alias;
990            join_steps.push(analysis);
991        } else if let Some(version_col) = extract_temporal_version(&join.relation) {
992            // Temporal join: right side has FOR SYSTEM_TIME AS OF
993            let (left_key, right_key, additional, _, _) =
994                analyze_join_constraint(&join.join_operator)?;
995
996            let mut analysis = JoinAnalysis::temporal(
997                prev_left_table.clone(),
998                right_table.clone(),
999                left_key,
1000                right_key,
1001                version_col,
1002                join_type,
1003            );
1004            analysis.left_alias.clone_from(&prev_left_alias);
1005            analysis.right_alias = right_alias;
1006            analysis.additional_key_columns = additional;
1007            join_steps.push(analysis);
1008        } else {
1009            // Regular join (inner, left, right, full)
1010            let (left_key, right_key, additional, time_bound, time_cols) =
1011                analyze_join_constraint(&join.join_operator)?;
1012
1013            let mut analysis = if let Some(tb) = time_bound {
1014                JoinAnalysis::stream_stream(
1015                    prev_left_table.clone(),
1016                    right_table.clone(),
1017                    left_key,
1018                    right_key,
1019                    tb,
1020                    join_type,
1021                )
1022            } else {
1023                JoinAnalysis::lookup(
1024                    prev_left_table.clone(),
1025                    right_table.clone(),
1026                    left_key,
1027                    right_key,
1028                    join_type,
1029                )
1030            };
1031            analysis.left_alias.clone_from(&prev_left_alias);
1032            analysis.right_alias.clone_from(&right_alias);
1033            analysis.additional_key_columns = additional;
1034
1035            if let Some(ref raw) = time_cols {
1036                let (lt, rt) = resolve_time_cols(
1037                    raw,
1038                    &analysis.left_table,
1039                    &analysis.right_table,
1040                    prev_left_alias.as_deref(),
1041                    right_alias.as_deref(),
1042                );
1043                analysis.left_time_column = Some(lt);
1044                analysis.right_time_column = Some(rt);
1045            }
1046            join_steps.push(analysis);
1047        }
1048
1049        // Next step's left table is this step's right table (left-deep)
1050        prev_left_table = right_table;
1051        prev_left_alias = extract_table_alias(&join.relation);
1052    }
1053
1054    Ok(Some(MultiJoinAnalysis {
1055        joins: join_steps,
1056        tables,
1057    }))
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062    use super::*;
1063    use sqlparser::ast::{SetExpr, Statement};
1064    use sqlparser::dialect::GenericDialect;
1065    use sqlparser::parser::Parser;
1066
1067    fn parse_select(sql: &str) -> Select {
1068        let dialect = GenericDialect {};
1069        let statements = Parser::parse_sql(&dialect, sql).unwrap();
1070        if let Statement::Query(query) = &statements[0] {
1071            if let SetExpr::Select(select) = query.body.as_ref() {
1072                return *select.clone();
1073            }
1074        }
1075        panic!("Expected SELECT query");
1076    }
1077
1078    #[test]
1079    fn test_analyze_inner_join() {
1080        let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
1081        let select = parse_select(sql);
1082
1083        let analysis = analyze_join(&select).unwrap().unwrap();
1084
1085        assert_eq!(analysis.join_type, JoinType::Inner);
1086        assert_eq!(analysis.left_table, "orders");
1087        assert_eq!(analysis.right_table, "payments");
1088        assert_eq!(analysis.left_key_column, "order_id");
1089        assert_eq!(analysis.right_key_column, "order_id");
1090        assert!(analysis.is_lookup_join); // No time bound = lookup join
1091    }
1092
1093    #[test]
1094    fn test_analyze_left_join() {
1095        let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
1096        let select = parse_select(sql);
1097
1098        let analysis = analyze_join(&select).unwrap().unwrap();
1099
1100        assert_eq!(analysis.join_type, JoinType::Left);
1101        assert_eq!(analysis.left_key_column, "customer_id");
1102        assert_eq!(analysis.right_key_column, "id");
1103    }
1104
1105    #[test]
1106    fn test_analyze_join_using() {
1107        let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
1108        let select = parse_select(sql);
1109
1110        let analysis = analyze_join(&select).unwrap().unwrap();
1111
1112        assert_eq!(analysis.left_key_column, "order_id");
1113        assert_eq!(analysis.right_key_column, "order_id");
1114    }
1115
1116    #[test]
1117    fn test_analyze_stream_stream_join_with_time_bound() {
1118        let sql = "SELECT * FROM orders o
1119                   JOIN payments p ON o.order_id = p.order_id
1120                   AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
1121        let select = parse_select(sql);
1122
1123        let analysis = analyze_join(&select).unwrap().unwrap();
1124
1125        assert!(!analysis.is_lookup_join);
1126        assert!(analysis.time_bound.is_some());
1127        assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
1128    }
1129
1130    #[test]
1131    fn test_no_join() {
1132        let sql = "SELECT * FROM orders";
1133        let select = parse_select(sql);
1134
1135        let analysis = analyze_join(&select).unwrap();
1136        assert!(analysis.is_none());
1137    }
1138
1139    #[test]
1140    fn test_has_join() {
1141        let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1142        let sql_without_join = "SELECT * FROM orders";
1143
1144        let select_with = parse_select(sql_with_join);
1145        let select_without = parse_select(sql_without_join);
1146
1147        assert!(has_join(&select_with));
1148        assert!(!has_join(&select_without));
1149    }
1150
1151    #[test]
1152    fn test_count_joins() {
1153        let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
1154        let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
1155        let sql_zero = "SELECT * FROM a";
1156
1157        assert_eq!(count_joins(&parse_select(sql_one)), 1);
1158        assert_eq!(count_joins(&parse_select(sql_two)), 2);
1159        assert_eq!(count_joins(&parse_select(sql_zero)), 0);
1160    }
1161
1162    #[test]
1163    fn test_aliases() {
1164        let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
1165        let select = parse_select(sql);
1166
1167        let analysis = analyze_join(&select).unwrap().unwrap();
1168
1169        assert_eq!(analysis.left_alias, Some("o".to_string()));
1170        assert_eq!(analysis.right_alias, Some("p".to_string()));
1171    }
1172
1173    // -- ASOF JOIN tests --
1174
1175    fn parse_select_snowflake(sql: &str) -> Select {
1176        let dialect = sqlparser::dialect::SnowflakeDialect {};
1177        let statements = Parser::parse_sql(&dialect, sql).unwrap();
1178        if let Statement::Query(query) = &statements[0] {
1179            if let SetExpr::Select(select) = query.body.as_ref() {
1180                return *select.clone();
1181            }
1182        }
1183        panic!("Expected SELECT query");
1184    }
1185
1186    fn parse_select_laminar(sql: &str) -> Select {
1187        let dialect = crate::parser::dialect::LaminarDialect::default();
1188        let statements = Parser::parse_sql(&dialect, sql).unwrap();
1189        if let Statement::Query(query) = &statements[0] {
1190            if let SetExpr::Select(select) = query.body.as_ref() {
1191                return *select.clone();
1192            }
1193        }
1194        panic!("Expected SELECT query");
1195    }
1196
1197    #[test]
1198    fn test_asof_join_backward() {
1199        let sql = "SELECT * FROM trades t \
1200                    ASOF JOIN quotes q \
1201                    MATCH_CONDITION(t.ts >= q.ts) \
1202                    ON t.symbol = q.symbol";
1203        let select = parse_select_snowflake(sql);
1204        let analysis = analyze_join(&select).unwrap().unwrap();
1205
1206        assert!(analysis.is_asof_join);
1207        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1208        assert_eq!(analysis.join_type, JoinType::AsOf);
1209        assert!(analysis.asof_tolerance.is_none());
1210    }
1211
1212    #[test]
1213    fn test_asof_join_forward() {
1214        let sql = "SELECT * FROM trades t \
1215                    ASOF JOIN quotes q \
1216                    MATCH_CONDITION(t.ts <= q.ts) \
1217                    ON t.symbol = q.symbol";
1218        let select = parse_select_snowflake(sql);
1219        let analysis = analyze_join(&select).unwrap().unwrap();
1220
1221        assert!(analysis.is_asof_join);
1222        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
1223    }
1224
1225    #[test]
1226    fn test_asof_join_nearest() {
1227        let sql = "SELECT * FROM trades t \
1228                    ASOF JOIN quotes q \
1229                    MATCH_CONDITION(NEAREST(t.ts, q.ts)) \
1230                    ON t.symbol = q.symbol";
1231        let select = parse_select_snowflake(sql);
1232        let analysis = analyze_join(&select).unwrap().unwrap();
1233
1234        assert!(analysis.is_asof_join);
1235        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Nearest));
1236        assert_eq!(analysis.join_type, JoinType::AsOf);
1237        assert!(analysis.asof_tolerance.is_none());
1238    }
1239
1240    #[test]
1241    fn test_asof_join_with_tolerance() {
1242        let sql = "SELECT * FROM trades t \
1243                    ASOF JOIN quotes q \
1244                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
1245                    ON t.symbol = q.symbol";
1246        let select = parse_select_snowflake(sql);
1247        let analysis = analyze_join(&select).unwrap().unwrap();
1248
1249        assert!(analysis.is_asof_join);
1250        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1251        assert_eq!(analysis.asof_tolerance, Some(Duration::from_millis(5000)));
1252    }
1253
1254    #[test]
1255    fn test_asof_join_with_interval_tolerance() {
1256        let sql = "SELECT * FROM trades t \
1257                    ASOF JOIN quotes q \
1258                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
1259                    ON t.symbol = q.symbol";
1260        let select = parse_select_snowflake(sql);
1261        let analysis = analyze_join(&select).unwrap().unwrap();
1262
1263        assert!(analysis.is_asof_join);
1264        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
1265        assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
1266    }
1267
1268    #[test]
1269    fn test_asof_join_type_mapping() {
1270        let sql = "SELECT * FROM trades t \
1271                    ASOF JOIN quotes q \
1272                    MATCH_CONDITION(t.ts >= q.ts) \
1273                    ON t.symbol = q.symbol";
1274        let select = parse_select_snowflake(sql);
1275        let analysis = analyze_join(&select).unwrap().unwrap();
1276
1277        assert_eq!(analysis.join_type, JoinType::AsOf);
1278        assert!(!analysis.is_lookup_join);
1279    }
1280
1281    #[test]
1282    fn test_asof_join_extracts_time_columns() {
1283        let sql = "SELECT * FROM trades t \
1284                    ASOF JOIN quotes q \
1285                    MATCH_CONDITION(t.ts >= q.ts) \
1286                    ON t.symbol = q.symbol";
1287        let select = parse_select_snowflake(sql);
1288        let analysis = analyze_join(&select).unwrap().unwrap();
1289
1290        assert_eq!(analysis.left_time_column, Some("ts".to_string()));
1291        assert_eq!(analysis.right_time_column, Some("ts".to_string()));
1292    }
1293
1294    #[test]
1295    fn test_asof_join_extracts_key_columns() {
1296        let sql = "SELECT * FROM trades t \
1297                    ASOF JOIN quotes q \
1298                    MATCH_CONDITION(t.ts >= q.ts) \
1299                    ON t.symbol = q.symbol";
1300        let select = parse_select_snowflake(sql);
1301        let analysis = analyze_join(&select).unwrap().unwrap();
1302
1303        assert_eq!(analysis.left_key_column, "symbol");
1304        assert_eq!(analysis.right_key_column, "symbol");
1305    }
1306
1307    #[test]
1308    fn test_asof_join_aliases() {
1309        let sql = "SELECT * FROM trades AS t \
1310                    ASOF JOIN quotes AS q \
1311                    MATCH_CONDITION(t.ts >= q.ts) \
1312                    ON t.symbol = q.symbol";
1313        let select = parse_select_snowflake(sql);
1314        let analysis = analyze_join(&select).unwrap().unwrap();
1315
1316        assert_eq!(analysis.left_alias, Some("t".to_string()));
1317        assert_eq!(analysis.right_alias, Some("q".to_string()));
1318        assert_eq!(analysis.left_table, "trades");
1319        assert_eq!(analysis.right_table, "quotes");
1320    }
1321
1322    // -- Multi-way JOIN tests --
1323
1324    #[test]
1325    fn test_multi_join_single_backward_compat() {
1326        let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1327        let select = parse_select(sql);
1328        let multi = analyze_joins(&select).unwrap().unwrap();
1329
1330        assert!(multi.is_single());
1331        assert_eq!(multi.len(), 1);
1332        assert!(!multi.is_empty());
1333        let first = multi.first().unwrap();
1334        assert_eq!(first.left_table, "orders");
1335        assert_eq!(first.right_table, "payments");
1336    }
1337
1338    #[test]
1339    fn test_multi_join_two_way() {
1340        let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1341        let select = parse_select(sql);
1342        let multi = analyze_joins(&select).unwrap().unwrap();
1343
1344        assert_eq!(multi.len(), 2);
1345        assert!(!multi.is_single());
1346
1347        assert_eq!(multi.joins[0].left_table, "a");
1348        assert_eq!(multi.joins[0].right_table, "b");
1349        assert_eq!(multi.joins[0].left_key_column, "id");
1350        assert_eq!(multi.joins[0].right_key_column, "a_id");
1351
1352        assert_eq!(multi.joins[1].left_table, "b");
1353        assert_eq!(multi.joins[1].right_table, "c");
1354        assert_eq!(multi.joins[1].left_key_column, "id");
1355        assert_eq!(multi.joins[1].right_key_column, "b_id");
1356    }
1357
1358    #[test]
1359    fn test_multi_join_three_way() {
1360        let sql = "SELECT * FROM a \
1361                    JOIN b ON a.id = b.a_id \
1362                    JOIN c ON b.id = c.b_id \
1363                    JOIN d ON c.id = d.c_id";
1364        let select = parse_select(sql);
1365        let multi = analyze_joins(&select).unwrap().unwrap();
1366
1367        assert_eq!(multi.len(), 3);
1368        assert_eq!(multi.tables.len(), 4);
1369        assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1370    }
1371
1372    #[test]
1373    fn test_multi_join_mixed_asof_and_lookup() {
1374        // ASOF first, then lookup (use Snowflake dialect for ASOF)
1375        let sql = "SELECT * FROM trades t \
1376                    ASOF JOIN quotes q \
1377                    MATCH_CONDITION(t.ts >= q.ts) \
1378                    ON t.symbol = q.symbol \
1379                    JOIN products p ON q.product_id = p.id";
1380        let select = parse_select_snowflake(sql);
1381        let multi = analyze_joins(&select).unwrap().unwrap();
1382
1383        assert_eq!(multi.len(), 2);
1384        assert!(multi.joins[0].is_asof_join);
1385        assert!(multi.joins[1].is_lookup_join);
1386    }
1387
1388    #[test]
1389    fn test_multi_join_stream_stream_and_lookup() {
1390        let sql = "SELECT * FROM orders o \
1391                    JOIN payments p ON o.id = p.order_id \
1392                        AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1393                    JOIN customers c ON o.customer_id = c.id";
1394        let select = parse_select(sql);
1395        let multi = analyze_joins(&select).unwrap().unwrap();
1396
1397        assert_eq!(multi.len(), 2);
1398        assert!(!multi.joins[0].is_lookup_join); // stream-stream
1399        assert!(multi.joins[0].time_bound.is_some());
1400        assert!(multi.joins[1].is_lookup_join); // lookup
1401    }
1402
1403    #[test]
1404    fn test_multi_join_tables_list() {
1405        let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1406        let select = parse_select(sql);
1407        let multi = analyze_joins(&select).unwrap().unwrap();
1408
1409        assert_eq!(multi.tables, vec!["a", "b", "c"]);
1410    }
1411
1412    #[test]
1413    fn test_multi_join_aliases() {
1414        let sql = "SELECT * FROM orders AS o \
1415                    JOIN payments AS p ON o.id = p.order_id \
1416                    JOIN refunds AS r ON p.id = r.payment_id";
1417        let select = parse_select(sql);
1418        let multi = analyze_joins(&select).unwrap().unwrap();
1419
1420        assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1421        assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1422        assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1423        assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1424    }
1425
1426    #[test]
1427    fn test_multi_join_no_join_returns_none() {
1428        let sql = "SELECT * FROM orders";
1429        let select = parse_select(sql);
1430        let multi = analyze_joins(&select).unwrap();
1431        assert!(multi.is_none());
1432    }
1433
1434    // -- Temporal JOIN tests (FOR SYSTEM_TIME AS OF) --
1435
1436    #[test]
1437    fn test_temporal_join_detected() {
1438        let sql = "SELECT o.*, p.price \
1439                    FROM orders o \
1440                    JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1441                    ON o.product_id = p.id";
1442        let select = parse_select_laminar(sql);
1443        let analysis = analyze_join(&select).unwrap().unwrap();
1444
1445        assert!(analysis.is_temporal_join);
1446        assert_eq!(
1447            analysis.temporal_version_column,
1448            Some("order_time".to_string())
1449        );
1450        assert_eq!(analysis.left_table, "orders");
1451        assert_eq!(analysis.right_table, "products");
1452        assert_eq!(analysis.left_key_column, "product_id");
1453        assert_eq!(analysis.right_key_column, "id");
1454        assert!(!analysis.is_lookup_join);
1455        assert!(!analysis.is_asof_join);
1456    }
1457
1458    #[test]
1459    fn test_temporal_join_via_analyze_joins() {
1460        let sql = "SELECT o.*, p.price \
1461                    FROM orders o \
1462                    JOIN products FOR SYSTEM_TIME AS OF o.order_time AS p \
1463                    ON o.product_id = p.id";
1464        let select = parse_select_laminar(sql);
1465        let multi = analyze_joins(&select).unwrap().unwrap();
1466
1467        assert_eq!(multi.len(), 1);
1468        let first = multi.first().unwrap();
1469        assert!(first.is_temporal_join);
1470        assert_eq!(
1471            first.temporal_version_column,
1472            Some("order_time".to_string())
1473        );
1474    }
1475
1476    #[test]
1477    fn test_non_temporal_join_not_flagged() {
1478        let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1479        let select = parse_select(sql);
1480        let analysis = analyze_join(&select).unwrap().unwrap();
1481
1482        assert!(!analysis.is_temporal_join);
1483        assert!(analysis.temporal_version_column.is_none());
1484    }
1485
1486    #[test]
1487    fn test_unqualified_anti_maps_to_left_anti() {
1488        let sql = "SELECT * FROM orders o ANTI JOIN returns r ON o.id = r.order_id";
1489        let select = parse_select(sql);
1490        let analysis = analyze_join(&select).unwrap().unwrap();
1491        assert_eq!(analysis.join_type, JoinType::LeftAnti);
1492    }
1493
1494    #[test]
1495    fn test_unqualified_semi_maps_to_left_semi() {
1496        let sql = "SELECT * FROM orders o SEMI JOIN payments p ON o.id = p.order_id";
1497        let select = parse_select(sql);
1498        let analysis = analyze_join(&select).unwrap().unwrap();
1499        assert_eq!(analysis.join_type, JoinType::LeftSemi);
1500    }
1501
1502    #[test]
1503    fn test_composite_join_keys() {
1504        let sql = "SELECT * FROM orders o \
1505                    JOIN shipments s \
1506                    ON o.order_id = s.order_id AND o.region = s.region";
1507        let select = parse_select(sql);
1508        let analysis = analyze_join(&select).unwrap().unwrap();
1509
1510        // First key pair is the primary key
1511        assert_eq!(analysis.left_key_column, "order_id");
1512        assert_eq!(analysis.right_key_column, "order_id");
1513
1514        // Second key pair should be in additional_key_columns
1515        assert_eq!(
1516            analysis.additional_key_columns.len(),
1517            1,
1518            "Should have 1 additional key pair"
1519        );
1520        assert_eq!(analysis.additional_key_columns[0].0, "region");
1521        assert_eq!(analysis.additional_key_columns[0].1, "region");
1522    }
1523
1524    #[test]
1525    fn test_composite_using_clause() {
1526        let sql = "SELECT * FROM orders o JOIN shipments s USING (order_id, region)";
1527        let select = parse_select(sql);
1528        let analysis = analyze_join(&select).unwrap().unwrap();
1529
1530        // First column becomes primary key
1531        assert_eq!(analysis.left_key_column, "order_id");
1532        assert_eq!(analysis.right_key_column, "order_id");
1533
1534        // Additional columns
1535        assert_eq!(
1536            analysis.additional_key_columns.len(),
1537            1,
1538            "USING(order_id, region) should have 1 additional key"
1539        );
1540        assert_eq!(analysis.additional_key_columns[0].0, "region");
1541        assert_eq!(analysis.additional_key_columns[0].1, "region");
1542    }
1543}