1use sqlparser::ast::{self, Query, Select, SetExpr};
8
9use crate::error::{Result, SqlError};
10use crate::functions::registry::{FunctionRegistry, SearchTrigger};
11use crate::parser::normalize::normalize_ident;
12use crate::resolver::columns::TableScope;
13use crate::resolver::expr::convert_expr;
14use crate::types::*;
15
16pub fn plan_query(
18 query: &Query,
19 catalog: &dyn SqlCatalog,
20 functions: &FunctionRegistry,
21) -> Result<SqlPlan> {
22 if let Some(with) = &query.with
24 && with.recursive
25 {
26 return super::cte::plan_recursive_cte(query, catalog, functions);
27 }
28 if let Some(with) = &query.with
30 && !with.cte_tables.is_empty()
31 {
32 let inner_query = Query {
33 with: None,
34 body: query.body.clone(),
35 order_by: query.order_by.clone(),
36 limit_clause: query.limit_clause.clone(),
37 fetch: query.fetch.clone(),
38 locks: query.locks.clone(),
39 for_clause: query.for_clause.clone(),
40 settings: query.settings.clone(),
41 format_clause: query.format_clause.clone(),
42 pipe_operators: query.pipe_operators.clone(),
43 };
44
45 let mut definitions = Vec::new();
47 let mut cte_names = Vec::new();
48 for cte in &with.cte_tables {
49 let name = normalize_ident(&cte.alias.name);
50 let cte_plan = plan_query(&cte.query, catalog, functions)?;
51 definitions.push((name.clone(), cte_plan));
52 cte_names.push(name);
53 }
54
55 let cte_catalog = CteCatalog {
57 inner: catalog,
58 cte_names,
59 };
60 let outer = plan_query(&inner_query, &cte_catalog, functions)?;
61
62 return Ok(SqlPlan::Cte {
63 definitions,
64 outer: Box::new(outer),
65 });
66 }
67
68 match &*query.body {
70 SetExpr::Select(select) => {
71 let mut plan = plan_select(select, catalog, functions)?;
72 if let Some(order_by) = &query.order_by {
73 plan = apply_order_by(&plan, order_by, functions)?;
74 }
75 plan = apply_limit(plan, &query.limit_clause);
76 Ok(plan)
77 }
78 SetExpr::SetOperation {
79 op,
80 left,
81 right,
82 set_quantifier,
83 } => super::union::plan_set_operation(op, left, right, set_quantifier, catalog, functions),
84 _ => Err(SqlError::Unsupported {
85 detail: format!("query body type: {}", query.body),
86 }),
87 }
88}
89
90fn plan_select(
92 select: &Select,
93 catalog: &dyn SqlCatalog,
94 functions: &FunctionRegistry,
95) -> Result<SqlPlan> {
96 let scope = TableScope::resolve_from(catalog, &select.from)?;
98
99 if select.from.is_empty() {
101 let projection = convert_projection(&select.projection)?;
102 let mut columns = Vec::new();
103 let mut values = Vec::new();
104 for (i, proj) in projection.iter().enumerate() {
105 match proj {
106 Projection::Computed { expr, alias } => {
107 columns.push(alias.clone());
108 values.push(eval_constant_expr(expr, functions));
109 }
110 Projection::Column(name) => {
111 columns.push(name.clone());
112 values.push(SqlValue::Null);
113 }
114 _ => {
115 columns.push(format!("col{i}"));
116 values.push(SqlValue::Null);
117 }
118 }
119 }
120 return Ok(SqlPlan::ConstantResult { columns, values });
121 }
122
123 if let Some(plan) = try_plan_join(select, &scope, catalog, functions)? {
125 return Ok(plan);
126 }
127
128 let table = scope.single_table().ok_or_else(|| SqlError::Unsupported {
130 detail: "multi-table FROM without JOIN".into(),
131 })?;
132
133 let (subquery_joins, effective_where) = if let Some(expr) = &select.selection {
135 let extraction = super::subquery::extract_subqueries(expr, catalog, functions)?;
136 (extraction.joins, extraction.remaining_where)
137 } else {
138 (Vec::new(), None)
139 };
140
141 let filters = match &effective_where {
143 Some(expr) => {
144 if let Some(plan) = try_extract_where_search(expr, table, functions)? {
146 return Ok(plan);
147 }
148 convert_where_to_filters(expr)?
149 }
150 None => Vec::new(),
151 };
152
153 if has_aggregation(select, functions) {
155 let mut plan =
156 super::aggregate::plan_aggregate(select, table, &filters, &scope, functions)?;
157
158 if let SqlPlan::Aggregate { input, .. } = &mut plan {
163 let mut base_input = std::mem::replace(
164 input,
165 Box::new(SqlPlan::ConstantResult {
166 columns: Vec::new(),
167 values: Vec::new(),
168 }),
169 );
170 for sq in subquery_joins
171 .iter()
172 .filter(|sq| sq.join_type != JoinType::Cross)
173 {
174 base_input = Box::new(SqlPlan::Join {
175 left: base_input,
176 right: Box::new(sq.inner_plan.clone()),
177 on: vec![(sq.outer_column.clone(), sq.inner_column.clone())],
178 join_type: sq.join_type,
179 condition: None,
180 limit: 10000,
181 projection: Vec::new(),
182 filters: Vec::new(),
183 });
184 }
185 *input = base_input;
186 }
187
188 for sq in subquery_joins
189 .into_iter()
190 .filter(|sq| sq.join_type == JoinType::Cross)
191 {
192 plan = SqlPlan::Join {
193 left: Box::new(plan),
194 right: Box::new(sq.inner_plan),
195 on: vec![(sq.outer_column, sq.inner_column)],
196 join_type: sq.join_type,
197 condition: None,
198 limit: 10000,
199 projection: Vec::new(),
200 filters: Vec::new(),
201 };
202 }
203 return Ok(plan);
204 }
205
206 let projection = convert_projection(&select.projection)?;
208
209 let window_functions = super::window::extract_window_functions(&select.projection, functions)?;
211
212 let scan_projection = if subquery_joins.is_empty() {
214 projection.clone()
215 } else {
216 Vec::new()
217 };
218
219 let rules = crate::engine_rules::resolve_engine_rules(table.info.engine);
220 let mut plan = rules.plan_scan(crate::engine_rules::ScanParams {
221 collection: table.name.clone(),
222 alias: table.alias.clone(),
223 filters,
224 projection: scan_projection,
225 sort_keys: Vec::new(),
226 limit: None,
227 offset: 0,
228 distinct: select.distinct.is_some(),
229 window_functions,
230 })?;
231
232 for sq in subquery_joins {
234 let join_filters = if sq.join_type == JoinType::Cross {
239 if let SqlPlan::Scan {
240 ref mut filters, ..
241 } = plan
242 {
243 let mut moved = Vec::new();
245 filters.retain(|f| {
246 if has_column_ref_filter(&f.expr) {
247 moved.push(f.clone());
248 false
249 } else {
250 true
251 }
252 });
253 moved
254 } else {
255 Vec::new()
256 }
257 } else {
258 Vec::new()
259 };
260
261 plan = SqlPlan::Join {
262 left: Box::new(plan),
263 right: Box::new(sq.inner_plan),
264 on: vec![(sq.outer_column, sq.inner_column)],
265 join_type: sq.join_type,
266 condition: None,
267 limit: 10000,
268 projection: Vec::new(),
269 filters: join_filters,
270 };
271 }
272
273 if let SqlPlan::Join {
274 projection: ref mut join_projection,
275 ..
276 } = plan
277 {
278 *join_projection = projection;
279 }
280
281 Ok(plan)
282}
283
284fn has_column_ref_filter(expr: &FilterExpr) -> bool {
288 match expr {
289 FilterExpr::Expr(sql_expr) => has_column_comparison(sql_expr),
290 FilterExpr::And(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
291 FilterExpr::Or(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
292 _ => false,
293 }
294}
295
296fn has_column_comparison(expr: &SqlExpr) -> bool {
297 match expr {
298 SqlExpr::BinaryOp { left, right, .. } => {
299 let left_is_col = matches!(left.as_ref(), SqlExpr::Column { .. });
300 let right_is_col = matches!(right.as_ref(), SqlExpr::Column { .. });
301 if left_is_col && right_is_col {
302 return true;
303 }
304 has_column_comparison(left) || has_column_comparison(right)
305 }
306 _ => false,
307 }
308}
309
310fn has_aggregation(select: &Select, functions: &FunctionRegistry) -> bool {
312 let group_by_non_empty = match &select.group_by {
313 ast::GroupByExpr::All(_) => true,
314 ast::GroupByExpr::Expressions(exprs, _) => !exprs.is_empty(),
315 };
316 if group_by_non_empty {
317 return true;
318 }
319 for item in &select.projection {
320 if let ast::SelectItem::UnnamedExpr(expr) | ast::SelectItem::ExprWithAlias { expr, .. } =
321 item
322 && expr_contains_aggregate(expr, functions)
323 {
324 return true;
325 }
326 }
327 false
328}
329
330fn expr_contains_aggregate(expr: &ast::Expr, functions: &FunctionRegistry) -> bool {
332 match expr {
333 ast::Expr::Function(func) => {
334 let name = func
335 .name
336 .0
337 .iter()
338 .map(|p| match p {
339 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
340 _ => String::new(),
341 })
342 .collect::<Vec<_>>()
343 .join(".");
344 if functions.is_aggregate(&name) {
345 return true;
346 }
347 if let ast::FunctionArguments::List(args) = &func.args {
349 for arg in &args.args {
350 if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) = arg
351 && expr_contains_aggregate(e, functions)
352 {
353 return true;
354 }
355 }
356 }
357 false
358 }
359 ast::Expr::BinaryOp { left, right, .. } => {
360 expr_contains_aggregate(left, functions) || expr_contains_aggregate(right, functions)
361 }
362 ast::Expr::Nested(inner) => expr_contains_aggregate(inner, functions),
363 _ => false,
364 }
365}
366
367fn try_extract_where_search(
369 expr: &ast::Expr,
370 table: &crate::resolver::columns::ResolvedTable,
371 functions: &FunctionRegistry,
372) -> Result<Option<SqlPlan>> {
373 match expr {
374 ast::Expr::Function(func) => {
375 let name = func
376 .name
377 .0
378 .iter()
379 .map(|p| match p {
380 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
381 _ => String::new(),
382 })
383 .collect::<Vec<_>>()
384 .join(".");
385 match functions.search_trigger(&name) {
386 SearchTrigger::TextMatch => {
387 let args = extract_func_args(func)?;
388 if args.len() >= 2 {
389 let query_text = extract_string_literal(&args[1])?;
390 return Ok(Some(SqlPlan::TextSearch {
391 collection: table.name.clone(),
392 query: query_text,
393 top_k: 1000,
394 fuzzy: true,
395 filters: Vec::new(),
396 }));
397 }
398 }
399 SearchTrigger::SpatialDWithin
400 | SearchTrigger::SpatialContains
401 | SearchTrigger::SpatialIntersects
402 | SearchTrigger::SpatialWithin => {
403 return plan_spatial_from_where(&name, func, table);
404 }
405 _ => {}
406 }
407 }
408 ast::Expr::BinaryOp {
410 left,
411 op: ast::BinaryOperator::And,
412 right,
413 } => {
414 if let Some(plan) = try_extract_where_search(left, table, functions)? {
415 return Ok(Some(plan));
416 }
417 if let Some(plan) = try_extract_where_search(right, table, functions)? {
418 return Ok(Some(plan));
419 }
420 }
421 _ => {}
422 }
423 Ok(None)
424}
425
426fn plan_spatial_from_where(
427 name: &str,
428 func: &ast::Function,
429 table: &crate::resolver::columns::ResolvedTable,
430) -> Result<Option<SqlPlan>> {
431 let predicate = match name {
432 "st_dwithin" => SpatialPredicate::DWithin,
433 "st_contains" => SpatialPredicate::Contains,
434 "st_intersects" => SpatialPredicate::Intersects,
435 "st_within" => SpatialPredicate::Within,
436 _ => return Ok(None),
437 };
438 let args = extract_func_args(func)?;
439 if args.is_empty() {
440 return Err(SqlError::MissingField {
441 field: "geometry column".into(),
442 context: name.into(),
443 });
444 }
445 let field = extract_column_name(&args[0])?;
446 let geom_arg = args.get(1).ok_or_else(|| SqlError::MissingField {
447 field: "query geometry".into(),
448 context: name.into(),
449 })?;
450 let geom_str = extract_geometry_arg(geom_arg)?;
451 let distance = if args.len() >= 3 {
452 extract_float(&args[2]).unwrap_or(0.0)
453 } else {
454 0.0
455 };
456 Ok(Some(SqlPlan::SpatialScan {
457 collection: table.name.clone(),
458 field,
459 predicate,
460 query_geometry: geom_str.into_bytes(),
461 distance_meters: distance,
462 attribute_filters: Vec::new(),
463 limit: 1000,
464 projection: Vec::new(),
465 }))
466}
467
468fn apply_order_by(
470 plan: &SqlPlan,
471 order_by: &ast::OrderBy,
472 functions: &FunctionRegistry,
473) -> Result<SqlPlan> {
474 let exprs = match &order_by.kind {
475 ast::OrderByKind::Expressions(exprs) => exprs,
476 ast::OrderByKind::All(_) => return Ok(plan.clone()),
477 };
478
479 if exprs.is_empty() {
480 return Ok(plan.clone());
481 }
482
483 let first = &exprs[0];
485 if let Some(search_plan) = try_extract_sort_search(&first.expr, plan, functions)? {
486 return Ok(search_plan);
487 }
488
489 let sort_keys: Vec<SortKey> = exprs
491 .iter()
492 .map(|o| {
493 Ok(SortKey {
494 expr: convert_expr(&o.expr)?,
495 ascending: o.options.asc.unwrap_or(true),
496 nulls_first: o.options.nulls_first.unwrap_or(false),
497 })
498 })
499 .collect::<Result<_>>()?;
500
501 match plan {
502 SqlPlan::Scan {
503 collection,
504 alias,
505 engine,
506 filters,
507 projection,
508 limit,
509 offset,
510 distinct,
511 window_functions,
512 ..
513 } => Ok(SqlPlan::Scan {
514 collection: collection.clone(),
515 alias: alias.clone(),
516 engine: *engine,
517 filters: filters.clone(),
518 projection: projection.clone(),
519 sort_keys,
520 limit: *limit,
521 offset: *offset,
522 distinct: *distinct,
523 window_functions: window_functions.clone(),
524 }),
525 _ => Ok(plan.clone()),
526 }
527}
528
529fn try_extract_sort_search(
531 expr: &ast::Expr,
532 plan: &SqlPlan,
533 functions: &FunctionRegistry,
534) -> Result<Option<SqlPlan>> {
535 if let ast::Expr::Function(func) = expr {
536 let name = func
537 .name
538 .0
539 .iter()
540 .map(|p| match p {
541 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
542 _ => String::new(),
543 })
544 .collect::<Vec<_>>()
545 .join(".");
546 let collection = match plan {
547 SqlPlan::Scan { collection, .. } => collection.clone(),
548 _ => return Ok(None),
549 };
550 let args = extract_func_args(func)?;
551
552 match functions.search_trigger(&name) {
553 SearchTrigger::VectorSearch => {
554 if args.len() < 2 {
555 return Ok(None);
556 }
557 let field = extract_column_name(&args[0])?;
558 let vector = extract_float_array(&args[1])?;
559 let limit = match plan {
560 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
561 _ => 10,
562 };
563 return Ok(Some(SqlPlan::VectorSearch {
564 collection,
565 field,
566 query_vector: vector,
567 top_k: limit,
568 ef_search: limit * 2,
569 filters: match plan {
570 SqlPlan::Scan { filters, .. } => filters.clone(),
571 _ => Vec::new(),
572 },
573 }));
574 }
575 SearchTrigger::TextSearch => {
576 if args.len() >= 2 {
577 let query_text = extract_string_literal(&args[1])?;
578 let limit = match plan {
579 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
580 _ => 10,
581 };
582 return Ok(Some(SqlPlan::TextSearch {
583 collection,
584 query: query_text,
585 top_k: limit,
586 fuzzy: true,
587 filters: match plan {
588 SqlPlan::Scan { filters, .. } => filters.clone(),
589 _ => Vec::new(),
590 },
591 }));
592 }
593 }
594 SearchTrigger::HybridSearch => {
595 return plan_hybrid_from_sort(&args, &collection, plan, functions);
596 }
597 _ => {}
598 }
599 }
600 Ok(None)
601}
602
603fn plan_hybrid_from_sort(
604 args: &[ast::Expr],
605 collection: &str,
606 plan: &SqlPlan,
607 _functions: &FunctionRegistry,
608) -> Result<Option<SqlPlan>> {
609 if args.len() < 2 {
611 return Ok(None);
612 }
613 let vector = match &args[0] {
614 ast::Expr::Function(f) => {
615 let inner_args = extract_func_args(f)?;
616 if inner_args.len() >= 2 {
617 extract_float_array(&inner_args[1]).unwrap_or_default()
618 } else {
619 Vec::new()
620 }
621 }
622 _ => Vec::new(),
623 };
624 let text = match &args[1] {
625 ast::Expr::Function(f) => {
626 let inner_args = extract_func_args(f)?;
627 if inner_args.len() >= 2 {
628 extract_string_literal(&inner_args[1]).unwrap_or_default()
629 } else {
630 String::new()
631 }
632 }
633 _ => String::new(),
634 };
635 let k1 = args
636 .get(2)
637 .and_then(|e| extract_float(e).ok())
638 .unwrap_or(60.0);
639 let k2 = args
640 .get(3)
641 .and_then(|e| extract_float(e).ok())
642 .unwrap_or(60.0);
643 let limit = match plan {
644 SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
645 _ => 10,
646 };
647 let vector_weight = k2 as f32 / (k1 as f32 + k2 as f32);
648
649 Ok(Some(SqlPlan::HybridSearch {
650 collection: collection.into(),
651 query_vector: vector,
652 query_text: text,
653 top_k: limit,
654 ef_search: limit * 2,
655 vector_weight,
656 fuzzy: true,
657 }))
658}
659
660fn apply_limit(mut plan: SqlPlan, limit_clause: &Option<ast::LimitClause>) -> SqlPlan {
662 let (limit_val, offset_val) = match limit_clause {
663 None => (None, 0usize),
664 Some(ast::LimitClause::LimitOffset { limit, offset, .. }) => {
665 let lv = limit.as_ref().and_then(|e| match e {
666 ast::Expr::Value(v) => match &v.value {
667 ast::Value::Number(n, _) => n.parse::<usize>().ok(),
668 _ => None,
669 },
670 _ => None,
671 });
672 let ov = offset
673 .as_ref()
674 .and_then(|o| match &o.value {
675 ast::Expr::Value(v) => match &v.value {
676 ast::Value::Number(n, _) => n.parse::<usize>().ok(),
677 _ => None,
678 },
679 _ => None,
680 })
681 .unwrap_or(0);
682 (lv, ov)
683 }
684 Some(ast::LimitClause::OffsetCommaLimit { offset, limit }) => {
685 let lv = match limit {
686 ast::Expr::Value(v) => match &v.value {
687 ast::Value::Number(n, _) => n.parse::<usize>().ok(),
688 _ => None,
689 },
690 _ => None,
691 };
692 let ov = match offset {
693 ast::Expr::Value(v) => match &v.value {
694 ast::Value::Number(n, _) => n.parse::<usize>().ok(),
695 _ => None,
696 },
697 _ => None,
698 }
699 .unwrap_or(0);
700 (lv, ov)
701 }
702 };
703
704 match plan {
705 SqlPlan::Scan {
706 ref mut limit,
707 ref mut offset,
708 ..
709 } => {
710 *limit = limit_val;
711 *offset = offset_val;
712 }
713 SqlPlan::Aggregate {
714 limit: ref mut l, ..
715 } => {
716 if let Some(lv) = limit_val {
717 *l = lv;
718 }
719 }
720 _ => {}
721 }
722 plan
723}
724
725pub fn convert_projection(items: &[ast::SelectItem]) -> Result<Vec<Projection>> {
729 let mut result = Vec::new();
730 for item in items {
731 match item {
732 ast::SelectItem::UnnamedExpr(expr) => {
733 let sql_expr = convert_expr(expr)?;
734 match &sql_expr {
735 SqlExpr::Column { table, name } => {
736 result.push(Projection::Column(qualified_name(table.as_deref(), name)));
737 }
738 SqlExpr::Wildcard => {
739 result.push(Projection::Star);
740 }
741 _ => {
742 result.push(Projection::Computed {
743 expr: sql_expr,
744 alias: format!("{expr}"),
745 });
746 }
747 }
748 }
749 ast::SelectItem::ExprWithAlias { expr, alias } => {
750 let sql_expr = convert_expr(expr)?;
751 result.push(Projection::Computed {
752 expr: sql_expr,
753 alias: normalize_ident(alias),
754 });
755 }
756 ast::SelectItem::Wildcard(_) => {
757 result.push(Projection::Star);
758 }
759 ast::SelectItem::QualifiedWildcard(kind, _) => {
760 let table_name = match kind {
761 ast::SelectItemQualifiedWildcardKind::ObjectName(name) => {
762 crate::parser::normalize::normalize_object_name(name)
763 }
764 _ => String::new(),
765 };
766 result.push(Projection::QualifiedStar(table_name));
767 }
768 }
769 }
770 Ok(result)
771}
772
773pub fn qualified_name(table: Option<&str>, name: &str) -> String {
775 table.map_or_else(|| name.to_string(), |table| format!("{table}.{name}"))
776}
777
778pub fn convert_where_to_filters(expr: &ast::Expr) -> Result<Vec<Filter>> {
780 let sql_expr = convert_expr(expr)?;
781 Ok(vec![Filter {
782 expr: FilterExpr::Expr(sql_expr),
783 }])
784}
785
786pub(crate) fn extract_func_args(func: &ast::Function) -> Result<Vec<ast::Expr>> {
787 match &func.args {
788 ast::FunctionArguments::List(args) => Ok(args
789 .args
790 .iter()
791 .filter_map(|a| match a {
792 ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(e.clone()),
793 _ => None,
794 })
795 .collect()),
796 _ => Ok(Vec::new()),
797 }
798}
799
800fn eval_constant_expr(expr: &SqlExpr, functions: &FunctionRegistry) -> SqlValue {
805 super::const_fold::fold_constant(expr, functions).unwrap_or(SqlValue::Null)
806}
807
808fn extract_geometry_arg(expr: &ast::Expr) -> Result<String> {
811 match expr {
812 ast::Expr::Function(func) => {
814 let name = func
815 .name
816 .0
817 .iter()
818 .map(|p| match p {
819 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
820 _ => String::new(),
821 })
822 .collect::<Vec<_>>()
823 .join(".");
824 let args = extract_func_args(func)?;
825 match name.as_str() {
826 "st_point" if args.len() >= 2 => {
827 let lon = extract_float(&args[0])?;
828 let lat = extract_float(&args[1])?;
829 Ok(format!(r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#))
830 }
831 "st_geomfromgeojson" if !args.is_empty() => extract_string_literal(&args[0]),
832 _ => Ok(format!("{expr}")),
833 }
834 }
835 _ => extract_string_literal(expr).or_else(|_| Ok(format!("{expr}"))),
837 }
838}
839
840fn extract_column_name(expr: &ast::Expr) -> Result<String> {
841 match expr {
842 ast::Expr::Identifier(ident) => Ok(normalize_ident(ident)),
843 ast::Expr::CompoundIdentifier(parts) => Ok(parts
844 .iter()
845 .map(normalize_ident)
846 .collect::<Vec<_>>()
847 .join(".")),
848 _ => Err(SqlError::Unsupported {
849 detail: format!("expected column name, got: {expr}"),
850 }),
851 }
852}
853
854pub(crate) fn extract_string_literal(expr: &ast::Expr) -> Result<String> {
855 match expr {
856 ast::Expr::Value(v) => match &v.value {
857 ast::Value::SingleQuotedString(s) | ast::Value::DoubleQuotedString(s) => Ok(s.clone()),
858 _ => Err(SqlError::Unsupported {
859 detail: format!("expected string literal, got: {expr}"),
860 }),
861 },
862 _ => Err(SqlError::Unsupported {
863 detail: format!("expected string literal, got: {expr}"),
864 }),
865 }
866}
867
868pub(crate) fn extract_float(expr: &ast::Expr) -> Result<f64> {
869 match expr {
870 ast::Expr::Value(v) => match &v.value {
871 ast::Value::Number(n, _) => n.parse::<f64>().map_err(|_| SqlError::TypeMismatch {
872 detail: format!("expected number: {n}"),
873 }),
874 _ => Err(SqlError::TypeMismatch {
875 detail: format!("expected number, got: {expr}"),
876 }),
877 },
878 ast::Expr::UnaryOp {
880 op: ast::UnaryOperator::Minus,
881 expr: inner,
882 } => extract_float(inner).map(|f| -f),
883 _ => Err(SqlError::TypeMismatch {
884 detail: format!("expected number, got: {expr}"),
885 }),
886 }
887}
888
889fn extract_float_array(expr: &ast::Expr) -> Result<Vec<f32>> {
891 match expr {
892 ast::Expr::Array(ast::Array { elem, .. }) => elem
893 .iter()
894 .map(|e| extract_float(e).map(|f| f as f32))
895 .collect(),
896 ast::Expr::Function(func) => {
897 let name = func
898 .name
899 .0
900 .iter()
901 .map(|p| match p {
902 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
903 _ => String::new(),
904 })
905 .collect::<Vec<_>>()
906 .join(".");
907 if name == "make_array" || name == "array" {
908 let args = extract_func_args(func)?;
909 args.iter()
910 .map(|e| extract_float(e).map(|f| f as f32))
911 .collect()
912 } else {
913 Err(SqlError::Unsupported {
914 detail: format!("expected array, got function: {name}"),
915 })
916 }
917 }
918 _ => Err(SqlError::Unsupported {
919 detail: format!("expected array literal, got: {expr}"),
920 }),
921 }
922}
923
924fn try_plan_join(
926 select: &Select,
927 scope: &TableScope,
928 catalog: &dyn SqlCatalog,
929 functions: &FunctionRegistry,
930) -> Result<Option<SqlPlan>> {
931 if select.from.len() != 1 {
932 return Ok(None);
933 }
934 let from = &select.from[0];
935 if from.joins.is_empty() {
936 return Ok(None);
937 }
938 super::join::plan_join_from_select(select, scope, catalog, functions)
939}
940
941struct CteCatalog<'a> {
943 inner: &'a dyn SqlCatalog,
944 cte_names: Vec<String>,
945}
946
947impl SqlCatalog for CteCatalog<'_> {
948 fn get_collection(
949 &self,
950 name: &str,
951 ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
952 if self.cte_names.iter().any(|n| n == name) {
954 return Ok(Some(CollectionInfo {
955 name: name.into(),
956 engine: EngineType::DocumentSchemaless,
957 columns: Vec::new(),
958 primary_key: Some("id".into()),
959 has_auto_tier: false,
960 }));
961 }
962 self.inner.get_collection(name)
963 }
964}
965
966#[cfg(test)]
967mod tests {
968 use super::*;
969 use crate::functions::registry::FunctionRegistry;
970 use crate::parser::statement::parse_sql;
971 use sqlparser::ast::Statement;
972
973 struct TestCatalog;
974
975 impl SqlCatalog for TestCatalog {
976 fn get_collection(
977 &self,
978 name: &str,
979 ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
980 let info = match name {
981 "products" => Some(CollectionInfo {
982 name: "products".into(),
983 engine: EngineType::DocumentSchemaless,
984 columns: Vec::new(),
985 primary_key: Some("id".into()),
986 has_auto_tier: false,
987 }),
988 "users" => Some(CollectionInfo {
989 name: "users".into(),
990 engine: EngineType::DocumentSchemaless,
991 columns: Vec::new(),
992 primary_key: Some("id".into()),
993 has_auto_tier: false,
994 }),
995 "orders" => Some(CollectionInfo {
996 name: "orders".into(),
997 engine: EngineType::DocumentSchemaless,
998 columns: Vec::new(),
999 primary_key: Some("id".into()),
1000 has_auto_tier: false,
1001 }),
1002 "docs" => Some(CollectionInfo {
1003 name: "docs".into(),
1004 engine: EngineType::DocumentSchemaless,
1005 columns: Vec::new(),
1006 primary_key: Some("id".into()),
1007 has_auto_tier: false,
1008 }),
1009 "tags" => Some(CollectionInfo {
1010 name: "tags".into(),
1011 engine: EngineType::DocumentSchemaless,
1012 columns: Vec::new(),
1013 primary_key: Some("id".into()),
1014 has_auto_tier: false,
1015 }),
1016 "user_prefs" => Some(CollectionInfo {
1017 name: "user_prefs".into(),
1018 engine: EngineType::KeyValue,
1019 columns: Vec::new(),
1020 primary_key: Some("key".into()),
1021 has_auto_tier: false,
1022 }),
1023 _ => None,
1024 };
1025 Ok(info)
1026 }
1027 }
1028
1029 fn plan_select_sql(sql: &str) -> SqlPlan {
1030 let statements = parse_sql(sql).unwrap();
1031 let Statement::Query(query) = &statements[0] else {
1032 panic!("expected query statement");
1033 };
1034 plan_query(query, &TestCatalog, &FunctionRegistry::new()).unwrap()
1035 }
1036
1037 #[test]
1038 fn aggregate_subquery_join_filters_input_before_aggregation() {
1039 let plan = plan_select_sql(
1040 "SELECT AVG(price) FROM products WHERE category IN (SELECT DISTINCT category FROM products WHERE qty > 100)",
1041 );
1042
1043 let SqlPlan::Aggregate { input, .. } = plan else {
1044 panic!("expected aggregate plan");
1045 };
1046
1047 let SqlPlan::Join {
1048 left,
1049 join_type,
1050 on,
1051 ..
1052 } = *input
1053 else {
1054 panic!("expected semi-join below aggregate");
1055 };
1056
1057 assert_eq!(join_type, JoinType::Semi);
1058 assert_eq!(on, vec![("category".into(), "category".into())]);
1059 assert!(matches!(*left, SqlPlan::Scan { .. }));
1060 }
1061
1062 #[test]
1063 fn scalar_subquery_defers_projection_until_after_join_filter() {
1064 let plan = plan_select_sql(
1065 "SELECT user_id FROM orders WHERE amount > (SELECT AVG(amount) FROM orders)",
1066 );
1067
1068 let SqlPlan::Join {
1069 left,
1070 projection,
1071 filters,
1072 ..
1073 } = plan
1074 else {
1075 panic!("expected join plan");
1076 };
1077
1078 let SqlPlan::Scan {
1079 projection: scan_projection,
1080 ..
1081 } = *left
1082 else {
1083 panic!("expected scan on join left");
1084 };
1085
1086 assert!(scan_projection.is_empty(), "scan projected too early");
1087 assert_eq!(projection.len(), 1);
1088 match &projection[0] {
1089 Projection::Column(name) => assert_eq!(name, "user_id"),
1090 other => panic!("expected user_id projection, got {other:?}"),
1091 }
1092 assert!(
1093 !filters.is_empty(),
1094 "scalar comparison should stay post-join"
1095 );
1096 }
1097
1098 #[test]
1099 fn chained_join_preserves_qualified_on_keys() {
1100 let plan = plan_select_sql(
1101 "SELECT d.name, t.tag, p.theme \
1102 FROM docs d \
1103 LEFT JOIN tags t ON d.id = t.doc_id \
1104 INNER JOIN user_prefs p ON d.id = p.key",
1105 );
1106
1107 let SqlPlan::Join { left, on, .. } = plan else {
1108 panic!("expected outer join plan");
1109 };
1110 assert_eq!(on, vec![("d.id".into(), "p.key".into())]);
1111
1112 let SqlPlan::Join { on: inner_on, .. } = *left else {
1113 panic!("expected nested left join");
1114 };
1115 assert_eq!(inner_on, vec![("d.id".into(), "t.doc_id".into())]);
1116 }
1117}