Skip to main content

krishiv_sql/
subquery.rs

1//! E5.1 — Correlated subquery decorrelation: EXISTS/IN/scalar subquery analysis.
2//!
3//! DataFusion 53 already handles subquery decorrelation for batch queries via
4//! the `DecorrelatePredicateSubquery` optimizer rule. This module adds:
5//!
6//! 1. **AST-level detection** of EXISTS/IN/NOT IN/scalar subquery patterns.
7//! 2. **Streaming guard**: rejects correlated subqueries that reference a
8//!    registered streaming table — DataFusion does not handle these.
9//! 3. **Kind classification** so callers can adapt error messages and explain output.
10
11use std::collections::HashSet;
12
13use datafusion::sql::sqlparser::ast::visit_relations;
14use datafusion::sql::sqlparser::ast::{
15    Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
16    Statement,
17};
18use datafusion::sql::sqlparser::dialect::GenericDialect;
19use datafusion::sql::sqlparser::parser::Parser;
20
21use crate::{SqlError, SqlResult};
22
23// ── Subquery kind ─────────────────────────────────────────────────────────────
24
25/// Classification of a subquery occurrence detected in a SQL statement.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum SubqueryKind {
28    /// `expr IN (SELECT ...)` — rewritten by DataFusion to a left-semi join.
29    InSubquery,
30    /// `expr NOT IN (SELECT ...)` — rewritten to a left-anti join.
31    NotInSubquery,
32    /// `EXISTS (SELECT ...)` — rewritten to a left-semi join.
33    Exists,
34    /// `NOT EXISTS (SELECT ...)` — rewritten to a left-anti join.
35    NotExists,
36    /// `(SELECT single_value)` used as a scalar expression — rewritten to an
37    /// apply/cross-join with a LIMIT 1 inner query.
38    Scalar,
39}
40
41/// A subquery occurrence found in a SQL statement.
42#[derive(Debug, Clone)]
43pub struct DetectedSubquery {
44    pub kind: SubqueryKind,
45    /// The inner query text (as rendered by the AST `Display` impl).
46    pub inner_query: String,
47}
48
49// ── Detection ─────────────────────────────────────────────────────────────────
50
51/// Analyse `sql` and return every subquery occurrence.
52///
53/// Returns an empty vec if the SQL contains no subqueries.
54/// Returns a parse error only when the SQL is syntactically invalid.
55pub fn detect_subqueries(sql: &str) -> SqlResult<Vec<DetectedSubquery>> {
56    let dialect = GenericDialect {};
57    let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Unsupported {
58        feature: format!("subquery detection: parse error: {e}"),
59    })?;
60
61    let mut found = Vec::new();
62
63    for stmt in &stmts {
64        if let Statement::Query(q) = stmt {
65            collect_subqueries_from_query(q, &mut found);
66        }
67    }
68
69    Ok(found)
70}
71
72fn collect_subqueries_from_query(query: &Query, out: &mut Vec<DetectedSubquery>) {
73    if let SetExpr::Select(sel) = query.body.as_ref() {
74        collect_from_select(sel, out);
75    }
76}
77
78fn collect_from_select(sel: &Select, out: &mut Vec<DetectedSubquery>) {
79    for item in &sel.projection {
80        match item {
81            SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
82                collect_from_expr(e, out);
83            }
84            _ => {}
85        }
86    }
87    if let Some(e) = &sel.selection {
88        collect_from_expr(e, out);
89    }
90    if let Some(e) = &sel.having {
91        collect_from_expr(e, out);
92    }
93}
94
95fn collect_from_expr(expr: &Expr, out: &mut Vec<DetectedSubquery>) {
96    match expr {
97        Expr::InSubquery {
98            subquery, negated, ..
99        } => {
100            let kind = if *negated {
101                SubqueryKind::NotInSubquery
102            } else {
103                SubqueryKind::InSubquery
104            };
105            out.push(DetectedSubquery {
106                kind,
107                inner_query: subquery.to_string(),
108            });
109            collect_subqueries_from_query(subquery, out);
110        }
111        Expr::Exists { subquery, negated } => {
112            let kind = if *negated {
113                SubqueryKind::NotExists
114            } else {
115                SubqueryKind::Exists
116            };
117            out.push(DetectedSubquery {
118                kind,
119                inner_query: subquery.to_string(),
120            });
121            collect_subqueries_from_query(subquery, out);
122        }
123        Expr::Subquery(q) => {
124            out.push(DetectedSubquery {
125                kind: SubqueryKind::Scalar,
126                inner_query: q.to_string(),
127            });
128            collect_subqueries_from_query(q, out);
129        }
130        Expr::BinaryOp { left, right, .. } => {
131            collect_from_expr(left, out);
132            collect_from_expr(right, out);
133        }
134        Expr::UnaryOp { expr, .. } => collect_from_expr(expr, out),
135        Expr::IsNull(e) | Expr::IsNotNull(e) => collect_from_expr(e, out),
136        Expr::Between {
137            expr, low, high, ..
138        } => {
139            collect_from_expr(expr, out);
140            collect_from_expr(low, out);
141            collect_from_expr(high, out);
142        }
143        Expr::Case {
144            operand,
145            conditions,
146            else_result,
147            ..
148        } => {
149            if let Some(e) = operand {
150                collect_from_expr(e, out);
151            }
152            for cw in conditions {
153                collect_from_expr(&cw.condition, out);
154                collect_from_expr(&cw.result, out);
155            }
156            if let Some(e) = else_result {
157                collect_from_expr(e, out);
158            }
159        }
160        Expr::Function(f) => {
161            if let FunctionArguments::List(list) = &f.args {
162                for fa in &list.args {
163                    let inner = match fa {
164                        FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
165                        FunctionArg::Named {
166                            arg: FunctionArgExpr::Expr(e),
167                            ..
168                        } => Some(e),
169                        _ => None,
170                    };
171                    if let Some(e) = inner {
172                        collect_from_expr(e, out);
173                    }
174                }
175            }
176        }
177        _ => {}
178    }
179}
180
181// ── Streaming guard ───────────────────────────────────────────────────────────
182
183/// Validate that `sql` contains no subqueries that reference a streaming table.
184///
185/// Returns `Ok(())` when either:
186/// - No subqueries are present, or
187/// - No subquery body references a name in `streaming_tables`.
188///
189/// Returns `Err` when a subquery body contains a streaming table name (case-
190/// insensitive), because DataFusion's decorrelation rules do not handle unbounded
191/// inputs.
192pub fn validate_no_streaming_subqueries(
193    sql: &str,
194    streaming_tables: &HashSet<String>,
195) -> SqlResult<()> {
196    if streaming_tables.is_empty() {
197        return Ok(());
198    }
199
200    // Normalize to lowercase for case-insensitive matching against the SQL
201    // identifier names produced by extract_table_names_from_query.
202    let lower_tables: HashSet<String> = streaming_tables.iter().map(|s| s.to_lowercase()).collect();
203
204    let dialect = GenericDialect {};
205    let stmts = match Parser::parse_sql(&dialect, sql) {
206        Ok(s) => s,
207        Err(_) => return Ok(()), // parse errors are surfaced later by DataFusion
208    };
209
210    for stmt in &stmts {
211        if let Statement::Query(q) = stmt {
212            let mut subqueries = Vec::new();
213            collect_subqueries_from_query(q, &mut subqueries);
214            for sq in &subqueries {
215                let inner_stmts =
216                    Parser::parse_sql(&GenericDialect {}, &sq.inner_query).unwrap_or_default();
217                for s in &inner_stmts {
218                    if let Statement::Query(iq) = s {
219                        let names = extract_table_names_from_query(iq);
220                        if names.iter().any(|t| lower_tables.contains(t)) {
221                            return Err(SqlError::Unsupported {
222                                feature: "correlated subquery over a streaming (unbounded) table \
223                                          is not supported; use a streaming join or MATCH_RECOGNIZE \
224                                          for event-pattern matching"
225                                    .into(),
226                            });
227                        }
228                    }
229                }
230            }
231        }
232    }
233    Ok(())
234}
235
236fn extract_table_names_from_query(query: &Query) -> HashSet<String> {
237    let mut names = HashSet::new();
238    let _ = visit_relations(query, |relation| {
239        names.insert(relation.to_string().to_lowercase());
240        std::ops::ControlFlow::<()>::Continue(())
241    });
242    names
243}
244
245// ── Explain helpers ───────────────────────────────────────────────────────────
246
247/// Return a human-readable summary of subquery kinds found in `sql`.
248///
249/// Returns `None` when `sql` has no subqueries.
250pub fn explain_subqueries(sql: &str) -> Option<String> {
251    let found = detect_subqueries(sql).unwrap_or_default();
252    if found.is_empty() {
253        return None;
254    }
255    let summary = found
256        .iter()
257        .map(|sq| match sq.kind {
258            SubqueryKind::InSubquery => "IN-subquery → semi-join",
259            SubqueryKind::NotInSubquery => "NOT IN-subquery → anti-join",
260            SubqueryKind::Exists => "EXISTS → semi-join",
261            SubqueryKind::NotExists => "NOT EXISTS → anti-join",
262            SubqueryKind::Scalar => "scalar subquery → cross-apply",
263        })
264        .collect::<Vec<_>>()
265        .join(", ");
266    Some(format!("subqueries: [{summary}]"))
267}
268
269// ── Tests ─────────────────────────────────────────────────────────────────────
270
271#[cfg(test)]
272mod tests {
273    use super::*;
274
275    #[test]
276    fn detects_in_subquery() {
277        let sql = "SELECT * FROM orders WHERE customer_id IN (SELECT id FROM vip_customers)";
278        let found = detect_subqueries(sql).unwrap();
279        assert_eq!(found.len(), 1);
280        assert_eq!(found[0].kind, SubqueryKind::InSubquery);
281    }
282
283    #[test]
284    fn detects_not_in_subquery() {
285        let sql = "SELECT * FROM orders WHERE customer_id NOT IN (SELECT id FROM banned)";
286        let found = detect_subqueries(sql).unwrap();
287        assert_eq!(found.len(), 1);
288        assert_eq!(found[0].kind, SubqueryKind::NotInSubquery);
289    }
290
291    #[test]
292    fn detects_exists_subquery() {
293        let sql = "SELECT * FROM orders o WHERE EXISTS (SELECT 1 FROM payments p WHERE p.order_id = o.id)";
294        let found = detect_subqueries(sql).unwrap();
295        assert_eq!(found.len(), 1);
296        assert_eq!(found[0].kind, SubqueryKind::Exists);
297    }
298
299    #[test]
300    fn detects_not_exists_subquery() {
301        let sql = "SELECT * FROM orders o WHERE NOT EXISTS (SELECT 1 FROM payments p WHERE p.order_id = o.id)";
302        let found = detect_subqueries(sql).unwrap();
303        assert_eq!(found.len(), 1);
304        assert_eq!(found[0].kind, SubqueryKind::NotExists);
305    }
306
307    #[test]
308    fn detects_scalar_subquery() {
309        let sql = "SELECT id, (SELECT MAX(amount) FROM payments WHERE order_id = o.id) as max_payment FROM orders o";
310        let found = detect_subqueries(sql).unwrap();
311        assert_eq!(found.len(), 1);
312        assert_eq!(found[0].kind, SubqueryKind::Scalar);
313    }
314
315    #[test]
316    fn detects_nested_subqueries() {
317        let sql = "SELECT * FROM a WHERE x IN (SELECT y FROM b WHERE y NOT IN (SELECT z FROM c))";
318        let found = detect_subqueries(sql).unwrap();
319        assert!(found.len() >= 2);
320        assert!(found.iter().any(|s| s.kind == SubqueryKind::InSubquery));
321        assert!(found.iter().any(|s| s.kind == SubqueryKind::NotInSubquery));
322    }
323
324    #[test]
325    fn no_subqueries_returns_empty() {
326        let sql = "SELECT id, amount FROM orders WHERE status = 'completed'";
327        let found = detect_subqueries(sql).unwrap();
328        assert!(found.is_empty());
329    }
330
331    #[test]
332    fn streaming_guard_passes_when_no_streaming_tables() {
333        let sql = "SELECT * FROM t WHERE id IN (SELECT id FROM s)";
334        let streaming: HashSet<String> = HashSet::new();
335        assert!(validate_no_streaming_subqueries(sql, &streaming).is_ok());
336    }
337
338    #[test]
339    fn streaming_guard_rejects_subquery_over_streaming_table() {
340        let sql = "SELECT * FROM events WHERE id IN (SELECT id FROM live_stream)";
341        let mut streaming = HashSet::new();
342        streaming.insert("live_stream".into());
343        let err = validate_no_streaming_subqueries(sql, &streaming).unwrap_err();
344        assert!(matches!(err, SqlError::Unsupported { .. }));
345    }
346
347    #[test]
348    fn streaming_guard_passes_for_batch_tables() {
349        let sql = "SELECT * FROM events WHERE id IN (SELECT id FROM reference_table)";
350        let mut streaming = HashSet::new();
351        streaming.insert("live_stream".into());
352        assert!(validate_no_streaming_subqueries(sql, &streaming).is_ok());
353    }
354
355    #[test]
356    fn explain_subqueries_returns_none_for_plain_sql() {
357        assert!(explain_subqueries("SELECT 1").is_none());
358    }
359
360    #[test]
361    fn explain_subqueries_describes_kinds() {
362        let sql = "SELECT * FROM t WHERE x IN (SELECT y FROM s)";
363        let desc = explain_subqueries(sql).unwrap();
364        assert!(desc.contains("semi-join"));
365    }
366
367    #[test]
368    fn case_expression_does_not_panic() {
369        let sql = "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t";
370        let found = detect_subqueries(sql).unwrap();
371        assert!(found.is_empty());
372    }
373}