1use sqlparser::ast::{self, Query, Select, SetExpr};
8
9use crate::error::{Result, SqlError};
10use crate::functions::registry::{FunctionRegistry, SearchTrigger};
11use crate::parser::normalize::normalize_ident;
12use crate::resolver::columns::TableScope;
13use crate::resolver::expr::convert_expr;
14use crate::types::*;
15
16pub fn plan_query(
18 query: &Query,
19 catalog: &dyn SqlCatalog,
20 functions: &FunctionRegistry,
21) -> Result<SqlPlan> {
22 if let Some(with) = &query.with
24 && with.recursive
25 {
26 return super::cte::plan_recursive_cte(query, catalog, functions);
27 }
28 if let Some(with) = &query.with
30 && !with.cte_tables.is_empty()
31 {
32 let inner_query = Query {
33 with: None,
34 body: query.body.clone(),
35 order_by: query.order_by.clone(),
36 limit_clause: query.limit_clause.clone(),
37 fetch: query.fetch.clone(),
38 locks: query.locks.clone(),
39 for_clause: query.for_clause.clone(),
40 settings: query.settings.clone(),
41 format_clause: query.format_clause.clone(),
42 pipe_operators: query.pipe_operators.clone(),
43 };
44
45 let mut definitions = Vec::new();
47 let mut cte_names = Vec::new();
48 for cte in &with.cte_tables {
49 let name = normalize_ident(&cte.alias.name);
50 let cte_plan = plan_query(&cte.query, catalog, functions)?;
51 definitions.push((name.clone(), cte_plan));
52 cte_names.push(name);
53 }
54
55 let cte_catalog = CteCatalog {
57 inner: catalog,
58 cte_names,
59 };
60 let outer = plan_query(&inner_query, &cte_catalog, functions)?;
61
62 return Ok(SqlPlan::Cte {
63 definitions,
64 outer: Box::new(outer),
65 });
66 }
67
68 match &*query.body {
70 SetExpr::Select(select) => {
71 let mut plan = plan_select(select, catalog, functions)?;
72 if let Some(order_by) = &query.order_by {
73 plan = apply_order_by(&plan, order_by, functions)?;
74 }
75 plan = apply_limit(plan, &query.limit_clause);
76 Ok(plan)
77 }
78 SetExpr::SetOperation {
79 op,
80 left,
81 right,
82 set_quantifier,
83 } => super::union::plan_set_operation(op, left, right, set_quantifier, catalog, functions),
84 _ => Err(SqlError::Unsupported {
85 detail: format!("query body type: {}", query.body),
86 }),
87 }
88}
89
90fn plan_select(
92 select: &Select,
93 catalog: &dyn SqlCatalog,
94 functions: &FunctionRegistry,
95) -> Result<SqlPlan> {
96 let scope = TableScope::resolve_from(catalog, &select.from)?;
98
99 if select.from.is_empty() {
101 let projection = convert_projection(&select.projection)?;
102 let mut columns = Vec::new();
103 let mut values = Vec::new();
104 for (i, proj) in projection.iter().enumerate() {
105 match proj {
106 Projection::Computed { expr, alias } => {
107 columns.push(alias.clone());
108 values.push(eval_constant_expr(expr, functions));
109 }
110 Projection::Column(name) => {
111 columns.push(name.clone());
112 values.push(SqlValue::Null);
113 }
114 _ => {
115 columns.push(format!("col{i}"));
116 values.push(SqlValue::Null);
117 }
118 }
119 }
120 return Ok(SqlPlan::ConstantResult { columns, values });
121 }
122
123 if let Some(plan) = try_plan_join(select, &scope, catalog, functions)? {
125 return Ok(plan);
126 }
127
128 let table = scope.single_table().ok_or_else(|| SqlError::Unsupported {
130 detail: "multi-table FROM without JOIN".into(),
131 })?;
132
133 let (subquery_joins, effective_where) = if let Some(expr) = &select.selection {
135 let extraction = super::subquery::extract_subqueries(expr, catalog, functions)?;
136 (extraction.joins, extraction.remaining_where)
137 } else {
138 (Vec::new(), None)
139 };
140
141 let filters = match &effective_where {
143 Some(expr) => {
144 if let Some(plan) = try_extract_where_search(expr, table, functions)? {
146 return Ok(plan);
147 }
148 convert_where_to_filters(expr)?
149 }
150 None => Vec::new(),
151 };
152
153 if has_aggregation(select, functions) {
155 let mut plan =
156 super::aggregate::plan_aggregate(select, table, &filters, &scope, functions)?;
157
158 if let SqlPlan::Aggregate { input, .. } = &mut plan {
163 let mut base_input = std::mem::replace(
164 input,
165 Box::new(SqlPlan::ConstantResult {
166 columns: Vec::new(),
167 values: Vec::new(),
168 }),
169 );
170 for sq in subquery_joins
171 .iter()
172 .filter(|sq| sq.join_type != JoinType::Cross)
173 {
174 base_input = Box::new(SqlPlan::Join {
175 left: base_input,
176 right: Box::new(sq.inner_plan.clone()),
177 on: vec![(sq.outer_column.clone(), sq.inner_column.clone())],
178 join_type: sq.join_type,
179 condition: None,
180 limit: 10000,
181 projection: Vec::new(),
182 filters: Vec::new(),
183 });
184 }
185 *input = base_input;
186 }
187
188 for sq in subquery_joins
189 .into_iter()
190 .filter(|sq| sq.join_type == JoinType::Cross)
191 {
192 plan = SqlPlan::Join {
193 left: Box::new(plan),
194 right: Box::new(sq.inner_plan),
195 on: vec![(sq.outer_column, sq.inner_column)],
196 join_type: sq.join_type,
197 condition: None,
198 limit: 10000,
199 projection: Vec::new(),
200 filters: Vec::new(),
201 };
202 }
203 return Ok(plan);
204 }
205
206 let projection = convert_projection(&select.projection)?;
208
209 let window_functions = super::window::extract_window_functions(&select.projection, functions)?;
211
212 let scan_projection = if subquery_joins.is_empty() {
214 projection.clone()
215 } else {
216 Vec::new()
217 };
218
219 let rules = crate::engine_rules::resolve_engine_rules(table.info.engine);
220 let mut plan = rules.plan_scan(crate::engine_rules::ScanParams {
221 collection: table.name.clone(),
222 alias: table.alias.clone(),
223 filters,
224 projection: scan_projection,
225 sort_keys: Vec::new(),
226 limit: None,
227 offset: 0,
228 distinct: select.distinct.is_some(),
229 window_functions,
230 indexes: table.info.indexes.clone(),
231 })?;
232
233 for sq in subquery_joins {
235 let join_filters = if sq.join_type == JoinType::Cross {
240 if let SqlPlan::Scan {
241 ref mut filters, ..
242 } = plan
243 {
244 let mut moved = Vec::new();
246 filters.retain(|f| {
247 if has_column_ref_filter(&f.expr) {
248 moved.push(f.clone());
249 false
250 } else {
251 true
252 }
253 });
254 moved
255 } else {
256 Vec::new()
257 }
258 } else {
259 Vec::new()
260 };
261
262 plan = SqlPlan::Join {
263 left: Box::new(plan),
264 right: Box::new(sq.inner_plan),
265 on: vec![(sq.outer_column, sq.inner_column)],
266 join_type: sq.join_type,
267 condition: None,
268 limit: 10000,
269 projection: Vec::new(),
270 filters: join_filters,
271 };
272 }
273
274 if let SqlPlan::Join {
275 projection: ref mut join_projection,
276 ..
277 } = plan
278 {
279 *join_projection = projection;
280 }
281
282 Ok(plan)
283}
284
285fn has_column_ref_filter(expr: &FilterExpr) -> bool {
289 match expr {
290 FilterExpr::Expr(sql_expr) => has_column_comparison(sql_expr),
291 FilterExpr::And(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
292 FilterExpr::Or(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
293 _ => false,
294 }
295}
296
297fn has_column_comparison(expr: &SqlExpr) -> bool {
298 match expr {
299 SqlExpr::BinaryOp { left, right, .. } => {
300 let left_is_col = matches!(left.as_ref(), SqlExpr::Column { .. });
301 let right_is_col = matches!(right.as_ref(), SqlExpr::Column { .. });
302 if left_is_col && right_is_col {
303 return true;
304 }
305 has_column_comparison(left) || has_column_comparison(right)
306 }
307 _ => false,
308 }
309}
310
311fn has_aggregation(select: &Select, functions: &FunctionRegistry) -> bool {
313 let group_by_non_empty = match &select.group_by {
314 ast::GroupByExpr::All(_) => true,
315 ast::GroupByExpr::Expressions(exprs, _) => !exprs.is_empty(),
316 };
317 if group_by_non_empty {
318 return true;
319 }
320 for item in &select.projection {
321 if let ast::SelectItem::UnnamedExpr(expr) | ast::SelectItem::ExprWithAlias { expr, .. } =
322 item
323 && expr_contains_aggregate(expr, functions)
324 {
325 return true;
326 }
327 }
328 false
329}
330
331fn expr_contains_aggregate(expr: &ast::Expr, functions: &FunctionRegistry) -> bool {
333 match expr {
334 ast::Expr::Function(func) => {
335 let name = func
336 .name
337 .0
338 .iter()
339 .map(|p| match p {
340 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
341 _ => String::new(),
342 })
343 .collect::<Vec<_>>()
344 .join(".");
345 if functions.is_aggregate(&name) {
346 return true;
347 }
348 if let ast::FunctionArguments::List(args) = &func.args {
350 for arg in &args.args {
351 if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) = arg
352 && expr_contains_aggregate(e, functions)
353 {
354 return true;
355 }
356 }
357 }
358 false
359 }
360 ast::Expr::BinaryOp { left, right, .. } => {
361 expr_contains_aggregate(left, functions) || expr_contains_aggregate(right, functions)
362 }
363 ast::Expr::Nested(inner) => expr_contains_aggregate(inner, functions),
364 _ => false,
365 }
366}
367
368fn try_extract_where_search(
370 expr: &ast::Expr,
371 table: &crate::resolver::columns::ResolvedTable,
372 functions: &FunctionRegistry,
373) -> Result<Option<SqlPlan>> {
374 match expr {
375 ast::Expr::Function(func) => {
376 let name = func
377 .name
378 .0
379 .iter()
380 .map(|p| match p {
381 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
382 _ => String::new(),
383 })
384 .collect::<Vec<_>>()
385 .join(".");
386 match functions.search_trigger(&name) {
387 SearchTrigger::TextMatch => {
388 let args = extract_func_args(func)?;
389 if args.len() >= 2 {
390 let query_text = extract_string_literal(&args[1])?;
391 return Ok(Some(SqlPlan::TextSearch {
392 collection: table.name.clone(),
393 query: query_text,
394 top_k: 1000,
395 fuzzy: true,
396 filters: Vec::new(),
397 }));
398 }
399 }
400 SearchTrigger::SpatialDWithin
401 | SearchTrigger::SpatialContains
402 | SearchTrigger::SpatialIntersects
403 | SearchTrigger::SpatialWithin => {
404 return plan_spatial_from_where(&name, func, table);
405 }
406 _ => {}
407 }
408 }
409 ast::Expr::BinaryOp {
411 left,
412 op: ast::BinaryOperator::And,
413 right,
414 } => {
415 if let Some(plan) = try_extract_where_search(left, table, functions)? {
416 return Ok(Some(plan));
417 }
418 if let Some(plan) = try_extract_where_search(right, table, functions)? {
419 return Ok(Some(plan));
420 }
421 }
422 _ => {}
423 }
424 Ok(None)
425}
426
427fn plan_spatial_from_where(
428 name: &str,
429 func: &ast::Function,
430 table: &crate::resolver::columns::ResolvedTable,
431) -> Result<Option<SqlPlan>> {
432 let predicate = match name {
433 "st_dwithin" => SpatialPredicate::DWithin,
434 "st_contains" => SpatialPredicate::Contains,
435 "st_intersects" => SpatialPredicate::Intersects,
436 "st_within" => SpatialPredicate::Within,
437 _ => return Ok(None),
438 };
439 let args = extract_func_args(func)?;
440 if args.is_empty() {
441 return Err(SqlError::MissingField {
442 field: "geometry column".into(),
443 context: name.into(),
444 });
445 }
446 let field = extract_column_name(&args[0])?;
447 let geom_arg = args.get(1).ok_or_else(|| SqlError::MissingField {
448 field: "query geometry".into(),
449 context: name.into(),
450 })?;
451 let geom_str = extract_geometry_arg(geom_arg)?;
452 let distance = if args.len() >= 3 {
453 extract_float(&args[2]).unwrap_or(0.0)
454 } else {
455 0.0
456 };
457 Ok(Some(SqlPlan::SpatialScan {
458 collection: table.name.clone(),
459 field,
460 predicate,
461 query_geometry: geom_str.into_bytes(),
462 distance_meters: distance,
463 attribute_filters: Vec::new(),
464 limit: 1000,
465 projection: Vec::new(),
466 }))
467}
468
469fn apply_order_by(
471 plan: &SqlPlan,
472 order_by: &ast::OrderBy,
473 functions: &FunctionRegistry,
474) -> Result<SqlPlan> {
475 let exprs = match &order_by.kind {
476 ast::OrderByKind::Expressions(exprs) => exprs,
477 ast::OrderByKind::All(_) => return Ok(plan.clone()),
478 };
479
480 if exprs.is_empty() {
481 return Ok(plan.clone());
482 }
483
484 let first = &exprs[0];
486 if let Some(search_plan) = try_extract_sort_search(&first.expr, plan, functions)? {
487 return Ok(search_plan);
488 }
489
490 let sort_keys: Vec<SortKey> = exprs
492 .iter()
493 .map(|o| {
494 Ok(SortKey {
495 expr: convert_expr(&o.expr)?,
496 ascending: o.options.asc.unwrap_or(true),
497 nulls_first: o.options.nulls_first.unwrap_or(false),
498 })
499 })
500 .collect::<Result<_>>()?;
501
502 match plan {
503 SqlPlan::Scan {
504 collection,
505 alias,
506 engine,
507 filters,
508 projection,
509 limit,
510 offset,
511 distinct,
512 window_functions,
513 ..
514 } => Ok(SqlPlan::Scan {
515 collection: collection.clone(),
516 alias: alias.clone(),
517 engine: *engine,
518 filters: filters.clone(),
519 projection: projection.clone(),
520 sort_keys,
521 limit: *limit,
522 offset: *offset,
523 distinct: *distinct,
524 window_functions: window_functions.clone(),
525 }),
526 _ => Ok(plan.clone()),
527 }
528}
529
530fn try_extract_sort_search(
532 expr: &ast::Expr,
533 plan: &SqlPlan,
534 functions: &FunctionRegistry,
535) -> Result<Option<SqlPlan>> {
536 if let ast::Expr::Function(func) = expr {
537 let name = func
538 .name
539 .0
540 .iter()
541 .map(|p| match p {
542 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
543 _ => String::new(),
544 })
545 .collect::<Vec<_>>()
546 .join(".");
547 let collection = match plan {
548 SqlPlan::Scan { collection, .. } => collection.clone(),
549 _ => return Ok(None),
550 };
551 let args = extract_func_args(func)?;
552
553 match functions.search_trigger(&name) {
554 SearchTrigger::VectorSearch => {
555 if args.len() < 2 {
556 return Ok(None);
557 }
558 let field = extract_column_name(&args[0])?;
559 let vector = extract_float_array(&args[1])?;
560 let limit = match plan {
561 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
562 _ => 10,
563 };
564 return Ok(Some(SqlPlan::VectorSearch {
565 collection,
566 field,
567 query_vector: vector,
568 top_k: limit,
569 ef_search: limit * 2,
570 filters: match plan {
571 SqlPlan::Scan { filters, .. } => filters.clone(),
572 _ => Vec::new(),
573 },
574 }));
575 }
576 SearchTrigger::TextSearch if args.len() >= 2 => {
577 let query_text = extract_string_literal(&args[1])?;
578 let limit = match plan {
579 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
580 _ => 10,
581 };
582 return Ok(Some(SqlPlan::TextSearch {
583 collection,
584 query: query_text,
585 top_k: limit,
586 fuzzy: true,
587 filters: match plan {
588 SqlPlan::Scan { filters, .. } => filters.clone(),
589 _ => Vec::new(),
590 },
591 }));
592 }
593 SearchTrigger::TextSearch => {}
594 SearchTrigger::HybridSearch => {
595 return plan_hybrid_from_sort(&args, &collection, plan, functions);
596 }
597 _ => {}
598 }
599 }
600 Ok(None)
601}
602
603fn plan_hybrid_from_sort(
604 args: &[ast::Expr],
605 collection: &str,
606 plan: &SqlPlan,
607 _functions: &FunctionRegistry,
608) -> Result<Option<SqlPlan>> {
609 if args.len() < 2 {
611 return Ok(None);
612 }
613 let vector = match &args[0] {
614 ast::Expr::Function(f) => {
615 let inner_args = extract_func_args(f)?;
616 if inner_args.len() >= 2 {
617 extract_float_array(&inner_args[1]).unwrap_or_default()
618 } else {
619 Vec::new()
620 }
621 }
622 _ => Vec::new(),
623 };
624 let text = match &args[1] {
625 ast::Expr::Function(f) => {
626 let inner_args = extract_func_args(f)?;
627 if inner_args.len() >= 2 {
628 extract_string_literal(&inner_args[1]).unwrap_or_default()
629 } else {
630 String::new()
631 }
632 }
633 _ => String::new(),
634 };
635 let k1 = args
636 .get(2)
637 .and_then(|e| extract_float(e).ok())
638 .unwrap_or(60.0);
639 let k2 = args
640 .get(3)
641 .and_then(|e| extract_float(e).ok())
642 .unwrap_or(60.0);
643 let limit = match plan {
644 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
645 _ => 10,
646 };
647 let vector_weight = k2 as f32 / (k1 as f32 + k2 as f32);
648
649 Ok(Some(SqlPlan::HybridSearch {
650 collection: collection.into(),
651 query_vector: vector,
652 query_text: text,
653 top_k: limit,
654 ef_search: limit * 2,
655 vector_weight,
656 fuzzy: true,
657 }))
658}
659
660fn apply_limit(mut plan: SqlPlan, limit_clause: &Option<ast::LimitClause>) -> SqlPlan {
662 let (limit_val, offset_val) = match limit_clause {
663 None => (None, 0usize),
664 Some(ast::LimitClause::LimitOffset { limit, offset, .. }) => {
665 let lv = limit.as_ref().and_then(|e| match e {
666 ast::Expr::Value(v) => match &v.value {
667 ast::Value::Number(n, _) => n.parse::<usize>().ok(),
668 _ => None,
669 },
670 _ => None,
671 });
672 let ov = offset
673 .as_ref()
674 .and_then(|o| match &o.value {
675 ast::Expr::Value(v) => match &v.value {
676 ast::Value::Number(n, _) => n.parse::<usize>().ok(),
677 _ => None,
678 },
679 _ => None,
680 })
681 .unwrap_or(0);
682 (lv, ov)
683 }
684 Some(ast::LimitClause::OffsetCommaLimit { offset, limit }) => {
685 let lv = match limit {
686 ast::Expr::Value(v) => match &v.value {
687 ast::Value::Number(n, _) => n.parse::<usize>().ok(),
688 _ => None,
689 },
690 _ => None,
691 };
692 let ov = match offset {
693 ast::Expr::Value(v) => match &v.value {
694 ast::Value::Number(n, _) => n.parse::<usize>().ok(),
695 _ => None,
696 },
697 _ => None,
698 }
699 .unwrap_or(0);
700 (lv, ov)
701 }
702 };
703
704 match plan {
705 SqlPlan::Scan {
706 ref mut limit,
707 ref mut offset,
708 ..
709 } => {
710 *limit = limit_val;
711 *offset = offset_val;
712 }
713 SqlPlan::Aggregate {
714 limit: ref mut l, ..
715 } => {
716 if let Some(lv) = limit_val {
717 *l = lv;
718 }
719 }
720 _ => {}
721 }
722 plan
723}
724
725pub fn convert_projection(items: &[ast::SelectItem]) -> Result<Vec<Projection>> {
729 let mut result = Vec::new();
730 for item in items {
731 match item {
732 ast::SelectItem::UnnamedExpr(expr) => {
733 let sql_expr = convert_expr(expr)?;
734 match &sql_expr {
735 SqlExpr::Column { table, name } => {
736 result.push(Projection::Column(qualified_name(table.as_deref(), name)));
737 }
738 SqlExpr::Wildcard => {
739 result.push(Projection::Star);
740 }
741 _ => {
742 result.push(Projection::Computed {
743 expr: sql_expr,
744 alias: format!("{expr}"),
745 });
746 }
747 }
748 }
749 ast::SelectItem::ExprWithAlias { expr, alias } => {
750 let sql_expr = convert_expr(expr)?;
751 result.push(Projection::Computed {
752 expr: sql_expr,
753 alias: normalize_ident(alias),
754 });
755 }
756 ast::SelectItem::Wildcard(_) => {
757 result.push(Projection::Star);
758 }
759 ast::SelectItem::QualifiedWildcard(kind, _) => {
760 let table_name = match kind {
761 ast::SelectItemQualifiedWildcardKind::ObjectName(name) => {
762 crate::parser::normalize::normalize_object_name(name)
763 }
764 _ => String::new(),
765 };
766 result.push(Projection::QualifiedStar(table_name));
767 }
768 }
769 }
770 Ok(result)
771}
772
773pub fn qualified_name(table: Option<&str>, name: &str) -> String {
775 table.map_or_else(|| name.to_string(), |table| format!("{table}.{name}"))
776}
777
778pub fn convert_where_to_filters(expr: &ast::Expr) -> Result<Vec<Filter>> {
780 let sql_expr = convert_expr(expr)?;
781 Ok(vec![Filter {
782 expr: FilterExpr::Expr(sql_expr),
783 }])
784}
785
786pub(crate) fn extract_func_args(func: &ast::Function) -> Result<Vec<ast::Expr>> {
787 match &func.args {
788 ast::FunctionArguments::List(args) => Ok(args
789 .args
790 .iter()
791 .filter_map(|a| match a {
792 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(e.clone()),
793 _ => None,
794 })
795 .collect()),
796 _ => Ok(Vec::new()),
797 }
798}
799
800fn eval_constant_expr(expr: &SqlExpr, functions: &FunctionRegistry) -> SqlValue {
805 super::const_fold::fold_constant(expr, functions).unwrap_or(SqlValue::Null)
806}
807
808fn extract_geometry_arg(expr: &ast::Expr) -> Result<String> {
811 match expr {
812 ast::Expr::Function(func) => {
814 let name = func
815 .name
816 .0
817 .iter()
818 .map(|p| match p {
819 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
820 _ => String::new(),
821 })
822 .collect::<Vec<_>>()
823 .join(".");
824 let args = extract_func_args(func)?;
825 match name.as_str() {
826 "st_point" if args.len() >= 2 => {
827 let lon = extract_float(&args[0])?;
828 let lat = extract_float(&args[1])?;
829 Ok(format!(r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#))
830 }
831 "st_geomfromgeojson" if !args.is_empty() => extract_string_literal(&args[0]),
832 _ => Ok(format!("{expr}")),
833 }
834 }
835 _ => extract_string_literal(expr).or_else(|_| Ok(format!("{expr}"))),
837 }
838}
839
840fn extract_column_name(expr: &ast::Expr) -> Result<String> {
841 match expr {
842 ast::Expr::Identifier(ident) => Ok(normalize_ident(ident)),
843 ast::Expr::CompoundIdentifier(parts) => Ok(parts
844 .iter()
845 .map(normalize_ident)
846 .collect::<Vec<_>>()
847 .join(".")),
848 _ => Err(SqlError::Unsupported {
849 detail: format!("expected column name, got: {expr}"),
850 }),
851 }
852}
853
854pub(crate) fn extract_string_literal(expr: &ast::Expr) -> Result<String> {
855 match expr {
856 ast::Expr::Value(v) => match &v.value {
857 ast::Value::SingleQuotedString(s) | ast::Value::DoubleQuotedString(s) => Ok(s.clone()),
858 _ => Err(SqlError::Unsupported {
859 detail: format!("expected string literal, got: {expr}"),
860 }),
861 },
862 _ => Err(SqlError::Unsupported {
863 detail: format!("expected string literal, got: {expr}"),
864 }),
865 }
866}
867
868pub(crate) fn extract_float(expr: &ast::Expr) -> Result<f64> {
869 match expr {
870 ast::Expr::Value(v) => match &v.value {
871 ast::Value::Number(n, _) => n.parse::<f64>().map_err(|_| SqlError::TypeMismatch {
872 detail: format!("expected number: {n}"),
873 }),
874 _ => Err(SqlError::TypeMismatch {
875 detail: format!("expected number, got: {expr}"),
876 }),
877 },
878 ast::Expr::UnaryOp {
880 op: ast::UnaryOperator::Minus,
881 expr: inner,
882 } => extract_float(inner).map(|f| -f),
883 _ => Err(SqlError::TypeMismatch {
884 detail: format!("expected number, got: {expr}"),
885 }),
886 }
887}
888
889fn extract_float_array(expr: &ast::Expr) -> Result<Vec<f32>> {
891 match expr {
892 ast::Expr::Array(ast::Array { elem, .. }) => elem
893 .iter()
894 .map(|e| extract_float(e).map(|f| f as f32))
895 .collect(),
896 ast::Expr::Function(func) => {
897 let name = func
898 .name
899 .0
900 .iter()
901 .map(|p| match p {
902 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
903 _ => String::new(),
904 })
905 .collect::<Vec<_>>()
906 .join(".");
907 if name == "make_array" || name == "array" {
908 let args = extract_func_args(func)?;
909 args.iter()
910 .map(|e| extract_float(e).map(|f| f as f32))
911 .collect()
912 } else {
913 Err(SqlError::Unsupported {
914 detail: format!("expected array, got function: {name}"),
915 })
916 }
917 }
918 _ => Err(SqlError::Unsupported {
919 detail: format!("expected array literal, got: {expr}"),
920 }),
921 }
922}
923
924fn try_plan_join(
926 select: &Select,
927 scope: &TableScope,
928 catalog: &dyn SqlCatalog,
929 functions: &FunctionRegistry,
930) -> Result<Option<SqlPlan>> {
931 if select.from.len() != 1 {
932 return Ok(None);
933 }
934 let from = &select.from[0];
935 if from.joins.is_empty() {
936 return Ok(None);
937 }
938 super::join::plan_join_from_select(select, scope, catalog, functions)
939}
940
941struct CteCatalog<'a> {
943 inner: &'a dyn SqlCatalog,
944 cte_names: Vec<String>,
945}
946
947impl SqlCatalog for CteCatalog<'_> {
948 fn get_collection(
949 &self,
950 name: &str,
951 ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
952 if self.cte_names.iter().any(|n| n == name) {
954 return Ok(Some(CollectionInfo {
955 name: name.into(),
956 engine: EngineType::DocumentSchemaless,
957 columns: Vec::new(),
958 primary_key: Some("id".into()),
959 has_auto_tier: false,
960 indexes: Vec::new(),
961 }));
962 }
963 self.inner.get_collection(name)
964 }
965}
966
967#[cfg(test)]
968mod tests {
969 use super::*;
970 use crate::functions::registry::FunctionRegistry;
971 use crate::parser::statement::parse_sql;
972 use sqlparser::ast::Statement;
973
974 struct TestCatalog;
975
976 impl SqlCatalog for TestCatalog {
977 fn get_collection(
978 &self,
979 name: &str,
980 ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
981 let info = match name {
982 "products" => Some(CollectionInfo {
983 name: "products".into(),
984 engine: EngineType::DocumentSchemaless,
985 columns: Vec::new(),
986 primary_key: Some("id".into()),
987 has_auto_tier: false,
988 indexes: Vec::new(),
989 }),
990 "users" => Some(CollectionInfo {
991 name: "users".into(),
992 engine: EngineType::DocumentSchemaless,
993 columns: Vec::new(),
994 primary_key: Some("id".into()),
995 has_auto_tier: false,
996 indexes: Vec::new(),
997 }),
998 "orders" => Some(CollectionInfo {
999 name: "orders".into(),
1000 engine: EngineType::DocumentSchemaless,
1001 columns: Vec::new(),
1002 primary_key: Some("id".into()),
1003 has_auto_tier: false,
1004 indexes: Vec::new(),
1005 }),
1006 "docs" => Some(CollectionInfo {
1007 name: "docs".into(),
1008 engine: EngineType::DocumentSchemaless,
1009 columns: Vec::new(),
1010 primary_key: Some("id".into()),
1011 has_auto_tier: false,
1012 indexes: Vec::new(),
1013 }),
1014 "tags" => Some(CollectionInfo {
1015 name: "tags".into(),
1016 engine: EngineType::DocumentSchemaless,
1017 columns: Vec::new(),
1018 primary_key: Some("id".into()),
1019 has_auto_tier: false,
1020 indexes: Vec::new(),
1021 }),
1022 "user_prefs" => Some(CollectionInfo {
1023 name: "user_prefs".into(),
1024 engine: EngineType::KeyValue,
1025 columns: Vec::new(),
1026 primary_key: Some("key".into()),
1027 has_auto_tier: false,
1028 indexes: Vec::new(),
1029 }),
1030 _ => None,
1031 };
1032 Ok(info)
1033 }
1034 }
1035
1036 fn plan_select_sql(sql: &str) -> SqlPlan {
1037 let statements = parse_sql(sql).unwrap();
1038 let Statement::Query(query) = &statements[0] else {
1039 panic!("expected query statement");
1040 };
1041 plan_query(query, &TestCatalog, &FunctionRegistry::new()).unwrap()
1042 }
1043
1044 #[test]
1045 fn aggregate_subquery_join_filters_input_before_aggregation() {
1046 let plan = plan_select_sql(
1047 "SELECT AVG(price) FROM products WHERE category IN (SELECT DISTINCT category FROM products WHERE qty > 100)",
1048 );
1049
1050 let SqlPlan::Aggregate { input, .. } = plan else {
1051 panic!("expected aggregate plan");
1052 };
1053
1054 let SqlPlan::Join {
1055 left,
1056 join_type,
1057 on,
1058 ..
1059 } = *input
1060 else {
1061 panic!("expected semi-join below aggregate");
1062 };
1063
1064 assert_eq!(join_type, JoinType::Semi);
1065 assert_eq!(on, vec![("category".into(), "category".into())]);
1066 assert!(matches!(*left, SqlPlan::Scan { .. }));
1067 }
1068
1069 #[test]
1070 fn scalar_subquery_defers_projection_until_after_join_filter() {
1071 let plan = plan_select_sql(
1072 "SELECT user_id FROM orders WHERE amount > (SELECT AVG(amount) FROM orders)",
1073 );
1074
1075 let SqlPlan::Join {
1076 left,
1077 projection,
1078 filters,
1079 ..
1080 } = plan
1081 else {
1082 panic!("expected join plan");
1083 };
1084
1085 let SqlPlan::Scan {
1086 projection: scan_projection,
1087 ..
1088 } = *left
1089 else {
1090 panic!("expected scan on join left");
1091 };
1092
1093 assert!(scan_projection.is_empty(), "scan projected too early");
1094 assert_eq!(projection.len(), 1);
1095 match &projection[0] {
1096 Projection::Column(name) => assert_eq!(name, "user_id"),
1097 other => panic!("expected user_id projection, got {other:?}"),
1098 }
1099 assert!(
1100 !filters.is_empty(),
1101 "scalar comparison should stay post-join"
1102 );
1103 }
1104
1105 #[test]
1106 fn chained_join_preserves_qualified_on_keys() {
1107 let plan = plan_select_sql(
1108 "SELECT d.name, t.tag, p.theme \
1109 FROM docs d \
1110 LEFT JOIN tags t ON d.id = t.doc_id \
1111 INNER JOIN user_prefs p ON d.id = p.key",
1112 );
1113
1114 let SqlPlan::Join { left, on, .. } = plan else {
1115 panic!("expected outer join plan");
1116 };
1117 assert_eq!(on, vec![("d.id".into(), "p.key".into())]);
1118
1119 let SqlPlan::Join { on: inner_on, .. } = *left else {
1120 panic!("expected nested left join");
1121 };
1122 assert_eq!(inner_on, vec![("d.id".into(), "t.doc_id".into())]);
1123 }
1124}