Skip to main content

laminar_sql/parser/
join_parser.rs

1//! Join query analysis and extraction
2//!
3//! This module analyzes JOIN clauses to extract:
4//! - Join type (INNER, LEFT, RIGHT, FULL)
5//! - Key columns for join condition
6//! - Time bounds for stream-stream joins
7//! - Detection of lookup joins vs stream-stream joins
8
9use std::time::Duration;
10
11use sqlparser::ast::{BinaryOperator, Expr, JoinConstraint, JoinOperator, Select, TableFactor};
12
13use super::window_rewriter::WindowRewriter;
14use super::ParseError;
15
16/// Join type classification
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum JoinType {
19    /// INNER JOIN
20    Inner,
21    /// LEFT \[OUTER\] JOIN
22    Left,
23    /// RIGHT \[OUTER\] JOIN
24    Right,
25    /// FULL \[OUTER\] JOIN
26    Full,
27    /// ASOF JOIN
28    AsOf,
29}
30
31/// Direction for ASOF JOIN time matching.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum AsofSqlDirection {
34    /// `left.ts >= right.ts` — find most recent right row
35    Backward,
36    /// `left.ts <= right.ts` — find next right row
37    Forward,
38}
39
40/// Analysis result for a JOIN clause
41#[derive(Debug, Clone)]
42pub struct JoinAnalysis {
43    /// Type of join (inner, left, right, full)
44    pub join_type: JoinType,
45    /// Left side table name
46    pub left_table: String,
47    /// Right side table name
48    pub right_table: String,
49    /// Left side key column
50    pub left_key_column: String,
51    /// Right side key column
52    pub right_key_column: String,
53    /// Time bound for stream-stream joins (None for lookup joins)
54    pub time_bound: Option<Duration>,
55    /// Whether this is a lookup join (no time bound)
56    pub is_lookup_join: bool,
57    /// Left side alias (if any)
58    pub left_alias: Option<String>,
59    /// Right side alias (if any)
60    pub right_alias: Option<String>,
61    /// Whether this is an ASOF join
62    pub is_asof_join: bool,
63    /// ASOF join direction (Backward or Forward)
64    pub asof_direction: Option<AsofSqlDirection>,
65    /// Left side time column for ASOF join
66    pub left_time_column: Option<String>,
67    /// Right side time column for ASOF join
68    pub right_time_column: Option<String>,
69    /// ASOF join tolerance (max time difference)
70    pub asof_tolerance: Option<Duration>,
71}
72
73impl JoinAnalysis {
74    /// Create a stream-stream join analysis
75    #[must_use]
76    pub fn stream_stream(
77        left_table: String,
78        right_table: String,
79        left_key: String,
80        right_key: String,
81        time_bound: Duration,
82        join_type: JoinType,
83    ) -> Self {
84        Self {
85            join_type,
86            left_table,
87            right_table,
88            left_key_column: left_key,
89            right_key_column: right_key,
90            time_bound: Some(time_bound),
91            is_lookup_join: false,
92            left_alias: None,
93            right_alias: None,
94            is_asof_join: false,
95            asof_direction: None,
96            left_time_column: None,
97            right_time_column: None,
98            asof_tolerance: None,
99        }
100    }
101
102    /// Create a lookup join analysis
103    #[must_use]
104    pub fn lookup(
105        left_table: String,
106        right_table: String,
107        left_key: String,
108        right_key: String,
109        join_type: JoinType,
110    ) -> Self {
111        Self {
112            join_type,
113            left_table,
114            right_table,
115            left_key_column: left_key,
116            right_key_column: right_key,
117            time_bound: None,
118            is_lookup_join: true,
119            left_alias: None,
120            right_alias: None,
121            is_asof_join: false,
122            asof_direction: None,
123            left_time_column: None,
124            right_time_column: None,
125            asof_tolerance: None,
126        }
127    }
128
129    /// Create an ASOF join analysis
130    #[must_use]
131    #[allow(clippy::too_many_arguments)]
132    pub fn asof(
133        left_table: String,
134        right_table: String,
135        left_key: String,
136        right_key: String,
137        direction: AsofSqlDirection,
138        left_time_col: String,
139        right_time_col: String,
140        tolerance: Option<Duration>,
141    ) -> Self {
142        Self {
143            join_type: JoinType::AsOf,
144            left_table,
145            right_table,
146            left_key_column: left_key,
147            right_key_column: right_key,
148            time_bound: None,
149            is_lookup_join: false,
150            left_alias: None,
151            right_alias: None,
152            is_asof_join: true,
153            asof_direction: Some(direction),
154            left_time_column: Some(left_time_col),
155            right_time_column: Some(right_time_col),
156            asof_tolerance: tolerance,
157        }
158    }
159}
160
161/// Analyze a SELECT statement for join information.
162///
163/// # Errors
164///
165/// Returns `ParseError::StreamingError` if:
166/// - Join constraint is not supported
167/// - Cannot extract key columns
168pub fn analyze_join(select: &Select) -> Result<Option<JoinAnalysis>, ParseError> {
169    let from = &select.from;
170    if from.is_empty() {
171        return Ok(None);
172    }
173
174    let first_table = &from[0];
175    if first_table.joins.is_empty() {
176        return Ok(None);
177    }
178
179    // Extract left table information
180    let left_table = extract_table_name(&first_table.relation)?;
181    let left_alias = extract_table_alias(&first_table.relation);
182
183    // Analyze the first join
184    let join = &first_table.joins[0];
185    let right_table = extract_table_name(&join.relation)?;
186    let right_alias = extract_table_alias(&join.relation);
187
188    let join_type = map_join_operator(&join.join_operator);
189
190    // Handle ASOF JOIN specially
191    if let JoinOperator::AsOf {
192        match_condition,
193        constraint,
194    } = &join.join_operator
195    {
196        let (direction, left_time, right_time, tolerance) =
197            analyze_asof_match_condition(match_condition)?;
198
199        // Extract key columns from the ON constraint
200        let (left_key, right_key) = analyze_asof_constraint(constraint)?;
201
202        let mut analysis = JoinAnalysis::asof(
203            left_table,
204            right_table,
205            left_key,
206            right_key,
207            direction,
208            left_time,
209            right_time,
210            tolerance,
211        );
212        analysis.left_alias = left_alias;
213        analysis.right_alias = right_alias;
214        return Ok(Some(analysis));
215    }
216
217    // Analyze the join constraint
218    let (left_key, right_key, time_bound) = analyze_join_constraint(&join.join_operator)?;
219
220    let mut analysis = if let Some(tb) = time_bound {
221        JoinAnalysis::stream_stream(left_table, right_table, left_key, right_key, tb, join_type)
222    } else {
223        JoinAnalysis::lookup(left_table, right_table, left_key, right_key, join_type)
224    };
225
226    analysis.left_alias = left_alias;
227    analysis.right_alias = right_alias;
228
229    Ok(Some(analysis))
230}
231
232/// Extract table name from a TableFactor.
233fn extract_table_name(factor: &TableFactor) -> Result<String, ParseError> {
234    match factor {
235        TableFactor::Table { name, .. } => Ok(name.to_string()),
236        TableFactor::Derived { alias, .. } => {
237            if let Some(alias) = alias {
238                Ok(alias.name.value.clone())
239            } else {
240                Err(ParseError::StreamingError(
241                    "Derived table without alias not supported".to_string(),
242                ))
243            }
244        }
245        _ => Err(ParseError::StreamingError(
246            "Unsupported table factor type".to_string(),
247        )),
248    }
249}
250
251/// Extract table alias from a TableFactor.
252fn extract_table_alias(factor: &TableFactor) -> Option<String> {
253    match factor {
254        TableFactor::Table { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
255        TableFactor::Derived { alias, .. } => alias.as_ref().map(|a| a.name.value.clone()),
256        _ => None,
257    }
258}
259
260/// Map sqlparser JoinOperator to our JoinType.
261fn map_join_operator(op: &JoinOperator) -> JoinType {
262    match op {
263        JoinOperator::Inner(_)
264        | JoinOperator::Join(_)
265        | JoinOperator::CrossJoin(_)
266        | JoinOperator::CrossApply
267        | JoinOperator::OuterApply
268        | JoinOperator::StraightJoin(_) => JoinType::Inner,
269        JoinOperator::Left(_)
270        | JoinOperator::LeftOuter(_)
271        | JoinOperator::LeftSemi(_)
272        | JoinOperator::LeftAnti(_)
273        | JoinOperator::Semi(_) => JoinType::Left,
274        JoinOperator::AsOf { .. } => JoinType::AsOf,
275        JoinOperator::Right(_)
276        | JoinOperator::RightOuter(_)
277        | JoinOperator::RightSemi(_)
278        | JoinOperator::RightAnti(_)
279        | JoinOperator::Anti(_) => JoinType::Right,
280        JoinOperator::FullOuter(_) => JoinType::Full,
281    }
282}
283
284/// Analyze join constraint to extract key columns and time bound.
285fn analyze_join_constraint(
286    op: &JoinOperator,
287) -> Result<(String, String, Option<Duration>), ParseError> {
288    let constraint = get_join_constraint(op)?;
289
290    match constraint {
291        JoinConstraint::On(expr) => analyze_on_expression(expr),
292        JoinConstraint::Using(cols) => {
293            if cols.is_empty() {
294                return Err(ParseError::StreamingError(
295                    "USING clause requires at least one column".to_string(),
296                ));
297            }
298            // For USING, both sides have the same column name
299            // Use to_string() on the Ident to get the column name
300            let col = cols[0].to_string();
301            Ok((col.clone(), col, None))
302        }
303        JoinConstraint::Natural => Err(ParseError::StreamingError(
304            "NATURAL JOIN not supported for streaming".to_string(),
305        )),
306        JoinConstraint::None => Err(ParseError::StreamingError(
307            "JOIN without condition not supported for streaming".to_string(),
308        )),
309    }
310}
311
312/// Get the JoinConstraint from a JoinOperator.
313fn get_join_constraint(op: &JoinOperator) -> Result<&JoinConstraint, ParseError> {
314    match op {
315        JoinOperator::Inner(constraint)
316        | JoinOperator::Join(constraint)
317        | JoinOperator::Left(constraint)
318        | JoinOperator::LeftOuter(constraint)
319        | JoinOperator::Right(constraint)
320        | JoinOperator::RightOuter(constraint)
321        | JoinOperator::FullOuter(constraint)
322        | JoinOperator::LeftSemi(constraint)
323        | JoinOperator::RightSemi(constraint)
324        | JoinOperator::LeftAnti(constraint)
325        | JoinOperator::RightAnti(constraint)
326        | JoinOperator::Semi(constraint)
327        | JoinOperator::Anti(constraint)
328        | JoinOperator::StraightJoin(constraint)
329        | JoinOperator::AsOf { constraint, .. } => Ok(constraint),
330        JoinOperator::CrossJoin(_) | JoinOperator::CrossApply | JoinOperator::OuterApply => Err(
331            ParseError::StreamingError("CROSS JOIN not supported for streaming".to_string()),
332        ),
333    }
334}
335
336/// Analyze ON expression to extract key columns and time bound.
337fn analyze_on_expression(expr: &Expr) -> Result<(String, String, Option<Duration>), ParseError> {
338    // Handle compound expressions (AND)
339    match expr {
340        Expr::BinaryOp {
341            left,
342            op: BinaryOperator::And,
343            right,
344        } => {
345            // Recursively analyze both sides
346            let left_result = analyze_on_expression(left);
347            let right_result = analyze_on_expression(right);
348
349            // Combine results - one should have keys, the other might have time bound
350            match (left_result, right_result) {
351                (Ok((lk, rk, None)), Ok((_, _, time))) if !lk.is_empty() => Ok((lk, rk, time)),
352                (Ok((_, _, time)), Ok((lk, rk, None))) if !lk.is_empty() => Ok((lk, rk, time)),
353                (Ok(result), Err(_)) | (Err(_), Ok(result)) => Ok(result),
354                (Ok((lk, rk, t1)), Ok((_, _, t2))) => {
355                    // If both have keys, prefer the first
356                    Ok((lk, rk, t1.or(t2)))
357                }
358                (Err(e), Err(_)) => Err(e),
359            }
360        }
361        // Equality condition: a.col = b.col
362        Expr::BinaryOp {
363            left,
364            op: BinaryOperator::Eq,
365            right,
366        } => {
367            let left_col = extract_column_ref(left);
368            let right_col = extract_column_ref(right);
369
370            match (left_col, right_col) {
371                (Some(l), Some(r)) => Ok((l, r, None)),
372                _ => Err(ParseError::StreamingError(
373                    "Cannot extract column references from equality condition".to_string(),
374                )),
375            }
376        }
377        // BETWEEN clause for time bound: p.ts BETWEEN o.ts AND o.ts + INTERVAL
378        Expr::Between {
379            expr: _,
380            low: _,
381            high,
382            ..
383        } => {
384            // Try to extract time bound from high expression
385            let time_bound = extract_time_bound_from_expr(high).ok();
386            Ok((String::new(), String::new(), time_bound))
387        }
388        // Comparison operators for time bounds
389        Expr::BinaryOp {
390            left: _,
391            op:
392                BinaryOperator::LtEq | BinaryOperator::Lt | BinaryOperator::GtEq | BinaryOperator::Gt,
393            right,
394        } => {
395            // Try to extract time bound from right side
396            let time_bound = extract_time_bound_from_expr(right).ok();
397            Ok((String::new(), String::new(), time_bound))
398        }
399        _ => Err(ParseError::StreamingError(format!(
400            "Unsupported join condition expression: {expr:?}"
401        ))),
402    }
403}
404
405/// Extract column reference from expression (e.g., "a.id" -> "id")
406fn extract_column_ref(expr: &Expr) -> Option<String> {
407    match expr {
408        Expr::Identifier(ident) => Some(ident.value.clone()),
409        Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
410        _ => None,
411    }
412}
413
414/// Extract time bound from an expression like "o.ts + INTERVAL '1' HOUR"
415fn extract_time_bound_from_expr(expr: &Expr) -> Result<Duration, ParseError> {
416    match expr {
417        // Direct interval
418        Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(expr),
419        // Addition or subtraction: col +/- INTERVAL
420        Expr::BinaryOp {
421            left: _,
422            op: BinaryOperator::Plus | BinaryOperator::Minus,
423            right,
424        } => extract_time_bound_from_expr(right),
425        // Nested expression
426        Expr::Nested(inner) => extract_time_bound_from_expr(inner),
427        _ => Err(ParseError::StreamingError(format!(
428            "Cannot extract time bound from: {expr:?}"
429        ))),
430    }
431}
432
433/// Analyze ASOF JOIN MATCH_CONDITION expression.
434///
435/// Extracts direction, time column names, and optional tolerance.
436fn analyze_asof_match_condition(
437    expr: &Expr,
438) -> Result<(AsofSqlDirection, String, String, Option<Duration>), ParseError> {
439    if let Expr::BinaryOp {
440        left,
441        op: BinaryOperator::And,
442        right,
443    } = expr
444    {
445        // Try to get direction from left, tolerance from right
446        let dir_result = analyze_asof_direction(left);
447        let tol_result = extract_asof_tolerance(right);
448
449        match (dir_result, tol_result) {
450            (Ok((dir, lt, rt)), Ok(tol)) => Ok((dir, lt, rt, Some(tol))),
451            (Ok((dir, lt, rt)), Err(_)) => {
452                // Maybe tolerance is on left and direction on right
453                let dir2 = analyze_asof_direction(right);
454                let tol2 = extract_asof_tolerance(left);
455                match (dir2, tol2) {
456                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
457                    _ => Ok((dir, lt, rt, None)),
458                }
459            }
460            (Err(_), _) => {
461                // Try reversed
462                let dir2 = analyze_asof_direction(right);
463                let tol2 = extract_asof_tolerance(left);
464                match (dir2, tol2) {
465                    (Ok((d, l, r)), Ok(t)) => Ok((d, l, r, Some(t))),
466                    (Ok((d, l, r)), Err(_)) => Ok((d, l, r, None)),
467                    _ => Err(ParseError::StreamingError(
468                        "Cannot extract ASOF direction from MATCH_CONDITION".to_string(),
469                    )),
470                }
471            }
472        }
473    } else {
474        let (dir, lt, rt) = analyze_asof_direction(expr)?;
475        Ok((dir, lt, rt, None))
476    }
477}
478
479/// Extract ASOF direction and time columns from a comparison expression.
480fn analyze_asof_direction(expr: &Expr) -> Result<(AsofSqlDirection, String, String), ParseError> {
481    match expr {
482        Expr::BinaryOp {
483            left,
484            op: BinaryOperator::GtEq,
485            right,
486        } => {
487            let left_col = extract_column_ref(left).ok_or_else(|| {
488                ParseError::StreamingError(
489                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
490                )
491            })?;
492            let right_col = extract_column_ref(right).ok_or_else(|| {
493                ParseError::StreamingError(
494                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
495                )
496            })?;
497            Ok((AsofSqlDirection::Backward, left_col, right_col))
498        }
499        Expr::BinaryOp {
500            left,
501            op: BinaryOperator::LtEq,
502            right,
503        } => {
504            let left_col = extract_column_ref(left).ok_or_else(|| {
505                ParseError::StreamingError(
506                    "Cannot extract left time column from MATCH_CONDITION".to_string(),
507                )
508            })?;
509            let right_col = extract_column_ref(right).ok_or_else(|| {
510                ParseError::StreamingError(
511                    "Cannot extract right time column from MATCH_CONDITION".to_string(),
512                )
513            })?;
514            Ok((AsofSqlDirection::Forward, left_col, right_col))
515        }
516        _ => Err(ParseError::StreamingError(
517            "ASOF MATCH_CONDITION must be >= or <= comparison".to_string(),
518        )),
519    }
520}
521
522/// Extract tolerance duration from an ASOF tolerance expression.
523///
524/// Handles: `left - right <= value` or `left - right <= INTERVAL '...'`
525fn extract_asof_tolerance(expr: &Expr) -> Result<Duration, ParseError> {
526    match expr {
527        Expr::BinaryOp {
528            left: _,
529            op: BinaryOperator::LtEq,
530            right,
531        } => {
532            // right side is either a literal number or INTERVAL
533            match right.as_ref() {
534                Expr::Value(v) => {
535                    if let sqlparser::ast::Value::Number(n, _) = &v.value {
536                        let ms: u64 = n.parse().map_err(|_| {
537                            ParseError::StreamingError(format!(
538                                "Cannot parse tolerance as number: {n}"
539                            ))
540                        })?;
541                        Ok(Duration::from_millis(ms))
542                    } else {
543                        Err(ParseError::StreamingError(
544                            "ASOF tolerance must be a number or INTERVAL".to_string(),
545                        ))
546                    }
547                }
548                Expr::Interval(_) => WindowRewriter::parse_interval_to_duration(right),
549                _ => Err(ParseError::StreamingError(
550                    "ASOF tolerance must be a number or INTERVAL".to_string(),
551                )),
552            }
553        }
554        _ => Err(ParseError::StreamingError(
555            "ASOF tolerance expression must be <= comparison".to_string(),
556        )),
557    }
558}
559
560/// Extract key columns from an ASOF JOIN constraint (ON clause).
561fn analyze_asof_constraint(constraint: &JoinConstraint) -> Result<(String, String), ParseError> {
562    match constraint {
563        JoinConstraint::On(expr) => extract_equality_columns(expr),
564        JoinConstraint::Using(cols) => {
565            if cols.is_empty() {
566                return Err(ParseError::StreamingError(
567                    "USING clause requires at least one column".to_string(),
568                ));
569            }
570            let col = cols[0].to_string();
571            Ok((col.clone(), col))
572        }
573        _ => Err(ParseError::StreamingError(
574            "ASOF JOIN requires ON or USING constraint".to_string(),
575        )),
576    }
577}
578
579/// Extract left and right column names from an equality expression.
580fn extract_equality_columns(expr: &Expr) -> Result<(String, String), ParseError> {
581    match expr {
582        Expr::BinaryOp {
583            left,
584            op: BinaryOperator::Eq,
585            right,
586        } => {
587            let left_col = extract_column_ref(left).ok_or_else(|| {
588                ParseError::StreamingError("Cannot extract left key column".to_string())
589            })?;
590            let right_col = extract_column_ref(right).ok_or_else(|| {
591                ParseError::StreamingError("Cannot extract right key column".to_string())
592            })?;
593            Ok((left_col, right_col))
594        }
595        // If there's an AND, find the equality part
596        Expr::BinaryOp {
597            left,
598            op: BinaryOperator::And,
599            right,
600        } => extract_equality_columns(left).or_else(|_| extract_equality_columns(right)),
601        _ => Err(ParseError::StreamingError(
602            "ASOF JOIN ON clause must contain an equality condition".to_string(),
603        )),
604    }
605}
606
607/// Check if a SELECT contains a join.
608#[must_use]
609pub fn has_join(select: &Select) -> bool {
610    !select.from.is_empty() && !select.from[0].joins.is_empty()
611}
612
613/// Count the number of joins in a SELECT.
614#[must_use]
615pub fn count_joins(select: &Select) -> usize {
616    select
617        .from
618        .iter()
619        .map(|table_with_joins| table_with_joins.joins.len())
620        .sum()
621}
622
623#[cfg(test)]
624mod tests {
625    use super::*;
626    use sqlparser::ast::{SetExpr, Statement};
627    use sqlparser::dialect::GenericDialect;
628    use sqlparser::parser::Parser;
629
630    fn parse_select(sql: &str) -> Select {
631        let dialect = GenericDialect {};
632        let statements = Parser::parse_sql(&dialect, sql).unwrap();
633        if let Statement::Query(query) = &statements[0] {
634            if let SetExpr::Select(select) = query.body.as_ref() {
635                return *select.clone();
636            }
637        }
638        panic!("Expected SELECT query");
639    }
640
641    #[test]
642    fn test_analyze_inner_join() {
643        let sql = "SELECT * FROM orders o INNER JOIN payments p ON o.order_id = p.order_id";
644        let select = parse_select(sql);
645
646        let analysis = analyze_join(&select).unwrap().unwrap();
647
648        assert_eq!(analysis.join_type, JoinType::Inner);
649        assert_eq!(analysis.left_table, "orders");
650        assert_eq!(analysis.right_table, "payments");
651        assert_eq!(analysis.left_key_column, "order_id");
652        assert_eq!(analysis.right_key_column, "order_id");
653        assert!(analysis.is_lookup_join); // No time bound = lookup join
654    }
655
656    #[test]
657    fn test_analyze_left_join() {
658        let sql = "SELECT * FROM orders o LEFT JOIN customers c ON o.customer_id = c.id";
659        let select = parse_select(sql);
660
661        let analysis = analyze_join(&select).unwrap().unwrap();
662
663        assert_eq!(analysis.join_type, JoinType::Left);
664        assert_eq!(analysis.left_key_column, "customer_id");
665        assert_eq!(analysis.right_key_column, "id");
666    }
667
668    #[test]
669    fn test_analyze_join_using() {
670        let sql = "SELECT * FROM orders o JOIN payments p USING (order_id)";
671        let select = parse_select(sql);
672
673        let analysis = analyze_join(&select).unwrap().unwrap();
674
675        assert_eq!(analysis.left_key_column, "order_id");
676        assert_eq!(analysis.right_key_column, "order_id");
677    }
678
679    #[test]
680    fn test_analyze_stream_stream_join_with_time_bound() {
681        let sql = "SELECT * FROM orders o
682                   JOIN payments p ON o.order_id = p.order_id
683                   AND p.ts BETWEEN o.ts AND o.ts + INTERVAL '1' HOUR";
684        let select = parse_select(sql);
685
686        let analysis = analyze_join(&select).unwrap().unwrap();
687
688        assert!(!analysis.is_lookup_join);
689        assert!(analysis.time_bound.is_some());
690        assert_eq!(analysis.time_bound.unwrap(), Duration::from_secs(3600));
691    }
692
693    #[test]
694    fn test_no_join() {
695        let sql = "SELECT * FROM orders";
696        let select = parse_select(sql);
697
698        let analysis = analyze_join(&select).unwrap();
699        assert!(analysis.is_none());
700    }
701
702    #[test]
703    fn test_has_join() {
704        let sql_with_join = "SELECT * FROM orders o JOIN payments p ON o.id = p.order_id";
705        let sql_without_join = "SELECT * FROM orders";
706
707        let select_with = parse_select(sql_with_join);
708        let select_without = parse_select(sql_without_join);
709
710        assert!(has_join(&select_with));
711        assert!(!has_join(&select_without));
712    }
713
714    #[test]
715    fn test_count_joins() {
716        let sql_one = "SELECT * FROM a JOIN b ON a.id = b.id";
717        let sql_two = "SELECT * FROM a JOIN b ON a.id = b.id JOIN c ON b.id = c.id";
718        let sql_zero = "SELECT * FROM a";
719
720        assert_eq!(count_joins(&parse_select(sql_one)), 1);
721        assert_eq!(count_joins(&parse_select(sql_two)), 2);
722        assert_eq!(count_joins(&parse_select(sql_zero)), 0);
723    }
724
725    #[test]
726    fn test_aliases() {
727        let sql = "SELECT * FROM orders AS o JOIN payments AS p ON o.id = p.order_id";
728        let select = parse_select(sql);
729
730        let analysis = analyze_join(&select).unwrap().unwrap();
731
732        assert_eq!(analysis.left_alias, Some("o".to_string()));
733        assert_eq!(analysis.right_alias, Some("p".to_string()));
734    }
735
736    // -- ASOF JOIN tests --
737
738    fn parse_select_snowflake(sql: &str) -> Select {
739        let dialect = sqlparser::dialect::SnowflakeDialect {};
740        let statements = Parser::parse_sql(&dialect, sql).unwrap();
741        if let Statement::Query(query) = &statements[0] {
742            if let SetExpr::Select(select) = query.body.as_ref() {
743                return *select.clone();
744            }
745        }
746        panic!("Expected SELECT query");
747    }
748
749    #[test]
750    fn test_asof_join_backward() {
751        let sql = "SELECT * FROM trades t \
752                    ASOF JOIN quotes q \
753                    MATCH_CONDITION(t.ts >= q.ts) \
754                    ON t.symbol = q.symbol";
755        let select = parse_select_snowflake(sql);
756        let analysis = analyze_join(&select).unwrap().unwrap();
757
758        assert!(analysis.is_asof_join);
759        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
760        assert_eq!(analysis.join_type, JoinType::AsOf);
761        assert!(analysis.asof_tolerance.is_none());
762    }
763
764    #[test]
765    fn test_asof_join_forward() {
766        let sql = "SELECT * FROM trades t \
767                    ASOF JOIN quotes q \
768                    MATCH_CONDITION(t.ts <= q.ts) \
769                    ON t.symbol = q.symbol";
770        let select = parse_select_snowflake(sql);
771        let analysis = analyze_join(&select).unwrap().unwrap();
772
773        assert!(analysis.is_asof_join);
774        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Forward));
775    }
776
777    #[test]
778    fn test_asof_join_with_tolerance() {
779        let sql = "SELECT * FROM trades t \
780                    ASOF JOIN quotes q \
781                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= 5000) \
782                    ON t.symbol = q.symbol";
783        let select = parse_select_snowflake(sql);
784        let analysis = analyze_join(&select).unwrap().unwrap();
785
786        assert!(analysis.is_asof_join);
787        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
788        assert_eq!(analysis.asof_tolerance, Some(Duration::from_millis(5000)));
789    }
790
791    #[test]
792    fn test_asof_join_with_interval_tolerance() {
793        let sql = "SELECT * FROM trades t \
794                    ASOF JOIN quotes q \
795                    MATCH_CONDITION(t.ts >= q.ts AND t.ts - q.ts <= INTERVAL '5' SECOND) \
796                    ON t.symbol = q.symbol";
797        let select = parse_select_snowflake(sql);
798        let analysis = analyze_join(&select).unwrap().unwrap();
799
800        assert!(analysis.is_asof_join);
801        assert_eq!(analysis.asof_direction, Some(AsofSqlDirection::Backward));
802        assert_eq!(analysis.asof_tolerance, Some(Duration::from_secs(5)));
803    }
804
805    #[test]
806    fn test_asof_join_type_mapping() {
807        let sql = "SELECT * FROM trades t \
808                    ASOF JOIN quotes q \
809                    MATCH_CONDITION(t.ts >= q.ts) \
810                    ON t.symbol = q.symbol";
811        let select = parse_select_snowflake(sql);
812        let analysis = analyze_join(&select).unwrap().unwrap();
813
814        assert_eq!(analysis.join_type, JoinType::AsOf);
815        assert!(!analysis.is_lookup_join);
816    }
817
818    #[test]
819    fn test_asof_join_extracts_time_columns() {
820        let sql = "SELECT * FROM trades t \
821                    ASOF JOIN quotes q \
822                    MATCH_CONDITION(t.ts >= q.ts) \
823                    ON t.symbol = q.symbol";
824        let select = parse_select_snowflake(sql);
825        let analysis = analyze_join(&select).unwrap().unwrap();
826
827        assert_eq!(analysis.left_time_column, Some("ts".to_string()));
828        assert_eq!(analysis.right_time_column, Some("ts".to_string()));
829    }
830
831    #[test]
832    fn test_asof_join_extracts_key_columns() {
833        let sql = "SELECT * FROM trades t \
834                    ASOF JOIN quotes q \
835                    MATCH_CONDITION(t.ts >= q.ts) \
836                    ON t.symbol = q.symbol";
837        let select = parse_select_snowflake(sql);
838        let analysis = analyze_join(&select).unwrap().unwrap();
839
840        assert_eq!(analysis.left_key_column, "symbol");
841        assert_eq!(analysis.right_key_column, "symbol");
842    }
843
844    #[test]
845    fn test_asof_join_aliases() {
846        let sql = "SELECT * FROM trades AS t \
847                    ASOF JOIN quotes AS q \
848                    MATCH_CONDITION(t.ts >= q.ts) \
849                    ON t.symbol = q.symbol";
850        let select = parse_select_snowflake(sql);
851        let analysis = analyze_join(&select).unwrap().unwrap();
852
853        assert_eq!(analysis.left_alias, Some("t".to_string()));
854        assert_eq!(analysis.right_alias, Some("q".to_string()));
855        assert_eq!(analysis.left_table, "trades");
856        assert_eq!(analysis.right_table, "quotes");
857    }
858}