1use sqlparser::ast::{self, Expr, SetExpr};
14
15use crate::error::{Result, SqlError};
16use crate::functions::registry::FunctionRegistry;
17use crate::parser::normalize::{SCHEMA_QUALIFIED_MSG, normalize_ident};
18use crate::types::*;
19
20pub struct SubqueryExtraction {
22 pub joins: Vec<SubqueryJoin>,
24 pub remaining_where: Option<Expr>,
26}
27
28pub struct SubqueryJoin {
30 pub outer_column: String,
32 pub inner_plan: SqlPlan,
34 pub inner_column: String,
36 pub join_type: JoinType,
38}
39
40fn canonical_aggregate_key(function: &str, field: &str) -> String {
41 format!("{function}({field})")
42}
43
44pub fn extract_subqueries(
50 expr: &Expr,
51 catalog: &dyn SqlCatalog,
52 functions: &FunctionRegistry,
53 temporal: crate::TemporalScope,
54) -> Result<SubqueryExtraction> {
55 let mut joins = Vec::new();
56 let remaining = extract_recursive(expr, &mut joins, catalog, functions, temporal)?;
57 Ok(SubqueryExtraction {
58 joins,
59 remaining_where: remaining,
60 })
61}
62
63fn extract_recursive(
68 expr: &Expr,
69 joins: &mut Vec<SubqueryJoin>,
70 catalog: &dyn SqlCatalog,
71 functions: &FunctionRegistry,
72 temporal: crate::TemporalScope,
73) -> Result<Option<Expr>> {
74 match expr {
75 Expr::BinaryOp {
77 left,
78 op: ast::BinaryOperator::And,
79 right,
80 } => {
81 let left_remaining = extract_recursive(left, joins, catalog, functions, temporal)?;
82 let right_remaining = extract_recursive(right, joins, catalog, functions, temporal)?;
83 match (left_remaining, right_remaining) {
84 (None, None) => Ok(None),
85 (Some(l), None) => Ok(Some(l)),
86 (None, Some(r)) => Ok(Some(r)),
87 (Some(l), Some(r)) => Ok(Some(Expr::BinaryOp {
88 left: Box::new(l),
89 op: ast::BinaryOperator::And,
90 right: Box::new(r),
91 })),
92 }
93 }
94
95 Expr::InSubquery {
97 expr: outer_expr,
98 subquery,
99 negated,
100 } => {
101 if let Some(join) =
102 try_plan_in_subquery(outer_expr, subquery, *negated, catalog, functions, temporal)?
103 {
104 joins.push(join);
105 Ok(None) } else {
107 Ok(Some(expr.clone()))
109 }
110 }
111
112 Expr::BinaryOp { left, op, right } if is_comparison_op(op) => {
114 if let Expr::Subquery(subquery) = right.as_ref() {
115 if let Some(scalar) =
116 try_plan_scalar_subquery(subquery, catalog, functions, temporal)?
117 {
118 joins.push(scalar.join);
119 Ok(Some(Expr::BinaryOp {
120 left: left.clone(),
121 op: op.clone(),
122 right: Box::new(scalar.replacement_expr),
123 }))
124 } else {
125 Ok(Some(expr.clone()))
126 }
127 } else {
128 Ok(Some(expr.clone()))
129 }
130 }
131
132 Expr::Exists { subquery, negated } => {
135 if let Some(join) =
136 try_plan_exists_subquery(subquery, *negated, catalog, functions, temporal)?
137 {
138 joins.push(join);
139 Ok(None)
140 } else {
141 Ok(Some(expr.clone()))
142 }
143 }
144
145 Expr::Nested(inner) => extract_recursive(inner, joins, catalog, functions, temporal),
147
148 _ => Ok(Some(expr.clone())),
150 }
151}
152
153fn try_plan_in_subquery(
155 outer_expr: &Expr,
156 subquery: &ast::Query,
157 negated: bool,
158 catalog: &dyn SqlCatalog,
159 functions: &FunctionRegistry,
160 temporal: crate::TemporalScope,
161) -> Result<Option<SubqueryJoin>> {
162 let outer_col = match outer_expr {
164 Expr::Identifier(ident) => normalize_ident(ident),
165 Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
166 let qualified: String = parts
167 .iter()
168 .map(normalize_ident)
169 .collect::<Vec<_>>()
170 .join(".");
171 return Err(SqlError::Unsupported {
172 detail: format!(
173 "schema-qualified column reference '{qualified}': {SCHEMA_QUALIFIED_MSG}"
174 ),
175 });
176 }
177 Expr::CompoundIdentifier(parts) if parts.len() == 2 => normalize_ident(&parts[1]),
178 _ => return Ok(None), };
180
181 let inner_plan = super::select::plan_query(subquery, catalog, functions, temporal)?;
183
184 let inner_col = extract_single_projected_column(subquery)?;
186
187 Ok(Some(SubqueryJoin {
188 outer_column: outer_col,
189 inner_plan,
190 inner_column: inner_col,
191 join_type: if negated {
192 JoinType::Anti
193 } else {
194 JoinType::Semi
195 },
196 }))
197}
198
199fn extract_single_projected_column(query: &ast::Query) -> Result<String> {
203 let select = match &*query.body {
204 SetExpr::Select(s) => s,
205 _ => {
206 return Err(SqlError::Unsupported {
207 detail: "subquery must be a simple SELECT".into(),
208 });
209 }
210 };
211
212 if select.projection.len() != 1 {
213 return Err(SqlError::Unsupported {
214 detail: format!(
215 "subquery must select exactly 1 column, got {}",
216 select.projection.len()
217 ),
218 });
219 }
220
221 match &select.projection[0] {
222 ast::SelectItem::UnnamedExpr(expr) => match expr {
223 Expr::Identifier(ident) => Ok(normalize_ident(ident)),
224 Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
225 let qualified: String = parts
226 .iter()
227 .map(normalize_ident)
228 .collect::<Vec<_>>()
229 .join(".");
230 Err(SqlError::Unsupported {
231 detail: format!(
232 "schema-qualified column reference '{qualified}': {SCHEMA_QUALIFIED_MSG}"
233 ),
234 })
235 }
236 Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(normalize_ident(&parts[1])),
237 _ => Err(SqlError::Unsupported {
238 detail: "subquery projection must be a column reference".into(),
239 }),
240 },
241 ast::SelectItem::ExprWithAlias { alias, .. } => Ok(normalize_ident(alias)),
242 _ => Err(SqlError::Unsupported {
243 detail: "subquery projection must be a column reference".into(),
244 }),
245 }
246}
247
248fn try_plan_exists_subquery(
252 subquery: &ast::Query,
253 negated: bool,
254 catalog: &dyn SqlCatalog,
255 functions: &FunctionRegistry,
256 temporal: crate::TemporalScope,
257) -> Result<Option<SubqueryJoin>> {
258 let select = match &*subquery.body {
259 SetExpr::Select(s) => s,
260 _ => return Ok(None),
261 };
262
263 let (outer_col, inner_col) = match &select.selection {
265 Some(expr) => match extract_correlated_eq(expr) {
266 Some(pair) => pair,
267 None => return Ok(None),
268 },
269 None => return Ok(None),
270 };
271
272 let inner_plan = super::select::plan_query(subquery, catalog, functions, temporal)?;
274
275 Ok(Some(SubqueryJoin {
276 outer_column: outer_col,
277 inner_plan,
278 inner_column: inner_col,
279 join_type: if negated {
280 JoinType::Anti
281 } else {
282 JoinType::Semi
283 },
284 }))
285}
286
287fn extract_correlated_eq(expr: &Expr) -> Option<(String, String)> {
293 match expr {
294 Expr::BinaryOp {
295 left,
296 op: ast::BinaryOperator::Eq,
297 right,
298 } => {
299 let left_parts = extract_qualified_column(left);
300 let right_parts = extract_qualified_column(right);
301 match (left_parts, right_parts) {
302 (Some((_lt, lc)), Some((_rt, rc))) => {
303 Some((rc, lc))
306 }
307 _ => None,
308 }
309 }
310 Expr::BinaryOp {
312 left,
313 op: ast::BinaryOperator::And,
314 right,
315 } => extract_correlated_eq(left).or_else(|| extract_correlated_eq(right)),
316 Expr::Nested(inner) => extract_correlated_eq(inner),
317 _ => None,
318 }
319}
320
321fn extract_qualified_column(expr: &Expr) -> Option<(String, String)> {
326 match expr {
327 Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
328 None
331 }
332 Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
333 Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
334 }
335 Expr::Identifier(ident) => Some((String::new(), normalize_ident(ident))),
336 _ => None,
337 }
338}
339
340fn is_comparison_op(op: &ast::BinaryOperator) -> bool {
341 matches!(
342 op,
343 ast::BinaryOperator::Gt
344 | ast::BinaryOperator::GtEq
345 | ast::BinaryOperator::Lt
346 | ast::BinaryOperator::LtEq
347 | ast::BinaryOperator::Eq
348 | ast::BinaryOperator::NotEq
349 )
350}
351
352struct ScalarSubqueryResult {
354 join: SubqueryJoin,
355 replacement_expr: Expr,
356}
357
358fn try_plan_scalar_subquery(
366 subquery: &ast::Query,
367 catalog: &dyn SqlCatalog,
368 functions: &FunctionRegistry,
369 temporal: crate::TemporalScope,
370) -> Result<Option<ScalarSubqueryResult>> {
371 let inner_plan = super::select::plan_query(subquery, catalog, functions, temporal)?;
372
373 let result_col = match extract_scalar_column(subquery) {
375 Some(col) => col,
376 None => return Ok(None),
377 };
378
379 let replacement = Expr::Identifier(ast::Ident::new(&result_col));
380
381 Ok(Some(ScalarSubqueryResult {
382 join: SubqueryJoin {
383 outer_column: String::new(),
384 inner_plan,
385 inner_column: String::new(),
386 join_type: JoinType::Cross,
387 },
388 replacement_expr: replacement,
389 }))
390}
391
392fn extract_scalar_column(query: &ast::Query) -> Option<String> {
398 let select = match &*query.body {
399 SetExpr::Select(s) => s,
400 _ => return None,
401 };
402 if select.projection.len() != 1 {
403 return None;
404 }
405 match &select.projection[0] {
406 ast::SelectItem::ExprWithAlias { alias, .. } => Some(normalize_ident(alias)),
407 ast::SelectItem::UnnamedExpr(expr) => match expr {
408 Expr::Identifier(ident) => Some(normalize_ident(ident)),
409 Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
410 None
412 }
413 Expr::CompoundIdentifier(parts) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
414 Expr::Function(func) => {
415 let func_name = func
416 .name
417 .0
418 .iter()
419 .map(|p| match p {
420 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
421 _ => String::new(),
422 })
423 .collect::<Vec<_>>()
424 .join(".")
425 .to_lowercase();
426 let arg = match &func.args {
427 ast::FunctionArguments::List(arg_list) => arg_list
428 .args
429 .first()
430 .and_then(|a| match a {
431 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
432 Expr::Identifier(ident),
433 )) => Some(normalize_ident(ident)),
434 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
435 Expr::CompoundIdentifier(parts),
436 )) if parts.len() >= 3 => None,
437 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
438 Expr::CompoundIdentifier(parts),
439 )) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
440 ast::FunctionArg::Unnamed(
441 ast::FunctionArgExpr::Wildcard
442 | ast::FunctionArgExpr::QualifiedWildcard(_),
443 ) => Some("all".to_string()),
444 _ => None,
445 })
446 .unwrap_or_else(|| "*".to_string()),
447 _ => "*".to_string(),
448 };
449 Some(canonical_aggregate_key(&func_name, &arg))
450 }
451 _ => None,
452 },
453 _ => None,
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::extract_scalar_column;
460 use crate::parser::statement::parse_sql;
461 use sqlparser::ast::Statement;
462
463 #[test]
464 fn unaliased_scalar_aggregate_uses_canonical_aggregate_key() {
465 let statements = parse_sql("SELECT AVG(amount) FROM orders").unwrap();
466 let Statement::Query(query) = &statements[0] else {
467 panic!("expected query");
468 };
469 assert_eq!(extract_scalar_column(query), Some("avg(amount)".into()));
470 }
471}