1use sqlparser::ast::{self, Query, Select, SetExpr};
8
9use crate::error::{Result, SqlError};
10use crate::functions::registry::{FunctionRegistry, SearchTrigger};
11use crate::parser::normalize::normalize_ident;
12use crate::resolver::columns::TableScope;
13use crate::resolver::expr::convert_expr;
14use crate::types::*;
15
16pub fn plan_query(
18 query: &Query,
19 catalog: &dyn SqlCatalog,
20 functions: &FunctionRegistry,
21) -> Result<SqlPlan> {
22 if let Some(with) = &query.with
24 && with.recursive
25 {
26 return super::cte::plan_recursive_cte(query, catalog, functions);
27 }
28 if let Some(with) = &query.with
30 && !with.cte_tables.is_empty()
31 {
32 let inner_query = Query {
33 with: None,
34 body: query.body.clone(),
35 order_by: query.order_by.clone(),
36 limit_clause: query.limit_clause.clone(),
37 fetch: query.fetch.clone(),
38 locks: query.locks.clone(),
39 for_clause: query.for_clause.clone(),
40 settings: query.settings.clone(),
41 format_clause: query.format_clause.clone(),
42 pipe_operators: query.pipe_operators.clone(),
43 };
44
45 let mut definitions = Vec::new();
47 let mut cte_names = Vec::new();
48 for cte in &with.cte_tables {
49 let name = normalize_ident(&cte.alias.name);
50 let cte_plan = plan_query(&cte.query, catalog, functions)?;
51 definitions.push((name.clone(), cte_plan));
52 cte_names.push(name);
53 }
54
55 let cte_catalog = CteCatalog {
57 inner: catalog,
58 cte_names,
59 };
60 let outer = plan_query(&inner_query, &cte_catalog, functions)?;
61
62 return Ok(SqlPlan::Cte {
63 definitions,
64 outer: Box::new(outer),
65 });
66 }
67
68 match &*query.body {
70 SetExpr::Select(select) => {
71 let mut plan = plan_select(select, catalog, functions)?;
72 if let Some(order_by) = &query.order_by {
73 plan = apply_order_by(&plan, order_by, functions)?;
74 }
75 plan = apply_limit(plan, &query.limit_clause);
76 Ok(plan)
77 }
78 SetExpr::SetOperation {
79 op,
80 left,
81 right,
82 set_quantifier,
83 } => super::union::plan_set_operation(op, left, right, set_quantifier, catalog, functions),
84 _ => Err(SqlError::Unsupported {
85 detail: format!("query body type: {}", query.body),
86 }),
87 }
88}
89
90fn plan_select(
92 select: &Select,
93 catalog: &dyn SqlCatalog,
94 functions: &FunctionRegistry,
95) -> Result<SqlPlan> {
96 let scope = TableScope::resolve_from(catalog, &select.from)?;
98
99 if select.from.is_empty() {
101 let projection = convert_projection(&select.projection)?;
102 let mut columns = Vec::new();
103 let mut values = Vec::new();
104 for (i, proj) in projection.iter().enumerate() {
105 match proj {
106 Projection::Computed { expr, alias } => {
107 columns.push(alias.clone());
108 values.push(eval_constant_expr(expr, functions));
109 }
110 Projection::Column(name) => {
111 columns.push(name.clone());
112 values.push(SqlValue::Null);
113 }
114 _ => {
115 columns.push(format!("col{i}"));
116 values.push(SqlValue::Null);
117 }
118 }
119 }
120 return Ok(SqlPlan::ConstantResult { columns, values });
121 }
122
123 if let Some(plan) = try_plan_join(select, &scope, catalog, functions)? {
125 return Ok(plan);
126 }
127
128 let table = scope.single_table().ok_or_else(|| SqlError::Unsupported {
130 detail: "multi-table FROM without JOIN".into(),
131 })?;
132
133 let (subquery_joins, effective_where) = if let Some(expr) = &select.selection {
135 let extraction = super::subquery::extract_subqueries(expr, catalog, functions)?;
136 (extraction.joins, extraction.remaining_where)
137 } else {
138 (Vec::new(), None)
139 };
140
141 let filters = match &effective_where {
143 Some(expr) => {
144 if let Some(plan) = try_extract_where_search(expr, table, functions)? {
146 return Ok(plan);
147 }
148 convert_where_to_filters(expr)?
149 }
150 None => Vec::new(),
151 };
152
153 if has_aggregation(select, functions) {
155 let mut plan =
156 super::aggregate::plan_aggregate(select, table, &filters, &scope, functions)?;
157
158 if let SqlPlan::Aggregate { input, .. } = &mut plan {
163 let mut base_input = std::mem::replace(
164 input,
165 Box::new(SqlPlan::ConstantResult {
166 columns: Vec::new(),
167 values: Vec::new(),
168 }),
169 );
170 for sq in subquery_joins
171 .iter()
172 .filter(|sq| sq.join_type != JoinType::Cross)
173 {
174 base_input = Box::new(SqlPlan::Join {
175 left: base_input,
176 right: Box::new(sq.inner_plan.clone()),
177 on: vec![(sq.outer_column.clone(), sq.inner_column.clone())],
178 join_type: sq.join_type,
179 condition: None,
180 limit: 10000,
181 projection: Vec::new(),
182 filters: Vec::new(),
183 });
184 }
185 *input = base_input;
186 }
187
188 for sq in subquery_joins
189 .into_iter()
190 .filter(|sq| sq.join_type == JoinType::Cross)
191 {
192 plan = SqlPlan::Join {
193 left: Box::new(plan),
194 right: Box::new(sq.inner_plan),
195 on: vec![(sq.outer_column, sq.inner_column)],
196 join_type: sq.join_type,
197 condition: None,
198 limit: 10000,
199 projection: Vec::new(),
200 filters: Vec::new(),
201 };
202 }
203 return Ok(plan);
204 }
205
206 let projection = convert_projection(&select.projection)?;
208
209 let window_functions = super::window::extract_window_functions(&select.projection, functions)?;
211
212 let scan_projection = if subquery_joins.is_empty() {
214 projection.clone()
215 } else {
216 Vec::new()
217 };
218
219 let rules = crate::engine_rules::resolve_engine_rules(table.info.engine);
220 let mut plan = rules.plan_scan(crate::engine_rules::ScanParams {
221 collection: table.name.clone(),
222 alias: table.alias.clone(),
223 filters,
224 projection: scan_projection,
225 sort_keys: Vec::new(),
226 limit: None,
227 offset: 0,
228 distinct: select.distinct.is_some(),
229 window_functions,
230 indexes: table.info.indexes.clone(),
231 })?;
232
233 for sq in subquery_joins {
235 let join_filters = if sq.join_type == JoinType::Cross {
240 if let SqlPlan::Scan {
241 ref mut filters, ..
242 } = plan
243 {
244 let mut moved = Vec::new();
246 filters.retain(|f| {
247 if has_column_ref_filter(&f.expr) {
248 moved.push(f.clone());
249 false
250 } else {
251 true
252 }
253 });
254 moved
255 } else {
256 Vec::new()
257 }
258 } else {
259 Vec::new()
260 };
261
262 plan = SqlPlan::Join {
263 left: Box::new(plan),
264 right: Box::new(sq.inner_plan),
265 on: vec![(sq.outer_column, sq.inner_column)],
266 join_type: sq.join_type,
267 condition: None,
268 limit: 10000,
269 projection: Vec::new(),
270 filters: join_filters,
271 };
272 }
273
274 if let SqlPlan::Join {
275 projection: ref mut join_projection,
276 ..
277 } = plan
278 {
279 *join_projection = projection;
280 }
281
282 Ok(plan)
283}
284
285fn has_column_ref_filter(expr: &FilterExpr) -> bool {
289 match expr {
290 FilterExpr::Expr(sql_expr) => has_column_comparison(sql_expr),
291 FilterExpr::And(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
292 FilterExpr::Or(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
293 _ => false,
294 }
295}
296
297fn has_column_comparison(expr: &SqlExpr) -> bool {
298 match expr {
299 SqlExpr::BinaryOp { left, right, .. } => {
300 let left_is_col = matches!(left.as_ref(), SqlExpr::Column { .. });
301 let right_is_col = matches!(right.as_ref(), SqlExpr::Column { .. });
302 if left_is_col && right_is_col {
303 return true;
304 }
305 has_column_comparison(left) || has_column_comparison(right)
306 }
307 _ => false,
308 }
309}
310
311fn has_aggregation(select: &Select, functions: &FunctionRegistry) -> bool {
313 let group_by_non_empty = match &select.group_by {
314 ast::GroupByExpr::All(_) => true,
315 ast::GroupByExpr::Expressions(exprs, _) => !exprs.is_empty(),
316 };
317 if group_by_non_empty {
318 return true;
319 }
320 for item in &select.projection {
321 if let ast::SelectItem::UnnamedExpr(expr) | ast::SelectItem::ExprWithAlias { expr, .. } =
322 item
323 && crate::aggregate_walk::contains_aggregate(expr, functions)
324 {
325 return true;
326 }
327 }
328 false
329}
330
331fn try_extract_where_search(
333 expr: &ast::Expr,
334 table: &crate::resolver::columns::ResolvedTable,
335 functions: &FunctionRegistry,
336) -> Result<Option<SqlPlan>> {
337 match expr {
338 ast::Expr::Function(func) => {
339 let name = func
340 .name
341 .0
342 .iter()
343 .map(|p| match p {
344 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
345 _ => String::new(),
346 })
347 .collect::<Vec<_>>()
348 .join(".");
349 match functions.search_trigger(&name) {
350 SearchTrigger::TextMatch => {
351 let args = extract_func_args(func)?;
352 if args.len() >= 2 {
353 let query_text = extract_string_literal(&args[1])?;
354 return Ok(Some(SqlPlan::TextSearch {
355 collection: table.name.clone(),
356 query: query_text,
357 top_k: 1000,
358 fuzzy: true,
359 filters: Vec::new(),
360 }));
361 }
362 }
363 SearchTrigger::SpatialDWithin
364 | SearchTrigger::SpatialContains
365 | SearchTrigger::SpatialIntersects
366 | SearchTrigger::SpatialWithin => {
367 return plan_spatial_from_where(&name, func, table);
368 }
369 _ => {}
370 }
371 }
372 ast::Expr::BinaryOp {
374 left,
375 op: ast::BinaryOperator::And,
376 right,
377 } => {
378 if let Some(plan) = try_extract_where_search(left, table, functions)? {
379 return Ok(Some(plan));
380 }
381 if let Some(plan) = try_extract_where_search(right, table, functions)? {
382 return Ok(Some(plan));
383 }
384 }
385 _ => {}
386 }
387 Ok(None)
388}
389
390fn plan_spatial_from_where(
391 name: &str,
392 func: &ast::Function,
393 table: &crate::resolver::columns::ResolvedTable,
394) -> Result<Option<SqlPlan>> {
395 let predicate = match name {
396 "st_dwithin" => SpatialPredicate::DWithin,
397 "st_contains" => SpatialPredicate::Contains,
398 "st_intersects" => SpatialPredicate::Intersects,
399 "st_within" => SpatialPredicate::Within,
400 _ => return Ok(None),
401 };
402 let args = extract_func_args(func)?;
403 if args.is_empty() {
404 return Err(SqlError::MissingField {
405 field: "geometry column".into(),
406 context: name.into(),
407 });
408 }
409 let field = extract_column_name(&args[0])?;
410 let geom_arg = args.get(1).ok_or_else(|| SqlError::MissingField {
411 field: "query geometry".into(),
412 context: name.into(),
413 })?;
414 let geom_str = extract_geometry_arg(geom_arg)?;
415 let distance = if args.len() >= 3 {
416 extract_float(&args[2]).unwrap_or(0.0)
417 } else {
418 0.0
419 };
420 Ok(Some(SqlPlan::SpatialScan {
421 collection: table.name.clone(),
422 field,
423 predicate,
424 query_geometry: geom_str.into_bytes(),
425 distance_meters: distance,
426 attribute_filters: Vec::new(),
427 limit: 1000,
428 projection: Vec::new(),
429 }))
430}
431
432fn apply_order_by(
434 plan: &SqlPlan,
435 order_by: &ast::OrderBy,
436 functions: &FunctionRegistry,
437) -> Result<SqlPlan> {
438 let exprs = match &order_by.kind {
439 ast::OrderByKind::Expressions(exprs) => exprs,
440 ast::OrderByKind::All(_) => return Ok(plan.clone()),
441 };
442
443 if exprs.is_empty() {
444 return Ok(plan.clone());
445 }
446
447 let first = &exprs[0];
449 if let Some(search_plan) = try_extract_sort_search(&first.expr, plan, functions)? {
450 return Ok(search_plan);
451 }
452
453 let sort_keys: Vec<SortKey> = exprs
455 .iter()
456 .map(|o| {
457 Ok(SortKey {
458 expr: convert_expr(&o.expr)?,
459 ascending: o.options.asc.unwrap_or(true),
460 nulls_first: o.options.nulls_first.unwrap_or(false),
461 })
462 })
463 .collect::<Result<_>>()?;
464
465 match plan {
466 SqlPlan::Scan {
467 collection,
468 alias,
469 engine,
470 filters,
471 projection,
472 limit,
473 offset,
474 distinct,
475 window_functions,
476 ..
477 } => Ok(SqlPlan::Scan {
478 collection: collection.clone(),
479 alias: alias.clone(),
480 engine: *engine,
481 filters: filters.clone(),
482 projection: projection.clone(),
483 sort_keys,
484 limit: *limit,
485 offset: *offset,
486 distinct: *distinct,
487 window_functions: window_functions.clone(),
488 }),
489 _ => Ok(plan.clone()),
490 }
491}
492
493fn try_extract_sort_search(
495 expr: &ast::Expr,
496 plan: &SqlPlan,
497 functions: &FunctionRegistry,
498) -> Result<Option<SqlPlan>> {
499 if let ast::Expr::Function(func) = expr {
500 let name = func
501 .name
502 .0
503 .iter()
504 .map(|p| match p {
505 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
506 _ => String::new(),
507 })
508 .collect::<Vec<_>>()
509 .join(".");
510 let collection = match plan {
511 SqlPlan::Scan { collection, .. } => collection.clone(),
512 _ => return Ok(None),
513 };
514 let args = extract_func_args(func)?;
515
516 match functions.search_trigger(&name) {
517 SearchTrigger::VectorSearch => {
518 if args.len() < 2 {
519 return Ok(None);
520 }
521 let field = extract_column_name(&args[0])?;
522 let vector = extract_float_array(&args[1])?;
523 let limit = match plan {
524 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
525 _ => 10,
526 };
527 return Ok(Some(SqlPlan::VectorSearch {
528 collection,
529 field,
530 query_vector: vector,
531 top_k: limit,
532 ef_search: limit * 2,
533 filters: match plan {
534 SqlPlan::Scan { filters, .. } => filters.clone(),
535 _ => Vec::new(),
536 },
537 }));
538 }
539 SearchTrigger::TextSearch if args.len() >= 2 => {
540 let query_text = extract_string_literal(&args[1])?;
541 let limit = match plan {
542 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
543 _ => 10,
544 };
545 return Ok(Some(SqlPlan::TextSearch {
546 collection,
547 query: query_text,
548 top_k: limit,
549 fuzzy: true,
550 filters: match plan {
551 SqlPlan::Scan { filters, .. } => filters.clone(),
552 _ => Vec::new(),
553 },
554 }));
555 }
556 SearchTrigger::TextSearch => {}
557 SearchTrigger::HybridSearch => {
558 return plan_hybrid_from_sort(&args, &collection, plan, functions);
559 }
560 _ => {}
561 }
562 }
563 Ok(None)
564}
565
566fn plan_hybrid_from_sort(
567 args: &[ast::Expr],
568 collection: &str,
569 plan: &SqlPlan,
570 _functions: &FunctionRegistry,
571) -> Result<Option<SqlPlan>> {
572 if args.len() < 2 {
574 return Ok(None);
575 }
576 let vector = match &args[0] {
577 ast::Expr::Function(f) => {
578 let inner_args = extract_func_args(f)?;
579 if inner_args.len() >= 2 {
580 extract_float_array(&inner_args[1]).unwrap_or_default()
581 } else {
582 Vec::new()
583 }
584 }
585 _ => Vec::new(),
586 };
587 let text = match &args[1] {
588 ast::Expr::Function(f) => {
589 let inner_args = extract_func_args(f)?;
590 if inner_args.len() >= 2 {
591 extract_string_literal(&inner_args[1]).unwrap_or_default()
592 } else {
593 String::new()
594 }
595 }
596 _ => String::new(),
597 };
598 let k1 = args
599 .get(2)
600 .and_then(|e| extract_float(e).ok())
601 .unwrap_or(60.0);
602 let k2 = args
603 .get(3)
604 .and_then(|e| extract_float(e).ok())
605 .unwrap_or(60.0);
606 let limit = match plan {
607 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
608 _ => 10,
609 };
610 let vector_weight = k2 as f32 / (k1 as f32 + k2 as f32);
611
612 Ok(Some(SqlPlan::HybridSearch {
613 collection: collection.into(),
614 query_vector: vector,
615 query_text: text,
616 top_k: limit,
617 ef_search: limit * 2,
618 vector_weight,
619 fuzzy: true,
620 }))
621}
622
623fn apply_limit(mut plan: SqlPlan, limit_clause: &Option<ast::LimitClause>) -> SqlPlan {
625 let (limit_val, offset_val) = match limit_clause {
626 None => (None, 0usize),
627 Some(ast::LimitClause::LimitOffset { limit, offset, .. }) => {
628 let lv = limit
629 .as_ref()
630 .and_then(crate::coerce::expr_as_usize_literal);
631 let ov = offset
632 .as_ref()
633 .and_then(|o| crate::coerce::expr_as_usize_literal(&o.value))
634 .unwrap_or(0);
635 (lv, ov)
636 }
637 Some(ast::LimitClause::OffsetCommaLimit { offset, limit }) => {
638 let lv = crate::coerce::expr_as_usize_literal(limit);
639 let ov = crate::coerce::expr_as_usize_literal(offset).unwrap_or(0);
640 (lv, ov)
641 }
642 };
643
644 match plan {
645 SqlPlan::Scan {
646 ref mut limit,
647 ref mut offset,
648 ..
649 } => {
650 *limit = limit_val;
651 *offset = offset_val;
652 }
653 SqlPlan::Aggregate {
654 limit: ref mut l, ..
655 } => {
656 if let Some(lv) = limit_val {
657 *l = lv;
658 }
659 }
660 _ => {}
661 }
662 plan
663}
664
665pub fn convert_projection(items: &[ast::SelectItem]) -> Result<Vec<Projection>> {
669 let mut result = Vec::new();
670 for item in items {
671 match item {
672 ast::SelectItem::UnnamedExpr(expr) => {
673 let sql_expr = convert_expr(expr)?;
674 match &sql_expr {
675 SqlExpr::Column { table, name } => {
676 result.push(Projection::Column(qualified_name(table.as_deref(), name)));
677 }
678 SqlExpr::Wildcard => {
679 result.push(Projection::Star);
680 }
681 _ => {
682 result.push(Projection::Computed {
683 expr: sql_expr,
684 alias: format!("{expr}"),
685 });
686 }
687 }
688 }
689 ast::SelectItem::ExprWithAlias { expr, alias } => {
690 let sql_expr = convert_expr(expr)?;
691 result.push(Projection::Computed {
692 expr: sql_expr,
693 alias: normalize_ident(alias),
694 });
695 }
696 ast::SelectItem::Wildcard(_) => {
697 result.push(Projection::Star);
698 }
699 ast::SelectItem::QualifiedWildcard(kind, _) => {
700 let table_name = match kind {
701 ast::SelectItemQualifiedWildcardKind::ObjectName(name) => {
702 crate::parser::normalize::normalize_object_name(name)
703 }
704 _ => String::new(),
705 };
706 result.push(Projection::QualifiedStar(table_name));
707 }
708 }
709 }
710 Ok(result)
711}
712
713pub fn qualified_name(table: Option<&str>, name: &str) -> String {
715 table.map_or_else(|| name.to_string(), |table| format!("{table}.{name}"))
716}
717
718pub fn convert_where_to_filters(expr: &ast::Expr) -> Result<Vec<Filter>> {
720 let sql_expr = convert_expr(expr)?;
721 Ok(vec![Filter {
722 expr: FilterExpr::Expr(sql_expr),
723 }])
724}
725
726pub(crate) fn extract_func_args(func: &ast::Function) -> Result<Vec<ast::Expr>> {
727 match &func.args {
728 ast::FunctionArguments::List(args) => Ok(args
729 .args
730 .iter()
731 .filter_map(|a| match a {
732 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(e.clone()),
733 _ => None,
734 })
735 .collect()),
736 _ => Ok(Vec::new()),
737 }
738}
739
740fn eval_constant_expr(expr: &SqlExpr, functions: &FunctionRegistry) -> SqlValue {
745 super::const_fold::fold_constant(expr, functions).unwrap_or(SqlValue::Null)
746}
747
748fn extract_geometry_arg(expr: &ast::Expr) -> Result<String> {
751 match expr {
752 ast::Expr::Function(func) => {
754 let name = func
755 .name
756 .0
757 .iter()
758 .map(|p| match p {
759 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
760 _ => String::new(),
761 })
762 .collect::<Vec<_>>()
763 .join(".");
764 let args = extract_func_args(func)?;
765 match name.as_str() {
766 "st_point" if args.len() >= 2 => {
767 let lon = extract_float(&args[0])?;
768 let lat = extract_float(&args[1])?;
769 Ok(format!(r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#))
770 }
771 "st_geomfromgeojson" if !args.is_empty() => extract_string_literal(&args[0]),
772 _ => Ok(format!("{expr}")),
773 }
774 }
775 _ => extract_string_literal(expr).or_else(|_| Ok(format!("{expr}"))),
777 }
778}
779
780fn extract_column_name(expr: &ast::Expr) -> Result<String> {
781 match expr {
782 ast::Expr::Identifier(ident) => Ok(normalize_ident(ident)),
783 ast::Expr::CompoundIdentifier(parts) => Ok(parts
784 .iter()
785 .map(normalize_ident)
786 .collect::<Vec<_>>()
787 .join(".")),
788 _ => Err(SqlError::Unsupported {
789 detail: format!("expected column name, got: {expr}"),
790 }),
791 }
792}
793
794pub(crate) fn extract_string_literal(expr: &ast::Expr) -> Result<String> {
795 match expr {
796 ast::Expr::Value(v) => match &v.value {
797 ast::Value::SingleQuotedString(s) | ast::Value::DoubleQuotedString(s) => Ok(s.clone()),
798 _ => Err(SqlError::Unsupported {
799 detail: format!("expected string literal, got: {expr}"),
800 }),
801 },
802 _ => Err(SqlError::Unsupported {
803 detail: format!("expected string literal, got: {expr}"),
804 }),
805 }
806}
807
808pub(crate) fn extract_float(expr: &ast::Expr) -> Result<f64> {
809 match expr {
810 ast::Expr::Value(v) => match &v.value {
811 ast::Value::Number(n, _) => n.parse::<f64>().map_err(|_| SqlError::TypeMismatch {
812 detail: format!("expected number: {n}"),
813 }),
814 _ => Err(SqlError::TypeMismatch {
815 detail: format!("expected number, got: {expr}"),
816 }),
817 },
818 ast::Expr::UnaryOp {
820 op: ast::UnaryOperator::Minus,
821 expr: inner,
822 } => extract_float(inner).map(|f| -f),
823 _ => Err(SqlError::TypeMismatch {
824 detail: format!("expected number, got: {expr}"),
825 }),
826 }
827}
828
829fn extract_float_array(expr: &ast::Expr) -> Result<Vec<f32>> {
831 match expr {
832 ast::Expr::Array(ast::Array { elem, .. }) => elem
833 .iter()
834 .map(|e| extract_float(e).map(|f| f as f32))
835 .collect(),
836 ast::Expr::Function(func) => {
837 let name = func
838 .name
839 .0
840 .iter()
841 .map(|p| match p {
842 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
843 _ => String::new(),
844 })
845 .collect::<Vec<_>>()
846 .join(".");
847 if name == "make_array" || name == "array" {
848 let args = extract_func_args(func)?;
849 args.iter()
850 .map(|e| extract_float(e).map(|f| f as f32))
851 .collect()
852 } else {
853 Err(SqlError::Unsupported {
854 detail: format!("expected array, got function: {name}"),
855 })
856 }
857 }
858 _ => Err(SqlError::Unsupported {
859 detail: format!("expected array literal, got: {expr}"),
860 }),
861 }
862}
863
864fn try_plan_join(
866 select: &Select,
867 scope: &TableScope,
868 catalog: &dyn SqlCatalog,
869 functions: &FunctionRegistry,
870) -> Result<Option<SqlPlan>> {
871 if select.from.len() != 1 {
872 return Ok(None);
873 }
874 let from = &select.from[0];
875 if from.joins.is_empty() {
876 return Ok(None);
877 }
878 super::join::plan_join_from_select(select, scope, catalog, functions)
879}
880
881struct CteCatalog<'a> {
883 inner: &'a dyn SqlCatalog,
884 cte_names: Vec<String>,
885}
886
887impl SqlCatalog for CteCatalog<'_> {
888 fn get_collection(
889 &self,
890 name: &str,
891 ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
892 if self.cte_names.iter().any(|n| n == name) {
894 return Ok(Some(CollectionInfo {
895 name: name.into(),
896 engine: EngineType::DocumentSchemaless,
897 columns: Vec::new(),
898 primary_key: Some("id".into()),
899 has_auto_tier: false,
900 indexes: Vec::new(),
901 }));
902 }
903 self.inner.get_collection(name)
904 }
905}
906
907#[cfg(test)]
908mod tests {
909 use super::*;
910 use crate::functions::registry::FunctionRegistry;
911 use crate::parser::statement::parse_sql;
912 use sqlparser::ast::Statement;
913
914 struct TestCatalog;
915
916 impl SqlCatalog for TestCatalog {
917 fn get_collection(
918 &self,
919 name: &str,
920 ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
921 let info = match name {
922 "products" => Some(CollectionInfo {
923 name: "products".into(),
924 engine: EngineType::DocumentSchemaless,
925 columns: Vec::new(),
926 primary_key: Some("id".into()),
927 has_auto_tier: false,
928 indexes: Vec::new(),
929 }),
930 "users" => Some(CollectionInfo {
931 name: "users".into(),
932 engine: EngineType::DocumentSchemaless,
933 columns: Vec::new(),
934 primary_key: Some("id".into()),
935 has_auto_tier: false,
936 indexes: Vec::new(),
937 }),
938 "orders" => Some(CollectionInfo {
939 name: "orders".into(),
940 engine: EngineType::DocumentSchemaless,
941 columns: Vec::new(),
942 primary_key: Some("id".into()),
943 has_auto_tier: false,
944 indexes: Vec::new(),
945 }),
946 "docs" => Some(CollectionInfo {
947 name: "docs".into(),
948 engine: EngineType::DocumentSchemaless,
949 columns: Vec::new(),
950 primary_key: Some("id".into()),
951 has_auto_tier: false,
952 indexes: Vec::new(),
953 }),
954 "tags" => Some(CollectionInfo {
955 name: "tags".into(),
956 engine: EngineType::DocumentSchemaless,
957 columns: Vec::new(),
958 primary_key: Some("id".into()),
959 has_auto_tier: false,
960 indexes: Vec::new(),
961 }),
962 "user_prefs" => Some(CollectionInfo {
963 name: "user_prefs".into(),
964 engine: EngineType::KeyValue,
965 columns: Vec::new(),
966 primary_key: Some("key".into()),
967 has_auto_tier: false,
968 indexes: Vec::new(),
969 }),
970 _ => None,
971 };
972 Ok(info)
973 }
974 }
975
976 fn plan_select_sql(sql: &str) -> SqlPlan {
977 let statements = parse_sql(sql).unwrap();
978 let Statement::Query(query) = &statements[0] else {
979 panic!("expected query statement");
980 };
981 plan_query(query, &TestCatalog, &FunctionRegistry::new()).unwrap()
982 }
983
984 #[test]
985 fn aggregate_subquery_join_filters_input_before_aggregation() {
986 let plan = plan_select_sql(
987 "SELECT AVG(price) FROM products WHERE category IN (SELECT DISTINCT category FROM products WHERE qty > 100)",
988 );
989
990 let SqlPlan::Aggregate { input, .. } = plan else {
991 panic!("expected aggregate plan");
992 };
993
994 let SqlPlan::Join {
995 left,
996 join_type,
997 on,
998 ..
999 } = *input
1000 else {
1001 panic!("expected semi-join below aggregate");
1002 };
1003
1004 assert_eq!(join_type, JoinType::Semi);
1005 assert_eq!(on, vec![("category".into(), "category".into())]);
1006 assert!(matches!(*left, SqlPlan::Scan { .. }));
1007 }
1008
1009 #[test]
1010 fn scalar_subquery_defers_projection_until_after_join_filter() {
1011 let plan = plan_select_sql(
1012 "SELECT user_id FROM orders WHERE amount > (SELECT AVG(amount) FROM orders)",
1013 );
1014
1015 let SqlPlan::Join {
1016 left,
1017 projection,
1018 filters,
1019 ..
1020 } = plan
1021 else {
1022 panic!("expected join plan");
1023 };
1024
1025 let SqlPlan::Scan {
1026 projection: scan_projection,
1027 ..
1028 } = *left
1029 else {
1030 panic!("expected scan on join left");
1031 };
1032
1033 assert!(scan_projection.is_empty(), "scan projected too early");
1034 assert_eq!(projection.len(), 1);
1035 match &projection[0] {
1036 Projection::Column(name) => assert_eq!(name, "user_id"),
1037 other => panic!("expected user_id projection, got {other:?}"),
1038 }
1039 assert!(
1040 !filters.is_empty(),
1041 "scalar comparison should stay post-join"
1042 );
1043 }
1044
1045 #[test]
1046 fn chained_join_preserves_qualified_on_keys() {
1047 let plan = plan_select_sql(
1048 "SELECT d.name, t.tag, p.theme \
1049 FROM docs d \
1050 LEFT JOIN tags t ON d.id = t.doc_id \
1051 INNER JOIN user_prefs p ON d.id = p.key",
1052 );
1053
1054 let SqlPlan::Join { left, on, .. } = plan else {
1055 panic!("expected outer join plan");
1056 };
1057 assert_eq!(on, vec![("d.id".into(), "p.key".into())]);
1058
1059 let SqlPlan::Join { on: inner_on, .. } = *left else {
1060 panic!("expected nested left join");
1061 };
1062 assert_eq!(inner_on, vec![("d.id".into(), "t.doc_id".into())]);
1063 }
1064}