Skip to main content

nodedb_sql/planner/
dml.rs

1//! INSERT, UPDATE, DELETE planning.
2
3use sqlparser::ast::{self};
4
5use crate::engine_rules::{self, DeleteParams, InsertParams, UpdateParams};
6use crate::error::{Result, SqlError};
7use crate::parser::normalize::{normalize_ident, normalize_object_name};
8use crate::resolver::expr::{convert_expr, convert_value};
9use crate::types::*;
10
11/// Plan an INSERT statement.
12pub fn plan_insert(ins: &ast::Insert, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
13    let table_name = match &ins.table {
14        ast::TableObject::TableName(name) => normalize_object_name(name),
15        ast::TableObject::TableFunction(_) => {
16            return Err(SqlError::Unsupported {
17                detail: "INSERT INTO table function not supported".into(),
18            });
19        }
20    };
21    let info = catalog
22        .get_collection(&table_name)
23        .ok_or_else(|| SqlError::UnknownTable {
24            name: table_name.clone(),
25        })?;
26
27    let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
28
29    // Check for INSERT...SELECT.
30    if let Some(source) = &ins.source
31        && let ast::SetExpr::Select(_select) = &*source.body
32    {
33        let source_plan = super::select::plan_query(
34            source,
35            catalog,
36            &crate::functions::registry::FunctionRegistry::new(),
37        )?;
38        return Ok(vec![SqlPlan::InsertSelect {
39            target: table_name,
40            source: Box::new(source_plan),
41            limit: 0,
42        }]);
43    }
44
45    // VALUES clause.
46    let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
47        detail: "INSERT requires VALUES or SELECT".into(),
48    })?;
49
50    let rows_ast = match &*source.body {
51        ast::SetExpr::Values(values) => &values.rows,
52        _ => {
53            return Err(SqlError::Unsupported {
54                detail: "INSERT source must be VALUES or SELECT".into(),
55            });
56        }
57    };
58
59    // KV engine: key and value are fundamentally separate — handle directly.
60    if info.engine == EngineType::KeyValue {
61        let key_idx = columns.iter().position(|c| c == "key");
62        let ttl_idx = columns.iter().position(|c| c == "ttl");
63        let mut entries = Vec::with_capacity(rows_ast.len());
64        let mut ttl_secs: u64 = 0;
65        for row_exprs in rows_ast {
66            let key_val = match key_idx {
67                Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
68                None => SqlValue::String(String::new()),
69            };
70            // Extract TTL if present (in seconds).
71            if let Some(idx) = ttl_idx {
72                match expr_to_sql_value(&row_exprs[idx]) {
73                    Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
74                    Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
75                    _ => {}
76                }
77            }
78            let value_cols: Vec<(String, SqlValue)> = columns
79                .iter()
80                .enumerate()
81                .filter(|(i, _)| Some(*i) != key_idx && Some(*i) != ttl_idx)
82                .map(|(i, col)| {
83                    let val = expr_to_sql_value(&row_exprs[i])?;
84                    Ok((col.clone(), val))
85                })
86                .collect::<Result<Vec<_>>>()?;
87            entries.push((key_val, value_cols));
88        }
89        return Ok(vec![SqlPlan::KvInsert {
90            collection: table_name,
91            entries,
92            ttl_secs,
93        }]);
94    }
95
96    // All other engines: delegate to engine rules.
97    let rows = convert_value_rows(&columns, rows_ast)?;
98    let column_defaults: Vec<(String, String)> = info
99        .columns
100        .iter()
101        .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
102        .collect();
103    let rules = engine_rules::resolve_engine_rules(info.engine);
104    rules.plan_insert(InsertParams {
105        collection: table_name,
106        columns,
107        rows,
108        column_defaults,
109    })
110}
111
112/// Plan an UPSERT statement (pre-processed from `UPSERT INTO` to `INSERT INTO`).
113///
114/// Same parsing as INSERT but routes through `engine_rules.plan_upsert()`.
115pub fn plan_upsert(ins: &ast::Insert, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
116    let table_name = match &ins.table {
117        ast::TableObject::TableName(name) => normalize_object_name(name),
118        ast::TableObject::TableFunction(_) => {
119            return Err(SqlError::Unsupported {
120                detail: "UPSERT INTO table function not supported".into(),
121            });
122        }
123    };
124    let info = catalog
125        .get_collection(&table_name)
126        .ok_or_else(|| SqlError::UnknownTable {
127            name: table_name.clone(),
128        })?;
129
130    let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
131
132    let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
133        detail: "UPSERT requires VALUES".into(),
134    })?;
135
136    let rows_ast = match &*source.body {
137        ast::SetExpr::Values(values) => &values.rows,
138        _ => {
139            return Err(SqlError::Unsupported {
140                detail: "UPSERT source must be VALUES".into(),
141            });
142        }
143    };
144
145    // KV: upsert is just a PUT (natural overwrite).
146    if info.engine == EngineType::KeyValue {
147        let key_idx = columns.iter().position(|c| c == "key");
148        let ttl_idx = columns.iter().position(|c| c == "ttl");
149        let mut entries = Vec::with_capacity(rows_ast.len());
150        let mut ttl_secs: u64 = 0;
151        for row_exprs in rows_ast {
152            let key_val = match key_idx {
153                Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
154                None => SqlValue::String(String::new()),
155            };
156            if let Some(idx) = ttl_idx {
157                match expr_to_sql_value(&row_exprs[idx]) {
158                    Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
159                    Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
160                    _ => {}
161                }
162            }
163            let value_cols: Vec<(String, SqlValue)> = columns
164                .iter()
165                .enumerate()
166                .filter(|(i, _)| Some(*i) != key_idx && Some(*i) != ttl_idx)
167                .map(|(i, col)| {
168                    let val = expr_to_sql_value(&row_exprs[i])?;
169                    Ok((col.clone(), val))
170                })
171                .collect::<Result<Vec<_>>>()?;
172            entries.push((key_val, value_cols));
173        }
174        return Ok(vec![SqlPlan::KvInsert {
175            collection: table_name,
176            entries,
177            ttl_secs,
178        }]);
179    }
180
181    let rows = convert_value_rows(&columns, rows_ast)?;
182    let column_defaults: Vec<(String, String)> = info
183        .columns
184        .iter()
185        .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
186        .collect();
187    let rules = engine_rules::resolve_engine_rules(info.engine);
188    rules.plan_upsert(engine_rules::UpsertParams {
189        collection: table_name,
190        columns,
191        rows,
192        column_defaults,
193    })
194}
195
196/// Plan an UPDATE statement.
197pub fn plan_update(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
198    let ast::Statement::Update(update) = stmt else {
199        return Err(SqlError::Parse {
200            detail: "expected UPDATE statement".into(),
201        });
202    };
203
204    let table_name = extract_table_name_from_table_with_joins(&update.table)?;
205    let info = catalog
206        .get_collection(&table_name)
207        .ok_or_else(|| SqlError::UnknownTable {
208            name: table_name.clone(),
209        })?;
210
211    let assigns: Vec<(String, SqlExpr)> = update
212        .assignments
213        .iter()
214        .map(|a| {
215            let col = match &a.target {
216                ast::AssignmentTarget::ColumnName(name) => normalize_object_name(name),
217                ast::AssignmentTarget::Tuple(names) => names
218                    .iter()
219                    .map(normalize_object_name)
220                    .collect::<Vec<_>>()
221                    .join(","),
222            };
223            let val = convert_expr(&a.value)?;
224            Ok((col, val))
225        })
226        .collect::<Result<_>>()?;
227
228    let filters = match &update.selection {
229        Some(expr) => super::select::convert_where_to_filters(expr)?,
230        None => Vec::new(),
231    };
232
233    // Detect point updates (WHERE pk = literal).
234    let target_keys = extract_point_keys(update.selection.as_ref(), &info);
235
236    let rules = engine_rules::resolve_engine_rules(info.engine);
237    rules.plan_update(UpdateParams {
238        collection: table_name,
239        assignments: assigns,
240        filters,
241        target_keys,
242        returning: update.returning.is_some(),
243    })
244}
245
246/// Plan a DELETE statement.
247pub fn plan_delete(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
248    let ast::Statement::Delete(delete) = stmt else {
249        return Err(SqlError::Parse {
250            detail: "expected DELETE statement".into(),
251        });
252    };
253
254    let from_tables = match &delete.from {
255        ast::FromTable::WithFromKeyword(tables) | ast::FromTable::WithoutKeyword(tables) => tables,
256    };
257    let table_name =
258        extract_table_name_from_table_with_joins(from_tables.first().ok_or_else(|| {
259            SqlError::Parse {
260                detail: "DELETE requires a FROM table".into(),
261            }
262        })?)?;
263    let info = catalog
264        .get_collection(&table_name)
265        .ok_or_else(|| SqlError::UnknownTable {
266            name: table_name.clone(),
267        })?;
268
269    let filters = match &delete.selection {
270        Some(expr) => super::select::convert_where_to_filters(expr)?,
271        None => Vec::new(),
272    };
273
274    let target_keys = extract_point_keys(delete.selection.as_ref(), &info);
275
276    let rules = engine_rules::resolve_engine_rules(info.engine);
277    rules.plan_delete(DeleteParams {
278        collection: table_name,
279        filters,
280        target_keys,
281    })
282}
283
284/// Plan a TRUNCATE statement.
285pub fn plan_truncate_stmt(stmt: &ast::Statement) -> Result<Vec<SqlPlan>> {
286    let ast::Statement::Truncate(truncate) = stmt else {
287        return Err(SqlError::Parse {
288            detail: "expected TRUNCATE statement".into(),
289        });
290    };
291    let restart_identity = matches!(
292        truncate.identity,
293        Some(sqlparser::ast::TruncateIdentityOption::Restart)
294    );
295    truncate
296        .table_names
297        .iter()
298        .map(|t| {
299            Ok(SqlPlan::Truncate {
300                collection: normalize_object_name(&t.name),
301                restart_identity,
302            })
303        })
304        .collect()
305}
306
307// ── Helpers ──
308
309fn convert_value_rows(
310    columns: &[String],
311    rows: &[Vec<ast::Expr>],
312) -> Result<Vec<Vec<(String, SqlValue)>>> {
313    rows.iter()
314        .map(|row| {
315            row.iter()
316                .enumerate()
317                .map(|(i, expr)| {
318                    let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
319                    let val = expr_to_sql_value(expr)?;
320                    Ok((col, val))
321                })
322                .collect::<Result<Vec<_>>>()
323        })
324        .collect()
325}
326
327fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
328    match expr {
329        ast::Expr::Value(v) => convert_value(&v.value),
330        ast::Expr::UnaryOp {
331            op: ast::UnaryOperator::Minus,
332            expr: inner,
333        } => {
334            let val = expr_to_sql_value(inner)?;
335            match val {
336                SqlValue::Int(n) => Ok(SqlValue::Int(-n)),
337                SqlValue::Float(f) => Ok(SqlValue::Float(-f)),
338                _ => Err(SqlError::TypeMismatch {
339                    detail: "cannot negate non-numeric value".into(),
340                }),
341            }
342        }
343        ast::Expr::Array(ast::Array { elem, .. }) => {
344            let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
345            Ok(SqlValue::Array(vals))
346        }
347        ast::Expr::Function(func) => {
348            let func_name = func
349                .name
350                .0
351                .iter()
352                .map(|p| match p {
353                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
354                    _ => String::new(),
355                })
356                .collect::<Vec<_>>()
357                .join(".")
358                .to_lowercase();
359            match func_name.as_str() {
360                "st_point" => {
361                    // ST_Point(lon, lat) → GeoJSON string at plan time.
362                    let args = super::select::extract_func_args(func)?;
363                    if args.len() >= 2 {
364                        let lon = super::select::extract_float(&args[0])?;
365                        let lat = super::select::extract_float(&args[1])?;
366                        Ok(SqlValue::String(format!(
367                            r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
368                        )))
369                    } else {
370                        Ok(SqlValue::String(format!("{expr}")))
371                    }
372                }
373                "st_geomfromgeojson" => {
374                    let args = super::select::extract_func_args(func)?;
375                    if !args.is_empty() {
376                        let s = super::select::extract_string_literal(&args[0])?;
377                        Ok(SqlValue::String(s))
378                    } else {
379                        Ok(SqlValue::String(format!("{expr}")))
380                    }
381                }
382                _ => {
383                    // Other functions like now() — store as string for runtime eval.
384                    Ok(SqlValue::String(format!("{expr}")))
385                }
386            }
387        }
388        _ => Err(SqlError::Unsupported {
389            detail: format!("value expression: {expr}"),
390        }),
391    }
392}
393
394fn extract_table_name_from_table_with_joins(table: &ast::TableWithJoins) -> Result<String> {
395    match &table.relation {
396        ast::TableFactor::Table { name, .. } => Ok(normalize_object_name(name)),
397        _ => Err(SqlError::Unsupported {
398            detail: "non-table target in DML".into(),
399        }),
400    }
401}
402
403/// Extract point-operation keys from WHERE clause (WHERE pk = literal OR pk IN (...)).
404fn extract_point_keys(selection: Option<&ast::Expr>, info: &CollectionInfo) -> Vec<SqlValue> {
405    let pk = match &info.primary_key {
406        Some(pk) => pk.clone(),
407        None => return Vec::new(),
408    };
409
410    let expr = match selection {
411        Some(e) => e,
412        None => return Vec::new(),
413    };
414
415    let mut keys = Vec::new();
416    collect_pk_equalities(expr, &pk, &mut keys);
417    keys
418}
419
420fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
421    match expr {
422        ast::Expr::BinaryOp {
423            left,
424            op: ast::BinaryOperator::Eq,
425            right,
426        } => {
427            if is_column(left, pk)
428                && let Ok(v) = expr_to_sql_value(right)
429            {
430                keys.push(v);
431            } else if is_column(right, pk)
432                && let Ok(v) = expr_to_sql_value(left)
433            {
434                keys.push(v);
435            }
436        }
437        ast::Expr::BinaryOp {
438            left,
439            op: ast::BinaryOperator::Or,
440            right,
441        } => {
442            collect_pk_equalities(left, pk, keys);
443            collect_pk_equalities(right, pk, keys);
444        }
445        ast::Expr::InList {
446            expr: inner,
447            list,
448            negated: false,
449        } => {
450            if is_column(inner, pk) {
451                for item in list {
452                    if let Ok(v) = expr_to_sql_value(item) {
453                        keys.push(v);
454                    }
455                }
456            }
457        }
458        _ => {}
459    }
460}
461
462fn is_column(expr: &ast::Expr, name: &str) -> bool {
463    match expr {
464        ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
465        ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
466            normalize_ident(&parts[1]) == name
467        }
468        _ => false,
469    }
470}