Skip to main content

nodedb_sql/planner/
dml_update_delete.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! UPDATE, DELETE, and TRUNCATE planning — extracted from `dml.rs`.
4
5use nodedb_types::DatabaseId;
6use sqlparser::ast;
7
8use super::super::ast_helpers::{
9    flatten_and_expr, qualified_ident_pair, strip_and_convert_filters,
10};
11use super::super::dml_helpers::{extract_point_keys, extract_table_name_from_table_with_joins};
12use crate::engine_rules::{self, DeleteParams, UpdateFromParams, UpdateParams};
13use crate::error::{Result, SqlError};
14use crate::parser::normalize::{
15    SCHEMA_QUALIFIED_MSG, normalize_ident, normalize_object_name_checked,
16};
17use crate::resolver::expr::convert_expr;
18use crate::types::*;
19
20/// Plan an UPDATE statement.
21pub fn plan_update(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
22    let ast::Statement::Update(update) = stmt else {
23        return Err(SqlError::Parse {
24            detail: "expected UPDATE statement".into(),
25        });
26    };
27
28    // Delegate to the UPDATE...FROM path when a FROM clause is present.
29    if update.from.is_some() {
30        return plan_update_from(update, catalog);
31    }
32
33    let table_name = extract_table_name_from_table_with_joins(&update.table)?;
34    let info = catalog
35        .get_collection(DatabaseId::DEFAULT, &table_name)?
36        .ok_or_else(|| SqlError::UnknownTable {
37            name: table_name.clone(),
38        })?;
39
40    let assigns = convert_assignments(&update.assignments)?;
41
42    let filters = match &update.selection {
43        Some(expr) => super::super::select::convert_where_to_filters(expr)?,
44        None => Vec::new(),
45    };
46
47    let target_keys = extract_point_keys(update.selection.as_ref(), &info);
48
49    let rules = engine_rules::resolve_engine_rules(info.engine);
50    rules.plan_update(UpdateParams {
51        collection: table_name,
52        assignments: assigns,
53        filters,
54        target_keys,
55        returning: update.returning.is_some(),
56    })
57}
58
59/// Plan `UPDATE target SET ... FROM src WHERE target.col = src.col ...`.
60fn plan_update_from(update: &ast::Update, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
61    let target_name = extract_table_name_from_table_with_joins(&update.table)?;
62
63    // Extract alias for the target table if present.
64    let target_alias: Option<String> = match &update.table.relation {
65        ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
66        _ => None,
67    };
68    let target_ref = target_alias.as_deref().unwrap_or(target_name.as_str());
69
70    let from_kind = update.from.as_ref().expect("caller ensures from.is_some()");
71    let from_tables: &Vec<ast::TableWithJoins> = match from_kind {
72        ast::UpdateTableFromKind::AfterSet(tables)
73        | ast::UpdateTableFromKind::BeforeSet(tables) => tables,
74    };
75
76    // Reject multi-table FROM.
77    if from_tables.len() > 1 {
78        return Err(SqlError::Unsupported {
79            detail: format!(
80                "UPDATE ... FROM with {} source tables is not supported; \
81                 only a single FROM table is accepted",
82                from_tables.len()
83            ),
84        });
85    }
86    let from_table = from_tables.first().ok_or_else(|| SqlError::Parse {
87        detail: "UPDATE ... FROM requires at least one source table".into(),
88    })?;
89
90    // Reject subquery in FROM.
91    let source_name = match &from_table.relation {
92        ast::TableFactor::Table { name, .. } => normalize_object_name_checked(name)?,
93        ast::TableFactor::Derived { .. } => {
94            return Err(SqlError::Unsupported {
95                detail: "UPDATE ... FROM (subquery) is not supported; \
96                     use a CTE: WITH cte AS (SELECT ...) UPDATE t SET ... FROM cte WHERE ..."
97                    .into(),
98            });
99        }
100        _ => {
101            return Err(SqlError::Unsupported {
102                detail: "non-table relation in UPDATE ... FROM is not supported".into(),
103            });
104        }
105    };
106    // Reject joins in the FROM source.
107    if !from_table.joins.is_empty() {
108        return Err(SqlError::Unsupported {
109            detail: "JOIN in UPDATE ... FROM source is not supported; \
110                     use a CTE to pre-join the source"
111                .into(),
112        });
113    }
114
115    let source_alias: Option<String> = match &from_table.relation {
116        ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
117        _ => None,
118    };
119    let source_ref = source_alias.as_deref().unwrap_or(source_name.as_str());
120
121    // Validate that the target and source collections exist.
122    let target_info = catalog
123        .get_collection(DatabaseId::DEFAULT, &target_name)?
124        .ok_or_else(|| SqlError::UnknownTable {
125            name: target_name.clone(),
126        })?;
127    let source_info = catalog
128        .get_collection(DatabaseId::DEFAULT, &source_name)?
129        .ok_or_else(|| SqlError::UnknownTable {
130            name: source_name.clone(),
131        })?;
132
133    let assigns = convert_assignments(&update.assignments)?;
134
135    // Split the WHERE clause into:
136    //   - one equi-join predicate linking target and source (required)
137    //   - remaining predicates that apply to target only
138    let (target_join_col, source_join_col, target_filters) = match &update.selection {
139        None => {
140            return Err(SqlError::Parse {
141                detail: "UPDATE ... FROM requires a WHERE clause with an equi-join predicate \
142                         linking the target and source tables"
143                    .into(),
144            });
145        }
146        Some(expr) => extract_join_predicate(expr, target_ref, source_ref)?,
147    };
148
149    // Plan the source as a simple scan (no filters — all filtering is via join key).
150    let source_rules = engine_rules::resolve_engine_rules(source_info.engine);
151    let source_plan = source_rules.plan_scan(crate::engine_rules::ScanParams {
152        collection: source_name,
153        alias: source_alias,
154        filters: Vec::new(),
155        projection: Vec::new(),
156        sort_keys: Vec::new(),
157        limit: None,
158        offset: 0,
159        distinct: false,
160        window_functions: Vec::new(),
161        indexes: Vec::new(),
162        temporal: crate::temporal::TemporalScope::default(),
163        bitemporal: source_info.bitemporal,
164    })?;
165
166    let rules = engine_rules::resolve_engine_rules(target_info.engine);
167    rules.plan_update_from(UpdateFromParams {
168        collection: target_name,
169        source: Box::new(source_plan),
170        target_join_col,
171        source_join_col,
172        assignments: assigns,
173        target_filters,
174        returning: update.returning.is_some(),
175    })
176}
177
178/// Extract a single equi-join predicate of the form `target_table.col = source_table.col`
179/// (or the reverse) from a WHERE expression, returning `(target_col, source_col, remaining_filters)`.
180///
181/// Also accepts `col = other_table.col` where `col` without a table qualifier is
182/// assumed to belong to the target (PostgreSQL behavior).
183fn extract_join_predicate(
184    expr: &ast::Expr,
185    target_ref: &str,
186    source_ref: &str,
187) -> Result<(String, String, Vec<Filter>)> {
188    // Flatten the top-level AND chain.
189    let mut conjuncts: Vec<ast::Expr> = Vec::new();
190    flatten_and_expr(expr, &mut conjuncts);
191
192    // Find the first conjunct that is an equi-join between target and source.
193    let mut join_idx: Option<usize> = None;
194    let mut target_col = String::new();
195    let mut source_col = String::new();
196
197    for (i, conjunct) in conjuncts.iter().enumerate() {
198        if let Some((tc, sc)) = try_equijoin_pair(conjunct, target_ref, source_ref) {
199            target_col = tc;
200            source_col = sc;
201            join_idx = Some(i);
202            break;
203        }
204    }
205
206    let join_idx = join_idx.ok_or_else(|| SqlError::Parse {
207        detail: format!(
208            "UPDATE ... FROM requires a WHERE clause equi-join predicate of the form \
209             `{target_ref}.col = {source_ref}.col`; none found"
210        ),
211    })?;
212
213    conjuncts.remove(join_idx);
214
215    // Remaining conjuncts become target_filters. Strip table qualifier so
216    // `uf_target.score` becomes `score` — documents store bare field names.
217    let target_filters = strip_and_convert_filters(conjuncts, target_ref)?;
218
219    Ok((target_col, source_col, target_filters))
220}
221
222/// Try to extract `(target_col, source_col)` from an equality expression
223/// where one side is `target_ref.col` and the other is `source_ref.col`.
224/// Also handles unqualified names by assuming they belong to the target.
225fn try_equijoin_pair(
226    expr: &ast::Expr,
227    target_ref: &str,
228    source_ref: &str,
229) -> Option<(String, String)> {
230    let ast::Expr::BinaryOp {
231        left,
232        op: ast::BinaryOperator::Eq,
233        right,
234    } = expr
235    else {
236        return None;
237    };
238
239    let lhs = qualified_ident_pair(left);
240    let rhs = qualified_ident_pair(right);
241
242    match (lhs, rhs) {
243        (Some((lt, lc)), Some((rt, rc))) => {
244            if lt == target_ref && rt == source_ref {
245                Some((lc, rc))
246            } else if lt == source_ref && rt == target_ref {
247                Some((rc, lc))
248            } else {
249                None
250            }
251        }
252        // One side is unqualified — treat it as belonging to target.
253        (Some((t, c)), None) if t == source_ref => {
254            if let ast::Expr::Identifier(ident) = right.as_ref() {
255                Some((normalize_ident(ident), c))
256            } else {
257                None
258            }
259        }
260        (None, Some((t, c))) if t == source_ref => {
261            if let ast::Expr::Identifier(ident) = left.as_ref() {
262                Some((normalize_ident(ident), c))
263            } else {
264                None
265            }
266        }
267        _ => None,
268    }
269}
270
271/// Convert `update.assignments` into `Vec<(col, SqlExpr)>`.
272fn convert_assignments(assignments: &[ast::Assignment]) -> Result<Vec<(String, SqlExpr)>> {
273    assignments
274        .iter()
275        .map(|a| {
276            let col = match &a.target {
277                ast::AssignmentTarget::ColumnName(name) => {
278                    if name.0.len() > 1 {
279                        return Err(SqlError::Unsupported {
280                            detail: format!(
281                                "qualified column name in SET target: {SCHEMA_QUALIFIED_MSG}"
282                            ),
283                        });
284                    }
285                    normalize_object_name_checked(name)?
286                }
287                ast::AssignmentTarget::Tuple(names) => names
288                    .iter()
289                    .map(normalize_object_name_checked)
290                    .collect::<Result<Vec<_>>>()?
291                    .join(","),
292            };
293            let val = convert_expr(&a.value)?;
294            Ok((col, val))
295        })
296        .collect()
297}
298
299/// Plan a DELETE statement.
300pub fn plan_delete(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
301    let ast::Statement::Delete(delete) = stmt else {
302        return Err(SqlError::Parse {
303            detail: "expected DELETE statement".into(),
304        });
305    };
306
307    let from_tables = match &delete.from {
308        ast::FromTable::WithFromKeyword(tables) | ast::FromTable::WithoutKeyword(tables) => tables,
309    };
310    let table_name =
311        extract_table_name_from_table_with_joins(from_tables.first().ok_or_else(|| {
312            SqlError::Parse {
313                detail: "DELETE requires a FROM table".into(),
314            }
315        })?)?;
316    let info = catalog
317        .get_collection(DatabaseId::DEFAULT, &table_name)?
318        .ok_or_else(|| SqlError::UnknownTable {
319            name: table_name.clone(),
320        })?;
321
322    let filters = match &delete.selection {
323        Some(expr) => super::super::select::convert_where_to_filters(expr)?,
324        None => Vec::new(),
325    };
326
327    let target_keys = extract_point_keys(delete.selection.as_ref(), &info);
328
329    let rules = engine_rules::resolve_engine_rules(info.engine);
330    rules.plan_delete(DeleteParams {
331        collection: table_name,
332        filters,
333        target_keys,
334    })
335}
336
337/// Plan a TRUNCATE statement.
338pub fn plan_truncate_stmt(stmt: &ast::Statement) -> Result<Vec<SqlPlan>> {
339    let ast::Statement::Truncate(truncate) = stmt else {
340        return Err(SqlError::Parse {
341            detail: "expected TRUNCATE statement".into(),
342        });
343    };
344    let restart_identity = matches!(
345        truncate.identity,
346        Some(sqlparser::ast::TruncateIdentityOption::Restart)
347    );
348    truncate
349        .table_names
350        .iter()
351        .map(|t| {
352            Ok(SqlPlan::Truncate {
353                collection: normalize_object_name_checked(&t.name)?,
354                restart_identity,
355            })
356        })
357        .collect()
358}