Skip to main content

nodedb_sql/planner/select/
helpers.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Shared helpers for SELECT planning: projection conversion, WHERE
4//! filter conversion, and AST literal extraction utilities.
5
6use sqlparser::ast;
7
8use crate::error::{Result, SqlError};
9use crate::functions::registry::FunctionRegistry;
10use crate::parser::normalize::{SCHEMA_QUALIFIED_MSG, normalize_ident};
11use crate::resolver::expr::convert_expr;
12use crate::types::*;
13
14/// Convert SELECT projection items.
15pub fn convert_projection(items: &[ast::SelectItem]) -> Result<Vec<Projection>> {
16    let mut result = Vec::new();
17    for item in items {
18        match item {
19            ast::SelectItem::UnnamedExpr(expr) => {
20                let sql_expr = convert_expr(expr)?;
21                match &sql_expr {
22                    SqlExpr::Column { table, name } => {
23                        result.push(Projection::Column(qualified_name(table.as_deref(), name)));
24                    }
25                    SqlExpr::Wildcard => {
26                        result.push(Projection::Star);
27                    }
28                    _ => {
29                        result.push(Projection::Computed {
30                            expr: sql_expr,
31                            alias: format!("{expr}").to_lowercase(),
32                        });
33                    }
34                }
35            }
36            ast::SelectItem::ExprWithAlias { expr, alias } => {
37                let sql_expr = convert_expr(expr)?;
38                result.push(Projection::Computed {
39                    expr: sql_expr,
40                    alias: normalize_ident(alias),
41                });
42            }
43            ast::SelectItem::Wildcard(_) => {
44                result.push(Projection::Star);
45            }
46            ast::SelectItem::QualifiedWildcard(kind, _) => {
47                let table_name = match kind {
48                    ast::SelectItemQualifiedWildcardKind::ObjectName(name) => {
49                        crate::parser::normalize::normalize_object_name_checked(name)?
50                    }
51                    _ => String::new(),
52                };
53                result.push(Projection::QualifiedStar(table_name));
54            }
55        }
56    }
57    Ok(result)
58}
59
60/// Build a qualified column reference (`table.name` or just `name`).
61pub fn qualified_name(table: Option<&str>, name: &str) -> String {
62    table.map_or_else(|| name.to_string(), |table| format!("{table}.{name}"))
63}
64
65/// Convert a WHERE expression into a list of Filter.
66pub fn convert_where_to_filters(expr: &ast::Expr) -> Result<Vec<Filter>> {
67    let sql_expr = convert_expr(expr)?;
68    Ok(vec![Filter {
69        expr: FilterExpr::Expr(sql_expr),
70    }])
71}
72
73pub fn extract_func_args(func: &ast::Function) -> Result<Vec<ast::Expr>> {
74    match &func.args {
75        ast::FunctionArguments::List(args) => Ok(args
76            .args
77            .iter()
78            .filter_map(|a| match a {
79                ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e)) => Some(e.clone()),
80                _ => None,
81            })
82            .collect()),
83        _ => Ok(Vec::new()),
84    }
85}
86
87/// Evaluate a constant SqlExpr to a SqlValue. Delegates to the shared
88/// `const_fold::fold_constant` helper so that zero-arg scalar functions
89/// like `now()` and `current_timestamp` go through the same evaluator
90/// as the runtime expression path.
91pub(super) fn eval_constant_expr(expr: &SqlExpr, functions: &FunctionRegistry) -> SqlValue {
92    crate::planner::const_fold::fold_constant(expr, functions).unwrap_or(SqlValue::Null)
93}
94
95/// Extract a geometry argument: handles ST_Point(lon, lat), ST_GeomFromGeoJSON('...'),
96/// or a raw string literal containing GeoJSON.
97pub(super) fn extract_geometry_arg(expr: &ast::Expr) -> Result<String> {
98    match expr {
99        // ST_Point(lon, lat) → GeoJSON Point
100        ast::Expr::Function(func) => {
101            let name = func
102                .name
103                .0
104                .iter()
105                .map(|p| match p {
106                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
107                    _ => String::new(),
108                })
109                .collect::<Vec<_>>()
110                .join(".");
111            let args = extract_func_args(func)?;
112            match name.as_str() {
113                "st_point" if args.len() >= 2 => {
114                    let lon = extract_float(&args[0])?;
115                    let lat = extract_float(&args[1])?;
116                    Ok(format!(r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#))
117                }
118                "st_geomfromgeojson" if !args.is_empty() => extract_string_literal(&args[0]),
119                _ => Ok(format!("{expr}")),
120            }
121        }
122        // Raw string literal: assumed to be GeoJSON.
123        _ => extract_string_literal(expr).or_else(|_| Ok(format!("{expr}"))),
124    }
125}
126
127pub(super) fn extract_column_name(expr: &ast::Expr) -> Result<String> {
128    match expr {
129        ast::Expr::Identifier(ident) => Ok(normalize_ident(ident)),
130        ast::Expr::CompoundIdentifier(parts) if parts.len() >= 3 => {
131            let qualified: String = parts
132                .iter()
133                .map(normalize_ident)
134                .collect::<Vec<_>>()
135                .join(".");
136            Err(SqlError::Unsupported {
137                detail: format!(
138                    "schema-qualified column reference '{qualified}': {SCHEMA_QUALIFIED_MSG}"
139                ),
140            })
141        }
142        ast::Expr::CompoundIdentifier(parts) => Ok(parts
143            .iter()
144            .map(normalize_ident)
145            .collect::<Vec<_>>()
146            .join(".")),
147        _ => Err(SqlError::Unsupported {
148            detail: format!("expected column name, got: {expr}"),
149        }),
150    }
151}
152
153pub fn extract_string_literal(expr: &ast::Expr) -> Result<String> {
154    match expr {
155        ast::Expr::Value(v) => match &v.value {
156            ast::Value::SingleQuotedString(s) => Ok(s.clone()),
157            _ => Err(SqlError::Unsupported {
158                detail: format!("expected string literal, got: {expr}"),
159            }),
160        },
161        _ => Err(SqlError::Unsupported {
162            detail: format!("expected string literal, got: {expr}"),
163        }),
164    }
165}
166
167pub fn extract_float(expr: &ast::Expr) -> Result<f64> {
168    match expr {
169        ast::Expr::Value(v) => match &v.value {
170            ast::Value::Number(n, _) => n.parse::<f64>().map_err(|_| SqlError::TypeMismatch {
171                detail: format!("expected number: {n}"),
172            }),
173            _ => Err(SqlError::TypeMismatch {
174                detail: format!("expected number, got: {expr}"),
175            }),
176        },
177        // Handle negative numbers: -73.9855 is parsed as UnaryOp { Minus, 73.9855 }
178        ast::Expr::UnaryOp {
179            op: ast::UnaryOperator::Minus,
180            expr: inner,
181        } => extract_float(inner).map(|f| -f),
182        _ => Err(SqlError::TypeMismatch {
183            detail: format!("expected number, got: {expr}"),
184        }),
185    }
186}
187
188/// Map a vector distance function name to its `DistanceMetric`.
189///
190/// `vector_distance` (and the rewritten `<->` operator) → L2;
191/// `vector_cosine_distance` (and `<=>`) → Cosine;
192/// `vector_neg_inner_product` (and `<#>`) → InnerProduct.
193/// Unknown names default to L2 — callers must gate on a `VectorSearch`
194/// search-trigger before invoking this so unknown names cannot leak in.
195pub(super) fn metric_from_func_name(name: &str) -> DistanceMetric {
196    if name.eq_ignore_ascii_case("vector_cosine_distance") {
197        DistanceMetric::Cosine
198    } else if name.eq_ignore_ascii_case("vector_neg_inner_product") {
199        DistanceMetric::InnerProduct
200    } else {
201        DistanceMetric::L2
202    }
203}
204
205/// Extract a float array from ARRAY[...] or make_array(...) expression.
206pub(super) fn extract_float_array(expr: &ast::Expr) -> Result<Vec<f32>> {
207    match expr {
208        ast::Expr::Array(ast::Array { elem, .. }) => elem
209            .iter()
210            .map(|e| extract_float(e).map(|f| f as f32))
211            .collect(),
212        ast::Expr::Function(func) => {
213            let name = func
214                .name
215                .0
216                .iter()
217                .map(|p| match p {
218                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
219                    _ => String::new(),
220                })
221                .collect::<Vec<_>>()
222                .join(".");
223            if name == "make_array" || name == "array" {
224                let args = extract_func_args(func)?;
225                args.iter()
226                    .map(|e| extract_float(e).map(|f| f as f32))
227                    .collect()
228            } else {
229                Err(SqlError::Unsupported {
230                    detail: format!("expected array, got function: {name}"),
231                })
232            }
233        }
234        _ => Err(SqlError::Unsupported {
235            detail: format!("expected array literal, got: {expr}"),
236        }),
237    }
238}