Skip to main content

laminar_sql/parser/
join_parser.rs

1//! Join query analysis and extraction
2//!
3//! This module analyzes JOIN clauses to extract:
4//! - Join type (INNER, LEFT, RIGHT, FULL)
5//! - Key columns for join condition
6//! - Time bounds for stream-stream joins
7//! - Detection of lookup joins vs stream-stream joins
8
9use std::time::Duration;
10
11use sqlparser::ast::{BinaryOperator, Expr, JoinConstraint, JoinOperator, Select, TableFactor};
12
13use super::window_rewriter::WindowRewriter;
14use super::ParseError;
15
16/// Join type classification
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum JoinType {
19    /// INNER JOIN
20    Inner,
21    /// LEFT \[OUTER\] JOIN
22    Left,
23    /// RIGHT \[OUTER\] JOIN
24    Right,
25    /// FULL \[OUTER\] JOIN
26    Full,
27    /// ASOF JOIN
28    AsOf,
29}
30
31/// Direction for ASOF JOIN time matching.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum AsofSqlDirection {
34    /// `left.ts >= right.ts` — find most recent right row
35    Backward,
36    /// `left.ts <= right.ts` — find next right row
37    Forward,
38}
39
40impl std::fmt::Display for AsofSqlDirection {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            AsofSqlDirection::Backward => write!(f, "BACKWARD"),
44            AsofSqlDirection::Forward => write!(f, "FORWARD"),
45        }
46    }
47}
48
49/// Analysis result for a JOIN clause
50#[derive(Debug, Clone)]
51pub struct JoinAnalysis {
52    /// Type of join (inner, left, right, full)
53    pub join_type: JoinType,
54    /// Left side table name
55    pub left_table: String,
56    /// Right side table name
57    pub right_table: String,
58    /// Left side key column
59    pub left_key_column: String,
60    /// Right side key column
61    pub right_key_column: String,
62    /// Time bound for stream-stream joins (None for lookup joins)
63    pub time_bound: Option<Duration>,
64    /// Whether this is a lookup join (no time bound)
65    pub is_lookup_join: bool,
66    /// Left side alias (if any)
67    pub left_alias: Option<String>,
68    /// Right side alias (if any)
69    pub right_alias: Option<String>,
70    /// Whether this is an ASOF join
71    pub is_asof_join: bool,
72    /// ASOF join direction (Backward or Forward)
73    pub asof_direction: Option<AsofSqlDirection>,
74    /// Left side time column for ASOF join
75    pub left_time_column: Option<String>,
76    /// Right side time column for ASOF join
77    pub right_time_column: Option<String>,
78    /// ASOF join tolerance (max time difference)
79    pub asof_tolerance: Option<Duration>,
80}
81
82impl JoinAnalysis {
83    /// Create a stream-stream join analysis
84    #[must_use]
85    pub fn stream_stream(
86        left_table: String,
87        right_table: String,
88        left_key: String,
89        right_key: String,
90        time_bound: Duration,
91        join_type: JoinType,
92    ) -> Self {
93        Self {
94            join_type,
95            left_table,
96            right_table,
97            left_key_column: left_key,
98            right_key_column: right_key,
99            time_bound: Some(time_bound),
100            is_lookup_join: false,
101            left_alias: None,
102            right_alias: None,
103            is_asof_join: false,
104            asof_direction: None,
105            left_time_column: None,
106            right_time_column: None,
107            asof_tolerance: None,
108        }
109    }
110
111    /// Create a lookup join analysis
112    #[must_use]
113    pub fn lookup(
114        left_table: String,
115        right_table: String,
116        left_key: String,
117        right_key: String,
118        join_type: JoinType,
119    ) -> Self {
120        Self {
121            join_type,
122            left_table,
123            right_table,
124            left_key_column: left_key,
125            right_key_column: right_key,
126            time_bound: None,
127            is_lookup_join: true,
128            left_alias: None,
129            right_alias: None,
130            is_asof_join: false,
131            asof_direction: None,
132            left_time_column: None,
133            right_time_column: None,
134            asof_tolerance: None,
135        }
136    }
137
138    /// Create an ASOF join analysis
139    #[must_use]
140    #[allow(clippy::too_many_arguments)]
141    pub fn asof(
142        left_table: String,
143        right_table: String,
144        left_key: String,
145        right_key: String,
146        direction: AsofSqlDirection,
147        left_time_col: String,
148        right_time_col: String,
149        tolerance: Option<Duration>,
150    ) -> Self {
151        Self {
152            join_type: JoinType::AsOf,
153            left_table,
154            right_table,
155            left_key_column: left_key,
156            right_key_column: right_key,
157            time_bound: None,
158            is_lookup_join: false,
159            left_alias: None,
160            right_alias: None,
161            is_asof_join: true,
162            asof_direction: Some(direction),
163            left_time_column: Some(left_time_col),
164            right_time_column: Some(right_time_col),
165            asof_tolerance: tolerance,
166        }
167    }
168}
169
170/// Analyze a SELECT statement for join information.
171///
172/// # Errors
173///
174/// Returns `ParseError::StreamingError` if:
175/// - Join constraint is not supported
176/// - Cannot extract key columns
177pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
178    let from = &select.from;
179    if from.is_empty() {
180        return Ok(None);
181    }
182
183    let first_table = &from[0];
184    if first_table.joins.is_empty() {
185        return Ok(None);
186    }
187
188    // Extract left table information
189    let left_table = extract_table_name(&first_table.relation)?;
190    let left_alias = extract_table_alias(&first_table.relation);
191
192    // Analyze the first join
193    let join = &first_table.joins[0];
194    let right_table = extract_table_name(&join.relation)?;
195    let right_alias = extract_table_alias(&join.relation);
196
197    let join_type = map_join_operator(&join.join_operator);
198
199    // Handle ASOF JOIN specially
200    if let JoinOperator::AsOf {
201        match_condition,
202        constraint,
203    } = &join.join_operator
204    {
205        let (direction, left_time, right_time, tolerance) =
206            analyze_asof_match_condition(match_condition)?;
207
208        // Extract key columns from the ON constraint
209        let (left_key, right_key) = analyze_asof_constraint(constraint)?;
210
211        let mut analysis = JoinAnalysis::asof(
212            left_table,
213            right_table,
214            left_key,
215            right_key,
216            direction,
217            left_time,
218            right_time,
219            tolerance,
220        );
221        analysis.left_alias = left_alias;
222        analysis.right_alias = right_alias;
223        return Ok(Some(analysis));
224    }
225
226    // Analyze the join constraint
227    let (left_key, right_key, time_bound) = analyze_join_constraint(&join.join_operator)?;
228
229    let mut analysis = if let Some(tb) = time_bound {
230        JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
231    } else {
232        JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
233    };
234
235    analysis.left_alias = left_alias;
236    analysis.right_alias = right_alias;
237
238    Ok(Some(analysis))
239}
240
241/// Extract table name from a TableFactor.
242fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
243    match factor {
244        TableFactor::Table { name, .. } => Ok(name.to_string()),
245        TableFactor::Derived { alias, .. } => {
246            if let Some(alias) = alias {
247                Ok(alias.name.value.clone())
248            } else {
249                Err(ParseError::StreamingError(
250                    "Derived table without alias not supported".to_string(),
251                ))
252            }
253        }
254        _ => Err(ParseError::StreamingError(
255            "Unsupported table factor type".to_string(),
256        )),
257    }
258}
259
260/// Extract table alias from a TableFactor.
261fn extract_table_alias(factor: &TableFactor) -> Option<String> {
262    match factor {
263        TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
264        TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
265        _ => None,
266    }
267}
268
269/// Map sqlparser JoinOperator to our JoinType.
270fn map_join_operator(op: &JoinOperator) -> JoinType {
271    match op {
272        JoinOperator::Inner(_)
273        | JoinOperator::Join(_)
274        | JoinOperator::CrossJoin(_)
275        | JoinOperator::CrossApply
276        | JoinOperator::OuterApply
277        | JoinOperator::StraightJoin(_) => JoinType::Inner,
278        JoinOperator::Left(_)
279        | JoinOperator::LeftOuter(_)
280        | JoinOperator::LeftSemi(_)
281        | JoinOperator::LeftAnti(_)
282        | JoinOperator::Semi(_) => JoinType::Left,
283        JoinOperator::AsOf { .. } => JoinType::AsOf,
284        JoinOperator::Right(_)
285        | JoinOperator::RightOuter(_)
286        | JoinOperator::RightSemi(_)
287        | JoinOperator::RightAnti(_)
288        | JoinOperator::Anti(_) => JoinType::Right,
289        JoinOperator::FullOuter(_) => JoinType::Full,
290    }
291}
292
293/// Analyze join constraint to extract key columns and time bound.
294fn analyze_join_constraint(
295    op: &JoinOperator,
296) -> Result<(String, String, Option<Duration>), ParseError> {
297    let constraint = get_join_constraint(op)?;
298
299    match constraint {
300        JoinConstraint::On(expr) => analyze_on_expression(expr),
301        JoinConstraint::Using(cols) => {
302            if cols.is_empty() {
303                return Err(ParseError::StreamingError(
304                    "USING clause requires at least one column".to_string(),
305                ));
306            }
307            // For USING, both sides have the same column name
308            // Use to_string() on the Ident to get the column name
309            let col = cols[0].to_string();
310            Ok((col.clone(), col, None))
311        }
312        JoinConstraint::Natural => Err(ParseError::StreamingError(
313            "NATURAL JOIN not supported for streaming".to_string(),
314        )),
315        JoinConstraint::None => Err(ParseError::StreamingError(
316            "JOIN without condition not supported for streaming".to_string(),
317        )),
318    }
319}
320
321/// Get the JoinConstraint from a JoinOperator.
322fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
323    match op {
324        JoinOperator::Inner(constraint)
325        | JoinOperator::Join(constraint)
326        | JoinOperator::Left(constraint)
327        | JoinOperator::LeftOuter(constraint)
328        | JoinOperator::Right(constraint)
329        | JoinOperator::RightOuter(constraint)
330        | JoinOperator::FullOuter(constraint)
331        | JoinOperator::LeftSemi(constraint)
332        | JoinOperator::RightSemi(constraint)
333        | JoinOperator::LeftAnti(constraint)
334        | JoinOperator::RightAnti(constraint)
335        | JoinOperator::Semi(constraint)
336        | JoinOperator::Anti(constraint)
337        | JoinOperator::StraightJoin(constraint)
338        | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
339        JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
340            ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
341        ),
342    }
343}
344
345/// Analyze ON expression to extract key columns and time bound.
346fn analyze_on_expression(expr: &Expr) -> Result<(String, String, Option<Duration>), ParseError> {
347    // Handle compound expressions (AND)
348    match expr {
349        Expr::BinaryOp {
350            left,
351            op: BinaryOperator::And,
352            right,
353        } => {
354            // Recursively analyze both sides
355            let left_result = analyze_on_expression(left);
356            let right_result = analyze_on_expression(right);
357
358            // Combine results - one should have keys, the other might have time bound
359            match (left_result, right_result) {
360                (Ok((lk, rk, None)), Ok((_, _, time))) if !lk.is_empty() => Ok((lk, rk, time)),
361                (Ok((_, _, time)), Ok((lk, rk, None))) if !lk.is_empty() => Ok((lk, rk, time)),
362                (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
363                (Ok((lk, rk, t1)), Ok((_, _, t2))) => {
364                    // If both have keys, prefer the first
365                    Ok((lk, rk, t1.or(t2)))
366                }
367                (Err(e), Err(_)) => Err(e),
368            }
369        }
370        // Equality condition: a.col = b.col
371        Expr::BinaryOp {
372            left,
373            op: BinaryOperator::Eq,
374            right,
375        } => {
376            let left_col = extract_column_ref(left);
377            let right_col = extract_column_ref(right);
378
379            match (left_col, right_col) {
380                (Some(l), Some(r)) => Ok((l, r, None)),
381                _ => Err(ParseError::StreamingError(
382                    "Cannot extract column references from equality condition".to_string(),
383                )),
384            }
385        }
386        // BETWEEN clause for time bound: p.ts BETWEEN o.ts AND o.ts + INTERVAL
387        Expr::Between {
388            expr: _,
389            low: _,
390            high,
391            ..
392        } => {
393            // Try to extract time bound from high expression
394            let time_bound = extract_time_bound_from_expr(high).ok();
395            Ok((String::new(), String::new(), time_bound))
396        }
397        // Comparison operators for time bounds
398        Expr::BinaryOp {
399            left: _,
400            op:
401                BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
402            right,
403        } => {
404            // Try to extract time bound from right side
405            let time_bound = extract_time_bound_from_expr(right).ok();
406            Ok((String::new(), String::new(), time_bound))
407        }
408        _ => Err(ParseError::StreamingError(format!(
409            "Unsupported join condition expression: {expr:?}"
410        ))),
411    }
412}
413
414/// Extract column reference from expression (e.g., "a.id" -> "id")
415fn extract_column_ref(expr: &Expr) -> Option<String> {
416    match expr {
417        Expr::Identifier(ident) => Some(ident.value.clone()),
418        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
419        _ => None,
420    }
421}
422
423/// Extract time bound from an expression like "o.ts + INTERVAL '1' HOUR"
424fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
425    match expr {
426        // Direct interval
427        Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
428        // Addition or subtraction: col +/- INTERVAL
429        Expr::BinaryOp {
430            left: _,
431            op: BinaryOperator::Plus | BinaryOperator::Minus,
432            right,
433        } => extract_time_bound_from_expr(right),
434        // Nested expression
435        Expr::Nested(inner) => extract_time_bound_from_expr(inner),
436        _ => Err(ParseError::StreamingError(format!(
437            "Cannot extract time bound from: {expr:?}"
438        ))),
439    }
440}
441
442/// Analyze ASOF JOIN MATCH_CONDITION expression.
443///
444/// Extracts direction, time column names, and optional tolerance.
445fn analyze_asof_match_condition(
446    expr: &Expr,
447) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
448    if let Expr::BinaryOp {
449        left,
450        op: BinaryOperator::And,
451        right,
452    } = expr
453    {
454        // Try to get direction from left, tolerance from right
455        let dir_result = analyze_asof_direction(left);
456        let tol_result = extract_asof_tolerance(right);
457
458        match (dir_result, tol_result) {
459            (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
460            (Ok((dir, lt, rt)), Err(_)) => {
461                // Maybe tolerance is on left and direction on right
462                let dir2 = analyze_asof_direction(right);
463                let tol2 = extract_asof_tolerance(left);
464                match (dir2, tol2) {
465                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
466                    _ => Ok((dir, lt, rt, None)),
467                }
468            }
469            (Err(_), _) => {
470                // Try reversed
471                let dir2 = analyze_asof_direction(right);
472                let tol2 = extract_asof_tolerance(left);
473                match (dir2, tol2) {
474                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
475                    (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
476                    _ => Err(ParseError::StreamingError(
477                        "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
478                    )),
479                }
480            }
481        }
482    } else {
483        let (dir, lt, rt) = analyze_asof_direction(expr)?;
484        Ok((dir, lt, rt, None))
485    }
486}
487
488/// Extract ASOF direction and time columns from a comparison expression.
489fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
490    match expr {
491        Expr::BinaryOp {
492            left,
493            op: BinaryOperator::GtEq,
494            right,
495        } => {
496            let left_col = extract_column_ref(left).ok_or_else(|| {
497                ParseError::StreamingError(
498                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
499                )
500            })?;
501            let right_col = extract_column_ref(right).ok_or_else(|| {
502                ParseError::StreamingError(
503                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
504                )
505            })?;
506            Ok((AsofSqlDirection::Backward, left_col, right_col))
507        }
508        Expr::BinaryOp {
509            left,
510            op: BinaryOperator::LtEq,
511            right,
512        } => {
513            let left_col = extract_column_ref(left).ok_or_else(|| {
514                ParseError::StreamingError(
515                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
516                )
517            })?;
518            let right_col = extract_column_ref(right).ok_or_else(|| {
519                ParseError::StreamingError(
520                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
521                )
522            })?;
523            Ok((AsofSqlDirection::Forward, left_col, right_col))
524        }
525        _ => Err(ParseError::StreamingError(
526            "ASOF MATCH_CONDITION must be >= or <= comparison".to_string(),
527        )),
528    }
529}
530
531/// Extract tolerance duration from an ASOF tolerance expression.
532///
533/// Handles: `left - right <= value` or `left - right <= INTERVAL '...'`
534fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
535    match expr {
536        Expr::BinaryOp {
537            left: _,
538            op: BinaryOperator::LtEq,
539            right,
540        } => {
541            // right side is either a literal number or INTERVAL
542            match right.as_ref() {
543                Expr::Value(v) => {
544                    if let sqlparser::ast::Value::Number(n, _) = &v.value {
545                        let ms: u64 = n.parse().map_err(|_| {
546                            ParseError::StreamingError(format!(
547                                "Cannot parse tolerance as number: {n}"
548                            ))
549                        })?;
550                        Ok(Duration::from_millis(ms))
551                    } else {
552                        Err(ParseError::StreamingError(
553                            "ASOF tolerance must be a number or INTERVAL".to_string(),
554                        ))
555                    }
556                }
557                Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
558                _ => Err(ParseError::StreamingError(
559                    "ASOF tolerance must be a number or INTERVAL".to_string(),
560                )),
561            }
562        }
563        _ => Err(ParseError::StreamingError(
564            "ASOF tolerance expression must be <= comparison".to_string(),
565        )),
566    }
567}
568
569/// Extract key columns from an ASOF JOIN constraint (ON clause).
570fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
571    match constraint {
572        JoinConstraint::On(expr) => extract_equality_columns(expr),
573        JoinConstraint::Using(cols) => {
574            if cols.is_empty() {
575                return Err(ParseError::StreamingError(
576                    "USING clause requires at least one column".to_string(),
577                ));
578            }
579            let col = cols[0].to_string();
580            Ok((col.clone(), col))
581        }
582        _ => Err(ParseError::StreamingError(
583            "ASOF JOIN requires ON or USING constraint".to_string(),
584        )),
585    }
586}
587
588/// Extract left and right column names from an equality expression.
589fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
590    match expr {
591        Expr::BinaryOp {
592            left,
593            op: BinaryOperator::Eq,
594            right,
595        } => {
596            let left_col = extract_column_ref(left).ok_or_else(|| {
597                ParseError::StreamingError("Cannot extract left key column".to_string())
598            })?;
599            let right_col = extract_column_ref(right).ok_or_else(|| {
600                ParseError::StreamingError("Cannot extract right key column".to_string())
601            })?;
602            Ok((left_col, right_col))
603        }
604        // If there's an AND, find the equality part
605        Expr::BinaryOp {
606            left,
607            op: BinaryOperator::And,
608            right,
609        } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
610        _ => Err(ParseError::StreamingError(
611            "ASOF JOIN ON clause must contain an equality condition".to_string(),
612        )),
613    }
614}
615
616/// Check if a SELECT contains a join.
617#[must_use]
618pub fn has_join(select: &Select) -> bool {
619    !select.from.is_empty() && !select.from[0].joins.is_empty()
620}
621
622/// Count the number of joins in a SELECT.
623#[must_use]
624pub fn count_joins(select: &Select) -> usize {
625    select
626        .from
627        .iter()
628        .map(|table_with_joins| table_with_joins.joins.len())
629        .sum()
630}
631
632/// Analysis result for multi-way JOINs (e.g., `A JOIN B ... JOIN C ...`).
633///
634/// Each step represents one left-deep join: step 0 joins the base table with
635/// the first right table, step 1 joins the result with the next right table, etc.
636#[derive(Debug, Clone)]
637pub struct MultiJoinAnalysis {
638    /// Ordered join steps (left-to-right)
639    pub joins: Vec<JoinAnalysis>,
640    /// All referenced tables in order (base table first, then each right table)
641    pub tables: Vec<String>,
642}
643
644impl MultiJoinAnalysis {
645    /// Number of join steps.
646    #[must_use]
647    pub fn len(&self) -> usize {
648        self.joins.len()
649    }
650
651    /// Whether there are no join steps.
652    #[must_use]
653    pub fn is_empty(&self) -> bool {
654        self.joins.is_empty()
655    }
656
657    /// Whether this is a single join (backward-compatible case).
658    #[must_use]
659    pub fn is_single(&self) -> bool {
660        self.joins.len() == 1
661    }
662
663    /// The first join step (convenience for single-join queries).
664    #[must_use]
665    pub fn first(&self) -> Option<&JoinAnalysis> {
666        self.joins.first()
667    }
668}
669
670/// Analyze a SELECT statement for all join steps (multi-way).
671///
672/// Returns `None` if the query has no joins. For a single join this
673/// returns a `MultiJoinAnalysis` with one step, making it backward
674/// compatible with `analyze_join()`.
675///
676/// # Errors
677///
678/// Returns `ParseError::StreamingError` if any join constraint is
679/// not supported or key columns cannot be extracted.
680pub fn analyze_joins(select: &Select) -> Result<Option<MultiJoinAnalysis>, ParseError> {
681    let from = &select.from;
682    if from.is_empty() {
683        return Ok(None);
684    }
685
686    let first_table = &from[0];
687    if first_table.joins.is_empty() {
688        return Ok(None);
689    }
690
691    // Extract base table
692    let base_table = extract_table_name(&first_table.relation)?;
693    let base_alias = extract_table_alias(&first_table.relation);
694
695    let mut join_steps = Vec::with_capacity(first_table.joins.len());
696    let mut tables = vec![base_table.clone()];
697
698    // Track the left table name for left-deep chaining
699    let mut prev_left_table = base_table;
700    let mut prev_left_alias = base_alias;
701
702    for join in &first_table.joins {
703        let right_table = extract_table_name(&join.relation)?;
704        let right_alias = extract_table_alias(&join.relation);
705        tables.push(right_table.clone());
706
707        let join_type = map_join_operator(&join.join_operator);
708
709        // Handle ASOF JOIN
710        if let JoinOperator::AsOf {
711            match_condition,
712            constraint,
713        } = &join.join_operator
714        {
715            let (direction, left_time, right_time, tolerance) =
716                analyze_asof_match_condition(match_condition)?;
717            let (left_key, right_key) = analyze_asof_constraint(constraint)?;
718
719            let mut analysis = JoinAnalysis::asof(
720                prev_left_table.clone(),
721                right_table.clone(),
722                left_key,
723                right_key,
724                direction,
725                left_time,
726                right_time,
727                tolerance,
728            );
729            analysis.left_alias.clone_from(&prev_left_alias);
730            analysis.right_alias = right_alias;
731            join_steps.push(analysis);
732        } else {
733            // Regular join (inner, left, right, full)
734            let (left_key, right_key, time_bound) = analyze_join_constraint(&join.join_operator)?;
735
736            let mut analysis = if let Some(tb) = time_bound {
737                JoinAnalysis::stream_stream(
738                    prev_left_table.clone(),
739                    right_table.clone(),
740                    left_key,
741                    right_key,
742                    tb,
743                    join_type,
744                )
745            } else {
746                JoinAnalysis::lookup(
747                    prev_left_table.clone(),
748                    right_table.clone(),
749                    left_key,
750                    right_key,
751                    join_type,
752                )
753            };
754            analysis.left_alias.clone_from(&prev_left_alias);
755            analysis.right_alias = right_alias;
756            join_steps.push(analysis);
757        }
758
759        // Next step's left table is this step's right table (left-deep)
760        prev_left_table = right_table;
761        prev_left_alias = extract_table_alias(&join.relation);
762    }
763
764    Ok(Some(MultiJoinAnalysis {
765        joins: join_steps,
766        tables,
767    }))
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773    use sqlparser::ast::{SetExpr, Statement};
774    use sqlparser::dialect::GenericDialect;
775    use sqlparser::parser::Parser;
776
777    fn parse_select(sql: &str) -> Select {
778        let dialect = GenericDialect {};
779        let statements = Parser::parse_sql(&dialect, sql).unwrap();
780        if let Statement::Query(query) = &statements[0] {
781            if let SetExpr::Select(select) = query.body.as_ref() {
782                return *select.clone();
783            }
784        }
785        panic!("Expected SELECT query");
786    }
787
788    #[test]
789    fn test_analyze_inner_join() {
790        let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
791        let select = parse_select(sql);
792
793        let analysis = analyze_join(&select).unwrap().unwrap();
794
795        assert_eq!(analysis.join_type, JoinType::Inner);
796        assert_eq!(analysis.left_table, "orders");
797        assert_eq!(analysis.right_table, "payments");
798        assert_eq!(analysis.left_key_column, "order_id");
799        assert_eq!(analysis.right_key_column, "order_id");
800        assert!(analysis.is_lookup_join); // No time bound = lookup join
801    }
802
803    #[test]
804    fn test_analyze_left_join() {
805        let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
806        let select = parse_select(sql);
807
808        let analysis = analyze_join(&select).unwrap().unwrap();
809
810        assert_eq!(analysis.join_type, JoinType::Left);
811        assert_eq!(analysis.left_key_column, "customer_id");
812        assert_eq!(analysis.right_key_column, "id");
813    }
814
815    #[test]
816    fn test_analyze_join_using() {
817        let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
818        let select = parse_select(sql);
819
820        let analysis = analyze_join(&select).unwrap().unwrap();
821
822        assert_eq!(analysis.left_key_column, "order_id");
823        assert_eq!(analysis.right_key_column, "order_id");
824    }
825
826    #[test]
827    fn test_analyze_stream_stream_join_with_time_bound() {
828        let sql = "SELECT * FROM orders o
829                   JOIN payments p ON o.order_id = p.order_id
830                   AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
831        let select = parse_select(sql);
832
833        let analysis = analyze_join(&select).unwrap().unwrap();
834
835        assert!(!analysis.is_lookup_join);
836        assert!(analysis.time_bound.is_some());
837        assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
838    }
839
840    #[test]
841    fn test_no_join() {
842        let sql = "SELECT * FROM orders";
843        let select = parse_select(sql);
844
845        let analysis = analyze_join(&select).unwrap();
846        assert!(analysis.is_none());
847    }
848
849    #[test]
850    fn test_has_join() {
851        let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
852        let sql_without_join = "SELECT * FROM orders";
853
854        let select_with = parse_select(sql_with_join);
855        let select_without = parse_select(sql_without_join);
856
857        assert!(has_join(&select_with));
858        assert!(!has_join(&select_without));
859    }
860
861    #[test]
862    fn test_count_joins() {
863        let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
864        let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
865        let sql_zero = "SELECT * FROM a";
866
867        assert_eq!(count_joins(&parse_select(sql_one)), 1);
868        assert_eq!(count_joins(&parse_select(sql_two)), 2);
869        assert_eq!(count_joins(&parse_select(sql_zero)), 0);
870    }
871
872    #[test]
873    fn test_aliases() {
874        let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
875        let select = parse_select(sql);
876
877        let analysis = analyze_join(&select).unwrap().unwrap();
878
879        assert_eq!(analysis.left_alias, Some("o".to_string()));
880        assert_eq!(analysis.right_alias, Some("p".to_string()));
881    }
882
883    // -- ASOF JOIN tests --
884
885    fn parse_select_snowflake(sql: &str) -> Select {
886        let dialect = sqlparser::dialect::SnowflakeDialect {};
887        let statements = Parser::parse_sql(&dialect, sql).unwrap();
888        if let Statement::Query(query) = &statements[0] {
889            if let SetExpr::Select(select) = query.body.as_ref() {
890                return *select.clone();
891            }
892        }
893        panic!("Expected SELECT query");
894    }
895
896    #[test]
897    fn test_asof_join_backward() {
898        let sql = "SELECT * FROM trades t \
899                    ASOF JOIN quotes q \
900                    MATCH_CONDITION(t.ts >= q.ts) \
901                    ON t.symbol = q.symbol";
902        let select = parse_select_snowflake(sql);
903        let analysis = analyze_join(&select).unwrap().unwrap();
904
905        assert!(analysis.is_asof_join);
906        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
907        assert_eq!(analysis.join_type, JoinType::AsOf);
908        assert!(analysis.asof_tolerance.is_none());
909    }
910
911    #[test]
912    fn test_asof_join_forward() {
913        let sql = "SELECT * FROM trades t \
914                    ASOF JOIN quotes q \
915                    MATCH_CONDITION(t.ts <= q.ts) \
916                    ON t.symbol = q.symbol";
917        let select = parse_select_snowflake(sql);
918        let analysis = analyze_join(&select).unwrap().unwrap();
919
920        assert!(analysis.is_asof_join);
921        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
922    }
923
924    #[test]
925    fn test_asof_join_with_tolerance() {
926        let sql = "SELECT * FROM trades t \
927                    ASOF JOIN quotes q \
928                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
929                    ON t.symbol = q.symbol";
930        let select = parse_select_snowflake(sql);
931        let analysis = analyze_join(&select).unwrap().unwrap();
932
933        assert!(analysis.is_asof_join);
934        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
935        assert_eq!(analysis.asof_tolerance, Some(Duration::from_millis(5000)));
936    }
937
938    #[test]
939    fn test_asof_join_with_interval_tolerance() {
940        let sql = "SELECT * FROM trades t \
941                    ASOF JOIN quotes q \
942                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
943                    ON t.symbol = q.symbol";
944        let select = parse_select_snowflake(sql);
945        let analysis = analyze_join(&select).unwrap().unwrap();
946
947        assert!(analysis.is_asof_join);
948        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
949        assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
950    }
951
952    #[test]
953    fn test_asof_join_type_mapping() {
954        let sql = "SELECT * FROM trades t \
955                    ASOF JOIN quotes q \
956                    MATCH_CONDITION(t.ts >= q.ts) \
957                    ON t.symbol = q.symbol";
958        let select = parse_select_snowflake(sql);
959        let analysis = analyze_join(&select).unwrap().unwrap();
960
961        assert_eq!(analysis.join_type, JoinType::AsOf);
962        assert!(!analysis.is_lookup_join);
963    }
964
965    #[test]
966    fn test_asof_join_extracts_time_columns() {
967        let sql = "SELECT * FROM trades t \
968                    ASOF JOIN quotes q \
969                    MATCH_CONDITION(t.ts >= q.ts) \
970                    ON t.symbol = q.symbol";
971        let select = parse_select_snowflake(sql);
972        let analysis = analyze_join(&select).unwrap().unwrap();
973
974        assert_eq!(analysis.left_time_column, Some("ts".to_string()));
975        assert_eq!(analysis.right_time_column, Some("ts".to_string()));
976    }
977
978    #[test]
979    fn test_asof_join_extracts_key_columns() {
980        let sql = "SELECT * FROM trades t \
981                    ASOF JOIN quotes q \
982                    MATCH_CONDITION(t.ts >= q.ts) \
983                    ON t.symbol = q.symbol";
984        let select = parse_select_snowflake(sql);
985        let analysis = analyze_join(&select).unwrap().unwrap();
986
987        assert_eq!(analysis.left_key_column, "symbol");
988        assert_eq!(analysis.right_key_column, "symbol");
989    }
990
991    #[test]
992    fn test_asof_join_aliases() {
993        let sql = "SELECT * FROM trades AS t \
994                    ASOF JOIN quotes AS q \
995                    MATCH_CONDITION(t.ts >= q.ts) \
996                    ON t.symbol = q.symbol";
997        let select = parse_select_snowflake(sql);
998        let analysis = analyze_join(&select).unwrap().unwrap();
999
1000        assert_eq!(analysis.left_alias, Some("t".to_string()));
1001        assert_eq!(analysis.right_alias, Some("q".to_string()));
1002        assert_eq!(analysis.left_table, "trades");
1003        assert_eq!(analysis.right_table, "quotes");
1004    }
1005
1006    // -- Multi-way JOIN tests (F-SQL-005) --
1007
1008    #[test]
1009    fn test_multi_join_single_backward_compat() {
1010        let sql = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
1011        let select = parse_select(sql);
1012        let multi = analyze_joins(&select).unwrap().unwrap();
1013
1014        assert!(multi.is_single());
1015        assert_eq!(multi.len(), 1);
1016        assert!(!multi.is_empty());
1017        let first = multi.first().unwrap();
1018        assert_eq!(first.left_table, "orders");
1019        assert_eq!(first.right_table, "payments");
1020    }
1021
1022    #[test]
1023    fn test_multi_join_two_way() {
1024        let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1025        let select = parse_select(sql);
1026        let multi = analyze_joins(&select).unwrap().unwrap();
1027
1028        assert_eq!(multi.len(), 2);
1029        assert!(!multi.is_single());
1030
1031        assert_eq!(multi.joins[0].left_table, "a");
1032        assert_eq!(multi.joins[0].right_table, "b");
1033        assert_eq!(multi.joins[0].left_key_column, "id");
1034        assert_eq!(multi.joins[0].right_key_column, "a_id");
1035
1036        assert_eq!(multi.joins[1].left_table, "b");
1037        assert_eq!(multi.joins[1].right_table, "c");
1038        assert_eq!(multi.joins[1].left_key_column, "id");
1039        assert_eq!(multi.joins[1].right_key_column, "b_id");
1040    }
1041
1042    #[test]
1043    fn test_multi_join_three_way() {
1044        let sql = "SELECT * FROM a \
1045                    JOIN b ON a.id = b.a_id \
1046                    JOIN c ON b.id = c.b_id \
1047                    JOIN d ON c.id = d.c_id";
1048        let select = parse_select(sql);
1049        let multi = analyze_joins(&select).unwrap().unwrap();
1050
1051        assert_eq!(multi.len(), 3);
1052        assert_eq!(multi.tables.len(), 4);
1053        assert_eq!(multi.tables, vec!["a", "b", "c", "d"]);
1054    }
1055
1056    #[test]
1057    fn test_multi_join_mixed_asof_and_lookup() {
1058        // ASOF first, then lookup (use Snowflake dialect for ASOF)
1059        let sql = "SELECT * FROM trades t \
1060                    ASOF JOIN quotes q \
1061                    MATCH_CONDITION(t.ts >= q.ts) \
1062                    ON t.symbol = q.symbol \
1063                    JOIN products p ON q.product_id = p.id";
1064        let select = parse_select_snowflake(sql);
1065        let multi = analyze_joins(&select).unwrap().unwrap();
1066
1067        assert_eq!(multi.len(), 2);
1068        assert!(multi.joins[0].is_asof_join);
1069        assert!(multi.joins[1].is_lookup_join);
1070    }
1071
1072    #[test]
1073    fn test_multi_join_stream_stream_and_lookup() {
1074        let sql = "SELECT * FROM orders o \
1075                    JOIN payments p ON o.id = p.order_id \
1076                        AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR \
1077                    JOIN customers c ON o.customer_id = c.id";
1078        let select = parse_select(sql);
1079        let multi = analyze_joins(&select).unwrap().unwrap();
1080
1081        assert_eq!(multi.len(), 2);
1082        assert!(!multi.joins[0].is_lookup_join); // stream-stream
1083        assert!(multi.joins[0].time_bound.is_some());
1084        assert!(multi.joins[1].is_lookup_join); // lookup
1085    }
1086
1087    #[test]
1088    fn test_multi_join_tables_list() {
1089        let sql = "SELECT * FROM a JOIN b ON a.id = b.a_id JOIN c ON b.id = c.b_id";
1090        let select = parse_select(sql);
1091        let multi = analyze_joins(&select).unwrap().unwrap();
1092
1093        assert_eq!(multi.tables, vec!["a", "b", "c"]);
1094    }
1095
1096    #[test]
1097    fn test_multi_join_aliases() {
1098        let sql = "SELECT * FROM orders AS o \
1099                    JOIN payments AS p ON o.id = p.order_id \
1100                    JOIN refunds AS r ON p.id = r.payment_id";
1101        let select = parse_select(sql);
1102        let multi = analyze_joins(&select).unwrap().unwrap();
1103
1104        assert_eq!(multi.joins[0].left_alias, Some("o".to_string()));
1105        assert_eq!(multi.joins[0].right_alias, Some("p".to_string()));
1106        assert_eq!(multi.joins[1].left_alias, Some("p".to_string()));
1107        assert_eq!(multi.joins[1].right_alias, Some("r".to_string()));
1108    }
1109
1110    #[test]
1111    fn test_multi_join_no_join_returns_none() {
1112        let sql = "SELECT * FROM orders";
1113        let select = parse_select(sql);
1114        let multi = analyze_joins(&select).unwrap();
1115        assert!(multi.is_none());
1116    }
1117}