1use sqlparser::ast::{self};
4
5use crate::engine_rules::{self, DeleteParams, InsertParams, UpdateParams};
6use crate::error::{Result, SqlError};
7use crate::parser::normalize::{normalize_ident, normalize_object_name};
8use crate::resolver::expr::{convert_expr, convert_value};
9use crate::types::*;
10
11fn extract_on_conflict_updates(ins: &ast::Insert) -> Result<Option<Vec<(String, SqlExpr)>>> {
14 let Some(on) = ins.on.as_ref() else {
15 return Ok(None);
16 };
17 let ast::OnInsert::OnConflict(oc) = on else {
18 return Ok(None);
19 };
20 let ast::OnConflictAction::DoUpdate(do_update) = &oc.action else {
21 return Err(SqlError::Unsupported {
23 detail: "ON CONFLICT DO NOTHING is not yet supported".into(),
24 });
25 };
26 let mut pairs = Vec::with_capacity(do_update.assignments.len());
27 for a in &do_update.assignments {
28 let name = match &a.target {
29 ast::AssignmentTarget::ColumnName(obj) => normalize_object_name(obj),
30 _ => {
31 return Err(SqlError::Unsupported {
32 detail: "ON CONFLICT DO UPDATE SET target must be a column name".into(),
33 });
34 }
35 };
36 let expr = convert_expr(&a.value)?;
37 pairs.push((name, expr));
38 }
39 Ok(Some(pairs))
40}
41
42pub fn plan_insert(ins: &ast::Insert, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
44 if let Some(on_conflict_updates) = extract_on_conflict_updates(ins)? {
48 return plan_upsert_with_on_conflict(ins, catalog, on_conflict_updates);
49 }
50 let table_name = match &ins.table {
51 ast::TableObject::TableName(name) => normalize_object_name(name),
52 ast::TableObject::TableFunction(_) => {
53 return Err(SqlError::Unsupported {
54 detail: "INSERT INTO table function not supported".into(),
55 });
56 }
57 };
58 let info = catalog
59 .get_collection(&table_name)?
60 .ok_or_else(|| SqlError::UnknownTable {
61 name: table_name.clone(),
62 })?;
63
64 let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
65
66 if let Some(source) = &ins.source
68 && let ast::SetExpr::Select(_select) = &*source.body
69 {
70 let source_plan = super::select::plan_query(
71 source,
72 catalog,
73 &crate::functions::registry::FunctionRegistry::new(),
74 )?;
75 return Ok(vec![SqlPlan::InsertSelect {
76 target: table_name,
77 source: Box::new(source_plan),
78 limit: 0,
79 }]);
80 }
81
82 let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
84 detail: "INSERT requires VALUES or SELECT".into(),
85 })?;
86
87 let rows_ast = match &*source.body {
88 ast::SetExpr::Values(values) => &values.rows,
89 _ => {
90 return Err(SqlError::Unsupported {
91 detail: "INSERT source must be VALUES or SELECT".into(),
92 });
93 }
94 };
95
96 if info.engine == EngineType::KeyValue {
98 let key_idx = columns.iter().position(|c| c == "key");
99 let ttl_idx = columns.iter().position(|c| c == "ttl");
100 let mut entries = Vec::with_capacity(rows_ast.len());
101 let mut ttl_secs: u64 = 0;
102 for row_exprs in rows_ast {
103 let key_val = match key_idx {
104 Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
105 None => SqlValue::String(String::new()),
106 };
107 if let Some(idx) = ttl_idx {
109 match expr_to_sql_value(&row_exprs[idx]) {
110 Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
111 Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
112 _ => {}
113 }
114 }
115 let value_cols: Vec<(String, SqlValue)> = columns
116 .iter()
117 .enumerate()
118 .filter(|(i, _)| Some(*i) != key_idx && Some(*i) != ttl_idx)
119 .map(|(i, col)| {
120 let val = expr_to_sql_value(&row_exprs[i])?;
121 Ok((col.clone(), val))
122 })
123 .collect::<Result<Vec<_>>>()?;
124 entries.push((key_val, value_cols));
125 }
126 return Ok(vec![SqlPlan::KvInsert {
127 collection: table_name,
128 entries,
129 ttl_secs,
130 }]);
131 }
132
133 let rows = convert_value_rows(&columns, rows_ast)?;
135 let column_defaults: Vec<(String, String)> = info
136 .columns
137 .iter()
138 .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
139 .collect();
140 let rules = engine_rules::resolve_engine_rules(info.engine);
141 rules.plan_insert(InsertParams {
142 collection: table_name,
143 columns,
144 rows,
145 column_defaults,
146 })
147}
148
149pub fn plan_upsert(ins: &ast::Insert, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
153 let table_name = match &ins.table {
154 ast::TableObject::TableName(name) => normalize_object_name(name),
155 ast::TableObject::TableFunction(_) => {
156 return Err(SqlError::Unsupported {
157 detail: "UPSERT INTO table function not supported".into(),
158 });
159 }
160 };
161 let info = catalog
162 .get_collection(&table_name)?
163 .ok_or_else(|| SqlError::UnknownTable {
164 name: table_name.clone(),
165 })?;
166
167 let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
168
169 let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
170 detail: "UPSERT requires VALUES".into(),
171 })?;
172
173 let rows_ast = match &*source.body {
174 ast::SetExpr::Values(values) => &values.rows,
175 _ => {
176 return Err(SqlError::Unsupported {
177 detail: "UPSERT source must be VALUES".into(),
178 });
179 }
180 };
181
182 if info.engine == EngineType::KeyValue {
184 let key_idx = columns.iter().position(|c| c == "key");
185 let ttl_idx = columns.iter().position(|c| c == "ttl");
186 let mut entries = Vec::with_capacity(rows_ast.len());
187 let mut ttl_secs: u64 = 0;
188 for row_exprs in rows_ast {
189 let key_val = match key_idx {
190 Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
191 None => SqlValue::String(String::new()),
192 };
193 if let Some(idx) = ttl_idx {
194 match expr_to_sql_value(&row_exprs[idx]) {
195 Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
196 Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
197 _ => {}
198 }
199 }
200 let value_cols: Vec<(String, SqlValue)> = columns
201 .iter()
202 .enumerate()
203 .filter(|(i, _)| Some(*i) != key_idx && Some(*i) != ttl_idx)
204 .map(|(i, col)| {
205 let val = expr_to_sql_value(&row_exprs[i])?;
206 Ok((col.clone(), val))
207 })
208 .collect::<Result<Vec<_>>>()?;
209 entries.push((key_val, value_cols));
210 }
211 return Ok(vec![SqlPlan::KvInsert {
212 collection: table_name,
213 entries,
214 ttl_secs,
215 }]);
216 }
217
218 let rows = convert_value_rows(&columns, rows_ast)?;
219 let column_defaults: Vec<(String, String)> = info
220 .columns
221 .iter()
222 .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
223 .collect();
224 let rules = engine_rules::resolve_engine_rules(info.engine);
225 rules.plan_upsert(engine_rules::UpsertParams {
226 collection: table_name,
227 columns,
228 rows,
229 column_defaults,
230 on_conflict_updates: Vec::new(),
231 })
232}
233
234fn plan_upsert_with_on_conflict(
239 ins: &ast::Insert,
240 catalog: &dyn SqlCatalog,
241 on_conflict_updates: Vec<(String, SqlExpr)>,
242) -> Result<Vec<SqlPlan>> {
243 let table_name = match &ins.table {
244 ast::TableObject::TableName(name) => normalize_object_name(name),
245 ast::TableObject::TableFunction(_) => {
246 return Err(SqlError::Unsupported {
247 detail: "INSERT ... ON CONFLICT on a table function is not supported".into(),
248 });
249 }
250 };
251 let info = catalog
252 .get_collection(&table_name)?
253 .ok_or_else(|| SqlError::UnknownTable {
254 name: table_name.clone(),
255 })?;
256
257 let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
258
259 let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
260 detail: "INSERT ... ON CONFLICT requires VALUES".into(),
261 })?;
262 let rows_ast = match &*source.body {
263 ast::SetExpr::Values(values) => &values.rows,
264 _ => {
265 return Err(SqlError::Unsupported {
266 detail: "INSERT ... ON CONFLICT source must be VALUES".into(),
267 });
268 }
269 };
270
271 let rows = convert_value_rows(&columns, rows_ast)?;
272 let column_defaults: Vec<(String, String)> = info
273 .columns
274 .iter()
275 .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
276 .collect();
277 let rules = engine_rules::resolve_engine_rules(info.engine);
278 rules.plan_upsert(engine_rules::UpsertParams {
279 collection: table_name,
280 columns,
281 rows,
282 column_defaults,
283 on_conflict_updates,
284 })
285}
286
287pub fn plan_update(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
289 let ast::Statement::Update(update) = stmt else {
290 return Err(SqlError::Parse {
291 detail: "expected UPDATE statement".into(),
292 });
293 };
294
295 let table_name = extract_table_name_from_table_with_joins(&update.table)?;
296 let info = catalog
297 .get_collection(&table_name)?
298 .ok_or_else(|| SqlError::UnknownTable {
299 name: table_name.clone(),
300 })?;
301
302 let assigns: Vec<(String, SqlExpr)> = update
303 .assignments
304 .iter()
305 .map(|a| {
306 let col = match &a.target {
307 ast::AssignmentTarget::ColumnName(name) => normalize_object_name(name),
308 ast::AssignmentTarget::Tuple(names) => names
309 .iter()
310 .map(normalize_object_name)
311 .collect::<Vec<_>>()
312 .join(","),
313 };
314 let val = convert_expr(&a.value)?;
315 Ok((col, val))
316 })
317 .collect::<Result<_>>()?;
318
319 let filters = match &update.selection {
320 Some(expr) => super::select::convert_where_to_filters(expr)?,
321 None => Vec::new(),
322 };
323
324 let target_keys = extract_point_keys(update.selection.as_ref(), &info);
326
327 let rules = engine_rules::resolve_engine_rules(info.engine);
328 rules.plan_update(UpdateParams {
329 collection: table_name,
330 assignments: assigns,
331 filters,
332 target_keys,
333 returning: update.returning.is_some(),
334 })
335}
336
337pub fn plan_delete(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
339 let ast::Statement::Delete(delete) = stmt else {
340 return Err(SqlError::Parse {
341 detail: "expected DELETE statement".into(),
342 });
343 };
344
345 let from_tables = match &delete.from {
346 ast::FromTable::WithFromKeyword(tables) | ast::FromTable::WithoutKeyword(tables) => tables,
347 };
348 let table_name =
349 extract_table_name_from_table_with_joins(from_tables.first().ok_or_else(|| {
350 SqlError::Parse {
351 detail: "DELETE requires a FROM table".into(),
352 }
353 })?)?;
354 let info = catalog
355 .get_collection(&table_name)?
356 .ok_or_else(|| SqlError::UnknownTable {
357 name: table_name.clone(),
358 })?;
359
360 let filters = match &delete.selection {
361 Some(expr) => super::select::convert_where_to_filters(expr)?,
362 None => Vec::new(),
363 };
364
365 let target_keys = extract_point_keys(delete.selection.as_ref(), &info);
366
367 let rules = engine_rules::resolve_engine_rules(info.engine);
368 rules.plan_delete(DeleteParams {
369 collection: table_name,
370 filters,
371 target_keys,
372 })
373}
374
375pub fn plan_truncate_stmt(stmt: &ast::Statement) -> Result<Vec<SqlPlan>> {
377 let ast::Statement::Truncate(truncate) = stmt else {
378 return Err(SqlError::Parse {
379 detail: "expected TRUNCATE statement".into(),
380 });
381 };
382 let restart_identity = matches!(
383 truncate.identity,
384 Some(sqlparser::ast::TruncateIdentityOption::Restart)
385 );
386 truncate
387 .table_names
388 .iter()
389 .map(|t| {
390 Ok(SqlPlan::Truncate {
391 collection: normalize_object_name(&t.name),
392 restart_identity,
393 })
394 })
395 .collect()
396}
397
398fn convert_value_rows(
401 columns: &[String],
402 rows: &[Vec<ast::Expr>],
403) -> Result<Vec<Vec<(String, SqlValue)>>> {
404 rows.iter()
405 .map(|row| {
406 row.iter()
407 .enumerate()
408 .map(|(i, expr)| {
409 let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
410 let val = expr_to_sql_value(expr)?;
411 Ok((col, val))
412 })
413 .collect::<Result<Vec<_>>>()
414 })
415 .collect()
416}
417
418fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
419 match expr {
420 ast::Expr::Value(v) => convert_value(&v.value),
421 ast::Expr::UnaryOp {
422 op: ast::UnaryOperator::Minus,
423 expr: inner,
424 } => {
425 let val = expr_to_sql_value(inner)?;
426 match val {
427 SqlValue::Int(n) => Ok(SqlValue::Int(-n)),
428 SqlValue::Float(f) => Ok(SqlValue::Float(-f)),
429 _ => Err(SqlError::TypeMismatch {
430 detail: "cannot negate non-numeric value".into(),
431 }),
432 }
433 }
434 ast::Expr::Array(ast::Array { elem, .. }) => {
435 let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
436 Ok(SqlValue::Array(vals))
437 }
438 ast::Expr::Function(func) => {
439 let func_name = func
440 .name
441 .0
442 .iter()
443 .map(|p| match p {
444 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
445 _ => String::new(),
446 })
447 .collect::<Vec<_>>()
448 .join(".")
449 .to_lowercase();
450 match func_name.as_str() {
451 "st_point" => {
452 let args = super::select::extract_func_args(func)?;
454 if args.len() >= 2 {
455 let lon = super::select::extract_float(&args[0])?;
456 let lat = super::select::extract_float(&args[1])?;
457 Ok(SqlValue::String(format!(
458 r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
459 )))
460 } else {
461 Ok(SqlValue::String(format!("{expr}")))
462 }
463 }
464 "st_geomfromgeojson" => {
465 let args = super::select::extract_func_args(func)?;
466 if !args.is_empty() {
467 let s = super::select::extract_string_literal(&args[0])?;
468 Ok(SqlValue::String(s))
469 } else {
470 Ok(SqlValue::String(format!("{expr}")))
471 }
472 }
473 _ => {
474 if let Ok(sql_expr) = crate::resolver::expr::convert_expr(expr)
481 && let Some(v) = super::const_fold::fold_constant_default(&sql_expr)
482 {
483 Ok(v)
484 } else {
485 Ok(SqlValue::String(format!("{expr}")))
486 }
487 }
488 }
489 }
490 _ => Err(SqlError::Unsupported {
491 detail: format!("value expression: {expr}"),
492 }),
493 }
494}
495
496fn extract_table_name_from_table_with_joins(table: &ast::TableWithJoins) -> Result<String> {
497 match &table.relation {
498 ast::TableFactor::Table { name, .. } => Ok(normalize_object_name(name)),
499 _ => Err(SqlError::Unsupported {
500 detail: "non-table target in DML".into(),
501 }),
502 }
503}
504
505fn extract_point_keys(selection: Option<&ast::Expr>, info: &CollectionInfo) -> Vec<SqlValue> {
507 let pk = match &info.primary_key {
508 Some(pk) => pk.clone(),
509 None => return Vec::new(),
510 };
511
512 let expr = match selection {
513 Some(e) => e,
514 None => return Vec::new(),
515 };
516
517 let mut keys = Vec::new();
518 collect_pk_equalities(expr, &pk, &mut keys);
519 keys
520}
521
522fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
523 match expr {
524 ast::Expr::BinaryOp {
525 left,
526 op: ast::BinaryOperator::Eq,
527 right,
528 } => {
529 if is_column(left, pk)
530 && let Ok(v) = expr_to_sql_value(right)
531 {
532 keys.push(v);
533 } else if is_column(right, pk)
534 && let Ok(v) = expr_to_sql_value(left)
535 {
536 keys.push(v);
537 }
538 }
539 ast::Expr::BinaryOp {
540 left,
541 op: ast::BinaryOperator::Or,
542 right,
543 } => {
544 collect_pk_equalities(left, pk, keys);
545 collect_pk_equalities(right, pk, keys);
546 }
547 ast::Expr::InList {
548 expr: inner,
549 list,
550 negated: false,
551 } if is_column(inner, pk) => {
552 for item in list {
553 if let Ok(v) = expr_to_sql_value(item) {
554 keys.push(v);
555 }
556 }
557 }
558 _ => {}
559 }
560}
561
562fn is_column(expr: &ast::Expr, name: &str) -> bool {
563 match expr {
564 ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
565 ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
566 normalize_ident(&parts[1]) == name
567 }
568 _ => false,
569 }
570}