Skip to main content

nodedb_sql/planner/
dml_helpers.rs

1// SPDX-License-Identifier: Apache-2.0
2
3use sqlparser::ast;
4
5use crate::error::{Result, SqlError};
6use crate::parser::normalize::{normalize_ident, normalize_object_name_checked};
7use crate::resolver::expr::convert_value;
8use crate::types::*;
9
10pub(super) fn convert_value_rows(
11    columns: &[String],
12    rows: &[Vec<ast::Expr>],
13) -> Result<Vec<Vec<(String, SqlValue)>>> {
14    rows.iter()
15        .map(|row| {
16            row.iter()
17                .enumerate()
18                .map(|(i, expr)| {
19                    let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
20                    let val = expr_to_sql_value(expr)?;
21                    Ok((col, val))
22                })
23                .collect::<Result<Vec<_>>>()
24        })
25        .collect()
26}
27
28pub(super) fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
29    match expr {
30        ast::Expr::Value(v) => convert_value(&v.value),
31        ast::Expr::UnaryOp {
32            op: ast::UnaryOperator::Minus,
33            expr: inner,
34        } => {
35            let val = expr_to_sql_value(inner)?;
36            match val {
37                SqlValue::Int(n) => Ok(SqlValue::Int(-n)),
38                SqlValue::Float(f) => Ok(SqlValue::Float(-f)),
39                SqlValue::Decimal(d) => Ok(SqlValue::Decimal(-d)),
40                _ => Err(SqlError::TypeMismatch {
41                    detail: "cannot negate non-numeric value".into(),
42                }),
43            }
44        }
45        ast::Expr::Array(ast::Array { elem, .. }) => {
46            let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
47            Ok(SqlValue::Array(vals))
48        }
49        ast::Expr::Function(func) => {
50            let func_name = func
51                .name
52                .0
53                .iter()
54                .map(|p| match p {
55                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
56                    _ => String::new(),
57                })
58                .collect::<Vec<_>>()
59                .join(".")
60                .to_lowercase();
61            match func_name.as_str() {
62                "st_point" => {
63                    let args = super::select::extract_func_args(func)?;
64                    if args.len() >= 2 {
65                        let lon = super::select::extract_float(&args[0])?;
66                        let lat = super::select::extract_float(&args[1])?;
67                        Ok(SqlValue::String(format!(
68                            r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
69                        )))
70                    } else {
71                        Ok(SqlValue::String(format!("{expr}")))
72                    }
73                }
74                "st_geomfromgeojson" => {
75                    let args = super::select::extract_func_args(func)?;
76                    if !args.is_empty() {
77                        let s = super::select::extract_string_literal(&args[0])?;
78                        Ok(SqlValue::String(s))
79                    } else {
80                        Ok(SqlValue::String(format!("{expr}")))
81                    }
82                }
83                _ => {
84                    if let Ok(sql_expr) = crate::resolver::expr::convert_expr(expr)
85                        && let Some(v) = super::const_fold::fold_constant_default(&sql_expr)
86                    {
87                        Ok(v)
88                    } else {
89                        Ok(SqlValue::String(format!("{expr}")))
90                    }
91                }
92            }
93        }
94        _ => Err(SqlError::Unsupported {
95            detail: format!("value expression: {expr}"),
96        }),
97    }
98}
99
100pub(super) fn extract_table_name_from_table_with_joins(
101    table: &ast::TableWithJoins,
102) -> Result<String> {
103    match &table.relation {
104        ast::TableFactor::Table { name, .. } => Ok(normalize_object_name_checked(name)?),
105        _ => Err(SqlError::Unsupported {
106            detail: "non-table target in DML".into(),
107        }),
108    }
109}
110
111/// Extract point-operation keys from WHERE clause (WHERE pk = literal OR pk IN (...)).
112pub fn extract_point_keys(selection: Option<&ast::Expr>, info: &CollectionInfo) -> Vec<SqlValue> {
113    let pk = match &info.primary_key {
114        Some(pk) => pk.clone(),
115        None => return Vec::new(),
116    };
117
118    let expr = match selection {
119        Some(e) => e,
120        None => return Vec::new(),
121    };
122
123    let mut keys = Vec::new();
124    collect_pk_equalities(expr, &pk, &mut keys);
125    keys
126}
127
128fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
129    match expr {
130        ast::Expr::BinaryOp {
131            left,
132            op: ast::BinaryOperator::Eq,
133            right,
134        } => {
135            if is_column(left, pk)
136                && let Ok(v) = expr_to_sql_value(right)
137            {
138                keys.push(v);
139            } else if is_column(right, pk)
140                && let Ok(v) = expr_to_sql_value(left)
141            {
142                keys.push(v);
143            }
144        }
145        ast::Expr::BinaryOp {
146            left,
147            op: ast::BinaryOperator::Or,
148            right,
149        } => {
150            collect_pk_equalities(left, pk, keys);
151            collect_pk_equalities(right, pk, keys);
152        }
153        ast::Expr::InList {
154            expr: inner,
155            list,
156            negated: false,
157        } if is_column(inner, pk) => {
158            for item in list {
159                if let Ok(v) = expr_to_sql_value(item) {
160                    keys.push(v);
161                }
162            }
163        }
164        _ => {}
165    }
166}
167
168fn is_column(expr: &ast::Expr, name: &str) -> bool {
169    match expr {
170        ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
171        // Three or more parts: schema.table.col — never matches a plain pk name.
172        ast::Expr::CompoundIdentifier(parts) if parts.len() >= 3 => false,
173        ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
174            normalize_ident(&parts[1]) == name
175        }
176        _ => false,
177    }
178}
179
180/// Build a `SqlPlan::VectorPrimaryInsert` from parsed rows.
181///
182/// Extracts the vector-field column into `vector: Vec<f32>` and collects
183/// all remaining columns into `payload_fields`. Rows missing the vector
184/// column are rejected.
185pub(super) fn build_vector_primary_insert_plan(
186    collection: &str,
187    vpc: &nodedb_types::VectorPrimaryConfig,
188    _columns: &[String],
189    rows: Vec<Vec<(String, SqlValue)>>,
190) -> Result<Vec<SqlPlan>> {
191    let mut result_rows = Vec::with_capacity(rows.len());
192    for row in rows {
193        let mut vector: Option<Vec<f32>> = None;
194        let mut payload_fields = std::collections::HashMap::new();
195
196        for (col, val) in row {
197            if col == vpc.vector_field {
198                match val {
199                    SqlValue::Array(items) => {
200                        let floats: Result<Vec<f32>> = items
201                            .iter()
202                            .map(|v| match v {
203                                SqlValue::Float(f) => Ok(*f as f32),
204                                SqlValue::Int(i) => Ok(*i as f32),
205                                SqlValue::Decimal(d) => {
206                                    use rust_decimal::prelude::ToPrimitive;
207                                    d.to_f32().ok_or_else(|| SqlError::Parse {
208                                        detail: format!(
209                                            "vector element decimal '{d}' is out of f32 range"
210                                        ),
211                                    })
212                                }
213                                other => Err(SqlError::Parse {
214                                    detail: format!(
215                                        "vector field must contain numbers, got {other:?}"
216                                    ),
217                                }),
218                            })
219                            .collect();
220                        vector = Some(floats?);
221                    }
222                    other => {
223                        return Err(SqlError::Parse {
224                            detail: format!(
225                                "vector field '{}' must be an array literal, got {other:?}",
226                                vpc.vector_field
227                            ),
228                        });
229                    }
230                }
231            } else {
232                payload_fields.insert(col, val);
233            }
234        }
235
236        let vector = vector.ok_or_else(|| SqlError::Parse {
237            detail: format!(
238                "vector-primary INSERT missing required vector field '{}'",
239                vpc.vector_field
240            ),
241        })?;
242
243        result_rows.push(VectorPrimaryRow {
244            surrogate: nodedb_types::Surrogate::ZERO,
245            vector,
246            payload_fields,
247        });
248    }
249
250    Ok(vec![SqlPlan::VectorPrimaryInsert {
251        collection: collection.to_string(),
252        field: vpc.vector_field.clone(),
253        quantization: vpc.quantization,
254        payload_indexes: vpc.payload_indexes.clone(),
255        rows: result_rows,
256    }])
257}
258
259/// Build a `SqlPlan::KvInsert` from a VALUES clause. Shared by plain INSERT,
260/// UPSERT, and `INSERT ... ON CONFLICT (key) DO UPDATE` — the three paths
261/// differ only in `intent` and `on_conflict_updates`, never in how entries
262/// are extracted from the row exprs.
263///
264/// `pk_col` is the schema-defined primary-key column name from
265/// `CollectionInfo::primary_key`.  When supplied, that column is used as
266/// the KV key regardless of whether it is named `"key"`.  Falls back to
267/// the literal name `"key"` when `pk_col` is `None` (legacy / generic
268/// KV collections that use the built-in key/value column convention).
269pub(super) fn build_kv_insert_plan(
270    table_name: String,
271    columns: &[String],
272    rows_ast: &[Vec<ast::Expr>],
273    intent: KvInsertIntent,
274    on_conflict_updates: Vec<(String, SqlExpr)>,
275    pk_col: Option<&str>,
276) -> Result<Vec<SqlPlan>> {
277    let key_col_name = pk_col.unwrap_or("key");
278    let key_idx = columns.iter().position(|c| c == key_col_name);
279    let ttl_idx = columns.iter().position(|c| c == "ttl");
280    // When using a named primary-key column (e.g. `k STRING PRIMARY KEY`), we
281    // store the key bytes in the KV key slot AND also keep the column in the
282    // value map.  This allows scan filters on the primary-key column (e.g.
283    // `WHERE k = 'x'`) and projection (e.g. `SELECT k FROM ...`) to work
284    // without teaching the KV scan handler to inspect the raw key bytes.
285    // The only column we exclude from the value map is the built-in `"key"`
286    // sentinel (used by raw key/value KV collections) and `"ttl"`.
287    let exclude_from_value: std::collections::HashSet<usize> = {
288        let mut s = std::collections::HashSet::new();
289        // Exclude the raw "key" sentinel column (not a named PK column).
290        if key_col_name == "key"
291            && let Some(idx) = key_idx
292        {
293            s.insert(idx);
294        }
295        if let Some(idx) = ttl_idx {
296            s.insert(idx);
297        }
298        s
299    };
300    let mut entries = Vec::with_capacity(rows_ast.len());
301    let mut ttl_secs: u64 = 0;
302    for row_exprs in rows_ast {
303        let key_val = match key_idx {
304            Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
305            None => SqlValue::String(String::new()),
306        };
307        if let Some(idx) = ttl_idx {
308            match expr_to_sql_value(&row_exprs[idx]) {
309                Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
310                Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
311                _ => {}
312            }
313        }
314        let value_cols: Vec<(String, SqlValue)> = columns
315            .iter()
316            .enumerate()
317            .filter(|(i, _)| !exclude_from_value.contains(i))
318            .map(|(i, col)| {
319                let val = expr_to_sql_value(&row_exprs[i])?;
320                Ok((col.clone(), val))
321            })
322            .collect::<Result<Vec<_>>>()?;
323        entries.push((key_val, value_cols));
324    }
325    Ok(vec![SqlPlan::KvInsert {
326        collection: table_name,
327        entries,
328        ttl_secs,
329        intent,
330        on_conflict_updates,
331    }])
332}