1use std::collections::HashSet;
12
13use datafusion::sql::sqlparser::ast::visit_relations;
14use datafusion::sql::sqlparser::ast::{
15 Expr, FunctionArg, FunctionArgExpr, FunctionArguments, Query, Select, SelectItem, SetExpr,
16 Statement,
17};
18use datafusion::sql::sqlparser::dialect::GenericDialect;
19use datafusion::sql::sqlparser::parser::Parser;
20
21use crate::{SqlError, SqlResult};
22
23#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum SubqueryKind {
28 InSubquery,
30 NotInSubquery,
32 Exists,
34 NotExists,
36 Scalar,
39}
40
41#[derive(Debug, Clone)]
43pub struct DetectedSubquery {
44 pub kind: SubqueryKind,
45 pub inner_query: String,
47}
48
49pub fn detect_subqueries(sql: &str) -> SqlResult<Vec<DetectedSubquery>> {
56 let dialect = GenericDialect {};
57 let stmts = Parser::parse_sql(&dialect, sql).map_err(|e| SqlError::Unsupported {
58 feature: format!("subquery detection: parse error: {e}"),
59 })?;
60
61 let mut found = Vec::new();
62
63 for stmt in &stmts {
64 if let Statement::Query(q) = stmt {
65 collect_subqueries_from_query(q, &mut found);
66 }
67 }
68
69 Ok(found)
70}
71
72fn collect_subqueries_from_query(query: &Query, out: &mut Vec<DetectedSubquery>) {
73 if let SetExpr::Select(sel) = query.body.as_ref() {
74 collect_from_select(sel, out);
75 }
76}
77
78fn collect_from_select(sel: &Select, out: &mut Vec<DetectedSubquery>) {
79 for item in &sel.projection {
80 match item {
81 SelectItem::UnnamedExpr(e) | SelectItem::ExprWithAlias { expr: e, .. } => {
82 collect_from_expr(e, out);
83 }
84 _ => {}
85 }
86 }
87 if let Some(e) = &sel.selection {
88 collect_from_expr(e, out);
89 }
90 if let Some(e) = &sel.having {
91 collect_from_expr(e, out);
92 }
93}
94
95fn collect_from_expr(expr: &Expr, out: &mut Vec<DetectedSubquery>) {
96 match expr {
97 Expr::InSubquery {
98 subquery, negated, ..
99 } => {
100 let kind = if *negated {
101 SubqueryKind::NotInSubquery
102 } else {
103 SubqueryKind::InSubquery
104 };
105 out.push(DetectedSubquery {
106 kind,
107 inner_query: subquery.to_string(),
108 });
109 collect_subqueries_from_query(subquery, out);
110 }
111 Expr::Exists { subquery, negated } => {
112 let kind = if *negated {
113 SubqueryKind::NotExists
114 } else {
115 SubqueryKind::Exists
116 };
117 out.push(DetectedSubquery {
118 kind,
119 inner_query: subquery.to_string(),
120 });
121 collect_subqueries_from_query(subquery, out);
122 }
123 Expr::Subquery(q) => {
124 out.push(DetectedSubquery {
125 kind: SubqueryKind::Scalar,
126 inner_query: q.to_string(),
127 });
128 collect_subqueries_from_query(q, out);
129 }
130 Expr::BinaryOp { left, right, .. } => {
131 collect_from_expr(left, out);
132 collect_from_expr(right, out);
133 }
134 Expr::UnaryOp { expr, .. } => collect_from_expr(expr, out),
135 Expr::IsNull(e) | Expr::IsNotNull(e) => collect_from_expr(e, out),
136 Expr::Between {
137 expr, low, high, ..
138 } => {
139 collect_from_expr(expr, out);
140 collect_from_expr(low, out);
141 collect_from_expr(high, out);
142 }
143 Expr::Case {
144 operand,
145 conditions,
146 else_result,
147 ..
148 } => {
149 if let Some(e) = operand {
150 collect_from_expr(e, out);
151 }
152 for cw in conditions {
153 collect_from_expr(&cw.condition, out);
154 collect_from_expr(&cw.result, out);
155 }
156 if let Some(e) = else_result {
157 collect_from_expr(e, out);
158 }
159 }
160 Expr::Function(f) => {
161 if let FunctionArguments::List(list) = &f.args {
162 for fa in &list.args {
163 let inner = match fa {
164 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
165 FunctionArg::Named {
166 arg: FunctionArgExpr::Expr(e),
167 ..
168 } => Some(e),
169 _ => None,
170 };
171 if let Some(e) = inner {
172 collect_from_expr(e, out);
173 }
174 }
175 }
176 }
177 _ => {}
178 }
179}
180
181pub fn validate_no_streaming_subqueries(
193 sql: &str,
194 streaming_tables: &HashSet<String>,
195) -> SqlResult<()> {
196 if streaming_tables.is_empty() {
197 return Ok(());
198 }
199
200 let lower_tables: HashSet<String> = streaming_tables.iter().map(|s| s.to_lowercase()).collect();
203
204 let dialect = GenericDialect {};
205 let stmts = match Parser::parse_sql(&dialect, sql) {
206 Ok(s) => s,
207 Err(_) => return Ok(()), };
209
210 for stmt in &stmts {
211 if let Statement::Query(q) = stmt {
212 let mut subqueries = Vec::new();
213 collect_subqueries_from_query(q, &mut subqueries);
214 for sq in &subqueries {
215 let inner_stmts =
216 Parser::parse_sql(&GenericDialect {}, &sq.inner_query).unwrap_or_default();
217 for s in &inner_stmts {
218 if let Statement::Query(iq) = s {
219 let names = extract_table_names_from_query(iq);
220 if names.iter().any(|t| lower_tables.contains(t)) {
221 return Err(SqlError::Unsupported {
222 feature: "correlated subquery over a streaming (unbounded) table \
223 is not supported; use a streaming join or MATCH_RECOGNIZE \
224 for event-pattern matching"
225 .into(),
226 });
227 }
228 }
229 }
230 }
231 }
232 }
233 Ok(())
234}
235
236fn extract_table_names_from_query(query: &Query) -> HashSet<String> {
237 let mut names = HashSet::new();
238 let _ = visit_relations(query, |relation| {
239 names.insert(relation.to_string().to_lowercase());
240 std::ops::ControlFlow::<()>::Continue(())
241 });
242 names
243}
244
245pub fn explain_subqueries(sql: &str) -> Option<String> {
251 let found = detect_subqueries(sql).unwrap_or_default();
252 if found.is_empty() {
253 return None;
254 }
255 let summary = found
256 .iter()
257 .map(|sq| match sq.kind {
258 SubqueryKind::InSubquery => "IN-subquery → semi-join",
259 SubqueryKind::NotInSubquery => "NOT IN-subquery → anti-join",
260 SubqueryKind::Exists => "EXISTS → semi-join",
261 SubqueryKind::NotExists => "NOT EXISTS → anti-join",
262 SubqueryKind::Scalar => "scalar subquery → cross-apply",
263 })
264 .collect::<Vec<_>>()
265 .join(", ");
266 Some(format!("subqueries: [{summary}]"))
267}
268
269#[cfg(test)]
272mod tests {
273 use super::*;
274
275 #[test]
276 fn detects_in_subquery() {
277 let sql = "SELECT * FROM orders WHERE customer_id IN (SELECT id FROM vip_customers)";
278 let found = detect_subqueries(sql).unwrap();
279 assert_eq!(found.len(), 1);
280 assert_eq!(found[0].kind, SubqueryKind::InSubquery);
281 }
282
283 #[test]
284 fn detects_not_in_subquery() {
285 let sql = "SELECT * FROM orders WHERE customer_id NOT IN (SELECT id FROM banned)";
286 let found = detect_subqueries(sql).unwrap();
287 assert_eq!(found.len(), 1);
288 assert_eq!(found[0].kind, SubqueryKind::NotInSubquery);
289 }
290
291 #[test]
292 fn detects_exists_subquery() {
293 let sql = "SELECT * FROM orders o WHERE EXISTS (SELECT 1 FROM payments p WHERE p.order_id = o.id)";
294 let found = detect_subqueries(sql).unwrap();
295 assert_eq!(found.len(), 1);
296 assert_eq!(found[0].kind, SubqueryKind::Exists);
297 }
298
299 #[test]
300 fn detects_not_exists_subquery() {
301 let sql = "SELECT * FROM orders o WHERE NOT EXISTS (SELECT 1 FROM payments p WHERE p.order_id = o.id)";
302 let found = detect_subqueries(sql).unwrap();
303 assert_eq!(found.len(), 1);
304 assert_eq!(found[0].kind, SubqueryKind::NotExists);
305 }
306
307 #[test]
308 fn detects_scalar_subquery() {
309 let sql = "SELECT id, (SELECT MAX(amount) FROM payments WHERE order_id = o.id) as max_payment FROM orders o";
310 let found = detect_subqueries(sql).unwrap();
311 assert_eq!(found.len(), 1);
312 assert_eq!(found[0].kind, SubqueryKind::Scalar);
313 }
314
315 #[test]
316 fn detects_nested_subqueries() {
317 let sql = "SELECT * FROM a WHERE x IN (SELECT y FROM b WHERE y NOT IN (SELECT z FROM c))";
318 let found = detect_subqueries(sql).unwrap();
319 assert!(found.len() >= 2);
320 assert!(found.iter().any(|s| s.kind == SubqueryKind::InSubquery));
321 assert!(found.iter().any(|s| s.kind == SubqueryKind::NotInSubquery));
322 }
323
324 #[test]
325 fn no_subqueries_returns_empty() {
326 let sql = "SELECT id, amount FROM orders WHERE status = 'completed'";
327 let found = detect_subqueries(sql).unwrap();
328 assert!(found.is_empty());
329 }
330
331 #[test]
332 fn streaming_guard_passes_when_no_streaming_tables() {
333 let sql = "SELECT * FROM t WHERE id IN (SELECT id FROM s)";
334 let streaming: HashSet<String> = HashSet::new();
335 assert!(validate_no_streaming_subqueries(sql, &streaming).is_ok());
336 }
337
338 #[test]
339 fn streaming_guard_rejects_subquery_over_streaming_table() {
340 let sql = "SELECT * FROM events WHERE id IN (SELECT id FROM live_stream)";
341 let mut streaming = HashSet::new();
342 streaming.insert("live_stream".into());
343 let err = validate_no_streaming_subqueries(sql, &streaming).unwrap_err();
344 assert!(matches!(err, SqlError::Unsupported { .. }));
345 }
346
347 #[test]
348 fn streaming_guard_passes_for_batch_tables() {
349 let sql = "SELECT * FROM events WHERE id IN (SELECT id FROM reference_table)";
350 let mut streaming = HashSet::new();
351 streaming.insert("live_stream".into());
352 assert!(validate_no_streaming_subqueries(sql, &streaming).is_ok());
353 }
354
355 #[test]
356 fn explain_subqueries_returns_none_for_plain_sql() {
357 assert!(explain_subqueries("SELECT 1").is_none());
358 }
359
360 #[test]
361 fn explain_subqueries_describes_kinds() {
362 let sql = "SELECT * FROM t WHERE x IN (SELECT y FROM s)";
363 let desc = explain_subqueries(sql).unwrap();
364 assert!(desc.contains("semi-join"));
365 }
366
367 #[test]
368 fn case_expression_does_not_panic() {
369 let sql = "SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t";
370 let found = detect_subqueries(sql).unwrap();
371 assert!(found.is_empty());
372 }
373}