Skip to main content

nodedb_sql/planner/
dml_helpers.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use sqlparser::ast;
4
5use crate::error::{Result, SqlError};
6use crate::parser::normalize::{normalize_ident, normalize_object_name_checked};
7use crate::resolver::expr::convert_value;
8use crate::types::*;
9
10pub(super) fn convert_value_rows(
11    columns: &[String],
12    rows: &[Vec<ast::Expr>],
13) -> Result<Vec<Vec<(String, SqlValue)>>> {
14    rows.iter()
15        .map(|row| {
16            row.iter()
17                .enumerate()
18                .map(|(i, expr)| {
19                    let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
20                    let val = expr_to_sql_value(expr)?;
21                    Ok((col, val))
22                })
23                .collect::<Result<Vec<_>>>()
24        })
25        .collect()
26}
27
28pub(super) fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
29    match expr {
30        ast::Expr::Value(v) => convert_value(&v.value),
31        // Array literals lower element-wise into `SqlValue::Array`; there is
32        // no array-literal `SqlValue` the constant folder could produce.
33        ast::Expr::Array(ast::Array { elem, .. }) => {
34            let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
35            Ok(SqlValue::Array(vals))
36        }
37        // `ST_Point(...)` / `ST_GeomFromGeoJSON(...)` synthesise a GeoJSON
38        // string in place rather than resolving as registered scalar
39        // functions, so they keep their bespoke handling.
40        ast::Expr::Function(func) => match SpatialConstructor::from_function(func) {
41            Some(ctor) => spatial_constructor_to_value(ctor, func),
42            // Non-spatial functions (`now()`, `date_add(...)`, registered
43            // scalars) fold through the shared pipeline below.
44            None => fold_constant_value(expr),
45        },
46        // Everything else — `::TYPE` / `CAST(... AS TYPE)` casts, arithmetic,
47        // string concatenation, parenthesised literals — goes through the
48        // same resolver and constant folder the `SELECT` projection path
49        // uses, so the two surfaces never drift. Only genuinely row- or
50        // runtime-dependent expressions (column refs, subqueries, unknown
51        // functions) fail here.
52        _ => fold_constant_value(expr),
53    }
54}
55
56fn fold_constant_value(expr: &ast::Expr) -> Result<SqlValue> {
57    let sql_expr = crate::resolver::expr::convert_expr(expr)?;
58    super::const_fold::fold_constant_default(&sql_expr).ok_or_else(|| SqlError::Unsupported {
59        detail: format!("value expression: {expr}"),
60    })
61}
62
63/// Spatial constructors that synthesise a GeoJSON string literal directly
64/// in value position (rather than going through the registered scalar
65/// evaluator). Closed set — adding a new constructor requires a new variant,
66/// which forces handling in `spatial_constructor_to_value`.
67#[derive(Copy, Clone)]
68enum SpatialConstructor {
69    Point,
70    GeomFromGeoJson,
71}
72
73impl SpatialConstructor {
74    fn from_function(func: &ast::Function) -> Option<Self> {
75        let name = func
76            .name
77            .0
78            .iter()
79            .map(|p| match p {
80                ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
81                _ => String::new(),
82            })
83            .collect::<Vec<_>>()
84            .join(".")
85            .to_lowercase();
86        match name.as_str() {
87            "st_point" => Some(Self::Point),
88            "st_geomfromgeojson" => Some(Self::GeomFromGeoJson),
89            _ => None,
90        }
91    }
92
93    fn display_name(self) -> &'static str {
94        match self {
95            Self::Point => "ST_Point",
96            Self::GeomFromGeoJson => "ST_GeomFromGeoJSON",
97        }
98    }
99}
100
101fn spatial_constructor_to_value(
102    ctor: SpatialConstructor,
103    func: &ast::Function,
104) -> Result<SqlValue> {
105    let args = super::select::extract_func_args(func)?;
106    match ctor {
107        SpatialConstructor::Point => {
108            if args.len() < 2 {
109                return Err(SqlError::InvalidFunction {
110                    detail: format!(
111                        "{} requires 2 arguments (longitude, latitude), got {}",
112                        ctor.display_name(),
113                        args.len()
114                    ),
115                });
116            }
117            let lon = super::select::extract_float(&args[0])?;
118            let lat = super::select::extract_float(&args[1])?;
119            Ok(SqlValue::String(format!(
120                r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
121            )))
122        }
123        SpatialConstructor::GeomFromGeoJson => {
124            if args.is_empty() {
125                return Err(SqlError::InvalidFunction {
126                    detail: format!(
127                        "{} requires 1 argument (GeoJSON string)",
128                        ctor.display_name()
129                    ),
130                });
131            }
132            let s = super::select::extract_string_literal(&args[0])?;
133            Ok(SqlValue::String(s))
134        }
135    }
136}
137
138pub(super) fn extract_table_name_from_table_with_joins(
139    table: &ast::TableWithJoins,
140) -> Result<String> {
141    match &table.relation {
142        ast::TableFactor::Table { name, .. } => Ok(normalize_object_name_checked(name)?),
143        _ => Err(SqlError::Unsupported {
144            detail: "non-table target in DML".into(),
145        }),
146    }
147}
148
149/// Extract point-operation keys from WHERE clause (WHERE pk = literal OR pk IN (...)).
150pub fn extract_point_keys(selection: Option<&ast::Expr>, info: &CollectionInfo) -> Vec<SqlValue> {
151    let pk = match &info.primary_key {
152        Some(pk) => pk.clone(),
153        None => return Vec::new(),
154    };
155
156    let expr = match selection {
157        Some(e) => e,
158        None => return Vec::new(),
159    };
160
161    let mut keys = Vec::new();
162    collect_pk_equalities(expr, &pk, &mut keys);
163    keys
164}
165
166fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
167    match expr {
168        ast::Expr::BinaryOp {
169            left,
170            op: ast::BinaryOperator::Eq,
171            right,
172        } => {
173            if is_column(left, pk)
174                && let Ok(v) = expr_to_sql_value(right)
175            {
176                keys.push(v);
177            } else if is_column(right, pk)
178                && let Ok(v) = expr_to_sql_value(left)
179            {
180                keys.push(v);
181            }
182        }
183        ast::Expr::BinaryOp {
184            left,
185            op: ast::BinaryOperator::Or,
186            right,
187        } => {
188            collect_pk_equalities(left, pk, keys);
189            collect_pk_equalities(right, pk, keys);
190        }
191        ast::Expr::InList {
192            expr: inner,
193            list,
194            negated: false,
195        } if is_column(inner, pk) => {
196            for item in list {
197                if let Ok(v) = expr_to_sql_value(item) {
198                    keys.push(v);
199                }
200            }
201        }
202        _ => {}
203    }
204}
205
206fn is_column(expr: &ast::Expr, name: &str) -> bool {
207    match expr {
208        ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
209        // Three or more parts: schema.table.col — never matches a plain pk name.
210        ast::Expr::CompoundIdentifier(parts) if parts.len() >= 3 => false,
211        ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
212            normalize_ident(&parts[1]) == name
213        }
214        _ => false,
215    }
216}
217
218/// Build a `SqlPlan::VectorPrimaryInsert` from parsed rows.
219///
220/// Extracts the vector-field column into `vector: Vec<f32>` and collects
221/// all remaining columns into `payload_fields`. Rows missing the vector
222/// column are rejected.
223pub(super) fn build_vector_primary_insert_plan(
224    collection: &str,
225    vpc: &nodedb_types::VectorPrimaryConfig,
226    _columns: &[String],
227    rows: Vec<Vec<(String, SqlValue)>>,
228) -> Result<Vec<SqlPlan>> {
229    let mut result_rows = Vec::with_capacity(rows.len());
230    for row in rows {
231        let mut vector: Option<Vec<f32>> = None;
232        let mut payload_fields = std::collections::HashMap::new();
233
234        for (col, val) in row {
235            if col == vpc.vector_field {
236                match val {
237                    SqlValue::Array(items) => {
238                        let floats: Result<Vec<f32>> = items
239                            .iter()
240                            .map(|v| match v {
241                                SqlValue::Float(f) => Ok(*f as f32),
242                                SqlValue::Int(i) => Ok(*i as f32),
243                                SqlValue::Decimal(d) => {
244                                    use rust_decimal::prelude::ToPrimitive;
245                                    d.to_f32().ok_or_else(|| SqlError::Parse {
246                                        detail: format!(
247                                            "vector element decimal '{d}' is out of f32 range"
248                                        ),
249                                    })
250                                }
251                                other => Err(SqlError::Parse {
252                                    detail: format!(
253                                        "vector field must contain numbers, got {other:?}"
254                                    ),
255                                }),
256                            })
257                            .collect();
258                        vector = Some(floats?);
259                    }
260                    other => {
261                        return Err(SqlError::Parse {
262                            detail: format!(
263                                "vector field '{}' must be an array literal, got {other:?}",
264                                vpc.vector_field
265                            ),
266                        });
267                    }
268                }
269            } else {
270                payload_fields.insert(col, val);
271            }
272        }
273
274        let vector = vector.ok_or_else(|| SqlError::Parse {
275            detail: format!(
276                "vector-primary INSERT missing required vector field '{}'",
277                vpc.vector_field
278            ),
279        })?;
280
281        result_rows.push(VectorPrimaryRow {
282            surrogate: nodedb_types::Surrogate::ZERO,
283            vector,
284            payload_fields,
285        });
286    }
287
288    Ok(vec![SqlPlan::VectorPrimaryInsert {
289        collection: collection.to_string(),
290        field: vpc.vector_field.clone(),
291        quantization: vpc.quantization,
292        payload_indexes: vpc.payload_indexes.clone(),
293        rows: result_rows,
294    }])
295}
296
297/// Build a `SqlPlan::KvInsert` from a VALUES clause. Shared by plain INSERT,
298/// UPSERT, and `INSERT ... ON CONFLICT (key) DO UPDATE` — the three paths
299/// differ only in `intent` and `on_conflict_updates`, never in how entries
300/// are extracted from the row exprs.
301///
302/// `pk_col` is the schema-defined primary-key column name from
303/// `CollectionInfo::primary_key`.  When supplied, that column is used as
304/// the KV key regardless of whether it is named `"key"`.  Falls back to
305/// the literal name `"key"` when `pk_col` is `None` (legacy / generic
306/// KV collections that use the built-in key/value column convention).
307pub(super) fn build_kv_insert_plan(
308    table_name: String,
309    columns: &[String],
310    rows_ast: &[Vec<ast::Expr>],
311    intent: KvInsertIntent,
312    on_conflict_updates: Vec<(String, SqlExpr)>,
313    pk_col: Option<&str>,
314) -> Result<Vec<SqlPlan>> {
315    let key_col_name = pk_col.unwrap_or("key");
316    let key_idx = columns.iter().position(|c| c == key_col_name);
317    let ttl_idx = columns.iter().position(|c| c == "ttl");
318    // When using a named primary-key column (e.g. `k STRING PRIMARY KEY`), we
319    // store the key bytes in the KV key slot AND also keep the column in the
320    // value map.  This allows scan filters on the primary-key column (e.g.
321    // `WHERE k = 'x'`) and projection (e.g. `SELECT k FROM ...`) to work
322    // without teaching the KV scan handler to inspect the raw key bytes.
323    // The only column we exclude from the value map is the built-in `"key"`
324    // sentinel (used by raw key/value KV collections) and `"ttl"`.
325    let exclude_from_value: std::collections::HashSet<usize> = {
326        let mut s = std::collections::HashSet::new();
327        // Exclude the raw "key" sentinel column (not a named PK column).
328        if key_col_name == "key"
329            && let Some(idx) = key_idx
330        {
331            s.insert(idx);
332        }
333        if let Some(idx) = ttl_idx {
334            s.insert(idx);
335        }
336        s
337    };
338    let mut entries = Vec::with_capacity(rows_ast.len());
339    let mut ttl_secs: u64 = 0;
340    for row_exprs in rows_ast {
341        let key_val = match key_idx {
342            Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
343            None => SqlValue::String(String::new()),
344        };
345        if let Some(idx) = ttl_idx {
346            match expr_to_sql_value(&row_exprs[idx]) {
347                Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
348                Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
349                _ => {}
350            }
351        }
352        let value_cols: Vec<(String, SqlValue)> = columns
353            .iter()
354            .enumerate()
355            .filter(|(i, _)| !exclude_from_value.contains(i))
356            .map(|(i, col)| {
357                let val = expr_to_sql_value(&row_exprs[i])?;
358                Ok((col.clone(), val))
359            })
360            .collect::<Result<Vec<_>>>()?;
361        entries.push((key_val, value_cols));
362    }
363    Ok(vec![SqlPlan::KvInsert {
364        collection: table_name,
365        entries,
366        ttl_secs,
367        intent,
368        on_conflict_updates,
369    }])
370}