use std::collections::{BTreeMap, HashSet};
use narwhal_core::{Column, Row, Value};
use narwhal_sql::Dialect;
use crate::cell_edit::{placeholder, quote_ident, quote_qualified};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TableId {
pub schema: String,
pub table: String,
}
impl TableId {
pub fn new(schema: impl Into<String>, table: impl Into<String>) -> Self {
Self {
schema: schema.into(),
table: table.into(),
}
}
pub fn display(&self) -> String {
if self.schema.is_empty() {
self.table.clone()
} else {
format!("{}.{}", self.schema, self.table)
}
}
}
#[derive(Debug, Clone)]
pub enum PendingMutation {
Insert {
target: TableId,
columns: Vec<Column>,
values: BTreeMap<String, Value>,
},
Update {
target: TableId,
columns: Vec<Column>,
column_name: String,
old_value: Value,
new_value: Value,
pk_values: BTreeMap<String, Value>,
},
Delete {
target: TableId,
columns: Vec<Column>,
pk_values: BTreeMap<String, Value>,
snapshot: Row,
column_order: Vec<String>,
},
}
impl PendingMutation {
pub const fn target(&self) -> &TableId {
match self {
Self::Insert { target, .. }
| Self::Update { target, .. }
| Self::Delete { target, .. } => target,
}
}
pub fn summary(&self) -> String {
match self {
Self::Insert { target, values, .. } => {
let cols: Vec<String> = values.keys().cloned().collect();
if cols.is_empty() {
format!("INSERT INTO {} (defaults)", target.display())
} else {
format!("INSERT INTO {} ({})", target.display(), cols.join(", "))
}
}
Self::Update {
target,
column_name,
old_value,
new_value,
..
} => format!(
"UPDATE {} SET {column_name} = {} (was {})",
target.display(),
new_value.render(),
old_value.render(),
),
Self::Delete {
target, pk_values, ..
} => {
let parts: Vec<String> = pk_values
.iter()
.map(|(k, v)| format!("{k}={}", v.render()))
.collect();
format!(
"DELETE FROM {} WHERE {}",
target.display(),
parts.join(" AND ")
)
}
}
}
}
#[derive(Debug, Default, Clone)]
pub struct PendingChanges {
mutations: Vec<PendingMutation>,
}
impl PendingChanges {
pub const fn new() -> Self {
Self {
mutations: Vec::new(),
}
}
pub fn push(&mut self, mutation: PendingMutation) {
self.mutations.push(mutation);
}
pub fn pop(&mut self) -> Option<PendingMutation> {
self.mutations.pop()
}
pub fn clear(&mut self) {
self.mutations.clear();
}
pub fn len(&self) -> usize {
self.mutations.len()
}
pub fn is_empty(&self) -> bool {
self.mutations.is_empty()
}
pub fn iter(&self) -> std::slice::Iter<'_, PendingMutation> {
self.mutations.iter()
}
pub fn iter_mut(&mut self) -> std::slice::IterMut<'_, PendingMutation> {
self.mutations.iter_mut()
}
}
impl<'a> IntoIterator for &'a PendingChanges {
type Item = &'a PendingMutation;
type IntoIter = std::slice::Iter<'a, PendingMutation>;
fn into_iter(self) -> Self::IntoIter {
self.mutations.iter()
}
}
impl<'a> IntoIterator for &'a mut PendingChanges {
type Item = &'a mut PendingMutation;
type IntoIter = std::slice::IterMut<'a, PendingMutation>;
fn into_iter(self) -> Self::IntoIter {
self.mutations.iter_mut()
}
}
impl PendingChanges {
pub fn as_slice(&self) -> &[PendingMutation] {
&self.mutations
}
pub fn get_mut(&mut self, index: usize) -> Option<&mut PendingMutation> {
self.mutations.get_mut(index)
}
pub fn compile_all(&self, dialect: Dialect) -> Result<Vec<CompiledMutation>, CompileError> {
let mut out = Vec::with_capacity(self.mutations.len());
for (idx, m) in self.mutations.iter().enumerate() {
let compiled =
compile(m, dialect).map_err(|reason| CompileError { index: idx, reason })?;
out.push(compiled);
}
Ok(out)
}
}
#[derive(Debug, Clone)]
pub struct CompiledMutation {
pub sql: String,
pub params: Vec<Value>,
pub expects: ExpectedRows,
pub summary: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpectedRows {
Insert,
Exactly(u64),
}
#[derive(Debug, Clone)]
pub struct CompileError {
pub index: usize,
pub reason: String,
}
impl std::fmt::Display for CompileError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "mutation #{}: {}", self.index + 1, self.reason)
}
}
impl std::error::Error for CompileError {}
fn compile(m: &PendingMutation, dialect: Dialect) -> Result<CompiledMutation, String> {
match m {
PendingMutation::Insert {
target,
columns,
values,
} => compile_insert(target, columns, values, dialect),
PendingMutation::Update {
target,
column_name,
old_value,
new_value,
pk_values,
..
} => compile_update(
target,
column_name,
old_value,
new_value,
pk_values,
dialect,
),
PendingMutation::Delete {
target, pk_values, ..
} => compile_delete(target, pk_values, dialect),
}
}
fn compile_insert(
target: &TableId,
columns: &[Column],
values: &BTreeMap<String, Value>,
dialect: Dialect,
) -> Result<CompiledMutation, String> {
if values.is_empty() {
let sql = format!(
"INSERT INTO {} DEFAULT VALUES",
quote_qualified(&target.schema, &target.table, dialect),
);
return Ok(CompiledMutation {
sql,
params: Vec::new(),
expects: ExpectedRows::Insert,
summary: format!("INSERT INTO {} DEFAULT VALUES", target.display()),
});
}
let known: HashSet<&str> = columns.iter().map(|c| c.name.as_str()).collect();
for col in values.keys() {
if !known.contains(col.as_str()) {
return Err(format!(
"column '{col}' not declared on {}",
target.display()
));
}
}
let mut col_names = Vec::with_capacity(values.len());
let mut placeholders = Vec::with_capacity(values.len());
let mut params = Vec::with_capacity(values.len());
for (i, (col, value)) in values.iter().enumerate() {
col_names.push(quote_ident(col, dialect));
placeholders.push(placeholder(i + 1, dialect));
params.push(value.clone());
}
let sql = format!(
"INSERT INTO {} ({}) VALUES ({})",
quote_qualified(&target.schema, &target.table, dialect),
col_names.join(", "),
placeholders.join(", "),
);
Ok(CompiledMutation {
sql,
params,
expects: ExpectedRows::Insert,
summary: format!(
"INSERT INTO {} ({})",
target.display(),
values.keys().cloned().collect::<Vec<_>>().join(", ")
),
})
}
fn compile_update(
target: &TableId,
column_name: &str,
old_value: &Value,
new_value: &Value,
pk_values: &BTreeMap<String, Value>,
dialect: Dialect,
) -> Result<CompiledMutation, String> {
if pk_values.is_empty() {
return Err(format!(
"{}: no primary key recorded, refusing UPDATE",
target.display()
));
}
let mut params = Vec::with_capacity(2 + pk_values.len());
params.push(new_value.clone());
let set_placeholder = placeholder(1, dialect);
let mut where_parts = Vec::with_capacity(pk_values.len() + 1);
for (col, val) in pk_values {
if val.is_null() {
return Err(format!(
"PK column '{col}' is NULL on {}; refusing UPDATE",
target.display()
));
}
let ph = placeholder(params.len() + 1, dialect);
where_parts.push(format!("{} = {ph}", quote_ident(col, dialect)));
params.push(val.clone());
}
if old_value.is_null() {
where_parts.push(format!("{} IS NULL", quote_ident(column_name, dialect)));
} else {
let ph = placeholder(params.len() + 1, dialect);
where_parts.push(format!("{} = {ph}", quote_ident(column_name, dialect)));
params.push(old_value.clone());
}
let sql = format!(
"UPDATE {} SET {} = {set_placeholder} WHERE {}",
quote_qualified(&target.schema, &target.table, dialect),
quote_ident(column_name, dialect),
where_parts.join(" AND "),
);
Ok(CompiledMutation {
sql,
params,
expects: ExpectedRows::Exactly(1),
summary: format!(
"UPDATE {} SET {column_name} = {} (was {})",
target.display(),
new_value.render(),
old_value.render(),
),
})
}
fn compile_delete(
target: &TableId,
pk_values: &BTreeMap<String, Value>,
dialect: Dialect,
) -> Result<CompiledMutation, String> {
if pk_values.is_empty() {
return Err(format!(
"{}: no primary key recorded, refusing DELETE",
target.display()
));
}
let mut params = Vec::with_capacity(pk_values.len());
let mut where_parts = Vec::with_capacity(pk_values.len());
for (col, val) in pk_values {
if val.is_null() {
return Err(format!(
"PK column '{col}' is NULL on {}; refusing DELETE",
target.display()
));
}
let ph = placeholder(params.len() + 1, dialect);
where_parts.push(format!("{} = {ph}", quote_ident(col, dialect)));
params.push(val.clone());
}
let sql = format!(
"DELETE FROM {} WHERE {}",
quote_qualified(&target.schema, &target.table, dialect),
where_parts.join(" AND "),
);
let summary = {
let parts: Vec<String> = pk_values
.iter()
.map(|(k, v)| format!("{k}={}", v.render()))
.collect();
format!(
"DELETE FROM {} WHERE {}",
target.display(),
parts.join(" AND ")
)
};
Ok(CompiledMutation {
sql,
params,
expects: ExpectedRows::Exactly(1),
summary,
})
}
#[cfg(test)]
mod tests {
use super::*;
fn pk(name: &str) -> Column {
Column {
name: name.into(),
data_type: "integer".into(),
nullable: false,
primary_key: true,
default: None,
}
}
fn col(name: &str) -> Column {
Column {
name: name.into(),
data_type: "text".into(),
nullable: true,
primary_key: false,
default: None,
}
}
fn target() -> TableId {
TableId::new("public", "items")
}
#[test]
fn insert_with_explicit_columns_postgres() {
let mut values = BTreeMap::new();
values.insert("label".into(), Value::String("hi".into()));
let m = PendingMutation::Insert {
target: target(),
columns: vec![pk("id"), col("label")],
values,
};
let compiled = compile(&m, Dialect::Postgres).unwrap();
assert_eq!(
compiled.sql,
"INSERT INTO \"public\".\"items\" (\"label\") VALUES ($1)"
);
assert_eq!(compiled.params.len(), 1);
assert_eq!(compiled.expects, ExpectedRows::Insert);
}
#[test]
fn insert_with_no_values_uses_default_values() {
let m = PendingMutation::Insert {
target: target(),
columns: vec![pk("id"), col("label")],
values: BTreeMap::new(),
};
let compiled = compile(&m, Dialect::Postgres).unwrap();
assert!(compiled.sql.contains("DEFAULT VALUES"));
assert!(compiled.params.is_empty());
}
#[test]
fn insert_rejects_unknown_column() {
let mut values = BTreeMap::new();
values.insert("nonsense".into(), Value::String("x".into()));
let m = PendingMutation::Insert {
target: target(),
columns: vec![pk("id"), col("label")],
values,
};
let err = compile(&m, Dialect::Postgres).unwrap_err();
assert!(err.contains("nonsense"));
}
#[test]
fn update_uses_optimistic_old_value_in_where() {
let mut pk_values = BTreeMap::new();
pk_values.insert("id".into(), Value::Int(7));
let m = PendingMutation::Update {
target: target(),
columns: vec![pk("id"), col("label")],
column_name: "label".into(),
old_value: Value::String("old".into()),
new_value: Value::String("new".into()),
pk_values,
};
let compiled = compile(&m, Dialect::Postgres).unwrap();
assert_eq!(
compiled.sql,
"UPDATE \"public\".\"items\" SET \"label\" = $1 WHERE \"id\" = $2 AND \"label\" = $3"
);
assert_eq!(compiled.params.len(), 3);
assert_eq!(compiled.expects, ExpectedRows::Exactly(1));
}
#[test]
fn update_uses_is_null_when_old_value_is_null() {
let mut pk_values = BTreeMap::new();
pk_values.insert("id".into(), Value::Int(7));
let m = PendingMutation::Update {
target: target(),
columns: vec![pk("id"), col("label")],
column_name: "label".into(),
old_value: Value::Null,
new_value: Value::String("x".into()),
pk_values,
};
let compiled = compile(&m, Dialect::Postgres).unwrap();
assert!(compiled.sql.contains("\"label\" IS NULL"));
assert_eq!(compiled.params.len(), 2);
}
#[test]
fn delete_with_composite_pk_mysql() {
let mut pk_values = BTreeMap::new();
pk_values.insert("a".into(), Value::Int(1));
pk_values.insert("b".into(), Value::Int(2));
let m = PendingMutation::Delete {
target: TableId::new("", "t"),
columns: vec![pk("a"), pk("b"), col("c")],
pk_values,
snapshot: Row(vec![
Value::Int(1),
Value::Int(2),
Value::String("x".into()),
]),
column_order: vec!["a".into(), "b".into(), "c".into()],
};
let compiled = compile(&m, Dialect::MySql).unwrap();
assert_eq!(compiled.sql, "DELETE FROM `t` WHERE `a` = ? AND `b` = ?");
assert_eq!(compiled.params.len(), 2);
assert_eq!(compiled.expects, ExpectedRows::Exactly(1));
}
#[test]
fn delete_rejects_null_pk_value() {
let mut pk_values = BTreeMap::new();
pk_values.insert("id".into(), Value::Null);
let m = PendingMutation::Delete {
target: target(),
columns: vec![pk("id")],
pk_values,
snapshot: Row(vec![Value::Null]),
column_order: vec!["id".into()],
};
let err = compile(&m, Dialect::Sqlite).unwrap_err();
assert!(err.contains("NULL"));
}
#[test]
fn delete_rejects_empty_pk() {
let m = PendingMutation::Delete {
target: target(),
columns: vec![col("a")],
pk_values: BTreeMap::new(),
snapshot: Row(vec![]),
column_order: vec![],
};
let err = compile(&m, Dialect::Sqlite).unwrap_err();
assert!(err.contains("primary key"));
}
#[test]
fn compile_all_preserves_order_and_reports_offending_index() {
let mut queue = PendingChanges::new();
let mut values = BTreeMap::new();
values.insert("label".into(), Value::String("ok".into()));
queue.push(PendingMutation::Insert {
target: target(),
columns: vec![pk("id"), col("label")],
values,
});
queue.push(PendingMutation::Delete {
target: target(),
columns: vec![col("a")],
pk_values: BTreeMap::new(),
snapshot: Row(vec![]),
column_order: vec![],
});
let err = queue.compile_all(Dialect::Postgres).unwrap_err();
assert_eq!(err.index, 1, "second mutation should be flagged");
assert!(err.to_string().contains("#2"));
}
}