1use nodedb_types::DatabaseId;
6use sqlparser::ast;
7
8use super::super::ast_helpers::{
9 flatten_and_expr, qualified_ident_pair, strip_and_convert_filters,
10};
11use super::super::dml_helpers::{extract_point_keys, extract_table_name_from_table_with_joins};
12use crate::engine_rules::{self, DeleteParams, UpdateFromParams, UpdateParams};
13use crate::error::{Result, SqlError};
14use crate::parser::normalize::{
15 SCHEMA_QUALIFIED_MSG, normalize_ident, normalize_object_name_checked,
16};
17use crate::resolver::expr::convert_expr;
18use crate::types::*;
19
20pub fn plan_update(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
22 let ast::Statement::Update(update) = stmt else {
23 return Err(SqlError::Parse {
24 detail: "expected UPDATE statement".into(),
25 });
26 };
27
28 if update.from.is_some() {
30 return plan_update_from(update, catalog);
31 }
32
33 let table_name = extract_table_name_from_table_with_joins(&update.table)?;
34 let info = catalog
35 .get_collection(DatabaseId::DEFAULT, &table_name)?
36 .ok_or_else(|| SqlError::UnknownTable {
37 name: table_name.clone(),
38 })?;
39
40 let assigns = convert_assignments(&update.assignments)?;
41
42 let filters = match &update.selection {
43 Some(expr) => super::super::select::convert_where_to_filters(expr)?,
44 None => Vec::new(),
45 };
46
47 let target_keys = extract_point_keys(update.selection.as_ref(), &info);
48
49 let rules = engine_rules::resolve_engine_rules(info.engine);
50 rules.plan_update(UpdateParams {
51 collection: table_name,
52 assignments: assigns,
53 filters,
54 target_keys,
55 returning: update.returning.is_some(),
56 })
57}
58
59fn plan_update_from(update: &ast::Update, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
61 let target_name = extract_table_name_from_table_with_joins(&update.table)?;
62
63 let target_alias: Option<String> = match &update.table.relation {
65 ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
66 _ => None,
67 };
68 let target_ref = target_alias.as_deref().unwrap_or(target_name.as_str());
69
70 let from_kind = update.from.as_ref().expect("caller ensures from.is_some()");
71 let from_tables: &Vec<ast::TableWithJoins> = match from_kind {
72 ast::UpdateTableFromKind::AfterSet(tables)
73 | ast::UpdateTableFromKind::BeforeSet(tables) => tables,
74 };
75
76 if from_tables.len() > 1 {
78 return Err(SqlError::Unsupported {
79 detail: format!(
80 "UPDATE ... FROM with {} source tables is not supported; \
81 only a single FROM table is accepted",
82 from_tables.len()
83 ),
84 });
85 }
86 let from_table = from_tables.first().ok_or_else(|| SqlError::Parse {
87 detail: "UPDATE ... FROM requires at least one source table".into(),
88 })?;
89
90 let source_name = match &from_table.relation {
92 ast::TableFactor::Table { name, .. } => normalize_object_name_checked(name)?,
93 ast::TableFactor::Derived { .. } => {
94 return Err(SqlError::Unsupported {
95 detail: "UPDATE ... FROM (subquery) is not supported; \
96 use a CTE: WITH cte AS (SELECT ...) UPDATE t SET ... FROM cte WHERE ..."
97 .into(),
98 });
99 }
100 _ => {
101 return Err(SqlError::Unsupported {
102 detail: "non-table relation in UPDATE ... FROM is not supported".into(),
103 });
104 }
105 };
106 if !from_table.joins.is_empty() {
108 return Err(SqlError::Unsupported {
109 detail: "JOIN in UPDATE ... FROM source is not supported; \
110 use a CTE to pre-join the source"
111 .into(),
112 });
113 }
114
115 let source_alias: Option<String> = match &from_table.relation {
116 ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
117 _ => None,
118 };
119 let source_ref = source_alias.as_deref().unwrap_or(source_name.as_str());
120
121 let target_info = catalog
123 .get_collection(DatabaseId::DEFAULT, &target_name)?
124 .ok_or_else(|| SqlError::UnknownTable {
125 name: target_name.clone(),
126 })?;
127 let source_info = catalog
128 .get_collection(DatabaseId::DEFAULT, &source_name)?
129 .ok_or_else(|| SqlError::UnknownTable {
130 name: source_name.clone(),
131 })?;
132
133 let assigns = convert_assignments(&update.assignments)?;
134
135 let (target_join_col, source_join_col, target_filters) = match &update.selection {
139 None => {
140 return Err(SqlError::Parse {
141 detail: "UPDATE ... FROM requires a WHERE clause with an equi-join predicate \
142 linking the target and source tables"
143 .into(),
144 });
145 }
146 Some(expr) => extract_join_predicate(expr, target_ref, source_ref)?,
147 };
148
149 let source_rules = engine_rules::resolve_engine_rules(source_info.engine);
151 let source_plan = source_rules.plan_scan(crate::engine_rules::ScanParams {
152 collection: source_name,
153 alias: source_alias,
154 filters: Vec::new(),
155 projection: Vec::new(),
156 sort_keys: Vec::new(),
157 limit: None,
158 offset: 0,
159 distinct: false,
160 window_functions: Vec::new(),
161 indexes: Vec::new(),
162 temporal: crate::temporal::TemporalScope::default(),
163 bitemporal: source_info.bitemporal,
164 })?;
165
166 let rules = engine_rules::resolve_engine_rules(target_info.engine);
167 rules.plan_update_from(UpdateFromParams {
168 collection: target_name,
169 source: Box::new(source_plan),
170 target_join_col,
171 source_join_col,
172 assignments: assigns,
173 target_filters,
174 returning: update.returning.is_some(),
175 })
176}
177
178fn extract_join_predicate(
184 expr: &ast::Expr,
185 target_ref: &str,
186 source_ref: &str,
187) -> Result<(String, String, Vec<Filter>)> {
188 let mut conjuncts: Vec<ast::Expr> = Vec::new();
190 flatten_and_expr(expr, &mut conjuncts);
191
192 let mut join_idx: Option<usize> = None;
194 let mut target_col = String::new();
195 let mut source_col = String::new();
196
197 for (i, conjunct) in conjuncts.iter().enumerate() {
198 if let Some((tc, sc)) = try_equijoin_pair(conjunct, target_ref, source_ref) {
199 target_col = tc;
200 source_col = sc;
201 join_idx = Some(i);
202 break;
203 }
204 }
205
206 let join_idx = join_idx.ok_or_else(|| SqlError::Parse {
207 detail: format!(
208 "UPDATE ... FROM requires a WHERE clause equi-join predicate of the form \
209 `{target_ref}.col = {source_ref}.col`; none found"
210 ),
211 })?;
212
213 conjuncts.remove(join_idx);
214
215 let target_filters = strip_and_convert_filters(conjuncts, target_ref)?;
218
219 Ok((target_col, source_col, target_filters))
220}
221
222fn try_equijoin_pair(
226 expr: &ast::Expr,
227 target_ref: &str,
228 source_ref: &str,
229) -> Option<(String, String)> {
230 let ast::Expr::BinaryOp {
231 left,
232 op: ast::BinaryOperator::Eq,
233 right,
234 } = expr
235 else {
236 return None;
237 };
238
239 let lhs = qualified_ident_pair(left);
240 let rhs = qualified_ident_pair(right);
241
242 match (lhs, rhs) {
243 (Some((lt, lc)), Some((rt, rc))) => {
244 if lt == target_ref && rt == source_ref {
245 Some((lc, rc))
246 } else if lt == source_ref && rt == target_ref {
247 Some((rc, lc))
248 } else {
249 None
250 }
251 }
252 (Some((t, c)), None) if t == source_ref => {
254 if let ast::Expr::Identifier(ident) = right.as_ref() {
255 Some((normalize_ident(ident), c))
256 } else {
257 None
258 }
259 }
260 (None, Some((t, c))) if t == source_ref => {
261 if let ast::Expr::Identifier(ident) = left.as_ref() {
262 Some((normalize_ident(ident), c))
263 } else {
264 None
265 }
266 }
267 _ => None,
268 }
269}
270
271fn convert_assignments(assignments: &[ast::Assignment]) -> Result<Vec<(String, SqlExpr)>> {
273 assignments
274 .iter()
275 .map(|a| {
276 let col = match &a.target {
277 ast::AssignmentTarget::ColumnName(name) => {
278 if name.0.len() > 1 {
279 return Err(SqlError::Unsupported {
280 detail: format!(
281 "qualified column name in SET target: {SCHEMA_QUALIFIED_MSG}"
282 ),
283 });
284 }
285 normalize_object_name_checked(name)?
286 }
287 ast::AssignmentTarget::Tuple(names) => names
288 .iter()
289 .map(normalize_object_name_checked)
290 .collect::<Result<Vec<_>>>()?
291 .join(","),
292 };
293 let val = convert_expr(&a.value)?;
294 Ok((col, val))
295 })
296 .collect()
297}
298
299pub fn plan_delete(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
301 let ast::Statement::Delete(delete) = stmt else {
302 return Err(SqlError::Parse {
303 detail: "expected DELETE statement".into(),
304 });
305 };
306
307 let from_tables = match &delete.from {
308 ast::FromTable::WithFromKeyword(tables) | ast::FromTable::WithoutKeyword(tables) => tables,
309 };
310 let table_name =
311 extract_table_name_from_table_with_joins(from_tables.first().ok_or_else(|| {
312 SqlError::Parse {
313 detail: "DELETE requires a FROM table".into(),
314 }
315 })?)?;
316 let info = catalog
317 .get_collection(DatabaseId::DEFAULT, &table_name)?
318 .ok_or_else(|| SqlError::UnknownTable {
319 name: table_name.clone(),
320 })?;
321
322 let filters = match &delete.selection {
323 Some(expr) => super::super::select::convert_where_to_filters(expr)?,
324 None => Vec::new(),
325 };
326
327 let target_keys = extract_point_keys(delete.selection.as_ref(), &info);
328
329 let rules = engine_rules::resolve_engine_rules(info.engine);
330 rules.plan_delete(DeleteParams {
331 collection: table_name,
332 filters,
333 target_keys,
334 })
335}
336
337pub fn plan_truncate_stmt(stmt: &ast::Statement) -> Result<Vec<SqlPlan>> {
339 let ast::Statement::Truncate(truncate) = stmt else {
340 return Err(SqlError::Parse {
341 detail: "expected TRUNCATE statement".into(),
342 });
343 };
344 let restart_identity = matches!(
345 truncate.identity,
346 Some(sqlparser::ast::TruncateIdentityOption::Restart)
347 );
348 truncate
349 .table_names
350 .iter()
351 .map(|t| {
352 Ok(SqlPlan::Truncate {
353 collection: normalize_object_name_checked(&t.name)?,
354 restart_identity,
355 })
356 })
357 .collect()
358}