1use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::inputs::WriteOp;
8use crate::nested::NestedWriteOp;
9use crate::traits::{Model, QueryEngine};
10use crate::types::Select;
11
12pub(crate) fn extract_pk_from_filter(filter: &Filter, pk_col: &str) -> Option<FilterValue> {
19 match filter {
20 Filter::Equals(name, value) if name.as_ref() == pk_col => Some(value.clone()),
21 _ => None,
22 }
23}
24
25pub struct UpdateOperation<E: QueryEngine, M: Model> {
39 engine: E,
40 filter: Filter,
41 updates: Vec<(String, WriteOp)>,
42 select: Select,
43 nested: Vec<NestedWriteOp>,
47 _model: PhantomData<M>,
48}
49
50impl<E: QueryEngine, M: Model + crate::row::FromRow> UpdateOperation<E, M> {
51 pub fn new(engine: E) -> Self {
53 Self {
54 engine,
55 filter: Filter::None,
56 updates: Vec::new(),
57 select: Select::All,
58 nested: Vec::new(),
59 _model: PhantomData,
60 }
61 }
62
63 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
65 let new_filter = filter.into();
66 self.filter = self.filter.and_then(new_filter);
67 self
68 }
69
70 pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
72 self.updates
73 .push((column.into(), WriteOp::Set(value.into())));
74 self
75 }
76
77 pub fn set_many(
79 mut self,
80 values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
81 ) -> Self {
82 for (col, val) in values {
83 self.updates.push((col.into(), WriteOp::Set(val.into())));
84 }
85 self
86 }
87
88 pub fn increment(mut self, column: impl Into<String>, amount: i64) -> Self {
90 self.updates
91 .push((column.into(), WriteOp::Increment(FilterValue::Int(amount))));
92 self
93 }
94
95 pub fn set_op(mut self, column: impl Into<String>, op: WriteOp) -> Self {
102 self.updates.push((column.into(), op));
103 self
104 }
105
106 pub fn select(mut self, select: impl Into<Select>) -> Self {
108 self.select = select.into();
109 self
110 }
111
112 pub fn build_sql(
114 &self,
115 dialect: &dyn crate::dialect::SqlDialect,
116 ) -> (String, Vec<FilterValue>) {
117 let mut sql = String::new();
118 let mut params = Vec::new();
119 let mut param_idx = 1;
120
121 sql.push_str("UPDATE ");
123 sql.push_str(M::TABLE_NAME);
124
125 sql.push_str(" SET ");
127 let set_parts: Vec<String> = self
128 .updates
129 .iter()
130 .map(|(col, op)| {
131 let placeholder = dialect.placeholder(param_idx);
132 let (fragment, value) = op.to_set_fragment(col, &placeholder);
133 if let Some(v) = value {
134 params.push(v);
135 param_idx += 1;
136 }
137 fragment
138 })
139 .collect();
140 sql.push_str(&set_parts.join(", "));
141
142 if !self.filter.is_none() {
144 let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
145 sql.push_str(" WHERE ");
146 sql.push_str(&where_sql);
147 params.extend(where_params);
148 }
149
150 sql.push_str(&dialect.returning_clause(&self.select.to_sql()));
152
153 (sql, params)
154 }
155
156 pub fn with(mut self, nw: NestedWriteOp) -> Self
168 where
169 E: crate::capabilities::SupportsNestedWrites,
170 {
171 self.nested.push(nw);
172 self
173 }
174
175 pub async fn exec(self) -> QueryResult<Vec<M>>
177 where
178 M: Send + 'static,
179 {
180 if self.nested.is_empty() {
182 let dialect = self.engine.dialect();
183 let (sql, params) = self.build_sql(dialect);
184 return self.engine.execute_update::<M>(&sql, params).await;
185 }
186
187 let parent_pk =
190 extract_pk_from_filter(&self.filter, M::PRIMARY_KEY[0]).ok_or_else(|| {
191 crate::error::QueryError::invalid_input(
192 "where",
193 "nested writes inside `update!` require the `where:` clause to equal-match \
194 the primary-key column",
195 )
196 .with_help(format!(
197 "expected `where: {{ {pk}: <value> }}` on `{table}` — non-PK unique \
198 columns are not yet supported for nested writes inside update!. \
199 Lift this restriction by running the nested ops in a separate operation \
200 after looking up the row's PK.",
201 pk = M::PRIMARY_KEY[0],
202 table = M::TABLE_NAME,
203 ))
204 })?;
205
206 let UpdateOperation {
207 engine,
208 filter,
209 updates,
210 select,
211 nested,
212 _model,
213 } = self;
214
215 engine
216 .transaction(move |tx| async move {
217 let dialect = tx.dialect();
218 let (sql, params) = Self::build_sql_parts(&filter, &updates, &select, dialect);
219 let parent: Vec<M> = tx.execute_update::<M>(&sql, params).await?;
220
221 let mut idx = 0;
223 while idx < nested.len() {
224 if let NestedWriteOp::Connect {
225 target_table: run_table,
226 foreign_key: run_fk,
227 target_pk: run_target_pk,
228 ..
229 } = &nested[idx]
230 {
231 let run_table = *run_table;
232 let run_fk = *run_fk;
233 let run_target_pk = *run_target_pk;
234 let mut end = idx + 1;
235 while end < nested.len() {
236 match &nested[end] {
237 NestedWriteOp::Connect {
238 target_table,
239 foreign_key,
240 target_pk,
241 ..
242 } if *target_table == run_table
243 && *foreign_key == run_fk
244 && *target_pk == run_target_pk =>
245 {
246 end += 1;
247 }
248 _ => break,
249 }
250 }
251
252 if end - idx == 1 {
253 let op = nested[idx].clone();
254 op.execute(&tx, &parent_pk).await?;
255 } else {
256 let expected = (end - idx) as u64;
257 let mut pks: Vec<FilterValue> = Vec::with_capacity(end - idx + 1);
258 pks.push(parent_pk.clone());
259 for op in &nested[idx..end] {
260 if let NestedWriteOp::Connect { pk, .. } = op {
261 pks.push(pk.clone());
262 }
263 }
264 let placeholders: Vec<String> =
265 (2..=pks.len()).map(|i| dialect.placeholder(i)).collect();
266 let sql = format!(
267 "UPDATE {} SET {} = {} WHERE {} IN ({})",
268 dialect.quote_ident(run_table),
269 dialect.quote_ident(run_fk),
270 dialect.placeholder(1),
271 dialect.quote_ident(run_target_pk),
272 placeholders.join(", "),
273 );
274 let affected = tx.execute_raw(&sql, pks).await?;
275 if affected != expected {
276 return Err(crate::error::QueryError::not_found(run_table)
277 .with_context("Nested Connect batch")
278 .with_help(format!(
279 "Expected {} matching rows but UPDATE affected {}",
280 expected, affected
281 )));
282 }
283 }
284 idx = end;
285 } else {
286 let op = nested[idx].clone();
287 op.execute(&tx, &parent_pk).await?;
288 idx += 1;
289 }
290 }
291 Ok(parent)
292 })
293 .await
294 }
295
296 fn build_sql_parts(
300 filter: &Filter,
301 updates: &[(String, WriteOp)],
302 select: &Select,
303 dialect: &dyn crate::dialect::SqlDialect,
304 ) -> (String, Vec<FilterValue>) {
305 let mut sql = String::new();
306 let mut params = Vec::new();
307 let mut param_idx = 1;
308
309 sql.push_str("UPDATE ");
310 sql.push_str(M::TABLE_NAME);
311
312 sql.push_str(" SET ");
313 let set_parts: Vec<String> = updates
314 .iter()
315 .map(|(col, op)| {
316 let placeholder = dialect.placeholder(param_idx);
317 let (fragment, value) = op.to_set_fragment(col, &placeholder);
318 if let Some(v) = value {
319 params.push(v);
320 param_idx += 1;
321 }
322 fragment
323 })
324 .collect();
325 sql.push_str(&set_parts.join(", "));
326
327 if !filter.is_none() {
328 let (where_sql, where_params) = filter.to_sql(param_idx - 1, dialect);
329 sql.push_str(" WHERE ");
330 sql.push_str(&where_sql);
331 params.extend(where_params);
332 }
333
334 sql.push_str(&dialect.returning_clause(&select.to_sql()));
335
336 (sql, params)
337 }
338
339 pub async fn exec_one(self) -> QueryResult<M>
341 where
342 M: Send + 'static,
343 {
344 let dialect = self.engine.dialect();
345 let (sql, params) = self.build_sql(dialect);
346 self.engine.query_one::<M>(&sql, params).await
347 }
348
349 pub fn with_where_input<W: crate::inputs::WhereUniqueInput<Model = M>>(mut self, w: W) -> Self {
353 let f = w.into_ir();
354 self.filter = self.filter.and_then(f);
355 self
356 }
357
358 pub fn with_select_input<S: crate::inputs::SelectInput<Model = M>>(mut self, s: S) -> Self {
360 self.select = s.into_ir();
361 self
362 }
363
364 pub fn with_update_input<I>(mut self, input: I) -> Self
372 where
373 I: crate::inputs::UpdateInput<Model = M, Data = crate::inputs::UpdatePayload>,
374 {
375 let data: crate::inputs::UpdatePayload = input.into_ir();
376 for (col, op) in data {
377 self.updates.push((col, op));
378 }
379 self
380 }
381
382 #[doc(hidden)]
384 pub fn filter_for_test(&self) -> &Filter {
385 &self.filter
386 }
387}
388
389pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
391 engine: E,
392 filter: Filter,
393 updates: Vec<(String, WriteOp)>,
394 _model: PhantomData<M>,
395}
396
397impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
398 pub fn new(engine: E) -> Self {
400 Self {
401 engine,
402 filter: Filter::None,
403 updates: Vec::new(),
404 _model: PhantomData,
405 }
406 }
407
408 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
410 let new_filter = filter.into();
411 self.filter = self.filter.and_then(new_filter);
412 self
413 }
414
415 pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
417 self.updates
418 .push((column.into(), WriteOp::Set(value.into())));
419 self
420 }
421
422 pub fn set_op(mut self, column: impl Into<String>, op: WriteOp) -> Self {
424 self.updates.push((column.into(), op));
425 self
426 }
427
428 pub fn with_where_input<W: crate::inputs::WhereInput<Model = M>>(mut self, w: W) -> Self {
430 let f = w.into_ir();
431 self.filter = self.filter.and_then(f);
432 self
433 }
434
435 pub fn with_update_input<I>(mut self, input: I) -> Self
441 where
442 I: crate::inputs::UpdateInput<Model = M, Data = crate::inputs::UpdatePayload>,
443 {
444 let data: crate::inputs::UpdatePayload = input.into_ir();
445 for (col, op) in data {
446 self.updates.push((col, op));
447 }
448 self
449 }
450
451 pub fn build_sql(
453 &self,
454 dialect: &dyn crate::dialect::SqlDialect,
455 ) -> (String, Vec<FilterValue>) {
456 let mut sql = String::new();
457 let mut params = Vec::new();
458 let mut param_idx = 1;
459
460 sql.push_str("UPDATE ");
462 sql.push_str(M::TABLE_NAME);
463
464 sql.push_str(" SET ");
466 let set_parts: Vec<String> = self
467 .updates
468 .iter()
469 .map(|(col, op)| {
470 let placeholder = dialect.placeholder(param_idx);
471 let (fragment, value) = op.to_set_fragment(col, &placeholder);
472 if let Some(v) = value {
473 params.push(v);
474 param_idx += 1;
475 }
476 fragment
477 })
478 .collect();
479 sql.push_str(&set_parts.join(", "));
480
481 if !self.filter.is_none() {
483 let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
484 sql.push_str(" WHERE ");
485 sql.push_str(&where_sql);
486 params.extend(where_params);
487 }
488
489 (sql, params)
490 }
491
492 pub async fn exec(self) -> QueryResult<u64> {
494 let dialect = self.engine.dialect();
495 let (sql, params) = self.build_sql(dialect);
496 self.engine.execute_raw(&sql, params).await
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503 use crate::error::QueryError;
504 use crate::types::Select;
505
506 struct TestModel;
507
508 impl Model for TestModel {
509 const MODEL_NAME: &'static str = "TestModel";
510 const TABLE_NAME: &'static str = "test_models";
511 const PRIMARY_KEY: &'static [&'static str] = &["id"];
512 const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
513 }
514
515 impl crate::row::FromRow for TestModel {
516 fn from_row(_row: &impl crate::row::RowRef) -> Result<Self, crate::row::RowError> {
517 Ok(TestModel)
518 }
519 }
520
521 #[derive(Clone)]
522 struct MockEngine {
523 return_count: u64,
524 }
525
526 impl MockEngine {
527 fn new() -> Self {
528 Self { return_count: 0 }
529 }
530
531 fn with_count(count: u64) -> Self {
532 Self {
533 return_count: count,
534 }
535 }
536 }
537
538 impl QueryEngine for MockEngine {
539 fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
540 &crate::dialect::Postgres
541 }
542
543 fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
544 &self,
545 _sql: &str,
546 _params: Vec<FilterValue>,
547 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
548 Box::pin(async { Ok(Vec::new()) })
549 }
550
551 fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
552 &self,
553 _sql: &str,
554 _params: Vec<FilterValue>,
555 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
556 Box::pin(async { Err(QueryError::not_found("test")) })
557 }
558
559 fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
560 &self,
561 _sql: &str,
562 _params: Vec<FilterValue>,
563 ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
564 Box::pin(async { Ok(None) })
565 }
566
567 fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
568 &self,
569 _sql: &str,
570 _params: Vec<FilterValue>,
571 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
572 Box::pin(async { Err(QueryError::not_found("test")) })
573 }
574
575 fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
576 &self,
577 _sql: &str,
578 _params: Vec<FilterValue>,
579 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
580 Box::pin(async { Ok(Vec::new()) })
581 }
582
583 fn execute_delete(
584 &self,
585 _sql: &str,
586 _params: Vec<FilterValue>,
587 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
588 Box::pin(async { Ok(0) })
589 }
590
591 fn execute_raw(
592 &self,
593 _sql: &str,
594 _params: Vec<FilterValue>,
595 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
596 let count = self.return_count;
597 Box::pin(async move { Ok(count) })
598 }
599
600 fn count(
601 &self,
602 _sql: &str,
603 _params: Vec<FilterValue>,
604 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
605 Box::pin(async { Ok(0) })
606 }
607 }
608
609 #[test]
612 fn test_update_new() {
613 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
614 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
615
616 assert!(sql.contains("UPDATE test_models SET"));
617 assert!(sql.contains("RETURNING *"));
618 assert!(params.is_empty());
619 }
620
621 #[test]
622 fn test_update_basic() {
623 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
624 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
625 .set("name", "Updated");
626
627 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
628
629 assert!(sql.contains("UPDATE test_models SET"));
630 assert!(sql.contains("name = $1"));
631 assert!(sql.contains("WHERE"));
632 assert!(sql.contains("RETURNING *"));
633 assert_eq!(params.len(), 2);
634 }
635
636 #[test]
637 fn test_update_many_fields() {
638 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
639 .set("name", "Updated")
640 .set("email", "updated@example.com");
641
642 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
643
644 assert!(sql.contains("name = $1"));
645 assert!(sql.contains("email = $2"));
646 assert_eq!(params.len(), 2);
647 }
648
649 #[test]
650 fn test_update_with_set_many() {
651 let updates = vec![
652 ("name", FilterValue::String("Alice".to_string())),
653 ("email", FilterValue::String("alice@test.com".to_string())),
654 ("age", FilterValue::Int(30)),
655 ];
656 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(updates);
657
658 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
659
660 assert!(sql.contains("name = $1"));
661 assert!(sql.contains("email = $2"));
662 assert!(sql.contains("age = $3"));
663 assert_eq!(params.len(), 3);
664 }
665
666 #[test]
667 fn test_update_increment() {
668 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
669 .increment("counter", 5);
670
671 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
672
673 assert!(
677 sql.contains("counter = counter + $1"),
678 "expected `counter = counter + $1`, got: {sql}"
679 );
680 assert_eq!(params.len(), 1);
681 assert_eq!(params[0], FilterValue::Int(5));
682 }
683
684 #[test]
685 fn test_update_with_select() {
686 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
687 .set("name", "Updated")
688 .select(Select::fields(["id", "name"]));
689
690 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
691
692 assert!(sql.contains("RETURNING id, name"));
693 }
694
695 #[test]
696 fn test_update_with_complex_filter() {
697 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
698 .r#where(Filter::Equals(
699 "status".into(),
700 FilterValue::String("active".to_string()),
701 ))
702 .r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
703 .set("verified", FilterValue::Bool(true));
704
705 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
706
707 assert!(sql.contains("WHERE"));
708 assert!(sql.contains("AND"));
709 assert_eq!(params.len(), 3); }
711
712 #[test]
713 fn test_update_without_filter() {
714 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
715 .set("status", "updated");
716
717 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
718
719 assert!(!sql.contains("WHERE"));
721 assert!(sql.contains("UPDATE test_models SET"));
722 }
723
724 #[test]
725 fn test_update_with_null_value() {
726 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
727 .set("deleted_at", FilterValue::Null);
728
729 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
730
731 assert!(sql.contains("deleted_at = $1"));
732 assert_eq!(params.len(), 1);
733 assert_eq!(params[0], FilterValue::Null);
734 }
735
736 #[test]
737 fn test_update_with_boolean() {
738 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
739 .set("active", FilterValue::Bool(true))
740 .set("verified", FilterValue::Bool(false));
741
742 let (_sql, params) = op.build_sql(&crate::dialect::Postgres);
743
744 assert_eq!(params.len(), 2);
745 assert_eq!(params[0], FilterValue::Bool(true));
746 assert_eq!(params[1], FilterValue::Bool(false));
747 }
748
749 #[tokio::test]
750 async fn test_update_exec() {
751 let op =
752 UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Updated");
753
754 let result = op.exec().await;
755 assert!(result.is_ok());
756 assert!(result.unwrap().is_empty());
757 }
758
759 #[tokio::test]
760 async fn test_update_exec_one() {
761 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
762 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
763 .set("name", "Updated");
764
765 let result = op.exec_one().await;
766 assert!(result.is_err()); }
768
769 #[test]
772 fn test_update_many_new() {
773 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
774 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
775
776 assert!(sql.contains("UPDATE test_models SET"));
777 assert!(!sql.contains("RETURNING")); assert!(params.is_empty());
779 }
780
781 #[test]
782 fn test_update_many_basic() {
783 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
784 .r#where(Filter::In(
785 "id".into(),
786 vec![
787 FilterValue::Int(1),
788 FilterValue::Int(2),
789 FilterValue::Int(3),
790 ],
791 ))
792 .set("status", "processed");
793
794 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
795
796 assert!(sql.contains("UPDATE test_models SET"));
797 assert!(sql.contains("status = $1"));
798 assert!(sql.contains("WHERE"));
799 assert!(sql.contains("IN"));
800 assert_eq!(params.len(), 4); }
802
803 #[test]
804 fn test_update_many_with_multiple_conditions() {
805 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
806 .r#where(Filter::Equals(
807 "department".into(),
808 FilterValue::String("engineering".to_string()),
809 ))
810 .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
811 .set("reviewed", FilterValue::Bool(true));
812
813 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
814
815 assert!(sql.contains("AND"));
816 assert_eq!(params.len(), 3);
817 }
818
819 #[test]
820 fn test_update_many_without_where() {
821 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
822 .set("reset_password", FilterValue::Bool(true));
823
824 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
825
826 assert!(!sql.contains("WHERE"));
827 }
828
829 #[tokio::test]
830 async fn test_update_many_exec() {
831 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
832 .set("status", "updated");
833
834 let result = op.exec().await;
835 assert!(result.is_ok());
836 assert_eq!(result.unwrap(), 5);
837 }
838
839 #[test]
842 fn test_update_param_ordering() {
843 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
844 .set("field1", "value1")
845 .set("field2", "value2")
846 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
847
848 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
849
850 assert!(sql.contains("field1 = $1"));
852 assert!(sql.contains("field2 = $2"));
853 assert!(sql.contains(r#""id" = $3"#));
854 assert_eq!(params.len(), 3);
855 }
856
857 #[test]
858 fn test_update_many_param_ordering() {
859 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
860 .set("field1", "value1")
861 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
862
863 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
864
865 assert!(sql.contains("field1 = $1"));
866 assert!(sql.contains(r#""id" = $2"#));
867 assert_eq!(params.len(), 2);
868 }
869
870 #[test]
871 fn test_update_with_float_value() {
872 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
873 .set("price", FilterValue::Float(99.99));
874
875 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
876
877 assert!(sql.contains("price = $1"));
878 assert_eq!(params.len(), 1);
879 }
880
881 #[test]
882 fn test_update_with_json_value() {
883 let json_value = serde_json::json!({"key": "value"});
884 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
885 .set("metadata", FilterValue::Json(json_value.clone()));
886
887 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
888
889 assert!(sql.contains("metadata = $1"));
890 assert_eq!(params[0], FilterValue::Json(json_value));
891 }
892
893 struct MockUpdateInput(Vec<(String, WriteOp)>);
899
900 impl crate::inputs::UpdateInput for MockUpdateInput {
901 type Model = TestModel;
902 type Data = crate::inputs::UpdatePayload;
903 fn into_ir(self) -> Self::Data {
904 self.0
905 }
906 }
907
908 #[test]
909 fn with_update_input_appends_set_ops() {
910 let input = MockUpdateInput(vec![
911 (
912 "name".into(),
913 WriteOp::Set(FilterValue::String("Bob".into())),
914 ),
915 ("age".into(), WriteOp::Increment(FilterValue::Int(1))),
916 ]);
917
918 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
919 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
920 .with_update_input(input);
921
922 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
923
924 assert!(sql.contains("name = $1"), "got: {sql}");
927 assert!(sql.contains("age = age + $2"), "got: {sql}");
928 assert_eq!(params.len(), 3);
930 assert_eq!(params[0], FilterValue::String("Bob".into()));
931 assert_eq!(params[1], FilterValue::Int(1));
932 }
933
934 #[test]
935 fn with_update_input_unset_emits_null_no_param() {
936 let input = MockUpdateInput(vec![("nickname".into(), WriteOp::Unset)]);
937
938 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
939 .with_update_input(input);
940
941 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
942
943 assert!(sql.contains("nickname = NULL"), "got: {sql}");
944 assert!(params.is_empty(), "expected no params, got: {params:?}");
946 }
947
948 #[test]
949 fn update_many_with_update_input_appends() {
950 let input = MockUpdateInput(vec![(
951 "name".into(),
952 WriteOp::Set(FilterValue::String("Bob".into())),
953 )]);
954
955 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
956 .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
957 .with_update_input(input);
958
959 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
960
961 assert!(sql.contains("UPDATE test_models SET"));
962 assert!(sql.contains("name = $1"), "got: {sql}");
963 assert!(sql.contains("WHERE"));
964 assert_eq!(params.len(), 2);
965 }
966
967 use std::sync::{Arc, Mutex};
970
971 type StatementLog = Arc<Mutex<Vec<(String, Vec<FilterValue>)>>>;
972
973 #[derive(Clone)]
978 struct RecordingEngine {
979 recorded: StatementLog,
980 affected: Arc<Mutex<Vec<u64>>>,
981 }
982
983 impl RecordingEngine {
984 fn new() -> Self {
985 Self {
986 recorded: Arc::new(Mutex::new(Vec::new())),
987 affected: Arc::new(Mutex::new(Vec::new())),
988 }
989 }
990
991 fn statements(&self) -> Vec<(String, Vec<FilterValue>)> {
992 self.recorded.lock().unwrap().clone()
993 }
994 }
995
996 impl crate::capabilities::SupportsNestedWrites for RecordingEngine {}
997
998 impl QueryEngine for RecordingEngine {
999 fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
1000 &crate::dialect::Postgres
1001 }
1002
1003 fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
1004 &self,
1005 _sql: &str,
1006 _params: Vec<FilterValue>,
1007 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
1008 Box::pin(async { Ok(Vec::new()) })
1009 }
1010
1011 fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
1012 &self,
1013 _sql: &str,
1014 _params: Vec<FilterValue>,
1015 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
1016 Box::pin(async { Err(QueryError::not_found("test")) })
1017 }
1018
1019 fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
1020 &self,
1021 _sql: &str,
1022 _params: Vec<FilterValue>,
1023 ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
1024 Box::pin(async { Ok(None) })
1025 }
1026
1027 fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
1028 &self,
1029 _sql: &str,
1030 _params: Vec<FilterValue>,
1031 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
1032 Box::pin(async { Err(QueryError::not_found("test")) })
1033 }
1034
1035 fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
1036 &self,
1037 sql: &str,
1038 params: Vec<FilterValue>,
1039 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
1040 let recorded = self.recorded.clone();
1041 let sql = sql.to_string();
1042 Box::pin(async move {
1043 recorded.lock().unwrap().push((sql, params));
1044 Ok(Vec::new())
1045 })
1046 }
1047
1048 fn execute_delete(
1049 &self,
1050 _sql: &str,
1051 _params: Vec<FilterValue>,
1052 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
1053 Box::pin(async { Ok(0) })
1054 }
1055
1056 fn execute_raw(
1057 &self,
1058 sql: &str,
1059 params: Vec<FilterValue>,
1060 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
1061 let recorded = self.recorded.clone();
1062 let affected = self.affected.clone();
1063 let sql_string = sql.to_string();
1064 let default = if sql.contains(" IN (") {
1065 (params.len() as u64).saturating_sub(1)
1066 } else {
1067 1
1068 };
1069 Box::pin(async move {
1070 recorded.lock().unwrap().push((sql_string, params));
1071 Ok(affected.lock().unwrap().pop().unwrap_or(default))
1072 })
1073 }
1074
1075 fn count(
1076 &self,
1077 _sql: &str,
1078 _params: Vec<FilterValue>,
1079 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
1080 Box::pin(async { Ok(0) })
1081 }
1082 }
1083
1084 #[tokio::test]
1085 async fn update_with_nested_create_runs_parent_then_child_insert() {
1086 let engine = RecordingEngine::new();
1087 let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
1088 .r#where(Filter::Equals("id".into(), FilterValue::Int(7)))
1089 .set("name", "Renamed")
1090 .with(NestedWriteOp::Create {
1091 relation: "posts",
1092 target_table: "posts",
1093 foreign_key: "author_id",
1094 payload: vec![vec![("title".into(), FilterValue::String("p1".into()))]],
1095 });
1096
1097 let _ = op.exec().await.expect("update + nested create");
1098
1099 let stmts = engine.statements();
1100 assert_eq!(
1101 stmts.len(),
1102 2,
1103 "parent UPDATE + nested child INSERT; got {stmts:#?}"
1104 );
1105 assert!(
1106 stmts[0].0.contains("UPDATE test_models"),
1107 "got: {}",
1108 stmts[0].0
1109 );
1110 assert!(stmts[1].0.contains("INSERT INTO"), "got: {}", stmts[1].0);
1111 assert!(stmts[1].0.contains("posts"), "got: {}", stmts[1].0);
1112 assert!(stmts[1].0.contains("author_id"), "got: {}", stmts[1].0);
1113 }
1114
1115 #[tokio::test]
1116 async fn update_with_nested_disconnect_emits_set_null_update() {
1117 let engine = RecordingEngine::new();
1118 let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
1119 .r#where(Filter::Equals("id".into(), FilterValue::Int(7)))
1120 .set("name", "Renamed")
1121 .with(NestedWriteOp::Disconnect {
1122 relation: "posts",
1123 target_table: "posts",
1124 foreign_key: "author_id",
1125 target_pk: "id",
1126 pk: FilterValue::Int(42),
1127 });
1128
1129 let _ = op.exec().await.expect("update + nested disconnect");
1130
1131 let stmts = engine.statements();
1132 assert_eq!(stmts.len(), 2, "got {stmts:#?}");
1133 assert!(
1134 stmts[0].0.contains("UPDATE test_models"),
1135 "got: {}",
1136 stmts[0].0
1137 );
1138 let (sql, params) = &stmts[1];
1139 assert!(sql.contains("UPDATE"), "got: {sql}");
1140 assert!(sql.contains("posts"), "got: {sql}");
1141 assert!(sql.contains("author_id"), "got: {sql}");
1142 assert!(sql.contains("NULL"), "got: {sql}");
1143 assert_eq!(params, &vec![FilterValue::Int(42)]);
1144 }
1145
1146 #[tokio::test]
1147 async fn update_nested_requires_pk_in_where_filter() {
1148 let engine = RecordingEngine::new();
1149 let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
1152 .r#where(Filter::Equals(
1153 "email".into(),
1154 FilterValue::String("a@x.com".into()),
1155 ))
1156 .set("name", "Renamed")
1157 .with(NestedWriteOp::Disconnect {
1158 relation: "posts",
1159 target_table: "posts",
1160 foreign_key: "author_id",
1161 target_pk: "id",
1162 pk: FilterValue::Int(42),
1163 });
1164
1165 let result = op.exec().await;
1166 let err = result.err().expect("non-PK where must error");
1167 let msg = err.to_string();
1168 assert!(
1169 msg.contains("primary-key column") || msg.contains("primary key"),
1170 "expected PK-required diagnostic, got: {msg}"
1171 );
1172 assert!(
1174 engine.statements().is_empty(),
1175 "no SQL should run: {:#?}",
1176 engine.statements()
1177 );
1178 }
1179}