use nodedb_types::DatabaseId;
use sqlparser::ast;
use super::super::ast_helpers::{
flatten_and_expr, qualified_ident_pair, strip_and_convert_filters,
};
use super::super::dml_helpers::{extract_point_keys, extract_table_name_from_table_with_joins};
use crate::engine_rules::{self, DeleteParams, UpdateFromParams, UpdateParams};
use crate::error::{Result, SqlError};
use crate::parser::normalize::{
SCHEMA_QUALIFIED_MSG, normalize_ident, normalize_object_name_checked,
};
use crate::resolver::expr::convert_expr;
use crate::types::*;
pub fn plan_update(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
let ast::Statement::Update(update) = stmt else {
return Err(SqlError::Parse {
detail: "expected UPDATE statement".into(),
});
};
if update.from.is_some() {
return plan_update_from(update, catalog);
}
let table_name = extract_table_name_from_table_with_joins(&update.table)?;
let info = catalog
.get_collection(DatabaseId::DEFAULT, &table_name)?
.ok_or_else(|| SqlError::UnknownTable {
name: table_name.clone(),
})?;
let assigns = convert_assignments(&update.assignments)?;
let filters = match &update.selection {
Some(expr) => super::super::select::convert_where_to_filters(expr)?,
None => Vec::new(),
};
let target_keys = extract_point_keys(update.selection.as_ref(), &info);
let rules = engine_rules::resolve_engine_rules(info.engine);
rules.plan_update(UpdateParams {
collection: table_name,
assignments: assigns,
filters,
target_keys,
returning: update.returning.is_some(),
})
}
fn plan_update_from(update: &ast::Update, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
let target_name = extract_table_name_from_table_with_joins(&update.table)?;
let target_alias: Option<String> = match &update.table.relation {
ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
_ => None,
};
let target_ref = target_alias.as_deref().unwrap_or(target_name.as_str());
let from_kind = update.from.as_ref().expect("caller ensures from.is_some()");
let from_tables: &Vec<ast::TableWithJoins> = match from_kind {
ast::UpdateTableFromKind::AfterSet(tables)
| ast::UpdateTableFromKind::BeforeSet(tables) => tables,
};
if from_tables.len() > 1 {
return Err(SqlError::Unsupported {
detail: format!(
"UPDATE ... FROM with {} source tables is not supported; \
only a single FROM table is accepted",
from_tables.len()
),
});
}
let from_table = from_tables.first().ok_or_else(|| SqlError::Parse {
detail: "UPDATE ... FROM requires at least one source table".into(),
})?;
let source_name = match &from_table.relation {
ast::TableFactor::Table { name, .. } => normalize_object_name_checked(name)?,
ast::TableFactor::Derived { .. } => {
return Err(SqlError::Unsupported {
detail: "UPDATE ... FROM (subquery) is not supported; \
use a CTE: WITH cte AS (SELECT ...) UPDATE t SET ... FROM cte WHERE ..."
.into(),
});
}
_ => {
return Err(SqlError::Unsupported {
detail: "non-table relation in UPDATE ... FROM is not supported".into(),
});
}
};
if !from_table.joins.is_empty() {
return Err(SqlError::Unsupported {
detail: "JOIN in UPDATE ... FROM source is not supported; \
use a CTE to pre-join the source"
.into(),
});
}
let source_alias: Option<String> = match &from_table.relation {
ast::TableFactor::Table { alias, .. } => alias.as_ref().map(|a| normalize_ident(&a.name)),
_ => None,
};
let source_ref = source_alias.as_deref().unwrap_or(source_name.as_str());
let target_info = catalog
.get_collection(DatabaseId::DEFAULT, &target_name)?
.ok_or_else(|| SqlError::UnknownTable {
name: target_name.clone(),
})?;
let source_info = catalog
.get_collection(DatabaseId::DEFAULT, &source_name)?
.ok_or_else(|| SqlError::UnknownTable {
name: source_name.clone(),
})?;
let assigns = convert_assignments(&update.assignments)?;
let (target_join_col, source_join_col, target_filters) = match &update.selection {
None => {
return Err(SqlError::Parse {
detail: "UPDATE ... FROM requires a WHERE clause with an equi-join predicate \
linking the target and source tables"
.into(),
});
}
Some(expr) => extract_join_predicate(expr, target_ref, source_ref)?,
};
let source_rules = engine_rules::resolve_engine_rules(source_info.engine);
let source_plan = source_rules.plan_scan(crate::engine_rules::ScanParams {
collection: source_name,
alias: source_alias,
filters: Vec::new(),
projection: Vec::new(),
sort_keys: Vec::new(),
limit: None,
offset: 0,
distinct: false,
window_functions: Vec::new(),
indexes: Vec::new(),
temporal: crate::temporal::TemporalScope::default(),
bitemporal: source_info.bitemporal,
})?;
let rules = engine_rules::resolve_engine_rules(target_info.engine);
rules.plan_update_from(UpdateFromParams {
collection: target_name,
source: Box::new(source_plan),
target_join_col,
source_join_col,
assignments: assigns,
target_filters,
returning: update.returning.is_some(),
})
}
fn extract_join_predicate(
expr: &ast::Expr,
target_ref: &str,
source_ref: &str,
) -> Result<(String, String, Vec<Filter>)> {
let mut conjuncts: Vec<ast::Expr> = Vec::new();
flatten_and_expr(expr, &mut conjuncts);
let mut join_idx: Option<usize> = None;
let mut target_col = String::new();
let mut source_col = String::new();
for (i, conjunct) in conjuncts.iter().enumerate() {
if let Some((tc, sc)) = try_equijoin_pair(conjunct, target_ref, source_ref) {
target_col = tc;
source_col = sc;
join_idx = Some(i);
break;
}
}
let join_idx = join_idx.ok_or_else(|| SqlError::Parse {
detail: format!(
"UPDATE ... FROM requires a WHERE clause equi-join predicate of the form \
`{target_ref}.col = {source_ref}.col`; none found"
),
})?;
conjuncts.remove(join_idx);
let target_filters = strip_and_convert_filters(conjuncts, target_ref)?;
Ok((target_col, source_col, target_filters))
}
fn try_equijoin_pair(
expr: &ast::Expr,
target_ref: &str,
source_ref: &str,
) -> Option<(String, String)> {
let ast::Expr::BinaryOp {
left,
op: ast::BinaryOperator::Eq,
right,
} = expr
else {
return None;
};
let lhs = qualified_ident_pair(left);
let rhs = qualified_ident_pair(right);
match (lhs, rhs) {
(Some((lt, lc)), Some((rt, rc))) => {
if lt == target_ref && rt == source_ref {
Some((lc, rc))
} else if lt == source_ref && rt == target_ref {
Some((rc, lc))
} else {
None
}
}
(Some((t, c)), None) if t == source_ref => {
if let ast::Expr::Identifier(ident) = right.as_ref() {
Some((normalize_ident(ident), c))
} else {
None
}
}
(None, Some((t, c))) if t == source_ref => {
if let ast::Expr::Identifier(ident) = left.as_ref() {
Some((normalize_ident(ident), c))
} else {
None
}
}
_ => None,
}
}
fn convert_assignments(assignments: &[ast::Assignment]) -> Result<Vec<(String, SqlExpr)>> {
assignments
.iter()
.map(|a| {
let col = match &a.target {
ast::AssignmentTarget::ColumnName(name) => {
if name.0.len() > 1 {
return Err(SqlError::Unsupported {
detail: format!(
"qualified column name in SET target: {SCHEMA_QUALIFIED_MSG}"
),
});
}
normalize_object_name_checked(name)?
}
ast::AssignmentTarget::Tuple(names) => names
.iter()
.map(normalize_object_name_checked)
.collect::<Result<Vec<_>>>()?
.join(","),
};
let val = convert_expr(&a.value)?;
Ok((col, val))
})
.collect()
}
pub fn plan_delete(stmt: &ast::Statement, catalog: &dyn SqlCatalog) -> Result<Vec<SqlPlan>> {
let ast::Statement::Delete(delete) = stmt else {
return Err(SqlError::Parse {
detail: "expected DELETE statement".into(),
});
};
let from_tables = match &delete.from {
ast::FromTable::WithFromKeyword(tables) | ast::FromTable::WithoutKeyword(tables) => tables,
};
let table_name =
extract_table_name_from_table_with_joins(from_tables.first().ok_or_else(|| {
SqlError::Parse {
detail: "DELETE requires a FROM table".into(),
}
})?)?;
let info = catalog
.get_collection(DatabaseId::DEFAULT, &table_name)?
.ok_or_else(|| SqlError::UnknownTable {
name: table_name.clone(),
})?;
let filters = match &delete.selection {
Some(expr) => super::super::select::convert_where_to_filters(expr)?,
None => Vec::new(),
};
let target_keys = extract_point_keys(delete.selection.as_ref(), &info);
let rules = engine_rules::resolve_engine_rules(info.engine);
rules.plan_delete(DeleteParams {
collection: table_name,
filters,
target_keys,
})
}
pub fn plan_truncate_stmt(stmt: &ast::Statement) -> Result<Vec<SqlPlan>> {
let ast::Statement::Truncate(truncate) = stmt else {
return Err(SqlError::Parse {
detail: "expected TRUNCATE statement".into(),
});
};
let restart_identity = matches!(
truncate.identity,
Some(sqlparser::ast::TruncateIdentityOption::Restart)
);
truncate
.table_names
.iter()
.map(|t| {
Ok(SqlPlan::Truncate {
collection: normalize_object_name_checked(&t.name)?,
restart_identity,
})
})
.collect()
}