use std::marker::PhantomData;
use crate::error::QueryResult;
use crate::filter::{Filter, FilterValue};
use crate::inputs::WriteOp;
use crate::nested::NestedWriteOp;
use crate::traits::{Model, QueryEngine};
use crate::types::Select;
pub(crate) fn extract_pk_from_filter(filter: &Filter, pk_col: &str) -> Option<FilterValue> {
match filter {
Filter::Equals(name, value) if name.as_ref() == pk_col => Some(value.clone()),
_ => None,
}
}
pub struct UpdateOperation<E: QueryEngine, M: Model> {
engine: E,
filter: Filter,
updates: Vec<(String, WriteOp)>,
select: Select,
nested: Vec<NestedWriteOp>,
_model: PhantomData<M>,
}
impl<E: QueryEngine, M: Model + crate::row::FromRow> UpdateOperation<E, M> {
pub fn new(engine: E) -> Self {
Self {
engine,
filter: Filter::None,
updates: Vec::new(),
select: Select::All,
nested: Vec::new(),
_model: PhantomData,
}
}
pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
let new_filter = filter.into();
self.filter = self.filter.and_then(new_filter);
self
}
pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
self.updates
.push((column.into(), WriteOp::Set(value.into())));
self
}
pub fn set_many(
mut self,
values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
) -> Self {
for (col, val) in values {
self.updates.push((col.into(), WriteOp::Set(val.into())));
}
self
}
pub fn increment(mut self, column: impl Into<String>, amount: i64) -> Self {
self.updates
.push((column.into(), WriteOp::Increment(FilterValue::Int(amount))));
self
}
pub fn set_op(mut self, column: impl Into<String>, op: WriteOp) -> Self {
self.updates.push((column.into(), op));
self
}
pub fn select(mut self, select: impl Into<Select>) -> Self {
self.select = select.into();
self
}
pub fn build_sql(
&self,
dialect: &dyn crate::dialect::SqlDialect,
) -> (String, Vec<FilterValue>) {
let mut sql = String::new();
let mut params = Vec::new();
let mut param_idx = 1;
sql.push_str("UPDATE ");
sql.push_str(M::TABLE_NAME);
sql.push_str(" SET ");
let set_parts: Vec<String> = self
.updates
.iter()
.map(|(col, op)| {
let placeholder = dialect.placeholder(param_idx);
let (fragment, value) = op.to_set_fragment(col, &placeholder);
if let Some(v) = value {
params.push(v);
param_idx += 1;
}
fragment
})
.collect();
sql.push_str(&set_parts.join(", "));
if !self.filter.is_none() {
let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
sql.push_str(" WHERE ");
sql.push_str(&where_sql);
params.extend(where_params);
}
sql.push_str(&dialect.returning_clause(&self.select.to_sql()));
(sql, params)
}
pub fn with(mut self, nw: NestedWriteOp) -> Self
where
E: crate::capabilities::SupportsNestedWrites,
{
self.nested.push(nw);
self
}
pub async fn exec(self) -> QueryResult<Vec<M>>
where
M: Send + 'static,
{
if self.nested.is_empty() {
let dialect = self.engine.dialect();
let (sql, params) = self.build_sql(dialect);
return self.engine.execute_update::<M>(&sql, params).await;
}
let parent_pk =
extract_pk_from_filter(&self.filter, M::PRIMARY_KEY[0]).ok_or_else(|| {
crate::error::QueryError::invalid_input(
"where",
"nested writes inside `update!` require the `where:` clause to equal-match \
the primary-key column",
)
.with_help(format!(
"expected `where: {{ {pk}: <value> }}` on `{table}` — non-PK unique \
columns are not yet supported for nested writes inside update!. \
Lift this restriction by running the nested ops in a separate operation \
after looking up the row's PK.",
pk = M::PRIMARY_KEY[0],
table = M::TABLE_NAME,
))
})?;
let UpdateOperation {
engine,
filter,
updates,
select,
nested,
_model,
} = self;
engine
.transaction(move |tx| async move {
let dialect = tx.dialect();
let (sql, params) = Self::build_sql_parts(&filter, &updates, &select, dialect);
let parent: Vec<M> = tx.execute_update::<M>(&sql, params).await?;
let mut idx = 0;
while idx < nested.len() {
if let NestedWriteOp::Connect {
target_table: run_table,
foreign_key: run_fk,
target_pk: run_target_pk,
..
} = &nested[idx]
{
let run_table = *run_table;
let run_fk = *run_fk;
let run_target_pk = *run_target_pk;
let mut end = idx + 1;
while end < nested.len() {
match &nested[end] {
NestedWriteOp::Connect {
target_table,
foreign_key,
target_pk,
..
} if *target_table == run_table
&& *foreign_key == run_fk
&& *target_pk == run_target_pk =>
{
end += 1;
}
_ => break,
}
}
if end - idx == 1 {
let op = nested[idx].clone();
op.execute(&tx, &parent_pk).await?;
} else {
let expected = (end - idx) as u64;
let mut pks: Vec<FilterValue> = Vec::with_capacity(end - idx + 1);
pks.push(parent_pk.clone());
for op in &nested[idx..end] {
if let NestedWriteOp::Connect { pk, .. } = op {
pks.push(pk.clone());
}
}
let placeholders: Vec<String> =
(2..=pks.len()).map(|i| dialect.placeholder(i)).collect();
let sql = format!(
"UPDATE {} SET {} = {} WHERE {} IN ({})",
dialect.quote_ident(run_table),
dialect.quote_ident(run_fk),
dialect.placeholder(1),
dialect.quote_ident(run_target_pk),
placeholders.join(", "),
);
let affected = tx.execute_raw(&sql, pks).await?;
if affected != expected {
return Err(crate::error::QueryError::not_found(run_table)
.with_context("Nested Connect batch")
.with_help(format!(
"Expected {} matching rows but UPDATE affected {}",
expected, affected
)));
}
}
idx = end;
} else {
let op = nested[idx].clone();
op.execute(&tx, &parent_pk).await?;
idx += 1;
}
}
Ok(parent)
})
.await
}
fn build_sql_parts(
filter: &Filter,
updates: &[(String, WriteOp)],
select: &Select,
dialect: &dyn crate::dialect::SqlDialect,
) -> (String, Vec<FilterValue>) {
let mut sql = String::new();
let mut params = Vec::new();
let mut param_idx = 1;
sql.push_str("UPDATE ");
sql.push_str(M::TABLE_NAME);
sql.push_str(" SET ");
let set_parts: Vec<String> = updates
.iter()
.map(|(col, op)| {
let placeholder = dialect.placeholder(param_idx);
let (fragment, value) = op.to_set_fragment(col, &placeholder);
if let Some(v) = value {
params.push(v);
param_idx += 1;
}
fragment
})
.collect();
sql.push_str(&set_parts.join(", "));
if !filter.is_none() {
let (where_sql, where_params) = filter.to_sql(param_idx - 1, dialect);
sql.push_str(" WHERE ");
sql.push_str(&where_sql);
params.extend(where_params);
}
sql.push_str(&dialect.returning_clause(&select.to_sql()));
(sql, params)
}
pub async fn exec_one(self) -> QueryResult<M>
where
M: Send + 'static,
{
let dialect = self.engine.dialect();
let (sql, params) = self.build_sql(dialect);
self.engine.query_one::<M>(&sql, params).await
}
pub fn with_where_input<W: crate::inputs::WhereUniqueInput<Model = M>>(mut self, w: W) -> Self {
let f = w.into_ir();
self.filter = self.filter.and_then(f);
self
}
pub fn with_select_input<S: crate::inputs::SelectInput<Model = M>>(mut self, s: S) -> Self {
self.select = s.into_ir();
self
}
pub fn with_update_input<I>(mut self, input: I) -> Self
where
I: crate::inputs::UpdateInput<Model = M, Data = crate::inputs::UpdatePayload>,
{
let data: crate::inputs::UpdatePayload = input.into_ir();
for (col, op) in data {
self.updates.push((col, op));
}
self
}
#[doc(hidden)]
pub fn filter_for_test(&self) -> &Filter {
&self.filter
}
}
pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
engine: E,
filter: Filter,
updates: Vec<(String, WriteOp)>,
_model: PhantomData<M>,
}
impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
pub fn new(engine: E) -> Self {
Self {
engine,
filter: Filter::None,
updates: Vec::new(),
_model: PhantomData,
}
}
pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
let new_filter = filter.into();
self.filter = self.filter.and_then(new_filter);
self
}
pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
self.updates
.push((column.into(), WriteOp::Set(value.into())));
self
}
pub fn set_op(mut self, column: impl Into<String>, op: WriteOp) -> Self {
self.updates.push((column.into(), op));
self
}
pub fn with_where_input<W: crate::inputs::WhereInput<Model = M>>(mut self, w: W) -> Self {
let f = w.into_ir();
self.filter = self.filter.and_then(f);
self
}
pub fn with_update_input<I>(mut self, input: I) -> Self
where
I: crate::inputs::UpdateInput<Model = M, Data = crate::inputs::UpdatePayload>,
{
let data: crate::inputs::UpdatePayload = input.into_ir();
for (col, op) in data {
self.updates.push((col, op));
}
self
}
pub fn build_sql(
&self,
dialect: &dyn crate::dialect::SqlDialect,
) -> (String, Vec<FilterValue>) {
let mut sql = String::new();
let mut params = Vec::new();
let mut param_idx = 1;
sql.push_str("UPDATE ");
sql.push_str(M::TABLE_NAME);
sql.push_str(" SET ");
let set_parts: Vec<String> = self
.updates
.iter()
.map(|(col, op)| {
let placeholder = dialect.placeholder(param_idx);
let (fragment, value) = op.to_set_fragment(col, &placeholder);
if let Some(v) = value {
params.push(v);
param_idx += 1;
}
fragment
})
.collect();
sql.push_str(&set_parts.join(", "));
if !self.filter.is_none() {
let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
sql.push_str(" WHERE ");
sql.push_str(&where_sql);
params.extend(where_params);
}
(sql, params)
}
pub async fn exec(self) -> QueryResult<u64> {
let dialect = self.engine.dialect();
let (sql, params) = self.build_sql(dialect);
self.engine.execute_raw(&sql, params).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::QueryError;
use crate::types::Select;
struct TestModel;
impl Model for TestModel {
const MODEL_NAME: &'static str = "TestModel";
const TABLE_NAME: &'static str = "test_models";
const PRIMARY_KEY: &'static [&'static str] = &["id"];
const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
}
impl crate::row::FromRow for TestModel {
fn from_row(_row: &impl crate::row::RowRef) -> Result<Self, crate::row::RowError> {
Ok(TestModel)
}
}
#[derive(Clone)]
struct MockEngine {
return_count: u64,
}
impl MockEngine {
fn new() -> Self {
Self { return_count: 0 }
}
fn with_count(count: u64) -> Self {
Self {
return_count: count,
}
}
}
impl QueryEngine for MockEngine {
fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
&crate::dialect::Postgres
}
fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
Box::pin(async { Ok(Vec::new()) })
}
fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
Box::pin(async { Err(QueryError::not_found("test")) })
}
fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
Box::pin(async { Ok(None) })
}
fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
Box::pin(async { Err(QueryError::not_found("test")) })
}
fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
Box::pin(async { Ok(Vec::new()) })
}
fn execute_delete(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
Box::pin(async { Ok(0) })
}
fn execute_raw(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
let count = self.return_count;
Box::pin(async move { Ok(count) })
}
fn count(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
Box::pin(async { Ok(0) })
}
}
#[test]
fn test_update_new() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("UPDATE test_models SET"));
assert!(sql.contains("RETURNING *"));
assert!(params.is_empty());
}
#[test]
fn test_update_basic() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
.set("name", "Updated");
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("UPDATE test_models SET"));
assert!(sql.contains("name = $1"));
assert!(sql.contains("WHERE"));
assert!(sql.contains("RETURNING *"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_update_many_fields() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("name", "Updated")
.set("email", "updated@example.com");
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("name = $1"));
assert!(sql.contains("email = $2"));
assert_eq!(params.len(), 2);
}
#[test]
fn test_update_with_set_many() {
let updates = vec![
("name", FilterValue::String("Alice".to_string())),
("email", FilterValue::String("alice@test.com".to_string())),
("age", FilterValue::Int(30)),
];
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(updates);
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("name = $1"));
assert!(sql.contains("email = $2"));
assert!(sql.contains("age = $3"));
assert_eq!(params.len(), 3);
}
#[test]
fn test_update_increment() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.increment("counter", 5);
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(
sql.contains("counter = counter + $1"),
"expected `counter = counter + $1`, got: {sql}"
);
assert_eq!(params.len(), 1);
assert_eq!(params[0], FilterValue::Int(5));
}
#[test]
fn test_update_with_select() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("name", "Updated")
.select(Select::fields(["id", "name"]));
let (sql, _) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("RETURNING id, name"));
}
#[test]
fn test_update_with_complex_filter() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.r#where(Filter::Equals(
"status".into(),
FilterValue::String("active".to_string()),
))
.r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
.set("verified", FilterValue::Bool(true));
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("WHERE"));
assert!(sql.contains("AND"));
assert_eq!(params.len(), 3); }
#[test]
fn test_update_without_filter() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("status", "updated");
let (sql, _) = op.build_sql(&crate::dialect::Postgres);
assert!(!sql.contains("WHERE"));
assert!(sql.contains("UPDATE test_models SET"));
}
#[test]
fn test_update_with_null_value() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("deleted_at", FilterValue::Null);
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("deleted_at = $1"));
assert_eq!(params.len(), 1);
assert_eq!(params[0], FilterValue::Null);
}
#[test]
fn test_update_with_boolean() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("active", FilterValue::Bool(true))
.set("verified", FilterValue::Bool(false));
let (_sql, params) = op.build_sql(&crate::dialect::Postgres);
assert_eq!(params.len(), 2);
assert_eq!(params[0], FilterValue::Bool(true));
assert_eq!(params[1], FilterValue::Bool(false));
}
#[tokio::test]
async fn test_update_exec() {
let op =
UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Updated");
let result = op.exec().await;
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[tokio::test]
async fn test_update_exec_one() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
.set("name", "Updated");
let result = op.exec_one().await;
assert!(result.is_err()); }
#[test]
fn test_update_many_new() {
let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("UPDATE test_models SET"));
assert!(!sql.contains("RETURNING")); assert!(params.is_empty());
}
#[test]
fn test_update_many_basic() {
let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.r#where(Filter::In(
"id".into(),
vec![
FilterValue::Int(1),
FilterValue::Int(2),
FilterValue::Int(3),
],
))
.set("status", "processed");
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("UPDATE test_models SET"));
assert!(sql.contains("status = $1"));
assert!(sql.contains("WHERE"));
assert!(sql.contains("IN"));
assert_eq!(params.len(), 4); }
#[test]
fn test_update_many_with_multiple_conditions() {
let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.r#where(Filter::Equals(
"department".into(),
FilterValue::String("engineering".to_string()),
))
.r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
.set("reviewed", FilterValue::Bool(true));
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("AND"));
assert_eq!(params.len(), 3);
}
#[test]
fn test_update_many_without_where() {
let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("reset_password", FilterValue::Bool(true));
let (sql, _) = op.build_sql(&crate::dialect::Postgres);
assert!(!sql.contains("WHERE"));
}
#[tokio::test]
async fn test_update_many_exec() {
let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
.set("status", "updated");
let result = op.exec().await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 5);
}
#[test]
fn test_update_param_ordering() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("field1", "value1")
.set("field2", "value2")
.r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("field1 = $1"));
assert!(sql.contains("field2 = $2"));
assert!(sql.contains(r#""id" = $3"#));
assert_eq!(params.len(), 3);
}
#[test]
fn test_update_many_param_ordering() {
let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("field1", "value1")
.r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("field1 = $1"));
assert!(sql.contains(r#""id" = $2"#));
assert_eq!(params.len(), 2);
}
#[test]
fn test_update_with_float_value() {
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("price", FilterValue::Float(99.99));
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("price = $1"));
assert_eq!(params.len(), 1);
}
#[test]
fn test_update_with_json_value() {
let json_value = serde_json::json!({"key": "value"});
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.set("metadata", FilterValue::Json(json_value.clone()));
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("metadata = $1"));
assert_eq!(params[0], FilterValue::Json(json_value));
}
struct MockUpdateInput(Vec<(String, WriteOp)>);
impl crate::inputs::UpdateInput for MockUpdateInput {
type Model = TestModel;
type Data = crate::inputs::UpdatePayload;
fn into_ir(self) -> Self::Data {
self.0
}
}
#[test]
fn with_update_input_appends_set_ops() {
let input = MockUpdateInput(vec![
(
"name".into(),
WriteOp::Set(FilterValue::String("Bob".into())),
),
("age".into(), WriteOp::Increment(FilterValue::Int(1))),
]);
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
.with_update_input(input);
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("name = $1"), "got: {sql}");
assert!(sql.contains("age = age + $2"), "got: {sql}");
assert_eq!(params.len(), 3);
assert_eq!(params[0], FilterValue::String("Bob".into()));
assert_eq!(params[1], FilterValue::Int(1));
}
#[test]
fn with_update_input_unset_emits_null_no_param() {
let input = MockUpdateInput(vec![("nickname".into(), WriteOp::Unset)]);
let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
.with_update_input(input);
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("nickname = NULL"), "got: {sql}");
assert!(params.is_empty(), "expected no params, got: {params:?}");
}
#[test]
fn update_many_with_update_input_appends() {
let input = MockUpdateInput(vec![(
"name".into(),
WriteOp::Set(FilterValue::String("Bob".into())),
)]);
let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
.r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
.with_update_input(input);
let (sql, params) = op.build_sql(&crate::dialect::Postgres);
assert!(sql.contains("UPDATE test_models SET"));
assert!(sql.contains("name = $1"), "got: {sql}");
assert!(sql.contains("WHERE"));
assert_eq!(params.len(), 2);
}
use std::sync::{Arc, Mutex};
type StatementLog = Arc<Mutex<Vec<(String, Vec<FilterValue>)>>>;
#[derive(Clone)]
struct RecordingEngine {
recorded: StatementLog,
affected: Arc<Mutex<Vec<u64>>>,
}
impl RecordingEngine {
fn new() -> Self {
Self {
recorded: Arc::new(Mutex::new(Vec::new())),
affected: Arc::new(Mutex::new(Vec::new())),
}
}
fn statements(&self) -> Vec<(String, Vec<FilterValue>)> {
self.recorded.lock().unwrap().clone()
}
}
impl crate::capabilities::SupportsNestedWrites for RecordingEngine {}
impl QueryEngine for RecordingEngine {
fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
&crate::dialect::Postgres
}
fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
Box::pin(async { Ok(Vec::new()) })
}
fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
Box::pin(async { Err(QueryError::not_found("test")) })
}
fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
Box::pin(async { Ok(None) })
}
fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
Box::pin(async { Err(QueryError::not_found("test")) })
}
fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
let recorded = self.recorded.clone();
let sql = sql.to_string();
Box::pin(async move {
recorded.lock().unwrap().push((sql, params));
Ok(Vec::new())
})
}
fn execute_delete(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
Box::pin(async { Ok(0) })
}
fn execute_raw(
&self,
sql: &str,
params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
let recorded = self.recorded.clone();
let affected = self.affected.clone();
let sql_string = sql.to_string();
let default = if sql.contains(" IN (") {
(params.len() as u64).saturating_sub(1)
} else {
1
};
Box::pin(async move {
recorded.lock().unwrap().push((sql_string, params));
Ok(affected.lock().unwrap().pop().unwrap_or(default))
})
}
fn count(
&self,
_sql: &str,
_params: Vec<FilterValue>,
) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
Box::pin(async { Ok(0) })
}
}
#[tokio::test]
async fn update_with_nested_create_runs_parent_then_child_insert() {
let engine = RecordingEngine::new();
let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
.r#where(Filter::Equals("id".into(), FilterValue::Int(7)))
.set("name", "Renamed")
.with(NestedWriteOp::Create {
relation: "posts",
target_table: "posts",
foreign_key: "author_id",
payload: vec![vec![("title".into(), FilterValue::String("p1".into()))]],
});
let _ = op.exec().await.expect("update + nested create");
let stmts = engine.statements();
assert_eq!(
stmts.len(),
2,
"parent UPDATE + nested child INSERT; got {stmts:#?}"
);
assert!(
stmts[0].0.contains("UPDATE test_models"),
"got: {}",
stmts[0].0
);
assert!(stmts[1].0.contains("INSERT INTO"), "got: {}", stmts[1].0);
assert!(stmts[1].0.contains("posts"), "got: {}", stmts[1].0);
assert!(stmts[1].0.contains("author_id"), "got: {}", stmts[1].0);
}
#[tokio::test]
async fn update_with_nested_disconnect_emits_set_null_update() {
let engine = RecordingEngine::new();
let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
.r#where(Filter::Equals("id".into(), FilterValue::Int(7)))
.set("name", "Renamed")
.with(NestedWriteOp::Disconnect {
relation: "posts",
target_table: "posts",
foreign_key: "author_id",
target_pk: "id",
pk: FilterValue::Int(42),
});
let _ = op.exec().await.expect("update + nested disconnect");
let stmts = engine.statements();
assert_eq!(stmts.len(), 2, "got {stmts:#?}");
assert!(
stmts[0].0.contains("UPDATE test_models"),
"got: {}",
stmts[0].0
);
let (sql, params) = &stmts[1];
assert!(sql.contains("UPDATE"), "got: {sql}");
assert!(sql.contains("posts"), "got: {sql}");
assert!(sql.contains("author_id"), "got: {sql}");
assert!(sql.contains("NULL"), "got: {sql}");
assert_eq!(params, &vec![FilterValue::Int(42)]);
}
#[tokio::test]
async fn update_nested_requires_pk_in_where_filter() {
let engine = RecordingEngine::new();
let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
.r#where(Filter::Equals(
"email".into(),
FilterValue::String("a@x.com".into()),
))
.set("name", "Renamed")
.with(NestedWriteOp::Disconnect {
relation: "posts",
target_table: "posts",
foreign_key: "author_id",
target_pk: "id",
pk: FilterValue::Int(42),
});
let result = op.exec().await;
let err = result.err().expect("non-PK where must error");
let msg = err.to_string();
assert!(
msg.contains("primary-key column") || msg.contains("primary key"),
"expected PK-required diagnostic, got: {msg}"
);
assert!(
engine.statements().is_empty(),
"no SQL should run: {:#?}",
engine.statements()
);
}
}