1use sqlparser::ast::{self};
4
5use crate::engine_rules::{self, DeleteParams, InsertParams, UpdateParams};
6use crate::error::{Result, SqlError};
7use crate::parser::normalize::{normalize_ident, normalize_object_name};
8use crate::resolver::expr::{convert_expr, convert_value};
9use crate::types::*;
10
11pub fn plan_insert(ins: &ast::Insert, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
13 let table_name = match &ins.table {
14 ast::TableObject::TableName(name) => normalize_object_name(name),
15 ast::TableObject::TableFunction(_) => {
16 return Err(SqlError::Unsupported {
17 detail: "INSERT INTO table function not supported".into(),
18 });
19 }
20 };
21 let info = catalog
22 .get_collection(&table_name)?
23 .ok_or_else(|| SqlError::UnknownTable {
24 name: table_name.clone(),
25 })?;
26
27 let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
28
29 if let Some(source) = &ins.source
31 && let ast::SetExpr::Select(_select) = &*source.body
32 {
33 let source_plan = super::select::plan_query(
34 source,
35 catalog,
36 &crate::functions::registry::FunctionRegistry::new(),
37 )?;
38 return Ok(vec![SqlPlan::InsertSelect {
39 target: table_name,
40 source: Box::new(source_plan),
41 limit: 0,
42 }]);
43 }
44
45 let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
47 detail: "INSERT requires VALUES or SELECT".into(),
48 })?;
49
50 let rows_ast = match &*source.body {
51 ast::SetExpr::Values(values) => &values.rows,
52 _ => {
53 return Err(SqlError::Unsupported {
54 detail: "INSERT source must be VALUES or SELECT".into(),
55 });
56 }
57 };
58
59 if info.engine == EngineType::KeyValue {
61 let key_idx = columns.iter().position(|c| c == "key");
62 let ttl_idx = columns.iter().position(|c| c == "ttl");
63 let mut entries = Vec::with_capacity(rows_ast.len());
64 let mut ttl_secs: u64 = 0;
65 for row_exprs in rows_ast {
66 let key_val = match key_idx {
67 Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
68 None => SqlValue::String(String::new()),
69 };
70 if let Some(idx) = ttl_idx {
72 match expr_to_sql_value(&row_exprs[idx]) {
73 Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
74 Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
75 _ => {}
76 }
77 }
78 let value_cols: Vec<(String, SqlValue)> = columns
79 .iter()
80 .enumerate()
81 .filter(|(i, _)| Some(*i) != key_idx && Some(*i) != ttl_idx)
82 .map(|(i, col)| {
83 let val = expr_to_sql_value(&row_exprs[i])?;
84 Ok((col.clone(), val))
85 })
86 .collect::<Result<Vec<_>>>()?;
87 entries.push((key_val, value_cols));
88 }
89 return Ok(vec![SqlPlan::KvInsert {
90 collection: table_name,
91 entries,
92 ttl_secs,
93 }]);
94 }
95
96 let rows = convert_value_rows(&columns, rows_ast)?;
98 let column_defaults: Vec<(String, String)> = info
99 .columns
100 .iter()
101 .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
102 .collect();
103 let rules = engine_rules::resolve_engine_rules(info.engine);
104 rules.plan_insert(InsertParams {
105 collection: table_name,
106 columns,
107 rows,
108 column_defaults,
109 })
110}
111
112pub fn plan_upsert(ins: &ast::Insert, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
116 let table_name = match &ins.table {
117 ast::TableObject::TableName(name) => normalize_object_name(name),
118 ast::TableObject::TableFunction(_) => {
119 return Err(SqlError::Unsupported {
120 detail: "UPSERT INTO table function not supported".into(),
121 });
122 }
123 };
124 let info = catalog
125 .get_collection(&table_name)?
126 .ok_or_else(|| SqlError::UnknownTable {
127 name: table_name.clone(),
128 })?;
129
130 let columns: Vec<String> = ins.columns.iter().map(normalize_ident).collect();
131
132 let source = ins.source.as_ref().ok_or_else(|| SqlError::Parse {
133 detail: "UPSERT requires VALUES".into(),
134 })?;
135
136 let rows_ast = match &*source.body {
137 ast::SetExpr::Values(values) => &values.rows,
138 _ => {
139 return Err(SqlError::Unsupported {
140 detail: "UPSERT source must be VALUES".into(),
141 });
142 }
143 };
144
145 if info.engine == EngineType::KeyValue {
147 let key_idx = columns.iter().position(|c| c == "key");
148 let ttl_idx = columns.iter().position(|c| c == "ttl");
149 let mut entries = Vec::with_capacity(rows_ast.len());
150 let mut ttl_secs: u64 = 0;
151 for row_exprs in rows_ast {
152 let key_val = match key_idx {
153 Some(idx) => expr_to_sql_value(&row_exprs[idx])?,
154 None => SqlValue::String(String::new()),
155 };
156 if let Some(idx) = ttl_idx {
157 match expr_to_sql_value(&row_exprs[idx]) {
158 Ok(SqlValue::Int(n)) => ttl_secs = n.max(0) as u64,
159 Ok(SqlValue::Float(f)) => ttl_secs = f.max(0.0) as u64,
160 _ => {}
161 }
162 }
163 let value_cols: Vec<(String, SqlValue)> = columns
164 .iter()
165 .enumerate()
166 .filter(|(i, _)| Some(*i) != key_idx && Some(*i) != ttl_idx)
167 .map(|(i, col)| {
168 let val = expr_to_sql_value(&row_exprs[i])?;
169 Ok((col.clone(), val))
170 })
171 .collect::<Result<Vec<_>>>()?;
172 entries.push((key_val, value_cols));
173 }
174 return Ok(vec![SqlPlan::KvInsert {
175 collection: table_name,
176 entries,
177 ttl_secs,
178 }]);
179 }
180
181 let rows = convert_value_rows(&columns, rows_ast)?;
182 let column_defaults: Vec<(String, String)> = info
183 .columns
184 .iter()
185 .filter_map(|c| c.default.as_ref().map(|d| (c.name.clone(), d.clone())))
186 .collect();
187 let rules = engine_rules::resolve_engine_rules(info.engine);
188 rules.plan_upsert(engine_rules::UpsertParams {
189 collection: table_name,
190 columns,
191 rows,
192 column_defaults,
193 })
194}
195
196pub fn plan_update(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
198 let ast::Statement::Update(update) = stmt else {
199 return Err(SqlError::Parse {
200 detail: "expected UPDATE statement".into(),
201 });
202 };
203
204 let table_name = extract_table_name_from_table_with_joins(&update.table)?;
205 let info = catalog
206 .get_collection(&table_name)?
207 .ok_or_else(|| SqlError::UnknownTable {
208 name: table_name.clone(),
209 })?;
210
211 let assigns: Vec<(String, SqlExpr)> = update
212 .assignments
213 .iter()
214 .map(|a| {
215 let col = match &a.target {
216 ast::AssignmentTarget::ColumnName(name) => normalize_object_name(name),
217 ast::AssignmentTarget::Tuple(names) => names
218 .iter()
219 .map(normalize_object_name)
220 .collect::<Vec<_>>()
221 .join(","),
222 };
223 let val = convert_expr(&a.value)?;
224 Ok((col, val))
225 })
226 .collect::<Result<_>>()?;
227
228 let filters = match &update.selection {
229 Some(expr) => super::select::convert_where_to_filters(expr)?,
230 None => Vec::new(),
231 };
232
233 let target_keys = extract_point_keys(update.selection.as_ref(), &info);
235
236 let rules = engine_rules::resolve_engine_rules(info.engine);
237 rules.plan_update(UpdateParams {
238 collection: table_name,
239 assignments: assigns,
240 filters,
241 target_keys,
242 returning: update.returning.is_some(),
243 })
244}
245
246pub fn plan_delete(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
248 let ast::Statement::Delete(delete) = stmt else {
249 return Err(SqlError::Parse {
250 detail: "expected DELETE statement".into(),
251 });
252 };
253
254 let from_tables = match &delete.from {
255 ast::FromTable::WithFromKeyword(tables) | ast::FromTable::WithoutKeyword(tables) => tables,
256 };
257 let table_name =
258 extract_table_name_from_table_with_joins(from_tables.first().ok_or_else(|| {
259 SqlError::Parse {
260 detail: "DELETE requires a FROM table".into(),
261 }
262 })?)?;
263 let info = catalog
264 .get_collection(&table_name)?
265 .ok_or_else(|| SqlError::UnknownTable {
266 name: table_name.clone(),
267 })?;
268
269 let filters = match &delete.selection {
270 Some(expr) => super::select::convert_where_to_filters(expr)?,
271 None => Vec::new(),
272 };
273
274 let target_keys = extract_point_keys(delete.selection.as_ref(), &info);
275
276 let rules = engine_rules::resolve_engine_rules(info.engine);
277 rules.plan_delete(DeleteParams {
278 collection: table_name,
279 filters,
280 target_keys,
281 })
282}
283
284pub fn plan_truncate_stmt(stmt: &ast::Statement) -> Result<Vec<SqlPlan>> {
286 let ast::Statement::Truncate(truncate) = stmt else {
287 return Err(SqlError::Parse {
288 detail: "expected TRUNCATE statement".into(),
289 });
290 };
291 let restart_identity = matches!(
292 truncate.identity,
293 Some(sqlparser::ast::TruncateIdentityOption::Restart)
294 );
295 truncate
296 .table_names
297 .iter()
298 .map(|t| {
299 Ok(SqlPlan::Truncate {
300 collection: normalize_object_name(&t.name),
301 restart_identity,
302 })
303 })
304 .collect()
305}
306
307fn convert_value_rows(
310 columns: &[String],
311 rows: &[Vec<ast::Expr>],
312) -> Result<Vec<Vec<(String, SqlValue)>>> {
313 rows.iter()
314 .map(|row| {
315 row.iter()
316 .enumerate()
317 .map(|(i, expr)| {
318 let col = columns.get(i).cloned().unwrap_or_else(|| format!("col{i}"));
319 let val = expr_to_sql_value(expr)?;
320 Ok((col, val))
321 })
322 .collect::<Result<Vec<_>>>()
323 })
324 .collect()
325}
326
327fn expr_to_sql_value(expr: &ast::Expr) -> Result<SqlValue> {
328 match expr {
329 ast::Expr::Value(v) => convert_value(&v.value),
330 ast::Expr::UnaryOp {
331 op: ast::UnaryOperator::Minus,
332 expr: inner,
333 } => {
334 let val = expr_to_sql_value(inner)?;
335 match val {
336 SqlValue::Int(n) => Ok(SqlValue::Int(-n)),
337 SqlValue::Float(f) => Ok(SqlValue::Float(-f)),
338 _ => Err(SqlError::TypeMismatch {
339 detail: "cannot negate non-numeric value".into(),
340 }),
341 }
342 }
343 ast::Expr::Array(ast::Array { elem, .. }) => {
344 let vals = elem.iter().map(expr_to_sql_value).collect::<Result<_>>()?;
345 Ok(SqlValue::Array(vals))
346 }
347 ast::Expr::Function(func) => {
348 let func_name = func
349 .name
350 .0
351 .iter()
352 .map(|p| match p {
353 ast::ObjectNamePart::Identifier(ident) => normalize_ident(ident),
354 _ => String::new(),
355 })
356 .collect::<Vec<_>>()
357 .join(".")
358 .to_lowercase();
359 match func_name.as_str() {
360 "st_point" => {
361 let args = super::select::extract_func_args(func)?;
363 if args.len() >= 2 {
364 let lon = super::select::extract_float(&args[0])?;
365 let lat = super::select::extract_float(&args[1])?;
366 Ok(SqlValue::String(format!(
367 r#"{{"type":"Point","coordinates":[{lon},{lat}]}}"#
368 )))
369 } else {
370 Ok(SqlValue::String(format!("{expr}")))
371 }
372 }
373 "st_geomfromgeojson" => {
374 let args = super::select::extract_func_args(func)?;
375 if !args.is_empty() {
376 let s = super::select::extract_string_literal(&args[0])?;
377 Ok(SqlValue::String(s))
378 } else {
379 Ok(SqlValue::String(format!("{expr}")))
380 }
381 }
382 _ => {
383 Ok(SqlValue::String(format!("{expr}")))
385 }
386 }
387 }
388 _ => Err(SqlError::Unsupported {
389 detail: format!("value expression: {expr}"),
390 }),
391 }
392}
393
394fn extract_table_name_from_table_with_joins(table: &ast::TableWithJoins) -> Result<String> {
395 match &table.relation {
396 ast::TableFactor::Table { name, .. } => Ok(normalize_object_name(name)),
397 _ => Err(SqlError::Unsupported {
398 detail: "non-table target in DML".into(),
399 }),
400 }
401}
402
403fn extract_point_keys(selection: Option<&ast::Expr>, info: &CollectionInfo) -> Vec<SqlValue> {
405 let pk = match &info.primary_key {
406 Some(pk) => pk.clone(),
407 None => return Vec::new(),
408 };
409
410 let expr = match selection {
411 Some(e) => e,
412 None => return Vec::new(),
413 };
414
415 let mut keys = Vec::new();
416 collect_pk_equalities(expr, &pk, &mut keys);
417 keys
418}
419
420fn collect_pk_equalities(expr: &ast::Expr, pk: &str, keys: &mut Vec<SqlValue>) {
421 match expr {
422 ast::Expr::BinaryOp {
423 left,
424 op: ast::BinaryOperator::Eq,
425 right,
426 } => {
427 if is_column(left, pk)
428 && let Ok(v) = expr_to_sql_value(right)
429 {
430 keys.push(v);
431 } else if is_column(right, pk)
432 && let Ok(v) = expr_to_sql_value(left)
433 {
434 keys.push(v);
435 }
436 }
437 ast::Expr::BinaryOp {
438 left,
439 op: ast::BinaryOperator::Or,
440 right,
441 } => {
442 collect_pk_equalities(left, pk, keys);
443 collect_pk_equalities(right, pk, keys);
444 }
445 ast::Expr::InList {
446 expr: inner,
447 list,
448 negated: false,
449 } => {
450 if is_column(inner, pk) {
451 for item in list {
452 if let Ok(v) = expr_to_sql_value(item) {
453 keys.push(v);
454 }
455 }
456 }
457 }
458 _ => {}
459 }
460}
461
462fn is_column(expr: &ast::Expr, name: &str) -> bool {
463 match expr {
464 ast::Expr::Identifier(ident) => normalize_ident(ident) == name,
465 ast::Expr::CompoundIdentifier(parts) if parts.len() == 2 => {
466 normalize_ident(&parts[1]) == name
467 }
468 _ => false,
469 }
470}