1use sqlparser::ast::{self, Expr, SetExpr};
12
13use crate::error::{Result, SqlError};
14use crate::functions::registry::FunctionRegistry;
15use crate::parser::normalize::normalize_ident;
16use crate::types::*;
17
18pub struct SubqueryExtraction {
20 pub joins: Vec<SubqueryJoin>,
22 pub remaining_where: Option<Expr>,
24}
25
26pub struct SubqueryJoin {
28 pub outer_column: String,
30 pub inner_plan: SqlPlan,
32 pub inner_column: String,
34 pub join_type: JoinType,
36}
37
38fn canonical_aggregate_key(function: &str, field: &str) -> String {
39 format!("{function}({field})")
40}
41
42pub fn extract_subqueries(
48 expr: &Expr,
49 catalog: &dyn SqlCatalog,
50 functions: &FunctionRegistry,
51) -> Result<SubqueryExtraction> {
52 let mut joins = Vec::new();
53 let remaining = extract_recursive(expr, &mut joins, catalog, functions)?;
54 Ok(SubqueryExtraction {
55 joins,
56 remaining_where: remaining,
57 })
58}
59
60fn extract_recursive(
65 expr: &Expr,
66 joins: &mut Vec<SubqueryJoin>,
67 catalog: &dyn SqlCatalog,
68 functions: &FunctionRegistry,
69) -> Result<Option<Expr>> {
70 match expr {
71 Expr::BinaryOp {
73 left,
74 op: ast::BinaryOperator::And,
75 right,
76 } => {
77 let left_remaining = extract_recursive(left, joins, catalog, functions)?;
78 let right_remaining = extract_recursive(right, joins, catalog, functions)?;
79 match (left_remaining, right_remaining) {
80 (None, None) => Ok(None),
81 (Some(l), None) => Ok(Some(l)),
82 (None, Some(r)) => Ok(Some(r)),
83 (Some(l), Some(r)) => Ok(Some(Expr::BinaryOp {
84 left: Box::new(l),
85 op: ast::BinaryOperator::And,
86 right: Box::new(r),
87 })),
88 }
89 }
90
91 Expr::InSubquery {
93 expr: outer_expr,
94 subquery,
95 negated,
96 } => {
97 if let Some(join) =
98 try_plan_in_subquery(outer_expr, subquery, *negated, catalog, functions)?
99 {
100 joins.push(join);
101 Ok(None) } else {
103 Ok(Some(expr.clone()))
105 }
106 }
107
108 Expr::BinaryOp { left, op, right } if is_comparison_op(op) => {
110 if let Expr::Subquery(subquery) = right.as_ref() {
111 if let Some(scalar) = try_plan_scalar_subquery(subquery, catalog, functions)? {
112 joins.push(scalar.join);
113 Ok(Some(Expr::BinaryOp {
114 left: left.clone(),
115 op: op.clone(),
116 right: Box::new(scalar.replacement_expr),
117 }))
118 } else {
119 Ok(Some(expr.clone()))
120 }
121 } else {
122 Ok(Some(expr.clone()))
123 }
124 }
125
126 Expr::Exists { subquery, negated } => {
129 if let Some(join) = try_plan_exists_subquery(subquery, *negated, catalog, functions)? {
130 joins.push(join);
131 Ok(None)
132 } else {
133 Ok(Some(expr.clone()))
134 }
135 }
136
137 Expr::Nested(inner) => extract_recursive(inner, joins, catalog, functions),
139
140 _ => Ok(Some(expr.clone())),
142 }
143}
144
145fn try_plan_in_subquery(
147 outer_expr: &Expr,
148 subquery: &ast::Query,
149 negated: bool,
150 catalog: &dyn SqlCatalog,
151 functions: &FunctionRegistry,
152) -> Result<Option<SubqueryJoin>> {
153 let outer_col = match outer_expr {
155 Expr::Identifier(ident) => normalize_ident(ident),
156 Expr::CompoundIdentifier(parts) if parts.len() == 2 => normalize_ident(&parts[1]),
157 _ => return Ok(None), };
159
160 let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
162
163 let inner_col = extract_single_projected_column(subquery)?;
165
166 Ok(Some(SubqueryJoin {
167 outer_column: outer_col,
168 inner_plan,
169 inner_column: inner_col,
170 join_type: if negated {
171 JoinType::Anti
172 } else {
173 JoinType::Semi
174 },
175 }))
176}
177
178fn extract_single_projected_column(query: &ast::Query) -> Result<String> {
182 let select = match &*query.body {
183 SetExpr::Select(s) => s,
184 _ => {
185 return Err(SqlError::Unsupported {
186 detail: "subquery must be a simple SELECT".into(),
187 });
188 }
189 };
190
191 if select.projection.len() != 1 {
192 return Err(SqlError::Unsupported {
193 detail: format!(
194 "subquery must select exactly 1 column, got {}",
195 select.projection.len()
196 ),
197 });
198 }
199
200 match &select.projection[0] {
201 ast::SelectItem::UnnamedExpr(expr) => match expr {
202 Expr::Identifier(ident) => Ok(normalize_ident(ident)),
203 Expr::CompoundIdentifier(parts) if parts.len() == 2 => Ok(normalize_ident(&parts[1])),
204 _ => Err(SqlError::Unsupported {
205 detail: "subquery projection must be a column reference".into(),
206 }),
207 },
208 ast::SelectItem::ExprWithAlias { alias, .. } => Ok(normalize_ident(alias)),
209 _ => Err(SqlError::Unsupported {
210 detail: "subquery projection must be a column reference".into(),
211 }),
212 }
213}
214
215fn try_plan_exists_subquery(
219 subquery: &ast::Query,
220 negated: bool,
221 catalog: &dyn SqlCatalog,
222 functions: &FunctionRegistry,
223) -> Result<Option<SubqueryJoin>> {
224 let select = match &*subquery.body {
225 SetExpr::Select(s) => s,
226 _ => return Ok(None),
227 };
228
229 let (outer_col, inner_col) = match &select.selection {
231 Some(expr) => match extract_correlated_eq(expr) {
232 Some(pair) => pair,
233 None => return Ok(None),
234 },
235 None => return Ok(None),
236 };
237
238 let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
240
241 Ok(Some(SubqueryJoin {
242 outer_column: outer_col,
243 inner_plan,
244 inner_column: inner_col,
245 join_type: if negated {
246 JoinType::Anti
247 } else {
248 JoinType::Semi
249 },
250 }))
251}
252
253fn extract_correlated_eq(expr: &Expr) -> Option<(String, String)> {
259 match expr {
260 Expr::BinaryOp {
261 left,
262 op: ast::BinaryOperator::Eq,
263 right,
264 } => {
265 let left_parts = extract_qualified_column(left);
266 let right_parts = extract_qualified_column(right);
267 match (left_parts, right_parts) {
268 (Some((_lt, lc)), Some((_rt, rc))) => {
269 Some((rc, lc))
272 }
273 _ => None,
274 }
275 }
276 Expr::BinaryOp {
278 left,
279 op: ast::BinaryOperator::And,
280 right,
281 } => extract_correlated_eq(left).or_else(|| extract_correlated_eq(right)),
282 Expr::Nested(inner) => extract_correlated_eq(inner),
283 _ => None,
284 }
285}
286
287fn extract_qualified_column(expr: &Expr) -> Option<(String, String)> {
289 match expr {
290 Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
291 Some((normalize_ident(&parts[0]), normalize_ident(&parts[1])))
292 }
293 Expr::Identifier(ident) => Some((String::new(), normalize_ident(ident))),
294 _ => None,
295 }
296}
297
298fn is_comparison_op(op: &ast::BinaryOperator) -> bool {
299 matches!(
300 op,
301 ast::BinaryOperator::Gt
302 | ast::BinaryOperator::GtEq
303 | ast::BinaryOperator::Lt
304 | ast::BinaryOperator::LtEq
305 | ast::BinaryOperator::Eq
306 | ast::BinaryOperator::NotEq
307 )
308}
309
310struct ScalarSubqueryResult {
312 join: SubqueryJoin,
313 replacement_expr: Expr,
314}
315
316fn try_plan_scalar_subquery(
324 subquery: &ast::Query,
325 catalog: &dyn SqlCatalog,
326 functions: &FunctionRegistry,
327) -> Result<Option<ScalarSubqueryResult>> {
328 let inner_plan = super::select::plan_query(subquery, catalog, functions)?;
329
330 let result_col = match extract_scalar_column(subquery) {
332 Some(col) => col,
333 None => return Ok(None),
334 };
335
336 let replacement = Expr::Identifier(ast::Ident::new(&result_col));
337
338 Ok(Some(ScalarSubqueryResult {
339 join: SubqueryJoin {
340 outer_column: String::new(),
341 inner_plan,
342 inner_column: String::new(),
343 join_type: JoinType::Cross,
344 },
345 replacement_expr: replacement,
346 }))
347}
348
349fn extract_scalar_column(query: &ast::Query) -> Option<String> {
355 let select = match &*query.body {
356 SetExpr::Select(s) => s,
357 _ => return None,
358 };
359 if select.projection.len() != 1 {
360 return None;
361 }
362 match &select.projection[0] {
363 ast::SelectItem::ExprWithAlias { alias, .. } => Some(normalize_ident(alias)),
364 ast::SelectItem::UnnamedExpr(expr) => match expr {
365 Expr::Identifier(ident) => Some(normalize_ident(ident)),
366 Expr::CompoundIdentifier(parts) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
367 Expr::Function(func) => {
368 let func_name = func
369 .name
370 .0
371 .iter()
372 .map(|p| match p {
373 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
374 _ => String::new(),
375 })
376 .collect::<Vec<_>>()
377 .join(".")
378 .to_lowercase();
379 let arg = match &func.args {
380 ast::FunctionArguments::List(arg_list) => arg_list
381 .args
382 .first()
383 .and_then(|a| match a {
384 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
385 Expr::Identifier(ident),
386 )) => Some(normalize_ident(ident)),
387 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
388 Expr::CompoundIdentifier(parts),
389 )) if parts.len() == 2 => Some(normalize_ident(&parts[1])),
390 ast::FunctionArg::Unnamed(
391 ast::FunctionArgExpr::Wildcard
392 | ast::FunctionArgExpr::QualifiedWildcard(_),
393 ) => Some("all".to_string()),
394 _ => None,
395 })
396 .unwrap_or_else(|| "*".to_string()),
397 _ => "*".to_string(),
398 };
399 Some(canonical_aggregate_key(&func_name, &arg))
400 }
401 _ => None,
402 },
403 _ => None,
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::extract_scalar_column;
410 use crate::parser::statement::parse_sql;
411 use sqlparser::ast::Statement;
412
413 #[test]
414 fn unaliased_scalar_aggregate_uses_canonical_aggregate_key() {
415 let statements = parse_sql("SELECT AVG(amount) FROM orders").unwrap();
416 let Statement::Query(query) = &statements[0] else {
417 panic!("expected query");
418 };
419 assert_eq!(extract_scalar_column(query), Some("avg(amount)".into()));
420 }
421}