Skip to main content

nodedb_sql/planner/
dml.rs

1//! INSERT, UPDATE, DELETE planning.
2
3use sqlparser::ast::{self};
4
5use crate::engine_rules::{self, DeleteParams, InsertParams, UpdateParams};
6use crate::error::{Result, SqlError};
7use crate::parser::normalize::{normalize_ident, normalize_object_name};
8use crate::resolver::expr::{convert_expr, convert_value};
9use crate::types::*;
10
11/// Extract `ON CONFLICT (...) DO UPDATE SET` assignments from an AST
12/// insert, or `None` if this is a plain INSERT.
13fn extract_on_conflict_updates(ins: &ast::Insert) -> Result<Option<Vec<(String, SqlExpr)>>> {
14    let Some(on) = ins.on.as_ref() else {
15        return Ok(None);
16    };
17    let ast::OnInsert::OnConflict(oc) = on else {
18        return Ok(None);
19    };
20    let ast::OnConflictAction::DoUpdate(do_update) = &oc.action else {
21        // DO NOTHING maps to "ignore conflict" — currently unsupported.
22        return Err(SqlError::Unsupported {
23            detail: "ON CONFLICT DO NOTHING is not yet supported".into(),
24        });
25    };
26    let mut pairs = Vec::with_capacity(do_update.assignments.len());
27    for a in &do_update.assignments {
28        let name = match &a.target {
29            ast::AssignmentTarget::ColumnName(obj) => normalize_object_name(obj),
30            _ => {
31                return Err(SqlError::Unsupported {
32                    detail: "ON CONFLICT DO UPDATE SET target must be a column name".into(),
33                });
34            }
35        };
36        let expr = convert_expr(&a.value)?;
37        pairs.push((name, expr));
38    }
39    Ok(Some(pairs))
40}
41
42/// Plan an INSERT statement.
43pub fn plan_insert(ins: &ast::Insert, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
44    // `INSERT ... ON CONFLICT DO UPDATE SET` reroutes to the upsert path
45    // with the assignments carried through. Detected before any other
46    // work so both planning paths share the `ast::Insert` decode below.
47    if let Some(on_conflict_updates) = extract_on_conflict_updates(ins)? {
48        return plan_upsert_with_on_conflict(ins, catalog, on_conflict_updates);
49    }
50    let table_name = match &ins.table {
51        ast::TableObject::TableName(name) => normalize_object_name(name),
52        ast::TableObject::TableFunction(_) => {
53            return Err(SqlError::Unsupported {
54                detail: "INSERT INTO table function not supported".into(),
55            });
56        }
57    };
58    let info = catalog
59        .get_collection(&table_name)?
60        .ok_or_else(|| SqlError::UnknownTable {
61            name: table_name.clone(),
62        })?;
63
64    let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
65
66    // Check for INSERT...SELECT.
67    if let Some(source) = &ins.source
68        && let ast::SetExpr::Select(_select) = &*source.body
69    {
70        let source_plan = super::select::plan_query(
71            source,
72            catalog,
73            &crate::functions::registry::FunctionRegistry::new(),
74        )?;
75        return Ok(vec![SqlPlan::InsertSelect {
76            target: table_name,
77            source: Box::new(source_plan),
78            limit: 0,
79        }]);
80    }
81
82    // VALUES clause.
83    let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
84        detail: "INSERT requires VALUES or SELECT".into(),
85    })?;
86
87    let rows_ast = match &*source.body {
88        ast::SetExpr::Values(values) => &values.rows,
89        _ => {
90            return Err(SqlError::Unsupported {
91                detail: "INSERT source must be VALUES or SELECT".into(),
92            });
93        }
94    };
95
96    // KV engine: key and value are fundamentally separate — handle directly.
97    if info.engine == EngineType::KeyValue {
98        let key_idx = columns.iter().position(|c| c == "key");
99        let ttl_idx = columns.iter().position(|c| c == "ttl");
100        let mut entries = Vec::with_capacity(rows_ast.len());
101        let mut ttl_secs: u64 = 0;
102        for row_exprs in rows_ast {
103            let key_val = match key_idx {
104                Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
105                None => SqlValue::String(String::new()),
106            };
107            // Extract TTL if present (in seconds).
108            if let Some(idx) = ttl_idx {
109                match expr_to_sql_value(&row_exprs[idx]) {
110                    Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
111                    Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
112                    _ => {}
113                }
114            }
115            let value_cols: Vec<(String, SqlValue)> = columns
116                .iter()
117                .enumerate()
118                .filter(|(i, _)| Some(*i) != key_idx && Some(*i) != ttl_idx)
119                .map(|(i, col)| {
120                    let val = expr_to_sql_value(&row_exprs[i])?;
121                    Ok((col.clone(), val))
122                })
123                .collect::<Result<Vec<_>>>()?;
124            entries.push((key_val, value_cols));
125        }
126        return Ok(vec![SqlPlan::KvInsert {
127            collection: table_name,
128            entries,
129            ttl_secs,
130        }]);
131    }
132
133    // All other engines: delegate to engine rules.
134    let rows = convert_value_rows(&columns, rows_ast)?;
135    let column_defaults: Vec<(String, String)> = info
136        .columns
137        .iter()
138        .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
139        .collect();
140    let rules = engine_rules::resolve_engine_rules(info.engine);
141    rules.plan_insert(InsertParams {
142        collection: table_name,
143        columns,
144        rows,
145        column_defaults,
146    })
147}
148
149/// Plan an UPSERT statement (pre-processed from `UPSERT INTO` to `INSERT INTO`).
150///
151/// Same parsing as INSERT but routes through `engine_rules.plan_upsert()`.
152pub fn plan_upsert(ins: &ast::Insert, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
153    let table_name = match &ins.table {
154        ast::TableObject::TableName(name) => normalize_object_name(name),
155        ast::TableObject::TableFunction(_) => {
156            return Err(SqlError::Unsupported {
157                detail: "UPSERT INTO table function not supported".into(),
158            });
159        }
160    };
161    let info = catalog
162        .get_collection(&table_name)?
163        .ok_or_else(|| SqlError::UnknownTable {
164            name: table_name.clone(),
165        })?;
166
167    let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
168
169    let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
170        detail: "UPSERT requires VALUES".into(),
171    })?;
172
173    let rows_ast = match &*source.body {
174        ast::SetExpr::Values(values) => &values.rows,
175        _ => {
176            return Err(SqlError::Unsupported {
177                detail: "UPSERT source must be VALUES".into(),
178            });
179        }
180    };
181
182    // KV: upsert is just a PUT (natural overwrite).
183    if info.engine == EngineType::KeyValue {
184        let key_idx = columns.iter().position(|c| c == "key");
185        let ttl_idx = columns.iter().position(|c| c == "ttl");
186        let mut entries = Vec::with_capacity(rows_ast.len());
187        let mut ttl_secs: u64 = 0;
188        for row_exprs in rows_ast {
189            let key_val = match key_idx {
190                Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
191                None => SqlValue::String(String::new()),
192            };
193            if let Some(idx) = ttl_idx {
194                match expr_to_sql_value(&row_exprs[idx]) {
195                    Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
196                    Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
197                    _ => {}
198                }
199            }
200            let value_cols: Vec<(String, SqlValue)> = columns
201                .iter()
202                .enumerate()
203                .filter(|(i, _)| Some(*i) != key_idx && Some(*i) != ttl_idx)
204                .map(|(i, col)| {
205                    let val = expr_to_sql_value(&row_exprs[i])?;
206                    Ok((col.clone(), val))
207                })
208                .collect::<Result<Vec<_>>>()?;
209            entries.push((key_val, value_cols));
210        }
211        return Ok(vec![SqlPlan::KvInsert {
212            collection: table_name,
213            entries,
214            ttl_secs,
215        }]);
216    }
217
218    let rows = convert_value_rows(&columns, rows_ast)?;
219    let column_defaults: Vec<(String, String)> = info
220        .columns
221        .iter()
222        .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
223        .collect();
224    let rules = engine_rules::resolve_engine_rules(info.engine);
225    rules.plan_upsert(engine_rules::UpsertParams {
226        collection: table_name,
227        columns,
228        rows,
229        column_defaults,
230        on_conflict_updates: Vec::new(),
231    })
232}
233
234/// Plan an `INSERT ... ON CONFLICT DO UPDATE SET` statement. Identical to
235/// `plan_upsert` except the assignments are carried onto the upsert plan
236/// so the Data Plane can evaluate them against the existing row instead
237/// of merging the would-be-inserted values.
238fn plan_upsert_with_on_conflict(
239    ins: &ast::Insert,
240    catalog: &dyn SqlCatalog,
241    on_conflict_updates: Vec<(String, SqlExpr)>,
242) -> Result<Vec<SqlPlan>> {
243    let table_name = match &ins.table {
244        ast::TableObject::TableName(name) => normalize_object_name(name),
245        ast::TableObject::TableFunction(_) => {
246            return Err(SqlError::Unsupported {
247                detail: "INSERT ... ON CONFLICT on a table function is not supported".into(),
248            });
249        }
250    };
251    let info = catalog
252        .get_collection(&table_name)?
253        .ok_or_else(|| SqlError::UnknownTable {
254            name: table_name.clone(),
255        })?;
256
257    let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
258
259    let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
260        detail: "INSERT ... ON CONFLICT requires VALUES".into(),
261    })?;
262    let rows_ast = match &*source.body {
263        ast::SetExpr::Values(values) => &values.rows,
264        _ => {
265            return Err(SqlError::Unsupported {
266                detail: "INSERT ... ON CONFLICT source must be VALUES".into(),
267            });
268        }
269    };
270
271    let rows = convert_value_rows(&columns, rows_ast)?;
272    let column_defaults: Vec<(String, String)> = info
273        .columns
274        .iter()
275        .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
276        .collect();
277    let rules = engine_rules::resolve_engine_rules(info.engine);
278    rules.plan_upsert(engine_rules::UpsertParams {
279        collection: table_name,
280        columns,
281        rows,
282        column_defaults,
283        on_conflict_updates,
284    })
285}
286
287/// Plan an UPDATE statement.
288pub fn plan_update(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
289    let ast::Statement::Update(update) = stmt else {
290        return Err(SqlError::Parse {
291            detail: "expected UPDATE statement".into(),
292        });
293    };
294
295    let table_name = extract_table_name_from_table_with_joins(&update.table)?;
296    let info = catalog
297        .get_collection(&table_name)?
298        .ok_or_else(|| SqlError::UnknownTable {
299            name: table_name.clone(),
300        })?;
301
302    let assigns: Vec<(String, SqlExpr)> = update
303        .assignments
304        .iter()
305        .map(|a| {
306            let col = match &a.target {
307                ast::AssignmentTarget::ColumnName(name) => normalize_object_name(name),
308                ast::AssignmentTarget::Tuple(names) => names
309                    .iter()
310                    .map(normalize_object_name)
311                    .collect::<Vec<_>>()
312                    .join(","),
313            };
314            let val = convert_expr(&a.value)?;
315            Ok((col, val))
316        })
317        .collect::<Result<_>>()?;
318
319    let filters = match &update.selection {
320        Some(expr) => super::select::convert_where_to_filters(expr)?,
321        None => Vec::new(),
322    };
323
324    // Detect point updates (WHERE pk = literal).
325    let target_keys = extract_point_keys(update.selection.as_ref(), &info);
326
327    let rules = engine_rules::resolve_engine_rules(info.engine);
328    rules.plan_update(UpdateParams {
329        collection: table_name,
330        assignments: assigns,
331        filters,
332        target_keys,
333        returning: update.returning.is_some(),
334    })
335}
336
337/// Plan a DELETE statement.
338pub fn plan_delete(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
339    let ast::Statement::Delete(delete) = stmt else {
340        return Err(SqlError::Parse {
341            detail: "expected DELETE statement".into(),
342        });
343    };
344
345    let from_tables = match &delete.from {
346        ast::FromTable::WithFromKeyword(tables) | ast::FromTable::WithoutKeyword(tables) => tables,
347    };
348    let table_name =
349        extract_table_name_from_table_with_joins(from_tables.first().ok_or_else(|| {
350            SqlError::Parse {
351                detail: "DELETE requires a FROM table".into(),
352            }
353        })?)?;
354    let info = catalog
355        .get_collection(&table_name)?
356        .ok_or_else(|| SqlError::UnknownTable {
357            name: table_name.clone(),
358        })?;
359
360    let filters = match &delete.selection {
361        Some(expr) => super::select::convert_where_to_filters(expr)?,
362        None => Vec::new(),
363    };
364
365    let target_keys = extract_point_keys(delete.selection.as_ref(), &info);
366
367    let rules = engine_rules::resolve_engine_rules(info.engine);
368    rules.plan_delete(DeleteParams {
369        collection: table_name,
370        filters,
371        target_keys,
372    })
373}
374
375/// Plan a TRUNCATE statement.
376pub fn plan_truncate_stmt(stmt: &ast::Statement) -> Result<Vec<SqlPlan>> {
377    let ast::Statement::Truncate(truncate) = stmt else {
378        return Err(SqlError::Parse {
379            detail: "expected TRUNCATE statement".into(),
380        });
381    };
382    let restart_identity = matches!(
383        truncate.identity,
384        Some(sqlparser::ast::TruncateIdentityOption::Restart)
385    );
386    truncate
387        .table_names
388        .iter()
389        .map(|t| {
390            Ok(SqlPlan::Truncate {
391                collection: normalize_object_name(&t.name),
392                restart_identity,
393            })
394        })
395        .collect()
396}
397
398// ── Helpers ──
399
400fn convert_value_rows(
401    columns: &[String],
402    rows: &[Vec<ast::Expr>],
403) -> Result<Vec<Vec<(String, SqlValue)>>> {
404    rows.iter()
405        .map(|row| {
406            row.iter()
407                .enumerate()
408                .map(|(i, expr)| {
409                    let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
410                    let val = expr_to_sql_value(expr)?;
411                    Ok((col, val))
412                })
413                .collect::<Result<Vec<_>>>()
414        })
415        .collect()
416}
417
418fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
419    match expr {
420        ast::Expr::Value(v) => convert_value(&v.value),
421        ast::Expr::UnaryOp {
422            op: ast::UnaryOperator::Minus,
423            expr: inner,
424        } => {
425            let val = expr_to_sql_value(inner)?;
426            match val {
427                SqlValue::Int(n) => Ok(SqlValue::Int(-n)),
428                SqlValue::Float(f) => Ok(SqlValue::Float(-f)),
429                _ => Err(SqlError::TypeMismatch {
430                    detail: "cannot negate non-numeric value".into(),
431                }),
432            }
433        }
434        ast::Expr::Array(ast::Array { elem, .. }) => {
435            let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
436            Ok(SqlValue::Array(vals))
437        }
438        ast::Expr::Function(func) => {
439            let func_name = func
440                .name
441                .0
442                .iter()
443                .map(|p| match p {
444                    ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
445                    _ => String::new(),
446                })
447                .collect::<Vec<_>>()
448                .join(".")
449                .to_lowercase();
450            match func_name.as_str() {
451                "st_point" => {
452                    // ST_Point(lon, lat) → GeoJSON string at plan time.
453                    let args = super::select::extract_func_args(func)?;
454                    if args.len() >= 2 {
455                        let lon = super::select::extract_float(&args[0])?;
456                        let lat = super::select::extract_float(&args[1])?;
457                        Ok(SqlValue::String(format!(
458                            r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
459                        )))
460                    } else {
461                        Ok(SqlValue::String(format!("{expr}")))
462                    }
463                }
464                "st_geomfromgeojson" => {
465                    let args = super::select::extract_func_args(func)?;
466                    if !args.is_empty() {
467                        let s = super::select::extract_string_literal(&args[0])?;
468                        Ok(SqlValue::String(s))
469                    } else {
470                        Ok(SqlValue::String(format!("{expr}")))
471                    }
472                }
473                _ => {
474                    // Try folding via the shared scalar evaluator. Handles
475                    // `now()`, `current_timestamp`, `date_add(now(),'1h')`,
476                    // etc. — Postgres semantics: one snapshot per statement.
477                    // Unknown or non-foldable functions fall back to the
478                    // legacy string passthrough so existing behavior for
479                    // other callers is preserved.
480                    if let Ok(sql_expr) = crate::resolver::expr::convert_expr(expr)
481                        && let Some(v) = super::const_fold::fold_constant_default(&sql_expr)
482                    {
483                        Ok(v)
484                    } else {
485                        Ok(SqlValue::String(format!("{expr}")))
486                    }
487                }
488            }
489        }
490        _ => Err(SqlError::Unsupported {
491            detail: format!("value expression: {expr}"),
492        }),
493    }
494}
495
496fn extract_table_name_from_table_with_joins(table: &ast::TableWithJoins) -> Result<String> {
497    match &table.relation {
498        ast::TableFactor::Table { name, .. } => Ok(normalize_object_name(name)),
499        _ => Err(SqlError::Unsupported {
500            detail: "non-table target in DML".into(),
501        }),
502    }
503}
504
505/// Extract point-operation keys from WHERE clause (WHERE pk = literal OR pk IN (...)).
506fn extract_point_keys(selection: Option<&ast::Expr>, info: &CollectionInfo) -> Vec<SqlValue> {
507    let pk = match &info.primary_key {
508        Some(pk) => pk.clone(),
509        None => return Vec::new(),
510    };
511
512    let expr = match selection {
513        Some(e) => e,
514        None => return Vec::new(),
515    };
516
517    let mut keys = Vec::new();
518    collect_pk_equalities(expr, &pk, &mut keys);
519    keys
520}
521
522fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
523    match expr {
524        ast::Expr::BinaryOp {
525            left,
526            op: ast::BinaryOperator::Eq,
527            right,
528        } => {
529            if is_column(left, pk)
530                && let Ok(v) = expr_to_sql_value(right)
531            {
532                keys.push(v);
533            } else if is_column(right, pk)
534                && let Ok(v) = expr_to_sql_value(left)
535            {
536                keys.push(v);
537            }
538        }
539        ast::Expr::BinaryOp {
540            left,
541            op: ast::BinaryOperator::Or,
542            right,
543        } => {
544            collect_pk_equalities(left, pk, keys);
545            collect_pk_equalities(right, pk, keys);
546        }
547        ast::Expr::InList {
548            expr: inner,
549            list,
550            negated: false,
551        } if is_column(inner, pk) => {
552            for item in list {
553                if let Ok(v) = expr_to_sql_value(item) {
554                    keys.push(v);
555                }
556            }
557        }
558        _ => {}
559    }
560}
561
562fn is_column(expr: &ast::Expr, name: &str) -> bool {
563    match expr {
564        ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
565        ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
566            normalize_ident(&parts[1]) == name
567        }
568        _ => false,
569    }
570}