Skip to main content

nodedb_sql/planner/
select.rs

1//! SELECT query planning: FROM → WHERE → GROUP BY → HAVING → SELECT → ORDER BY → LIMIT.
2//!
3//! This is the main entry point for SELECT statement conversion. It detects
4//! search patterns (vector, text, hybrid, spatial) directly from the AST
5//! instead of reverse-engineering an optimizer's output.
6
7use sqlparser::ast::{self, Query, Select, SetExpr};
8
9use crate::error::{Result, SqlError};
10use crate::functions::registry::{FunctionRegistry, SearchTrigger};
11use crate::parser::normalize::normalize_ident;
12use crate::resolver::columns::TableScope;
13use crate::resolver::expr::convert_expr;
14use crate::types::*;
15
16/// Plan a SELECT query.
17pub fn plan_query(
18    query: &Query,
19    catalog: &dyn SqlCatalog,
20    functions: &FunctionRegistry,
21) -> Result<SqlPlan> {
22    // Handle CTEs (WITH clause).
23    if let Some(with) = &query.with
24        && with.recursive
25    {
26        return super::cte::plan_recursive_cte(query, catalog, functions);
27    }
28    // Non-recursive CTEs: plan each CTE subquery and the outer query.
29    if let Some(with) = &query.with
30        && !with.cte_tables.is_empty()
31    {
32        let inner_query = Query {
33            with: None,
34            body: query.body.clone(),
35            order_by: query.order_by.clone(),
36            limit_clause: query.limit_clause.clone(),
37            fetch: query.fetch.clone(),
38            locks: query.locks.clone(),
39            for_clause: query.for_clause.clone(),
40            settings: query.settings.clone(),
41            format_clause: query.format_clause.clone(),
42            pipe_operators: query.pipe_operators.clone(),
43        };
44
45        // Plan each CTE subquery.
46        let mut definitions = Vec::new();
47        let mut cte_names = Vec::new();
48        for cte in &with.cte_tables {
49            let name = normalize_ident(&cte.alias.name);
50            let cte_plan = plan_query(&cte.query, catalog, functions)?;
51            definitions.push((name.clone(), cte_plan));
52            cte_names.push(name);
53        }
54
55        // Build CTE-aware catalog so the outer query can reference CTE names.
56        let cte_catalog = CteCatalog {
57            inner: catalog,
58            cte_names,
59        };
60        let outer = plan_query(&inner_query, &cte_catalog, functions)?;
61
62        return Ok(SqlPlan::Cte {
63            definitions,
64            outer: Box::new(outer),
65        });
66    }
67
68    // Handle UNION.
69    match &*query.body {
70        SetExpr::Select(select) => {
71            let mut plan = plan_select(select, catalog, functions)?;
72            if let Some(order_by) = &query.order_by {
73                plan = apply_order_by(&plan, order_by, functions)?;
74            }
75            plan = apply_limit(plan, &query.limit_clause);
76            Ok(plan)
77        }
78        SetExpr::SetOperation {
79            op,
80            left,
81            right,
82            set_quantifier,
83        } => super::union::plan_set_operation(op, left, right, set_quantifier, catalog, functions),
84        _ => Err(SqlError::Unsupported {
85            detail: format!("query body type: {}", query.body),
86        }),
87    }
88}
89
90/// Plan a single SELECT statement (no UNION, no CTE wrapper).
91fn plan_select(
92    select: &Select,
93    catalog: &dyn SqlCatalog,
94    functions: &FunctionRegistry,
95) -> Result<SqlPlan> {
96    // 1. Resolve FROM tables.
97    let scope = TableScope::resolve_from(catalog, &select.from)?;
98
99    // 2. Handle constant queries (no FROM clause): SELECT 1, SELECT 'hello', etc.
100    if select.from.is_empty() {
101        let projection = convert_projection(&select.projection)?;
102        let mut columns = Vec::new();
103        let mut values = Vec::new();
104        for (i, proj) in projection.iter().enumerate() {
105            match proj {
106                Projection::Computed { expr, alias } => {
107                    columns.push(alias.clone());
108                    values.push(eval_constant_expr(expr, functions));
109                }
110                Projection::Column(name) => {
111                    columns.push(name.clone());
112                    values.push(SqlValue::Null);
113                }
114                _ => {
115                    columns.push(format!("col{i}"));
116                    values.push(SqlValue::Null);
117                }
118            }
119        }
120        return Ok(SqlPlan::ConstantResult { columns, values });
121    }
122
123    // 3. Check for JOINs.
124    if let Some(plan) = try_plan_join(select, &scope, catalog, functions)? {
125        return Ok(plan);
126    }
127
128    // 4. Single-table query.
129    let table = scope.single_table().ok_or_else(|| SqlError::Unsupported {
130        detail: "multi-table FROM without JOIN".into(),
131    })?;
132
133    // 4. Extract subqueries from WHERE and rewrite as semi/anti joins.
134    let (subquery_joins, effective_where) = if let Some(expr) = &select.selection {
135        let extraction = super::subquery::extract_subqueries(expr, catalog, functions)?;
136        (extraction.joins, extraction.remaining_where)
137    } else {
138        (Vec::new(), None)
139    };
140
141    // 5. Convert remaining WHERE filters.
142    let filters = match &effective_where {
143        Some(expr) => {
144            // Check for search-triggering functions in WHERE.
145            if let Some(plan) = try_extract_where_search(expr, table, functions)? {
146                return Ok(plan);
147            }
148            convert_where_to_filters(expr)?
149        }
150        None => Vec::new(),
151    };
152
153    // 6. Check for GROUP BY / aggregation.
154    if has_aggregation(select, functions) {
155        let mut plan =
156            super::aggregate::plan_aggregate(select, table, &filters, &scope, functions)?;
157
158        // Semi/anti subquery joins belong below the aggregate so they filter
159        // the input rows before grouping. Scalar subqueries remain above the
160        // aggregate because their column-vs-column comparison is evaluated
161        // after the cross join materializes the scalar result row.
162        if let SqlPlan::Aggregate { input, .. } = &mut plan {
163            let mut base_input = std::mem::replace(
164                input,
165                Box::new(SqlPlan::ConstantResult {
166                    columns: Vec::new(),
167                    values: Vec::new(),
168                }),
169            );
170            for sq in subquery_joins
171                .iter()
172                .filter(|sq| sq.join_type != JoinType::Cross)
173            {
174                base_input = Box::new(SqlPlan::Join {
175                    left: base_input,
176                    right: Box::new(sq.inner_plan.clone()),
177                    on: vec![(sq.outer_column.clone(), sq.inner_column.clone())],
178                    join_type: sq.join_type,
179                    condition: None,
180                    limit: 10000,
181                    projection: Vec::new(),
182                    filters: Vec::new(),
183                });
184            }
185            *input = base_input;
186        }
187
188        for sq in subquery_joins
189            .into_iter()
190            .filter(|sq| sq.join_type == JoinType::Cross)
191        {
192            plan = SqlPlan::Join {
193                left: Box::new(plan),
194                right: Box::new(sq.inner_plan),
195                on: vec![(sq.outer_column, sq.inner_column)],
196                join_type: sq.join_type,
197                condition: None,
198                limit: 10000,
199                projection: Vec::new(),
200                filters: Vec::new(),
201            };
202        }
203        return Ok(plan);
204    }
205
206    // 7. Convert projection.
207    let projection = convert_projection(&select.projection)?;
208
209    // 8. Convert window functions (SELECT with OVER).
210    let window_functions = super::window::extract_window_functions(&select.projection, functions)?;
211
212    // 9. Build base scan plan.
213    let scan_projection = if subquery_joins.is_empty() {
214        projection.clone()
215    } else {
216        Vec::new()
217    };
218
219    let rules = crate::engine_rules::resolve_engine_rules(table.info.engine);
220    let mut plan = rules.plan_scan(crate::engine_rules::ScanParams {
221        collection: table.name.clone(),
222        alias: table.alias.clone(),
223        filters,
224        projection: scan_projection,
225        sort_keys: Vec::new(),
226        limit: None,
227        offset: 0,
228        distinct: select.distinct.is_some(),
229        window_functions,
230    })?;
231
232    // 10. Wrap with subquery joins (semi/anti/cross) if any.
233    for sq in subquery_joins {
234        // For cross-joins (scalar subqueries), move column-referencing filters
235        // from the base scan to the join's post-filters. The filter compares
236        // a field from the base scan with a field from the subquery result,
237        // so it can only be evaluated after the join merges both sides.
238        let join_filters = if sq.join_type == JoinType::Cross {
239            if let SqlPlan::Scan {
240                ref mut filters, ..
241            } = plan
242            {
243                // Move filters that reference the scalar result column to the join.
244                let mut moved = Vec::new();
245                filters.retain(|f| {
246                    if has_column_ref_filter(&f.expr) {
247                        moved.push(f.clone());
248                        false
249                    } else {
250                        true
251                    }
252                });
253                moved
254            } else {
255                Vec::new()
256            }
257        } else {
258            Vec::new()
259        };
260
261        plan = SqlPlan::Join {
262            left: Box::new(plan),
263            right: Box::new(sq.inner_plan),
264            on: vec![(sq.outer_column, sq.inner_column)],
265            join_type: sq.join_type,
266            condition: None,
267            limit: 10000,
268            projection: Vec::new(),
269            filters: join_filters,
270        };
271    }
272
273    if let SqlPlan::Join {
274        projection: ref mut join_projection,
275        ..
276    } = plan
277    {
278        *join_projection = projection;
279    }
280
281    Ok(plan)
282}
283
284/// Check if a filter expression contains a column-vs-column comparison
285/// (from scalar subquery rewriting). These filters must be evaluated
286/// post-join, not pre-join, since one column comes from the subquery result.
287fn has_column_ref_filter(expr: &FilterExpr) -> bool {
288    match expr {
289        FilterExpr::Expr(sql_expr) => has_column_comparison(sql_expr),
290        FilterExpr::And(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
291        FilterExpr::Or(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
292        _ => false,
293    }
294}
295
296fn has_column_comparison(expr: &SqlExpr) -> bool {
297    match expr {
298        SqlExpr::BinaryOp { left, right, .. } => {
299            let left_is_col = matches!(left.as_ref(), SqlExpr::Column { .. });
300            let right_is_col = matches!(right.as_ref(), SqlExpr::Column { .. });
301            if left_is_col && right_is_col {
302                return true;
303            }
304            has_column_comparison(left) || has_column_comparison(right)
305        }
306        _ => false,
307    }
308}
309
310/// Check if a SELECT has aggregation (GROUP BY or aggregate functions in projection).
311fn has_aggregation(select: &Select, functions: &FunctionRegistry) -> bool {
312    let group_by_non_empty = match &select.group_by {
313        ast::GroupByExpr::All(_) => true,
314        ast::GroupByExpr::Expressions(exprs, _) => !exprs.is_empty(),
315    };
316    if group_by_non_empty {
317        return true;
318    }
319    for item in &select.projection {
320        if let ast::SelectItem::UnnamedExpr(expr) | ast::SelectItem::ExprWithAlias { expr, .. } =
321            item
322            && expr_contains_aggregate(expr, functions)
323        {
324            return true;
325        }
326    }
327    false
328}
329
330/// Check if an expression contains an aggregate function call.
331fn expr_contains_aggregate(expr: &ast::Expr, functions: &FunctionRegistry) -> bool {
332    match expr {
333        ast::Expr::Function(func) => {
334            let name = func
335                .name
336                .0
337                .iter()
338                .map(|p| match p {
339                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
340                    _ => String::new(),
341                })
342                .collect::<Vec<_>>()
343                .join(".");
344            if functions.is_aggregate(&name) {
345                return true;
346            }
347            // Check args recursively.
348            if let ast::FunctionArguments::List(args) = &func.args {
349                for arg in &args.args {
350                    if let ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) = arg
351                        && expr_contains_aggregate(e, functions)
352                    {
353                        return true;
354                    }
355                }
356            }
357            false
358        }
359        ast::Expr::BinaryOp { left, right, .. } => {
360            expr_contains_aggregate(left, functions) || expr_contains_aggregate(right, functions)
361        }
362        ast::Expr::Nested(inner) => expr_contains_aggregate(inner, functions),
363        _ => false,
364    }
365}
366
367/// Try to detect search-triggering patterns in WHERE clause.
368fn try_extract_where_search(
369    expr: &ast::Expr,
370    table: &crate::resolver::columns::ResolvedTable,
371    functions: &FunctionRegistry,
372) -> Result<Option<SqlPlan>> {
373    match expr {
374        ast::Expr::Function(func) => {
375            let name = func
376                .name
377                .0
378                .iter()
379                .map(|p| match p {
380                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
381                    _ => String::new(),
382                })
383                .collect::<Vec<_>>()
384                .join(".");
385            match functions.search_trigger(&name) {
386                SearchTrigger::TextMatch => {
387                    let args = extract_func_args(func)?;
388                    if args.len() >= 2 {
389                        let query_text = extract_string_literal(&args[1])?;
390                        return Ok(Some(SqlPlan::TextSearch {
391                            collection: table.name.clone(),
392                            query: query_text,
393                            top_k: 1000,
394                            fuzzy: true,
395                            filters: Vec::new(),
396                        }));
397                    }
398                }
399                SearchTrigger::SpatialDWithin
400                | SearchTrigger::SpatialContains
401                | SearchTrigger::SpatialIntersects
402                | SearchTrigger::SpatialWithin => {
403                    return plan_spatial_from_where(&name, func, table);
404                }
405                _ => {}
406            }
407        }
408        // AND: check left and right for search triggers, combine non-search as filters.
409        ast::Expr::BinaryOp {
410            left,
411            op: ast::BinaryOperator::And,
412            right,
413        } => {
414            if let Some(plan) = try_extract_where_search(left, table, functions)? {
415                return Ok(Some(plan));
416            }
417            if let Some(plan) = try_extract_where_search(right, table, functions)? {
418                return Ok(Some(plan));
419            }
420        }
421        _ => {}
422    }
423    Ok(None)
424}
425
426fn plan_spatial_from_where(
427    name: &str,
428    func: &ast::Function,
429    table: &crate::resolver::columns::ResolvedTable,
430) -> Result<Option<SqlPlan>> {
431    let predicate = match name {
432        "st_dwithin" => SpatialPredicate::DWithin,
433        "st_contains" => SpatialPredicate::Contains,
434        "st_intersects" => SpatialPredicate::Intersects,
435        "st_within" => SpatialPredicate::Within,
436        _ => return Ok(None),
437    };
438    let args = extract_func_args(func)?;
439    if args.is_empty() {
440        return Err(SqlError::MissingField {
441            field: "geometry column".into(),
442            context: name.into(),
443        });
444    }
445    let field = extract_column_name(&args[0])?;
446    let geom_arg = args.get(1).ok_or_else(|| SqlError::MissingField {
447        field: "query geometry".into(),
448        context: name.into(),
449    })?;
450    let geom_str = extract_geometry_arg(geom_arg)?;
451    let distance = if args.len() >= 3 {
452        extract_float(&args[2]).unwrap_or(0.0)
453    } else {
454        0.0
455    };
456    Ok(Some(SqlPlan::SpatialScan {
457        collection: table.name.clone(),
458        field,
459        predicate,
460        query_geometry: geom_str.into_bytes(),
461        distance_meters: distance,
462        attribute_filters: Vec::new(),
463        limit: 1000,
464        projection: Vec::new(),
465    }))
466}
467
468/// Apply ORDER BY, detecting search-triggering sort expressions.
469fn apply_order_by(
470    plan: &SqlPlan,
471    order_by: &ast::OrderBy,
472    functions: &FunctionRegistry,
473) -> Result<SqlPlan> {
474    let exprs = match &order_by.kind {
475        ast::OrderByKind::Expressions(exprs) => exprs,
476        ast::OrderByKind::All(_) => return Ok(plan.clone()),
477    };
478
479    if exprs.is_empty() {
480        return Ok(plan.clone());
481    }
482
483    // Check first ORDER BY expression for search triggers.
484    let first = &exprs[0];
485    if let Some(search_plan) = try_extract_sort_search(&first.expr, plan, functions)? {
486        return Ok(search_plan);
487    }
488
489    // Normal sort keys.
490    let sort_keys: Vec<SortKey> = exprs
491        .iter()
492        .map(|o| {
493            Ok(SortKey {
494                expr: convert_expr(&o.expr)?,
495                ascending: o.options.asc.unwrap_or(true),
496                nulls_first: o.options.nulls_first.unwrap_or(false),
497            })
498        })
499        .collect::<Result<_>>()?;
500
501    match plan {
502        SqlPlan::Scan {
503            collection,
504            alias,
505            engine,
506            filters,
507            projection,
508            limit,
509            offset,
510            distinct,
511            window_functions,
512            ..
513        } => Ok(SqlPlan::Scan {
514            collection: collection.clone(),
515            alias: alias.clone(),
516            engine: *engine,
517            filters: filters.clone(),
518            projection: projection.clone(),
519            sort_keys,
520            limit: *limit,
521            offset: *offset,
522            distinct: *distinct,
523            window_functions: window_functions.clone(),
524        }),
525        _ => Ok(plan.clone()),
526    }
527}
528
529/// Try to detect search-triggering ORDER BY expressions.
530fn try_extract_sort_search(
531    expr: &ast::Expr,
532    plan: &SqlPlan,
533    functions: &FunctionRegistry,
534) -> Result<Option<SqlPlan>> {
535    if let ast::Expr::Function(func) = expr {
536        let name = func
537            .name
538            .0
539            .iter()
540            .map(|p| match p {
541                ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
542                _ => String::new(),
543            })
544            .collect::<Vec<_>>()
545            .join(".");
546        let collection = match plan {
547            SqlPlan::Scan { collection, .. } => collection.clone(),
548            _ => return Ok(None),
549        };
550        let args = extract_func_args(func)?;
551
552        match functions.search_trigger(&name) {
553            SearchTrigger::VectorSearch => {
554                if args.len() < 2 {
555                    return Ok(None);
556                }
557                let field = extract_column_name(&args[0])?;
558                let vector = extract_float_array(&args[1])?;
559                let limit = match plan {
560                    SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
561                    _ => 10,
562                };
563                return Ok(Some(SqlPlan::VectorSearch {
564                    collection,
565                    field,
566                    query_vector: vector,
567                    top_k: limit,
568                    ef_search: limit * 2,
569                    filters: match plan {
570                        SqlPlan::Scan { filters, .. } => filters.clone(),
571                        _ => Vec::new(),
572                    },
573                }));
574            }
575            SearchTrigger::TextSearch => {
576                if args.len() >= 2 {
577                    let query_text = extract_string_literal(&args[1])?;
578                    let limit = match plan {
579                        SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
580                        _ => 10,
581                    };
582                    return Ok(Some(SqlPlan::TextSearch {
583                        collection,
584                        query: query_text,
585                        top_k: limit,
586                        fuzzy: true,
587                        filters: match plan {
588                            SqlPlan::Scan { filters, .. } => filters.clone(),
589                            _ => Vec::new(),
590                        },
591                    }));
592                }
593            }
594            SearchTrigger::HybridSearch => {
595                return plan_hybrid_from_sort(&args, &collection, plan, functions);
596            }
597            _ => {}
598        }
599    }
600    Ok(None)
601}
602
603fn plan_hybrid_from_sort(
604    args: &[ast::Expr],
605    collection: &str,
606    plan: &SqlPlan,
607    _functions: &FunctionRegistry,
608) -> Result<Option<SqlPlan>> {
609    // rrf_score(vector_distance(...), bm25_score(...), k1, k2)
610    if args.len() < 2 {
611        return Ok(None);
612    }
613    let vector = match &args[0] {
614        ast::Expr::Function(f) => {
615            let inner_args = extract_func_args(f)?;
616            if inner_args.len() >= 2 {
617                extract_float_array(&inner_args[1]).unwrap_or_default()
618            } else {
619                Vec::new()
620            }
621        }
622        _ => Vec::new(),
623    };
624    let text = match &args[1] {
625        ast::Expr::Function(f) => {
626            let inner_args = extract_func_args(f)?;
627            if inner_args.len() >= 2 {
628                extract_string_literal(&inner_args[1]).unwrap_or_default()
629            } else {
630                String::new()
631            }
632        }
633        _ => String::new(),
634    };
635    let k1 = args
636        .get(2)
637        .and_then(|e| extract_float(e).ok())
638        .unwrap_or(60.0);
639    let k2 = args
640        .get(3)
641        .and_then(|e| extract_float(e).ok())
642        .unwrap_or(60.0);
643    let limit = match plan {
644        SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
645        _ => 10,
646    };
647    let vector_weight = k2 as f32 / (k1 as f32 + k2 as f32);
648
649    Ok(Some(SqlPlan::HybridSearch {
650        collection: collection.into(),
651        query_vector: vector,
652        query_text: text,
653        top_k: limit,
654        ef_search: limit * 2,
655        vector_weight,
656        fuzzy: true,
657    }))
658}
659
660/// Apply LIMIT and OFFSET to a plan.
661fn apply_limit(mut plan: SqlPlan, limit_clause: &Option<ast::LimitClause>) -> SqlPlan {
662    let (limit_val, offset_val) = match limit_clause {
663        None => (None, 0usize),
664        Some(ast::LimitClause::LimitOffset { limit, offset, .. }) => {
665            let lv = limit.as_ref().and_then(|e| match e {
666                ast::Expr::Value(v) => match &v.value {
667                    ast::Value::Number(n, _) => n.parse::<usize>().ok(),
668                    _ => None,
669                },
670                _ => None,
671            });
672            let ov = offset
673                .as_ref()
674                .and_then(|o| match &o.value {
675                    ast::Expr::Value(v) => match &v.value {
676                        ast::Value::Number(n, _) => n.parse::<usize>().ok(),
677                        _ => None,
678                    },
679                    _ => None,
680                })
681                .unwrap_or(0);
682            (lv, ov)
683        }
684        Some(ast::LimitClause::OffsetCommaLimit { offset, limit }) => {
685            let lv = match limit {
686                ast::Expr::Value(v) => match &v.value {
687                    ast::Value::Number(n, _) => n.parse::<usize>().ok(),
688                    _ => None,
689                },
690                _ => None,
691            };
692            let ov = match offset {
693                ast::Expr::Value(v) => match &v.value {
694                    ast::Value::Number(n, _) => n.parse::<usize>().ok(),
695                    _ => None,
696                },
697                _ => None,
698            }
699            .unwrap_or(0);
700            (lv, ov)
701        }
702    };
703
704    match plan {
705        SqlPlan::Scan {
706            ref mut limit,
707            ref mut offset,
708            ..
709        } => {
710            *limit = limit_val;
711            *offset = offset_val;
712        }
713        SqlPlan::Aggregate {
714            limit: ref mut l, ..
715        } => {
716            if let Some(lv) = limit_val {
717                *l = lv;
718            }
719        }
720        _ => {}
721    }
722    plan
723}
724
725// ── Helpers ──
726
727/// Convert SELECT projection items.
728pub fn convert_projection(items: &[ast::SelectItem]) -> Result<Vec<Projection>> {
729    let mut result = Vec::new();
730    for item in items {
731        match item {
732            ast::SelectItem::UnnamedExpr(expr) => {
733                let sql_expr = convert_expr(expr)?;
734                match &sql_expr {
735                    SqlExpr::Column { table, name } => {
736                        result.push(Projection::Column(qualified_name(table.as_deref(), name)));
737                    }
738                    SqlExpr::Wildcard => {
739                        result.push(Projection::Star);
740                    }
741                    _ => {
742                        result.push(Projection::Computed {
743                            expr: sql_expr,
744                            alias: format!("{expr}"),
745                        });
746                    }
747                }
748            }
749            ast::SelectItem::ExprWithAlias { expr, alias } => {
750                let sql_expr = convert_expr(expr)?;
751                result.push(Projection::Computed {
752                    expr: sql_expr,
753                    alias: normalize_ident(alias),
754                });
755            }
756            ast::SelectItem::Wildcard(_) => {
757                result.push(Projection::Star);
758            }
759            ast::SelectItem::QualifiedWildcard(kind, _) => {
760                let table_name = match kind {
761                    ast::SelectItemQualifiedWildcardKind::ObjectName(name) => {
762                        crate::parser::normalize::normalize_object_name(name)
763                    }
764                    _ => String::new(),
765                };
766                result.push(Projection::QualifiedStar(table_name));
767            }
768        }
769    }
770    Ok(result)
771}
772
773/// Build a qualified column reference (`table.name` or just `name`).
774pub fn qualified_name(table: Option<&str>, name: &str) -> String {
775    table.map_or_else(|| name.to_string(), |table| format!("{table}.{name}"))
776}
777
778/// Convert a WHERE expression into a list of Filter.
779pub fn convert_where_to_filters(expr: &ast::Expr) -> Result<Vec<Filter>> {
780    let sql_expr = convert_expr(expr)?;
781    Ok(vec![Filter {
782        expr: FilterExpr::Expr(sql_expr),
783    }])
784}
785
786pub(crate) fn extract_func_args(func: &ast::Function) -> Result<Vec<ast::Expr>> {
787    match &func.args {
788        ast::FunctionArguments::List(args) => Ok(args
789            .args
790            .iter()
791            .filter_map(|a| match a {
792                ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(e.clone()),
793                _ => None,
794            })
795            .collect()),
796        _ => Ok(Vec::new()),
797    }
798}
799
800/// Evaluate a constant SqlExpr to a SqlValue. Delegates to the shared
801/// `const_fold::fold_constant` helper so that zero-arg scalar functions
802/// like `now()` and `current_timestamp` go through the same evaluator
803/// as the runtime expression path.
804fn eval_constant_expr(expr: &SqlExpr, functions: &FunctionRegistry) -> SqlValue {
805    super::const_fold::fold_constant(expr, functions).unwrap_or(SqlValue::Null)
806}
807
808/// Extract a geometry argument: handles ST_Point(lon, lat), ST_GeomFromGeoJSON('...'),
809/// or a raw string literal containing GeoJSON.
810fn extract_geometry_arg(expr: &ast::Expr) -> Result<String> {
811    match expr {
812        // ST_Point(lon, lat) → GeoJSON Point
813        ast::Expr::Function(func) => {
814            let name = func
815                .name
816                .0
817                .iter()
818                .map(|p| match p {
819                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
820                    _ => String::new(),
821                })
822                .collect::<Vec<_>>()
823                .join(".");
824            let args = extract_func_args(func)?;
825            match name.as_str() {
826                "st_point" if args.len() >= 2 => {
827                    let lon = extract_float(&args[0])?;
828                    let lat = extract_float(&args[1])?;
829                    Ok(format!(r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#))
830                }
831                "st_geomfromgeojson" if !args.is_empty() => extract_string_literal(&args[0]),
832                _ => Ok(format!("{expr}")),
833            }
834        }
835        // Raw string literal: assumed to be GeoJSON.
836        _ => extract_string_literal(expr).or_else(|_| Ok(format!("{expr}"))),
837    }
838}
839
840fn extract_column_name(expr: &ast::Expr) -> Result<String> {
841    match expr {
842        ast::Expr::Identifier(ident) => Ok(normalize_ident(ident)),
843        ast::Expr::CompoundIdentifier(parts) => Ok(parts
844            .iter()
845            .map(normalize_ident)
846            .collect::<Vec<_>>()
847            .join(".")),
848        _ => Err(SqlError::Unsupported {
849            detail: format!("expected column name, got: {expr}"),
850        }),
851    }
852}
853
854pub(crate) fn extract_string_literal(expr: &ast::Expr) -> Result<String> {
855    match expr {
856        ast::Expr::Value(v) => match &v.value {
857            ast::Value::SingleQuotedString(s) | ast::Value::DoubleQuotedString(s) => Ok(s.clone()),
858            _ => Err(SqlError::Unsupported {
859                detail: format!("expected string literal, got: {expr}"),
860            }),
861        },
862        _ => Err(SqlError::Unsupported {
863            detail: format!("expected string literal, got: {expr}"),
864        }),
865    }
866}
867
868pub(crate) fn extract_float(expr: &ast::Expr) -> Result<f64> {
869    match expr {
870        ast::Expr::Value(v) => match &v.value {
871            ast::Value::Number(n, _) => n.parse::<f64>().map_err(|_| SqlError::TypeMismatch {
872                detail: format!("expected number: {n}"),
873            }),
874            _ => Err(SqlError::TypeMismatch {
875                detail: format!("expected number, got: {expr}"),
876            }),
877        },
878        // Handle negative numbers: -73.9855 is parsed as UnaryOp { Minus, 73.9855 }
879        ast::Expr::UnaryOp {
880            op: ast::UnaryOperator::Minus,
881            expr: inner,
882        } => extract_float(inner).map(|f| -f),
883        _ => Err(SqlError::TypeMismatch {
884            detail: format!("expected number, got: {expr}"),
885        }),
886    }
887}
888
889/// Extract a float array from ARRAY[...] or make_array(...) expression.
890fn extract_float_array(expr: &ast::Expr) -> Result<Vec<f32>> {
891    match expr {
892        ast::Expr::Array(ast::Array { elem, .. }) => elem
893            .iter()
894            .map(|e| extract_float(e).map(|f| f as f32))
895            .collect(),
896        ast::Expr::Function(func) => {
897            let name = func
898                .name
899                .0
900                .iter()
901                .map(|p| match p {
902                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
903                    _ => String::new(),
904                })
905                .collect::<Vec<_>>()
906                .join(".");
907            if name == "make_array" || name == "array" {
908                let args = extract_func_args(func)?;
909                args.iter()
910                    .map(|e| extract_float(e).map(|f| f as f32))
911                    .collect()
912            } else {
913                Err(SqlError::Unsupported {
914                    detail: format!("expected array, got function: {name}"),
915                })
916            }
917        }
918        _ => Err(SqlError::Unsupported {
919            detail: format!("expected array literal, got: {expr}"),
920        }),
921    }
922}
923
924/// Check if a SELECT has the DISTINCT keyword.
925fn try_plan_join(
926    select: &Select,
927    scope: &TableScope,
928    catalog: &dyn SqlCatalog,
929    functions: &FunctionRegistry,
930) -> Result<Option<SqlPlan>> {
931    if select.from.len() != 1 {
932        return Ok(None);
933    }
934    let from = &select.from[0];
935    if from.joins.is_empty() {
936        return Ok(None);
937    }
938    super::join::plan_join_from_select(select, scope, catalog, functions)
939}
940
941/// Catalog wrapper that resolves CTE names as schemaless document collections.
942struct CteCatalog<'a> {
943    inner: &'a dyn SqlCatalog,
944    cte_names: Vec<String>,
945}
946
947impl SqlCatalog for CteCatalog<'_> {
948    fn get_collection(
949        &self,
950        name: &str,
951    ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
952        // Check CTE names first.
953        if self.cte_names.iter().any(|n| n == name) {
954            return Ok(Some(CollectionInfo {
955                name: name.into(),
956                engine: EngineType::DocumentSchemaless,
957                columns: Vec::new(),
958                primary_key: Some("id".into()),
959                has_auto_tier: false,
960            }));
961        }
962        self.inner.get_collection(name)
963    }
964}
965
966#[cfg(test)]
967mod tests {
968    use super::*;
969    use crate::functions::registry::FunctionRegistry;
970    use crate::parser::statement::parse_sql;
971    use sqlparser::ast::Statement;
972
973    struct TestCatalog;
974
975    impl SqlCatalog for TestCatalog {
976        fn get_collection(
977            &self,
978            name: &str,
979        ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
980            let info = match name {
981                "products" => Some(CollectionInfo {
982                    name: "products".into(),
983                    engine: EngineType::DocumentSchemaless,
984                    columns: Vec::new(),
985                    primary_key: Some("id".into()),
986                    has_auto_tier: false,
987                }),
988                "users" => Some(CollectionInfo {
989                    name: "users".into(),
990                    engine: EngineType::DocumentSchemaless,
991                    columns: Vec::new(),
992                    primary_key: Some("id".into()),
993                    has_auto_tier: false,
994                }),
995                "orders" => Some(CollectionInfo {
996                    name: "orders".into(),
997                    engine: EngineType::DocumentSchemaless,
998                    columns: Vec::new(),
999                    primary_key: Some("id".into()),
1000                    has_auto_tier: false,
1001                }),
1002                "docs" => Some(CollectionInfo {
1003                    name: "docs".into(),
1004                    engine: EngineType::DocumentSchemaless,
1005                    columns: Vec::new(),
1006                    primary_key: Some("id".into()),
1007                    has_auto_tier: false,
1008                }),
1009                "tags" => Some(CollectionInfo {
1010                    name: "tags".into(),
1011                    engine: EngineType::DocumentSchemaless,
1012                    columns: Vec::new(),
1013                    primary_key: Some("id".into()),
1014                    has_auto_tier: false,
1015                }),
1016                "user_prefs" => Some(CollectionInfo {
1017                    name: "user_prefs".into(),
1018                    engine: EngineType::KeyValue,
1019                    columns: Vec::new(),
1020                    primary_key: Some("key".into()),
1021                    has_auto_tier: false,
1022                }),
1023                _ => None,
1024            };
1025            Ok(info)
1026        }
1027    }
1028
1029    fn plan_select_sql(sql: &str) -> SqlPlan {
1030        let statements = parse_sql(sql).unwrap();
1031        let Statement::Query(query) = &statements[0] else {
1032            panic!("expected query statement");
1033        };
1034        plan_query(query, &TestCatalog, &FunctionRegistry::new()).unwrap()
1035    }
1036
1037    #[test]
1038    fn aggregate_subquery_join_filters_input_before_aggregation() {
1039        let plan = plan_select_sql(
1040            "SELECT AVG(price) FROM products WHERE category IN (SELECT DISTINCT category FROM products WHERE qty > 100)",
1041        );
1042
1043        let SqlPlan::Aggregate { input, .. } = plan else {
1044            panic!("expected aggregate plan");
1045        };
1046
1047        let SqlPlan::Join {
1048            left,
1049            join_type,
1050            on,
1051            ..
1052        } = *input
1053        else {
1054            panic!("expected semi-join below aggregate");
1055        };
1056
1057        assert_eq!(join_type, JoinType::Semi);
1058        assert_eq!(on, vec![("category".into(), "category".into())]);
1059        assert!(matches!(*left, SqlPlan::Scan { .. }));
1060    }
1061
1062    #[test]
1063    fn scalar_subquery_defers_projection_until_after_join_filter() {
1064        let plan = plan_select_sql(
1065            "SELECT user_id FROM orders WHERE amount > (SELECT AVG(amount) FROM orders)",
1066        );
1067
1068        let SqlPlan::Join {
1069            left,
1070            projection,
1071            filters,
1072            ..
1073        } = plan
1074        else {
1075            panic!("expected join plan");
1076        };
1077
1078        let SqlPlan::Scan {
1079            projection: scan_projection,
1080            ..
1081        } = *left
1082        else {
1083            panic!("expected scan on join left");
1084        };
1085
1086        assert!(scan_projection.is_empty(), "scan projected too early");
1087        assert_eq!(projection.len(), 1);
1088        match &projection[0] {
1089            Projection::Column(name) => assert_eq!(name, "user_id"),
1090            other => panic!("expected user_id projection, got {other:?}"),
1091        }
1092        assert!(
1093            !filters.is_empty(),
1094            "scalar comparison should stay post-join"
1095        );
1096    }
1097
1098    #[test]
1099    fn chained_join_preserves_qualified_on_keys() {
1100        let plan = plan_select_sql(
1101            "SELECT d.name, t.tag, p.theme \
1102             FROM docs d \
1103             LEFT JOIN tags t ON d.id = t.doc_id \
1104             INNER JOIN user_prefs p ON d.id = p.key",
1105        );
1106
1107        let SqlPlan::Join { left, on, .. } = plan else {
1108            panic!("expected outer join plan");
1109        };
1110        assert_eq!(on, vec![("d.id".into(), "p.key".into())]);
1111
1112        let SqlPlan::Join { on: inner_on, .. } = *left else {
1113            panic!("expected nested left join");
1114        };
1115        assert_eq!(inner_on, vec![("d.id".into(), "t.doc_id".into())]);
1116    }
1117}