Skip to main content

nodedb_sql/planner/
select.rs

1//! SELECT query planning: FROM → WHERE → GROUP BY → HAVING → SELECT → ORDER BY → LIMIT.
2//!
3//! This is the main entry point for SELECT statement conversion. It detects
4//! search patterns (vector, text, hybrid, spatial) directly from the AST
5//! instead of reverse-engineering an optimizer's output.
6
7use sqlparser::ast::{self, Query, Select, SetExpr};
8
9use crate::error::{Result, SqlError};
10use crate::functions::registry::{FunctionRegistry, SearchTrigger};
11use crate::parser::normalize::normalize_ident;
12use crate::resolver::columns::TableScope;
13use crate::resolver::expr::convert_expr;
14use crate::types::*;
15
16/// Plan a SELECT query.
17pub fn plan_query(
18    query: &Query,
19    catalog: &dyn SqlCatalog,
20    functions: &FunctionRegistry,
21) -> Result<SqlPlan> {
22    // Handle CTEs (WITH clause).
23    if let Some(with) = &query.with
24        && with.recursive
25    {
26        return super::cte::plan_recursive_cte(query, catalog, functions);
27    }
28    // Non-recursive CTEs: plan each CTE subquery and the outer query.
29    if let Some(with) = &query.with
30        && !with.cte_tables.is_empty()
31    {
32        let inner_query = Query {
33            with: None,
34            body: query.body.clone(),
35            order_by: query.order_by.clone(),
36            limit_clause: query.limit_clause.clone(),
37            fetch: query.fetch.clone(),
38            locks: query.locks.clone(),
39            for_clause: query.for_clause.clone(),
40            settings: query.settings.clone(),
41            format_clause: query.format_clause.clone(),
42            pipe_operators: query.pipe_operators.clone(),
43        };
44
45        // Plan each CTE subquery.
46        let mut definitions = Vec::new();
47        let mut cte_names = Vec::new();
48        for cte in &with.cte_tables {
49            let name = normalize_ident(&cte.alias.name);
50            let cte_plan = plan_query(&cte.query, catalog, functions)?;
51            definitions.push((name.clone(), cte_plan));
52            cte_names.push(name);
53        }
54
55        // Build CTE-aware catalog so the outer query can reference CTE names.
56        let cte_catalog = CteCatalog {
57            inner: catalog,
58            cte_names,
59        };
60        let outer = plan_query(&inner_query, &cte_catalog, functions)?;
61
62        return Ok(SqlPlan::Cte {
63            definitions,
64            outer: Box::new(outer),
65        });
66    }
67
68    // Handle UNION.
69    match &*query.body {
70        SetExpr::Select(select) => {
71            let mut plan = plan_select(select, catalog, functions)?;
72            if let Some(order_by) = &query.order_by {
73                plan = apply_order_by(&plan, order_by, functions)?;
74            }
75            plan = apply_limit(plan, &query.limit_clause);
76            Ok(plan)
77        }
78        SetExpr::SetOperation {
79            op,
80            left,
81            right,
82            set_quantifier,
83        } => super::union::plan_set_operation(op, left, right, set_quantifier, catalog, functions),
84        _ => Err(SqlError::Unsupported {
85            detail: format!("query body type: {}", query.body),
86        }),
87    }
88}
89
90/// Plan a single SELECT statement (no UNION, no CTE wrapper).
91fn plan_select(
92    select: &Select,
93    catalog: &dyn SqlCatalog,
94    functions: &FunctionRegistry,
95) -> Result<SqlPlan> {
96    // 1. Resolve FROM tables.
97    let scope = TableScope::resolve_from(catalog, &select.from)?;
98
99    // 2. Handle constant queries (no FROM clause): SELECT 1, SELECT 'hello', etc.
100    if select.from.is_empty() {
101        let projection = convert_projection(&select.projection)?;
102        let mut columns = Vec::new();
103        let mut values = Vec::new();
104        for (i, proj) in projection.iter().enumerate() {
105            match proj {
106                Projection::Computed { expr, alias } => {
107                    columns.push(alias.clone());
108                    values.push(eval_constant_expr(expr, functions));
109                }
110                Projection::Column(name) => {
111                    columns.push(name.clone());
112                    values.push(SqlValue::Null);
113                }
114                _ => {
115                    columns.push(format!("col{i}"));
116                    values.push(SqlValue::Null);
117                }
118            }
119        }
120        return Ok(SqlPlan::ConstantResult { columns, values });
121    }
122
123    // 3. Check for JOINs.
124    if let Some(plan) = try_plan_join(select, &scope, catalog, functions)? {
125        return Ok(plan);
126    }
127
128    // 4. Single-table query.
129    let table = scope.single_table().ok_or_else(|| SqlError::Unsupported {
130        detail: "multi-table FROM without JOIN".into(),
131    })?;
132
133    // 4. Extract subqueries from WHERE and rewrite as semi/anti joins.
134    let (subquery_joins, effective_where) = if let Some(expr) = &select.selection {
135        let extraction = super::subquery::extract_subqueries(expr, catalog, functions)?;
136        (extraction.joins, extraction.remaining_where)
137    } else {
138        (Vec::new(), None)
139    };
140
141    // 5. Convert remaining WHERE filters.
142    let filters = match &effective_where {
143        Some(expr) => {
144            // Check for search-triggering functions in WHERE.
145            if let Some(plan) = try_extract_where_search(expr, table, functions)? {
146                return Ok(plan);
147            }
148            convert_where_to_filters(expr)?
149        }
150        None => Vec::new(),
151    };
152
153    // 6. Check for GROUP BY / aggregation.
154    if has_aggregation(select, functions) {
155        let mut plan =
156            super::aggregate::plan_aggregate(select, table, &filters, &scope, functions)?;
157
158        // Semi/anti subquery joins belong below the aggregate so they filter
159        // the input rows before grouping. Scalar subqueries remain above the
160        // aggregate because their column-vs-column comparison is evaluated
161        // after the cross join materializes the scalar result row.
162        if let SqlPlan::Aggregate { input, .. } = &mut plan {
163            let mut base_input = std::mem::replace(
164                input,
165                Box::new(SqlPlan::ConstantResult {
166                    columns: Vec::new(),
167                    values: Vec::new(),
168                }),
169            );
170            for sq in subquery_joins
171                .iter()
172                .filter(|sq| sq.join_type != JoinType::Cross)
173            {
174                base_input = Box::new(SqlPlan::Join {
175                    left: base_input,
176                    right: Box::new(sq.inner_plan.clone()),
177                    on: vec![(sq.outer_column.clone(), sq.inner_column.clone())],
178                    join_type: sq.join_type,
179                    condition: None,
180                    limit: 10000,
181                    projection: Vec::new(),
182                    filters: Vec::new(),
183                });
184            }
185            *input = base_input;
186        }
187
188        for sq in subquery_joins
189            .into_iter()
190            .filter(|sq| sq.join_type == JoinType::Cross)
191        {
192            plan = SqlPlan::Join {
193                left: Box::new(plan),
194                right: Box::new(sq.inner_plan),
195                on: vec![(sq.outer_column, sq.inner_column)],
196                join_type: sq.join_type,
197                condition: None,
198                limit: 10000,
199                projection: Vec::new(),
200                filters: Vec::new(),
201            };
202        }
203        return Ok(plan);
204    }
205
206    // 7. Convert projection.
207    let projection = convert_projection(&select.projection)?;
208
209    // 8. Convert window functions (SELECT with OVER).
210    let window_functions = super::window::extract_window_functions(&select.projection, functions)?;
211
212    // 9. Build base scan plan.
213    let scan_projection = if subquery_joins.is_empty() {
214        projection.clone()
215    } else {
216        Vec::new()
217    };
218
219    let rules = crate::engine_rules::resolve_engine_rules(table.info.engine);
220    let mut plan = rules.plan_scan(crate::engine_rules::ScanParams {
221        collection: table.name.clone(),
222        alias: table.alias.clone(),
223        filters,
224        projection: scan_projection,
225        sort_keys: Vec::new(),
226        limit: None,
227        offset: 0,
228        distinct: select.distinct.is_some(),
229        window_functions,
230        indexes: table.info.indexes.clone(),
231    })?;
232
233    // 10. Wrap with subquery joins (semi/anti/cross) if any.
234    for sq in subquery_joins {
235        // For cross-joins (scalar subqueries), move column-referencing filters
236        // from the base scan to the join's post-filters. The filter compares
237        // a field from the base scan with a field from the subquery result,
238        // so it can only be evaluated after the join merges both sides.
239        let join_filters = if sq.join_type == JoinType::Cross {
240            if let SqlPlan::Scan {
241                ref mut filters, ..
242            } = plan
243            {
244                // Move filters that reference the scalar result column to the join.
245                let mut moved = Vec::new();
246                filters.retain(|f| {
247                    if has_column_ref_filter(&f.expr) {
248                        moved.push(f.clone());
249                        false
250                    } else {
251                        true
252                    }
253                });
254                moved
255            } else {
256                Vec::new()
257            }
258        } else {
259            Vec::new()
260        };
261
262        plan = SqlPlan::Join {
263            left: Box::new(plan),
264            right: Box::new(sq.inner_plan),
265            on: vec![(sq.outer_column, sq.inner_column)],
266            join_type: sq.join_type,
267            condition: None,
268            limit: 10000,
269            projection: Vec::new(),
270            filters: join_filters,
271        };
272    }
273
274    if let SqlPlan::Join {
275        projection: ref mut join_projection,
276        ..
277    } = plan
278    {
279        *join_projection = projection;
280    }
281
282    Ok(plan)
283}
284
285/// Check if a filter expression contains a column-vs-column comparison
286/// (from scalar subquery rewriting). These filters must be evaluated
287/// post-join, not pre-join, since one column comes from the subquery result.
288fn has_column_ref_filter(expr: &FilterExpr) -> bool {
289    match expr {
290        FilterExpr::Expr(sql_expr) => has_column_comparison(sql_expr),
291        FilterExpr::And(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
292        FilterExpr::Or(filters) => filters.iter().any(|f| has_column_ref_filter(&f.expr)),
293        _ => false,
294    }
295}
296
297fn has_column_comparison(expr: &SqlExpr) -> bool {
298    match expr {
299        SqlExpr::BinaryOp { left, right, .. } => {
300            let left_is_col = matches!(left.as_ref(), SqlExpr::Column { .. });
301            let right_is_col = matches!(right.as_ref(), SqlExpr::Column { .. });
302            if left_is_col && right_is_col {
303                return true;
304            }
305            has_column_comparison(left) || has_column_comparison(right)
306        }
307        _ => false,
308    }
309}
310
311/// Check if a SELECT has aggregation (GROUP BY or aggregate functions in projection).
312fn has_aggregation(select: &Select, functions: &FunctionRegistry) -> bool {
313    let group_by_non_empty = match &select.group_by {
314        ast::GroupByExpr::All(_) => true,
315        ast::GroupByExpr::Expressions(exprs, _) => !exprs.is_empty(),
316    };
317    if group_by_non_empty {
318        return true;
319    }
320    for item in &select.projection {
321        if let ast::SelectItem::UnnamedExpr(expr) | ast::SelectItem::ExprWithAlias { expr, .. } =
322            item
323            && crate::aggregate_walk::contains_aggregate(expr, functions)
324        {
325            return true;
326        }
327    }
328    false
329}
330
331/// Try to detect search-triggering patterns in WHERE clause.
332fn try_extract_where_search(
333    expr: &ast::Expr,
334    table: &crate::resolver::columns::ResolvedTable,
335    functions: &FunctionRegistry,
336) -> Result<Option<SqlPlan>> {
337    match expr {
338        ast::Expr::Function(func) => {
339            let name = func
340                .name
341                .0
342                .iter()
343                .map(|p| match p {
344                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
345                    _ => String::new(),
346                })
347                .collect::<Vec<_>>()
348                .join(".");
349            match functions.search_trigger(&name) {
350                SearchTrigger::TextMatch => {
351                    let args = extract_func_args(func)?;
352                    if args.len() >= 2 {
353                        let query_text = extract_string_literal(&args[1])?;
354                        return Ok(Some(SqlPlan::TextSearch {
355                            collection: table.name.clone(),
356                            query: query_text,
357                            top_k: 1000,
358                            fuzzy: true,
359                            filters: Vec::new(),
360                        }));
361                    }
362                }
363                SearchTrigger::SpatialDWithin
364                | SearchTrigger::SpatialContains
365                | SearchTrigger::SpatialIntersects
366                | SearchTrigger::SpatialWithin => {
367                    return plan_spatial_from_where(&name, func, table);
368                }
369                _ => {}
370            }
371        }
372        // AND: check left and right for search triggers, combine non-search as filters.
373        ast::Expr::BinaryOp {
374            left,
375            op: ast::BinaryOperator::And,
376            right,
377        } => {
378            if let Some(plan) = try_extract_where_search(left, table, functions)? {
379                return Ok(Some(plan));
380            }
381            if let Some(plan) = try_extract_where_search(right, table, functions)? {
382                return Ok(Some(plan));
383            }
384        }
385        _ => {}
386    }
387    Ok(None)
388}
389
390fn plan_spatial_from_where(
391    name: &str,
392    func: &ast::Function,
393    table: &crate::resolver::columns::ResolvedTable,
394) -> Result<Option<SqlPlan>> {
395    let predicate = match name {
396        "st_dwithin" => SpatialPredicate::DWithin,
397        "st_contains" => SpatialPredicate::Contains,
398        "st_intersects" => SpatialPredicate::Intersects,
399        "st_within" => SpatialPredicate::Within,
400        _ => return Ok(None),
401    };
402    let args = extract_func_args(func)?;
403    if args.is_empty() {
404        return Err(SqlError::MissingField {
405            field: "geometry column".into(),
406            context: name.into(),
407        });
408    }
409    let field = extract_column_name(&args[0])?;
410    let geom_arg = args.get(1).ok_or_else(|| SqlError::MissingField {
411        field: "query geometry".into(),
412        context: name.into(),
413    })?;
414    let geom_str = extract_geometry_arg(geom_arg)?;
415    let distance = if args.len() >= 3 {
416        extract_float(&args[2]).unwrap_or(0.0)
417    } else {
418        0.0
419    };
420    Ok(Some(SqlPlan::SpatialScan {
421        collection: table.name.clone(),
422        field,
423        predicate,
424        query_geometry: geom_str.into_bytes(),
425        distance_meters: distance,
426        attribute_filters: Vec::new(),
427        limit: 1000,
428        projection: Vec::new(),
429    }))
430}
431
432/// Apply ORDER BY, detecting search-triggering sort expressions.
433fn apply_order_by(
434    plan: &SqlPlan,
435    order_by: &ast::OrderBy,
436    functions: &FunctionRegistry,
437) -> Result<SqlPlan> {
438    let exprs = match &order_by.kind {
439        ast::OrderByKind::Expressions(exprs) => exprs,
440        ast::OrderByKind::All(_) => return Ok(plan.clone()),
441    };
442
443    if exprs.is_empty() {
444        return Ok(plan.clone());
445    }
446
447    // Check first ORDER BY expression for search triggers.
448    let first = &exprs[0];
449    if let Some(search_plan) = try_extract_sort_search(&first.expr, plan, functions)? {
450        return Ok(search_plan);
451    }
452
453    // Normal sort keys.
454    let sort_keys: Vec<SortKey> = exprs
455        .iter()
456        .map(|o| {
457            Ok(SortKey {
458                expr: convert_expr(&o.expr)?,
459                ascending: o.options.asc.unwrap_or(true),
460                nulls_first: o.options.nulls_first.unwrap_or(false),
461            })
462        })
463        .collect::<Result<_>>()?;
464
465    match plan {
466        SqlPlan::Scan {
467            collection,
468            alias,
469            engine,
470            filters,
471            projection,
472            limit,
473            offset,
474            distinct,
475            window_functions,
476            ..
477        } => Ok(SqlPlan::Scan {
478            collection: collection.clone(),
479            alias: alias.clone(),
480            engine: *engine,
481            filters: filters.clone(),
482            projection: projection.clone(),
483            sort_keys,
484            limit: *limit,
485            offset: *offset,
486            distinct: *distinct,
487            window_functions: window_functions.clone(),
488        }),
489        _ => Ok(plan.clone()),
490    }
491}
492
493/// Try to detect search-triggering ORDER BY expressions.
494fn try_extract_sort_search(
495    expr: &ast::Expr,
496    plan: &SqlPlan,
497    functions: &FunctionRegistry,
498) -> Result<Option<SqlPlan>> {
499    if let ast::Expr::Function(func) = expr {
500        let name = func
501            .name
502            .0
503            .iter()
504            .map(|p| match p {
505                ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
506                _ => String::new(),
507            })
508            .collect::<Vec<_>>()
509            .join(".");
510        let collection = match plan {
511            SqlPlan::Scan { collection, .. } => collection.clone(),
512            _ => return Ok(None),
513        };
514        let args = extract_func_args(func)?;
515
516        match functions.search_trigger(&name) {
517            SearchTrigger::VectorSearch => {
518                if args.len() < 2 {
519                    return Ok(None);
520                }
521                let field = extract_column_name(&args[0])?;
522                let vector = extract_float_array(&args[1])?;
523                let limit = match plan {
524                    SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
525                    _ => 10,
526                };
527                return Ok(Some(SqlPlan::VectorSearch {
528                    collection,
529                    field,
530                    query_vector: vector,
531                    top_k: limit,
532                    ef_search: limit * 2,
533                    filters: match plan {
534                        SqlPlan::Scan { filters, .. } => filters.clone(),
535                        _ => Vec::new(),
536                    },
537                }));
538            }
539            SearchTrigger::TextSearch if args.len() >= 2 => {
540                let query_text = extract_string_literal(&args[1])?;
541                let limit = match plan {
542                    SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
543                    _ => 10,
544                };
545                return Ok(Some(SqlPlan::TextSearch {
546                    collection,
547                    query: query_text,
548                    top_k: limit,
549                    fuzzy: true,
550                    filters: match plan {
551                        SqlPlan::Scan { filters, .. } => filters.clone(),
552                        _ => Vec::new(),
553                    },
554                }));
555            }
556            SearchTrigger::TextSearch => {}
557            SearchTrigger::HybridSearch => {
558                return plan_hybrid_from_sort(&args, &collection, plan, functions);
559            }
560            _ => {}
561        }
562    }
563    Ok(None)
564}
565
566fn plan_hybrid_from_sort(
567    args: &[ast::Expr],
568    collection: &str,
569    plan: &SqlPlan,
570    _functions: &FunctionRegistry,
571) -> Result<Option<SqlPlan>> {
572    // rrf_score(vector_distance(...), bm25_score(...), k1, k2)
573    if args.len() < 2 {
574        return Ok(None);
575    }
576    let vector = match &args[0] {
577        ast::Expr::Function(f) => {
578            let inner_args = extract_func_args(f)?;
579            if inner_args.len() >= 2 {
580                extract_float_array(&inner_args[1]).unwrap_or_default()
581            } else {
582                Vec::new()
583            }
584        }
585        _ => Vec::new(),
586    };
587    let text = match &args[1] {
588        ast::Expr::Function(f) => {
589            let inner_args = extract_func_args(f)?;
590            if inner_args.len() >= 2 {
591                extract_string_literal(&inner_args[1]).unwrap_or_default()
592            } else {
593                String::new()
594            }
595        }
596        _ => String::new(),
597    };
598    let k1 = args
599        .get(2)
600        .and_then(|e| extract_float(e).ok())
601        .unwrap_or(60.0);
602    let k2 = args
603        .get(3)
604        .and_then(|e| extract_float(e).ok())
605        .unwrap_or(60.0);
606    let limit = match plan {
607        SqlPlan::Scan { limit, .. } => limit.unwrap_or(10),
608        _ => 10,
609    };
610    let vector_weight = k2 as f32 / (k1 as f32 + k2 as f32);
611
612    Ok(Some(SqlPlan::HybridSearch {
613        collection: collection.into(),
614        query_vector: vector,
615        query_text: text,
616        top_k: limit,
617        ef_search: limit * 2,
618        vector_weight,
619        fuzzy: true,
620    }))
621}
622
623/// Apply LIMIT and OFFSET to a plan.
624fn apply_limit(mut plan: SqlPlan, limit_clause: &Option<ast::LimitClause>) -> SqlPlan {
625    let (limit_val, offset_val) = match limit_clause {
626        None => (None, 0usize),
627        Some(ast::LimitClause::LimitOffset { limit, offset, .. }) => {
628            let lv = limit
629                .as_ref()
630                .and_then(crate::coerce::expr_as_usize_literal);
631            let ov = offset
632                .as_ref()
633                .and_then(|o| crate::coerce::expr_as_usize_literal(&o.value))
634                .unwrap_or(0);
635            (lv, ov)
636        }
637        Some(ast::LimitClause::OffsetCommaLimit { offset, limit }) => {
638            let lv = crate::coerce::expr_as_usize_literal(limit);
639            let ov = crate::coerce::expr_as_usize_literal(offset).unwrap_or(0);
640            (lv, ov)
641        }
642    };
643
644    match plan {
645        SqlPlan::Scan {
646            ref mut limit,
647            ref mut offset,
648            ..
649        } => {
650            *limit = limit_val;
651            *offset = offset_val;
652        }
653        SqlPlan::Aggregate {
654            limit: ref mut l, ..
655        } => {
656            if let Some(lv) = limit_val {
657                *l = lv;
658            }
659        }
660        _ => {}
661    }
662    plan
663}
664
665// ── Helpers ──
666
667/// Convert SELECT projection items.
668pub fn convert_projection(items: &[ast::SelectItem]) -> Result<Vec<Projection>> {
669    let mut result = Vec::new();
670    for item in items {
671        match item {
672            ast::SelectItem::UnnamedExpr(expr) => {
673                let sql_expr = convert_expr(expr)?;
674                match &sql_expr {
675                    SqlExpr::Column { table, name } => {
676                        result.push(Projection::Column(qualified_name(table.as_deref(), name)));
677                    }
678                    SqlExpr::Wildcard => {
679                        result.push(Projection::Star);
680                    }
681                    _ => {
682                        result.push(Projection::Computed {
683                            expr: sql_expr,
684                            alias: format!("{expr}"),
685                        });
686                    }
687                }
688            }
689            ast::SelectItem::ExprWithAlias { expr, alias } => {
690                let sql_expr = convert_expr(expr)?;
691                result.push(Projection::Computed {
692                    expr: sql_expr,
693                    alias: normalize_ident(alias),
694                });
695            }
696            ast::SelectItem::Wildcard(_) => {
697                result.push(Projection::Star);
698            }
699            ast::SelectItem::QualifiedWildcard(kind, _) => {
700                let table_name = match kind {
701                    ast::SelectItemQualifiedWildcardKind::ObjectName(name) => {
702                        crate::parser::normalize::normalize_object_name(name)
703                    }
704                    _ => String::new(),
705                };
706                result.push(Projection::QualifiedStar(table_name));
707            }
708        }
709    }
710    Ok(result)
711}
712
713/// Build a qualified column reference (`table.name` or just `name`).
714pub fn qualified_name(table: Option<&str>, name: &str) -> String {
715    table.map_or_else(|| name.to_string(), |table| format!("{table}.{name}"))
716}
717
718/// Convert a WHERE expression into a list of Filter.
719pub fn convert_where_to_filters(expr: &ast::Expr) -> Result<Vec<Filter>> {
720    let sql_expr = convert_expr(expr)?;
721    Ok(vec![Filter {
722        expr: FilterExpr::Expr(sql_expr),
723    }])
724}
725
726pub(crate) fn extract_func_args(func: &ast::Function) -> Result<Vec<ast::Expr>> {
727    match &func.args {
728        ast::FunctionArguments::List(args) => Ok(args
729            .args
730            .iter()
731            .filter_map(|a| match a {
732                ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(e.clone()),
733                _ => None,
734            })
735            .collect()),
736        _ => Ok(Vec::new()),
737    }
738}
739
740/// Evaluate a constant SqlExpr to a SqlValue. Delegates to the shared
741/// `const_fold::fold_constant` helper so that zero-arg scalar functions
742/// like `now()` and `current_timestamp` go through the same evaluator
743/// as the runtime expression path.
744fn eval_constant_expr(expr: &SqlExpr, functions: &FunctionRegistry) -> SqlValue {
745    super::const_fold::fold_constant(expr, functions).unwrap_or(SqlValue::Null)
746}
747
748/// Extract a geometry argument: handles ST_Point(lon, lat), ST_GeomFromGeoJSON('...'),
749/// or a raw string literal containing GeoJSON.
750fn extract_geometry_arg(expr: &ast::Expr) -> Result<String> {
751    match expr {
752        // ST_Point(lon, lat) → GeoJSON Point
753        ast::Expr::Function(func) => {
754            let name = func
755                .name
756                .0
757                .iter()
758                .map(|p| match p {
759                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
760                    _ => String::new(),
761                })
762                .collect::<Vec<_>>()
763                .join(".");
764            let args = extract_func_args(func)?;
765            match name.as_str() {
766                "st_point" if args.len() >= 2 => {
767                    let lon = extract_float(&args[0])?;
768                    let lat = extract_float(&args[1])?;
769                    Ok(format!(r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#))
770                }
771                "st_geomfromgeojson" if !args.is_empty() => extract_string_literal(&args[0]),
772                _ => Ok(format!("{expr}")),
773            }
774        }
775        // Raw string literal: assumed to be GeoJSON.
776        _ => extract_string_literal(expr).or_else(|_| Ok(format!("{expr}"))),
777    }
778}
779
780fn extract_column_name(expr: &ast::Expr) -> Result<String> {
781    match expr {
782        ast::Expr::Identifier(ident) => Ok(normalize_ident(ident)),
783        ast::Expr::CompoundIdentifier(parts) => Ok(parts
784            .iter()
785            .map(normalize_ident)
786            .collect::<Vec<_>>()
787            .join(".")),
788        _ => Err(SqlError::Unsupported {
789            detail: format!("expected column name, got: {expr}"),
790        }),
791    }
792}
793
794pub(crate) fn extract_string_literal(expr: &ast::Expr) -> Result<String> {
795    match expr {
796        ast::Expr::Value(v) => match &v.value {
797            ast::Value::SingleQuotedString(s) | ast::Value::DoubleQuotedString(s) => Ok(s.clone()),
798            _ => Err(SqlError::Unsupported {
799                detail: format!("expected string literal, got: {expr}"),
800            }),
801        },
802        _ => Err(SqlError::Unsupported {
803            detail: format!("expected string literal, got: {expr}"),
804        }),
805    }
806}
807
808pub(crate) fn extract_float(expr: &ast::Expr) -> Result<f64> {
809    match expr {
810        ast::Expr::Value(v) => match &v.value {
811            ast::Value::Number(n, _) => n.parse::<f64>().map_err(|_| SqlError::TypeMismatch {
812                detail: format!("expected number: {n}"),
813            }),
814            _ => Err(SqlError::TypeMismatch {
815                detail: format!("expected number, got: {expr}"),
816            }),
817        },
818        // Handle negative numbers: -73.9855 is parsed as UnaryOp { Minus, 73.9855 }
819        ast::Expr::UnaryOp {
820            op: ast::UnaryOperator::Minus,
821            expr: inner,
822        } => extract_float(inner).map(|f| -f),
823        _ => Err(SqlError::TypeMismatch {
824            detail: format!("expected number, got: {expr}"),
825        }),
826    }
827}
828
829/// Extract a float array from ARRAY[...] or make_array(...) expression.
830fn extract_float_array(expr: &ast::Expr) -> Result<Vec<f32>> {
831    match expr {
832        ast::Expr::Array(ast::Array { elem, .. }) => elem
833            .iter()
834            .map(|e| extract_float(e).map(|f| f as f32))
835            .collect(),
836        ast::Expr::Function(func) => {
837            let name = func
838                .name
839                .0
840                .iter()
841                .map(|p| match p {
842                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
843                    _ => String::new(),
844                })
845                .collect::<Vec<_>>()
846                .join(".");
847            if name == "make_array" || name == "array" {
848                let args = extract_func_args(func)?;
849                args.iter()
850                    .map(|e| extract_float(e).map(|f| f as f32))
851                    .collect()
852            } else {
853                Err(SqlError::Unsupported {
854                    detail: format!("expected array, got function: {name}"),
855                })
856            }
857        }
858        _ => Err(SqlError::Unsupported {
859            detail: format!("expected array literal, got: {expr}"),
860        }),
861    }
862}
863
864/// Check if a SELECT has the DISTINCT keyword.
865fn try_plan_join(
866    select: &Select,
867    scope: &TableScope,
868    catalog: &dyn SqlCatalog,
869    functions: &FunctionRegistry,
870) -> Result<Option<SqlPlan>> {
871    if select.from.len() != 1 {
872        return Ok(None);
873    }
874    let from = &select.from[0];
875    if from.joins.is_empty() {
876        return Ok(None);
877    }
878    super::join::plan_join_from_select(select, scope, catalog, functions)
879}
880
881/// Catalog wrapper that resolves CTE names as schemaless document collections.
882struct CteCatalog<'a> {
883    inner: &'a dyn SqlCatalog,
884    cte_names: Vec<String>,
885}
886
887impl SqlCatalog for CteCatalog<'_> {
888    fn get_collection(
889        &self,
890        name: &str,
891    ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
892        // Check CTE names first.
893        if self.cte_names.iter().any(|n| n == name) {
894            return Ok(Some(CollectionInfo {
895                name: name.into(),
896                engine: EngineType::DocumentSchemaless,
897                columns: Vec::new(),
898                primary_key: Some("id".into()),
899                has_auto_tier: false,
900                indexes: Vec::new(),
901            }));
902        }
903        self.inner.get_collection(name)
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910    use crate::functions::registry::FunctionRegistry;
911    use crate::parser::statement::parse_sql;
912    use sqlparser::ast::Statement;
913
914    struct TestCatalog;
915
916    impl SqlCatalog for TestCatalog {
917        fn get_collection(
918            &self,
919            name: &str,
920        ) -> std::result::Result<Option<CollectionInfo>, SqlCatalogError> {
921            let info = match name {
922                "products" => Some(CollectionInfo {
923                    name: "products".into(),
924                    engine: EngineType::DocumentSchemaless,
925                    columns: Vec::new(),
926                    primary_key: Some("id".into()),
927                    has_auto_tier: false,
928                    indexes: Vec::new(),
929                }),
930                "users" => Some(CollectionInfo {
931                    name: "users".into(),
932                    engine: EngineType::DocumentSchemaless,
933                    columns: Vec::new(),
934                    primary_key: Some("id".into()),
935                    has_auto_tier: false,
936                    indexes: Vec::new(),
937                }),
938                "orders" => Some(CollectionInfo {
939                    name: "orders".into(),
940                    engine: EngineType::DocumentSchemaless,
941                    columns: Vec::new(),
942                    primary_key: Some("id".into()),
943                    has_auto_tier: false,
944                    indexes: Vec::new(),
945                }),
946                "docs" => Some(CollectionInfo {
947                    name: "docs".into(),
948                    engine: EngineType::DocumentSchemaless,
949                    columns: Vec::new(),
950                    primary_key: Some("id".into()),
951                    has_auto_tier: false,
952                    indexes: Vec::new(),
953                }),
954                "tags" => Some(CollectionInfo {
955                    name: "tags".into(),
956                    engine: EngineType::DocumentSchemaless,
957                    columns: Vec::new(),
958                    primary_key: Some("id".into()),
959                    has_auto_tier: false,
960                    indexes: Vec::new(),
961                }),
962                "user_prefs" => Some(CollectionInfo {
963                    name: "user_prefs".into(),
964                    engine: EngineType::KeyValue,
965                    columns: Vec::new(),
966                    primary_key: Some("key".into()),
967                    has_auto_tier: false,
968                    indexes: Vec::new(),
969                }),
970                _ => None,
971            };
972            Ok(info)
973        }
974    }
975
976    fn plan_select_sql(sql: &str) -> SqlPlan {
977        let statements = parse_sql(sql).unwrap();
978        let Statement::Query(query) = &statements[0] else {
979            panic!("expected query statement");
980        };
981        plan_query(query, &TestCatalog, &FunctionRegistry::new()).unwrap()
982    }
983
984    #[test]
985    fn aggregate_subquery_join_filters_input_before_aggregation() {
986        let plan = plan_select_sql(
987            "SELECT AVG(price) FROM products WHERE category IN (SELECT DISTINCT category FROM products WHERE qty > 100)",
988        );
989
990        let SqlPlan::Aggregate { input, .. } = plan else {
991            panic!("expected aggregate plan");
992        };
993
994        let SqlPlan::Join {
995            left,
996            join_type,
997            on,
998            ..
999        } = *input
1000        else {
1001            panic!("expected semi-join below aggregate");
1002        };
1003
1004        assert_eq!(join_type, JoinType::Semi);
1005        assert_eq!(on, vec![("category".into(), "category".into())]);
1006        assert!(matches!(*left, SqlPlan::Scan { .. }));
1007    }
1008
1009    #[test]
1010    fn scalar_subquery_defers_projection_until_after_join_filter() {
1011        let plan = plan_select_sql(
1012            "SELECT user_id FROM orders WHERE amount > (SELECT AVG(amount) FROM orders)",
1013        );
1014
1015        let SqlPlan::Join {
1016            left,
1017            projection,
1018            filters,
1019            ..
1020        } = plan
1021        else {
1022            panic!("expected join plan");
1023        };
1024
1025        let SqlPlan::Scan {
1026            projection: scan_projection,
1027            ..
1028        } = *left
1029        else {
1030            panic!("expected scan on join left");
1031        };
1032
1033        assert!(scan_projection.is_empty(), "scan projected too early");
1034        assert_eq!(projection.len(), 1);
1035        match &projection[0] {
1036            Projection::Column(name) => assert_eq!(name, "user_id"),
1037            other => panic!("expected user_id projection, got {other:?}"),
1038        }
1039        assert!(
1040            !filters.is_empty(),
1041            "scalar comparison should stay post-join"
1042        );
1043    }
1044
1045    #[test]
1046    fn chained_join_preserves_qualified_on_keys() {
1047        let plan = plan_select_sql(
1048            "SELECT d.name, t.tag, p.theme \
1049             FROM docs d \
1050             LEFT JOIN tags t ON d.id = t.doc_id \
1051             INNER JOIN user_prefs p ON d.id = p.key",
1052        );
1053
1054        let SqlPlan::Join { left, on, .. } = plan else {
1055            panic!("expected outer join plan");
1056        };
1057        assert_eq!(on, vec![("d.id".into(), "p.key".into())]);
1058
1059        let SqlPlan::Join { on: inner_on, .. } = *left else {
1060            panic!("expected nested left join");
1061        };
1062        assert_eq!(inner_on, vec![("d.id".into(), "t.doc_id".into())]);
1063    }
1064}