Skip to main content

prax_query/operations/
update.rs

1//! Update operation for modifying existing records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::inputs::WriteOp;
8use crate::nested::NestedWriteOp;
9use crate::traits::{Model, QueryEngine};
10use crate::types::Select;
11
12/// Extract the parent PK value from a where-unique filter when it
13/// equal-matches the model's primary-key column directly.
14///
15/// Returns `None` for any other filter shape (non-PK equals, non-equals
16/// comparators, AND/OR composites, etc.). Nested-write callers turn this
17/// into a clear "where must equal-match the PK" error.
18pub(crate) fn extract_pk_from_filter(filter: &Filter, pk_col: &str) -> Option<FilterValue> {
19    match filter {
20        Filter::Equals(name, value) if name.as_ref() == pk_col => Some(value.clone()),
21        _ => None,
22    }
23}
24
25/// An update operation for modifying existing records.
26///
27/// # Example
28///
29/// ```rust,ignore
30/// let users = client
31///     .user()
32///     .update()
33///     .r#where(user::id::equals(1))
34///     .set("name", "Updated Name")
35///     .exec()
36///     .await?;
37/// ```
38pub struct UpdateOperation<E: QueryEngine, M: Model> {
39    engine: E,
40    filter: Filter,
41    updates: Vec<(String, WriteOp)>,
42    select: Select,
43    /// Queued nested-write ops run after the parent UPDATE inside an
44    /// implicit transaction. Populated by [`UpdateOperation::with`].
45    /// Empty on the fast path (single UPDATE, no transaction wrap).
46    nested: Vec<NestedWriteOp>,
47    _model: PhantomData<M>,
48}
49
50impl<E: QueryEngine, M: Model + crate::row::FromRow> UpdateOperation<E, M> {
51    /// Create a new Update operation.
52    pub fn new(engine: E) -> Self {
53        Self {
54            engine,
55            filter: Filter::None,
56            updates: Vec::new(),
57            select: Select::All,
58            nested: Vec::new(),
59            _model: PhantomData,
60        }
61    }
62
63    /// Add a filter condition.
64    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
65        let new_filter = filter.into();
66        self.filter = self.filter.and_then(new_filter);
67        self
68    }
69
70    /// Set a column to a new value.
71    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
72        self.updates
73            .push((column.into(), WriteOp::Set(value.into())));
74        self
75    }
76
77    /// Set multiple columns from an iterator.
78    pub fn set_many(
79        mut self,
80        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
81    ) -> Self {
82        for (col, val) in values {
83            self.updates.push((col.into(), WriteOp::Set(val.into())));
84        }
85        self
86    }
87
88    /// Increment a numeric column.
89    pub fn increment(mut self, column: impl Into<String>, amount: i64) -> Self {
90        self.updates
91            .push((column.into(), WriteOp::Increment(FilterValue::Int(amount))));
92        self
93    }
94
95    /// Apply a column-keyed [`WriteOp`].
96    ///
97    /// Used by `with_update_input` (and tests) to push an arbitrary
98    /// scalar atomic operator onto the update list. The DSL surface
99    /// for these operators is the `*FieldUpdate` wrappers in
100    /// [`crate::inputs::scalar_update`].
101    pub fn set_op(mut self, column: impl Into<String>, op: WriteOp) -> Self {
102        self.updates.push((column.into(), op));
103        self
104    }
105
106    /// Select specific fields to return.
107    pub fn select(mut self, select: impl Into<Select>) -> Self {
108        self.select = select.into();
109        self
110    }
111
112    /// Build the SQL query.
113    pub fn build_sql(
114        &self,
115        dialect: &dyn crate::dialect::SqlDialect,
116    ) -> (String, Vec<FilterValue>) {
117        let mut sql = String::new();
118        let mut params = Vec::new();
119        let mut param_idx = 1;
120
121        // UPDATE clause
122        sql.push_str("UPDATE ");
123        sql.push_str(M::TABLE_NAME);
124
125        // SET clause
126        sql.push_str(" SET ");
127        let set_parts: Vec<String> = self
128            .updates
129            .iter()
130            .map(|(col, op)| {
131                let placeholder = dialect.placeholder(param_idx);
132                let (fragment, value) = op.to_set_fragment(col, &placeholder);
133                if let Some(v) = value {
134                    params.push(v);
135                    param_idx += 1;
136                }
137                fragment
138            })
139            .collect();
140        sql.push_str(&set_parts.join(", "));
141
142        // WHERE clause
143        if !self.filter.is_none() {
144            let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
145            sql.push_str(" WHERE ");
146            sql.push_str(&where_sql);
147            params.extend(where_params);
148        }
149
150        // RETURNING clause
151        sql.push_str(&dialect.returning_clause(&self.select.to_sql()));
152
153        (sql, params)
154    }
155
156    /// Queue a nested write to run alongside this update.
157    ///
158    /// The parent `UPDATE` and every queued nested op execute inside a
159    /// single implicit transaction — any failure rolls back the parent
160    /// UPDATE too.
161    ///
162    /// Nested writes inside `update!` currently require the `where:`
163    /// filter to equal-match the primary-key column. Non-PK unique
164    /// columns (e.g. `where: { email: "..." }`) error at exec time
165    /// with a clear diagnostic. Lifting this restriction needs a
166    /// SELECT-then-update pattern to capture the row's PK — deferred.
167    pub fn with(mut self, nw: NestedWriteOp) -> Self
168    where
169        E: crate::capabilities::SupportsNestedWrites,
170    {
171        self.nested.push(nw);
172        self
173    }
174
175    /// Execute the update and return modified records.
176    pub async fn exec(self) -> QueryResult<Vec<M>>
177    where
178        M: Send + 'static,
179    {
180        // Fast path: no nested writes — single UPDATE statement.
181        if self.nested.is_empty() {
182            let dialect = self.engine.dialect();
183            let (sql, params) = self.build_sql(dialect);
184            return self.engine.execute_update::<M>(&sql, params).await;
185        }
186
187        // Slow path: extract the parent PK from the `where` filter, then
188        // run the UPDATE + queued nested ops inside a transaction.
189        let parent_pk =
190            extract_pk_from_filter(&self.filter, M::PRIMARY_KEY[0]).ok_or_else(|| {
191                crate::error::QueryError::invalid_input(
192                    "where",
193                    "nested writes inside `update!` require the `where:` clause to equal-match \
194                     the primary-key column",
195                )
196                .with_help(format!(
197                    "expected `where: {{ {pk}: <value> }}` on `{table}` — non-PK unique \
198                     columns are not yet supported for nested writes inside update!. \
199                     Lift this restriction by running the nested ops in a separate operation \
200                     after looking up the row's PK.",
201                    pk = M::PRIMARY_KEY[0],
202                    table = M::TABLE_NAME,
203                ))
204            })?;
205
206        let UpdateOperation {
207            engine,
208            filter,
209            updates,
210            select,
211            nested,
212            _model,
213        } = self;
214
215        engine
216            .transaction(move |tx| async move {
217                let dialect = tx.dialect();
218                let (sql, params) = Self::build_sql_parts(&filter, &updates, &select, dialect);
219                let parent: Vec<M> = tx.execute_update::<M>(&sql, params).await?;
220
221                // Batch consecutive Connect ops with the same target.
222                let mut idx = 0;
223                while idx < nested.len() {
224                    if let NestedWriteOp::Connect {
225                        target_table: run_table,
226                        foreign_key: run_fk,
227                        target_pk: run_target_pk,
228                        ..
229                    } = &nested[idx]
230                    {
231                        let run_table = *run_table;
232                        let run_fk = *run_fk;
233                        let run_target_pk = *run_target_pk;
234                        let mut end = idx + 1;
235                        while end < nested.len() {
236                            match &nested[end] {
237                                NestedWriteOp::Connect {
238                                    target_table,
239                                    foreign_key,
240                                    target_pk,
241                                    ..
242                                } if *target_table == run_table
243                                    && *foreign_key == run_fk
244                                    && *target_pk == run_target_pk =>
245                                {
246                                    end += 1;
247                                }
248                                _ => break,
249                            }
250                        }
251
252                        if end - idx == 1 {
253                            let op = nested[idx].clone();
254                            op.execute(&tx, &parent_pk).await?;
255                        } else {
256                            let expected = (end - idx) as u64;
257                            let mut pks: Vec<FilterValue> = Vec::with_capacity(end - idx + 1);
258                            pks.push(parent_pk.clone());
259                            for op in &nested[idx..end] {
260                                if let NestedWriteOp::Connect { pk, .. } = op {
261                                    pks.push(pk.clone());
262                                }
263                            }
264                            let placeholders: Vec<String> =
265                                (2..=pks.len()).map(|i| dialect.placeholder(i)).collect();
266                            let sql = format!(
267                                "UPDATE {} SET {} = {} WHERE {} IN ({})",
268                                dialect.quote_ident(run_table),
269                                dialect.quote_ident(run_fk),
270                                dialect.placeholder(1),
271                                dialect.quote_ident(run_target_pk),
272                                placeholders.join(", "),
273                            );
274                            let affected = tx.execute_raw(&sql, pks).await?;
275                            if affected != expected {
276                                return Err(crate::error::QueryError::not_found(run_table)
277                                    .with_context("Nested Connect batch")
278                                    .with_help(format!(
279                                        "Expected {} matching rows but UPDATE affected {}",
280                                        expected, affected
281                                    )));
282                            }
283                        }
284                        idx = end;
285                    } else {
286                        let op = nested[idx].clone();
287                        op.execute(&tx, &parent_pk).await?;
288                        idx += 1;
289                    }
290                }
291                Ok(parent)
292            })
293            .await
294    }
295
296    /// Free-function form of [`Self::build_sql`] — takes the pieces by
297    /// reference so the `exec` path can reuse it after destructuring
298    /// `self` to move the captured state into the transaction closure.
299    fn build_sql_parts(
300        filter: &Filter,
301        updates: &[(String, WriteOp)],
302        select: &Select,
303        dialect: &dyn crate::dialect::SqlDialect,
304    ) -> (String, Vec<FilterValue>) {
305        let mut sql = String::new();
306        let mut params = Vec::new();
307        let mut param_idx = 1;
308
309        sql.push_str("UPDATE ");
310        sql.push_str(M::TABLE_NAME);
311
312        sql.push_str(" SET ");
313        let set_parts: Vec<String> = updates
314            .iter()
315            .map(|(col, op)| {
316                let placeholder = dialect.placeholder(param_idx);
317                let (fragment, value) = op.to_set_fragment(col, &placeholder);
318                if let Some(v) = value {
319                    params.push(v);
320                    param_idx += 1;
321                }
322                fragment
323            })
324            .collect();
325        sql.push_str(&set_parts.join(", "));
326
327        if !filter.is_none() {
328            let (where_sql, where_params) = filter.to_sql(param_idx - 1, dialect);
329            sql.push_str(" WHERE ");
330            sql.push_str(&where_sql);
331            params.extend(where_params);
332        }
333
334        sql.push_str(&dialect.returning_clause(&select.to_sql()));
335
336        (sql, params)
337    }
338
339    /// Execute the update and return the first modified record.
340    pub async fn exec_one(self) -> QueryResult<M>
341    where
342        M: Send + 'static,
343    {
344        let dialect = self.engine.dialect();
345        let (sql, params) = self.build_sql(dialect);
346        self.engine.query_one::<M>(&sql, params).await
347    }
348
349    /// Apply a typed `WhereUniqueInput`. AND-composes with any
350    /// previously set filter so callers can combine the unique key
351    /// with side conditions when they need to.
352    pub fn with_where_input<W: crate::inputs::WhereUniqueInput<Model = M>>(mut self, w: W) -> Self {
353        let f = w.into_ir();
354        self.filter = self.filter.and_then(f);
355        self
356    }
357
358    /// Apply a typed `SelectInput`.
359    pub fn with_select_input<S: crate::inputs::SelectInput<Model = M>>(mut self, s: S) -> Self {
360        self.select = s.into_ir();
361        self
362    }
363
364    /// Apply a typed `UpdateInput`.
365    ///
366    /// The input's `into_ir` produces a `Vec<(column, WriteOp)>` —
367    /// each entry is appended to the operation's SET list. Atomic
368    /// operators (`Increment`/`Decrement`/`Multiply`/`Divide`) emit
369    /// `col = col <op> $n` in the resulting SQL; `Set` emits
370    /// `col = $n`; `Unset` emits `col = NULL` with no placeholder.
371    pub fn with_update_input<I>(mut self, input: I) -> Self
372    where
373        I: crate::inputs::UpdateInput<Model = M, Data = crate::inputs::UpdatePayload>,
374    {
375        let data: crate::inputs::UpdatePayload = input.into_ir();
376        for (col, op) in data {
377            self.updates.push((col, op));
378        }
379        self
380    }
381
382    /// Doc-hidden accessor for the current filter.
383    #[doc(hidden)]
384    pub fn filter_for_test(&self) -> &Filter {
385        &self.filter
386    }
387}
388
389/// Update many records at once.
390pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
391    engine: E,
392    filter: Filter,
393    updates: Vec<(String, WriteOp)>,
394    _model: PhantomData<M>,
395}
396
397impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
398    /// Create a new UpdateMany operation.
399    pub fn new(engine: E) -> Self {
400        Self {
401            engine,
402            filter: Filter::None,
403            updates: Vec::new(),
404            _model: PhantomData,
405        }
406    }
407
408    /// Add a filter condition.
409    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
410        let new_filter = filter.into();
411        self.filter = self.filter.and_then(new_filter);
412        self
413    }
414
415    /// Set a column to a new value.
416    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
417        self.updates
418            .push((column.into(), WriteOp::Set(value.into())));
419        self
420    }
421
422    /// Apply a column-keyed [`WriteOp`].
423    pub fn set_op(mut self, column: impl Into<String>, op: WriteOp) -> Self {
424        self.updates.push((column.into(), op));
425        self
426    }
427
428    /// Apply a typed `WhereInput`. AND-composes with the existing filter.
429    pub fn with_where_input<W: crate::inputs::WhereInput<Model = M>>(mut self, w: W) -> Self {
430        let f = w.into_ir();
431        self.filter = self.filter.and_then(f);
432        self
433    }
434
435    /// Apply a typed `UpdateInput`.
436    ///
437    /// See [`UpdateOperation::with_update_input`] for the lowering
438    /// semantics — the only difference here is that `update_many` does
439    /// not return rows.
440    pub fn with_update_input<I>(mut self, input: I) -> Self
441    where
442        I: crate::inputs::UpdateInput<Model = M, Data = crate::inputs::UpdatePayload>,
443    {
444        let data: crate::inputs::UpdatePayload = input.into_ir();
445        for (col, op) in data {
446            self.updates.push((col, op));
447        }
448        self
449    }
450
451    /// Build the SQL query.
452    pub fn build_sql(
453        &self,
454        dialect: &dyn crate::dialect::SqlDialect,
455    ) -> (String, Vec<FilterValue>) {
456        let mut sql = String::new();
457        let mut params = Vec::new();
458        let mut param_idx = 1;
459
460        // UPDATE clause
461        sql.push_str("UPDATE ");
462        sql.push_str(M::TABLE_NAME);
463
464        // SET clause
465        sql.push_str(" SET ");
466        let set_parts: Vec<String> = self
467            .updates
468            .iter()
469            .map(|(col, op)| {
470                let placeholder = dialect.placeholder(param_idx);
471                let (fragment, value) = op.to_set_fragment(col, &placeholder);
472                if let Some(v) = value {
473                    params.push(v);
474                    param_idx += 1;
475                }
476                fragment
477            })
478            .collect();
479        sql.push_str(&set_parts.join(", "));
480
481        // WHERE clause
482        if !self.filter.is_none() {
483            let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
484            sql.push_str(" WHERE ");
485            sql.push_str(&where_sql);
486            params.extend(where_params);
487        }
488
489        (sql, params)
490    }
491
492    /// Execute the update and return the count of modified records.
493    pub async fn exec(self) -> QueryResult<u64> {
494        let dialect = self.engine.dialect();
495        let (sql, params) = self.build_sql(dialect);
496        self.engine.execute_raw(&sql, params).await
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503    use crate::error::QueryError;
504    use crate::types::Select;
505
506    struct TestModel;
507
508    impl Model for TestModel {
509        const MODEL_NAME: &'static str = "TestModel";
510        const TABLE_NAME: &'static str = "test_models";
511        const PRIMARY_KEY: &'static [&'static str] = &["id"];
512        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
513    }
514
515    impl crate::row::FromRow for TestModel {
516        fn from_row(_row: &impl crate::row::RowRef) -> Result<Self, crate::row::RowError> {
517            Ok(TestModel)
518        }
519    }
520
521    #[derive(Clone)]
522    struct MockEngine {
523        return_count: u64,
524    }
525
526    impl MockEngine {
527        fn new() -> Self {
528            Self { return_count: 0 }
529        }
530
531        fn with_count(count: u64) -> Self {
532            Self {
533                return_count: count,
534            }
535        }
536    }
537
538    impl QueryEngine for MockEngine {
539        fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
540            &crate::dialect::Postgres
541        }
542
543        fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
544            &self,
545            _sql: &str,
546            _params: Vec<FilterValue>,
547        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
548            Box::pin(async { Ok(Vec::new()) })
549        }
550
551        fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
552            &self,
553            _sql: &str,
554            _params: Vec<FilterValue>,
555        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
556            Box::pin(async { Err(QueryError::not_found("test")) })
557        }
558
559        fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
560            &self,
561            _sql: &str,
562            _params: Vec<FilterValue>,
563        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
564            Box::pin(async { Ok(None) })
565        }
566
567        fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
568            &self,
569            _sql: &str,
570            _params: Vec<FilterValue>,
571        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
572            Box::pin(async { Err(QueryError::not_found("test")) })
573        }
574
575        fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
576            &self,
577            _sql: &str,
578            _params: Vec<FilterValue>,
579        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
580            Box::pin(async { Ok(Vec::new()) })
581        }
582
583        fn execute_delete(
584            &self,
585            _sql: &str,
586            _params: Vec<FilterValue>,
587        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
588            Box::pin(async { Ok(0) })
589        }
590
591        fn execute_raw(
592            &self,
593            _sql: &str,
594            _params: Vec<FilterValue>,
595        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
596            let count = self.return_count;
597            Box::pin(async move { Ok(count) })
598        }
599
600        fn count(
601            &self,
602            _sql: &str,
603            _params: Vec<FilterValue>,
604        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
605            Box::pin(async { Ok(0) })
606        }
607    }
608
609    // ========== UpdateOperation Tests ==========
610
611    #[test]
612    fn test_update_new() {
613        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
614        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
615
616        assert!(sql.contains("UPDATE test_models SET"));
617        assert!(sql.contains("RETURNING *"));
618        assert!(params.is_empty());
619    }
620
621    #[test]
622    fn test_update_basic() {
623        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
624            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
625            .set("name", "Updated");
626
627        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
628
629        assert!(sql.contains("UPDATE test_models SET"));
630        assert!(sql.contains("name = $1"));
631        assert!(sql.contains("WHERE"));
632        assert!(sql.contains("RETURNING *"));
633        assert_eq!(params.len(), 2);
634    }
635
636    #[test]
637    fn test_update_many_fields() {
638        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
639            .set("name", "Updated")
640            .set("email", "updated@example.com");
641
642        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
643
644        assert!(sql.contains("name = $1"));
645        assert!(sql.contains("email = $2"));
646        assert_eq!(params.len(), 2);
647    }
648
649    #[test]
650    fn test_update_with_set_many() {
651        let updates = vec![
652            ("name", FilterValue::String("Alice".to_string())),
653            ("email", FilterValue::String("alice@test.com".to_string())),
654            ("age", FilterValue::Int(30)),
655        ];
656        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(updates);
657
658        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
659
660        assert!(sql.contains("name = $1"));
661        assert!(sql.contains("email = $2"));
662        assert!(sql.contains("age = $3"));
663        assert_eq!(params.len(), 3);
664    }
665
666    #[test]
667    fn test_update_increment() {
668        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
669            .increment("counter", 5);
670
671        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
672
673        // `increment` now lowers to a true `col = col + $n` atomic
674        // operator (the prior implementation collapsed it to a plain
675        // `set`, which was a documented bug).
676        assert!(
677            sql.contains("counter = counter + $1"),
678            "expected `counter = counter + $1`, got: {sql}"
679        );
680        assert_eq!(params.len(), 1);
681        assert_eq!(params[0], FilterValue::Int(5));
682    }
683
684    #[test]
685    fn test_update_with_select() {
686        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
687            .set("name", "Updated")
688            .select(Select::fields(["id", "name"]));
689
690        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
691
692        assert!(sql.contains("RETURNING id, name"));
693    }
694
695    #[test]
696    fn test_update_with_complex_filter() {
697        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
698            .r#where(Filter::Equals(
699                "status".into(),
700                FilterValue::String("active".to_string()),
701            ))
702            .r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
703            .set("verified", FilterValue::Bool(true));
704
705        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
706
707        assert!(sql.contains("WHERE"));
708        assert!(sql.contains("AND"));
709        assert_eq!(params.len(), 3); // 1 set + 2 where
710    }
711
712    #[test]
713    fn test_update_without_filter() {
714        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
715            .set("status", "updated");
716
717        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
718
719        // Should not have WHERE clause
720        assert!(!sql.contains("WHERE"));
721        assert!(sql.contains("UPDATE test_models SET"));
722    }
723
724    #[test]
725    fn test_update_with_null_value() {
726        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
727            .set("deleted_at", FilterValue::Null);
728
729        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
730
731        assert!(sql.contains("deleted_at = $1"));
732        assert_eq!(params.len(), 1);
733        assert_eq!(params[0], FilterValue::Null);
734    }
735
736    #[test]
737    fn test_update_with_boolean() {
738        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
739            .set("active", FilterValue::Bool(true))
740            .set("verified", FilterValue::Bool(false));
741
742        let (_sql, params) = op.build_sql(&crate::dialect::Postgres);
743
744        assert_eq!(params.len(), 2);
745        assert_eq!(params[0], FilterValue::Bool(true));
746        assert_eq!(params[1], FilterValue::Bool(false));
747    }
748
749    #[tokio::test]
750    async fn test_update_exec() {
751        let op =
752            UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Updated");
753
754        let result = op.exec().await;
755        assert!(result.is_ok());
756        assert!(result.unwrap().is_empty());
757    }
758
759    #[tokio::test]
760    async fn test_update_exec_one() {
761        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
762            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
763            .set("name", "Updated");
764
765        let result = op.exec_one().await;
766        assert!(result.is_err()); // MockEngine returns not_found
767    }
768
769    // ========== UpdateManyOperation Tests ==========
770
771    #[test]
772    fn test_update_many_new() {
773        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
774        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
775
776        assert!(sql.contains("UPDATE test_models SET"));
777        assert!(!sql.contains("RETURNING")); // UpdateMany doesn't return records
778        assert!(params.is_empty());
779    }
780
781    #[test]
782    fn test_update_many_basic() {
783        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
784            .r#where(Filter::In(
785                "id".into(),
786                vec![
787                    FilterValue::Int(1),
788                    FilterValue::Int(2),
789                    FilterValue::Int(3),
790                ],
791            ))
792            .set("status", "processed");
793
794        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
795
796        assert!(sql.contains("UPDATE test_models SET"));
797        assert!(sql.contains("status = $1"));
798        assert!(sql.contains("WHERE"));
799        assert!(sql.contains("IN"));
800        assert_eq!(params.len(), 4); // 1 set + 3 IN values
801    }
802
803    #[test]
804    fn test_update_many_with_multiple_conditions() {
805        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
806            .r#where(Filter::Equals(
807                "department".into(),
808                FilterValue::String("engineering".to_string()),
809            ))
810            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
811            .set("reviewed", FilterValue::Bool(true));
812
813        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
814
815        assert!(sql.contains("AND"));
816        assert_eq!(params.len(), 3);
817    }
818
819    #[test]
820    fn test_update_many_without_where() {
821        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
822            .set("reset_password", FilterValue::Bool(true));
823
824        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
825
826        assert!(!sql.contains("WHERE"));
827    }
828
829    #[tokio::test]
830    async fn test_update_many_exec() {
831        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
832            .set("status", "updated");
833
834        let result = op.exec().await;
835        assert!(result.is_ok());
836        assert_eq!(result.unwrap(), 5);
837    }
838
839    // ========== SQL Generation Edge Cases ==========
840
841    #[test]
842    fn test_update_param_ordering() {
843        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
844            .set("field1", "value1")
845            .set("field2", "value2")
846            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
847
848        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
849
850        // SET params come first, then WHERE params
851        assert!(sql.contains("field1 = $1"));
852        assert!(sql.contains("field2 = $2"));
853        assert!(sql.contains(r#""id" = $3"#));
854        assert_eq!(params.len(), 3);
855    }
856
857    #[test]
858    fn test_update_many_param_ordering() {
859        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
860            .set("field1", "value1")
861            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
862
863        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
864
865        assert!(sql.contains("field1 = $1"));
866        assert!(sql.contains(r#""id" = $2"#));
867        assert_eq!(params.len(), 2);
868    }
869
870    #[test]
871    fn test_update_with_float_value() {
872        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
873            .set("price", FilterValue::Float(99.99));
874
875        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
876
877        assert!(sql.contains("price = $1"));
878        assert_eq!(params.len(), 1);
879    }
880
881    #[test]
882    fn test_update_with_json_value() {
883        let json_value = serde_json::json!({"key": "value"});
884        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
885            .set("metadata", FilterValue::Json(json_value.clone()));
886
887        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
888
889        assert!(sql.contains("metadata = $1"));
890        assert_eq!(params[0], FilterValue::Json(json_value));
891    }
892
893    // ========== Phase 5a: typed-input wiring ==========
894
895    /// Mock `UpdateInput` used by the `with_update_input` tests. The
896    /// codegen-emitted equivalent isn't available inside `prax-query`,
897    /// so we hand-roll the trait impl against `TestModel`.
898    struct MockUpdateInput(Vec<(String, WriteOp)>);
899
900    impl crate::inputs::UpdateInput for MockUpdateInput {
901        type Model = TestModel;
902        type Data = crate::inputs::UpdatePayload;
903        fn into_ir(self) -> Self::Data {
904            self.0
905        }
906    }
907
908    #[test]
909    fn with_update_input_appends_set_ops() {
910        let input = MockUpdateInput(vec![
911            (
912                "name".into(),
913                WriteOp::Set(FilterValue::String("Bob".into())),
914            ),
915            ("age".into(), WriteOp::Increment(FilterValue::Int(1))),
916        ]);
917
918        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
919            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
920            .with_update_input(input);
921
922        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
923
924        // `Set` emits the plain `col = $n` form; `Increment` emits the
925        // atomic-operator `col = col + $n` fragment.
926        assert!(sql.contains("name = $1"), "got: {sql}");
927        assert!(sql.contains("age = age + $2"), "got: {sql}");
928        // 2 SET params + 1 WHERE param.
929        assert_eq!(params.len(), 3);
930        assert_eq!(params[0], FilterValue::String("Bob".into()));
931        assert_eq!(params[1], FilterValue::Int(1));
932    }
933
934    #[test]
935    fn with_update_input_unset_emits_null_no_param() {
936        let input = MockUpdateInput(vec![("nickname".into(), WriteOp::Unset)]);
937
938        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
939            .with_update_input(input);
940
941        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
942
943        assert!(sql.contains("nickname = NULL"), "got: {sql}");
944        // Unset emits no placeholder, so no parameter is pushed.
945        assert!(params.is_empty(), "expected no params, got: {params:?}");
946    }
947
948    #[test]
949    fn update_many_with_update_input_appends() {
950        let input = MockUpdateInput(vec![(
951            "name".into(),
952            WriteOp::Set(FilterValue::String("Bob".into())),
953        )]);
954
955        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
956            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
957            .with_update_input(input);
958
959        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
960
961        assert!(sql.contains("UPDATE test_models SET"));
962        assert!(sql.contains("name = $1"), "got: {sql}");
963        assert!(sql.contains("WHERE"));
964        assert_eq!(params.len(), 2);
965    }
966
967    // ========== Phase 5c: nested-write wiring on update! ==========
968
969    use std::sync::{Arc, Mutex};
970
971    type StatementLog = Arc<Mutex<Vec<(String, Vec<FilterValue>)>>>;
972
973    /// Recording engine mirroring `nested.rs::tests::RecordingEngine`.
974    /// Captures every (sql, params) on `execute_raw`, returns the next
975    /// entry of `affected` (or 1 as fallback), and `execute_update`
976    /// records the parent UPDATE too while returning an empty row vec.
977    #[derive(Clone)]
978    struct RecordingEngine {
979        recorded: StatementLog,
980        affected: Arc<Mutex<Vec<u64>>>,
981    }
982
983    impl RecordingEngine {
984        fn new() -> Self {
985            Self {
986                recorded: Arc::new(Mutex::new(Vec::new())),
987                affected: Arc::new(Mutex::new(Vec::new())),
988            }
989        }
990
991        fn statements(&self) -> Vec<(String, Vec<FilterValue>)> {
992            self.recorded.lock().unwrap().clone()
993        }
994    }
995
996    impl crate::capabilities::SupportsNestedWrites for RecordingEngine {}
997
998    impl QueryEngine for RecordingEngine {
999        fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
1000            &crate::dialect::Postgres
1001        }
1002
1003        fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
1004            &self,
1005            _sql: &str,
1006            _params: Vec<FilterValue>,
1007        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
1008            Box::pin(async { Ok(Vec::new()) })
1009        }
1010
1011        fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
1012            &self,
1013            _sql: &str,
1014            _params: Vec<FilterValue>,
1015        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
1016            Box::pin(async { Err(QueryError::not_found("test")) })
1017        }
1018
1019        fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
1020            &self,
1021            _sql: &str,
1022            _params: Vec<FilterValue>,
1023        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
1024            Box::pin(async { Ok(None) })
1025        }
1026
1027        fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
1028            &self,
1029            _sql: &str,
1030            _params: Vec<FilterValue>,
1031        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
1032            Box::pin(async { Err(QueryError::not_found("test")) })
1033        }
1034
1035        fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
1036            &self,
1037            sql: &str,
1038            params: Vec<FilterValue>,
1039        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
1040            let recorded = self.recorded.clone();
1041            let sql = sql.to_string();
1042            Box::pin(async move {
1043                recorded.lock().unwrap().push((sql, params));
1044                Ok(Vec::new())
1045            })
1046        }
1047
1048        fn execute_delete(
1049            &self,
1050            _sql: &str,
1051            _params: Vec<FilterValue>,
1052        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
1053            Box::pin(async { Ok(0) })
1054        }
1055
1056        fn execute_raw(
1057            &self,
1058            sql: &str,
1059            params: Vec<FilterValue>,
1060        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
1061            let recorded = self.recorded.clone();
1062            let affected = self.affected.clone();
1063            let sql_string = sql.to_string();
1064            let default = if sql.contains(" IN (") {
1065                (params.len() as u64).saturating_sub(1)
1066            } else {
1067                1
1068            };
1069            Box::pin(async move {
1070                recorded.lock().unwrap().push((sql_string, params));
1071                Ok(affected.lock().unwrap().pop().unwrap_or(default))
1072            })
1073        }
1074
1075        fn count(
1076            &self,
1077            _sql: &str,
1078            _params: Vec<FilterValue>,
1079        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
1080            Box::pin(async { Ok(0) })
1081        }
1082    }
1083
1084    #[tokio::test]
1085    async fn update_with_nested_create_runs_parent_then_child_insert() {
1086        let engine = RecordingEngine::new();
1087        let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
1088            .r#where(Filter::Equals("id".into(), FilterValue::Int(7)))
1089            .set("name", "Renamed")
1090            .with(NestedWriteOp::Create {
1091                relation: "posts",
1092                target_table: "posts",
1093                foreign_key: "author_id",
1094                payload: vec![vec![("title".into(), FilterValue::String("p1".into()))]],
1095            });
1096
1097        let _ = op.exec().await.expect("update + nested create");
1098
1099        let stmts = engine.statements();
1100        assert_eq!(
1101            stmts.len(),
1102            2,
1103            "parent UPDATE + nested child INSERT; got {stmts:#?}"
1104        );
1105        assert!(
1106            stmts[0].0.contains("UPDATE test_models"),
1107            "got: {}",
1108            stmts[0].0
1109        );
1110        assert!(stmts[1].0.contains("INSERT INTO"), "got: {}", stmts[1].0);
1111        assert!(stmts[1].0.contains("posts"), "got: {}", stmts[1].0);
1112        assert!(stmts[1].0.contains("author_id"), "got: {}", stmts[1].0);
1113    }
1114
1115    #[tokio::test]
1116    async fn update_with_nested_disconnect_emits_set_null_update() {
1117        let engine = RecordingEngine::new();
1118        let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
1119            .r#where(Filter::Equals("id".into(), FilterValue::Int(7)))
1120            .set("name", "Renamed")
1121            .with(NestedWriteOp::Disconnect {
1122                relation: "posts",
1123                target_table: "posts",
1124                foreign_key: "author_id",
1125                target_pk: "id",
1126                pk: FilterValue::Int(42),
1127            });
1128
1129        let _ = op.exec().await.expect("update + nested disconnect");
1130
1131        let stmts = engine.statements();
1132        assert_eq!(stmts.len(), 2, "got {stmts:#?}");
1133        assert!(
1134            stmts[0].0.contains("UPDATE test_models"),
1135            "got: {}",
1136            stmts[0].0
1137        );
1138        let (sql, params) = &stmts[1];
1139        assert!(sql.contains("UPDATE"), "got: {sql}");
1140        assert!(sql.contains("posts"), "got: {sql}");
1141        assert!(sql.contains("author_id"), "got: {sql}");
1142        assert!(sql.contains("NULL"), "got: {sql}");
1143        assert_eq!(params, &vec![FilterValue::Int(42)]);
1144    }
1145
1146    #[tokio::test]
1147    async fn update_nested_requires_pk_in_where_filter() {
1148        let engine = RecordingEngine::new();
1149        // `email` is not the PK column for TestModel — the executor must
1150        // refuse the nested-write path with a clear diagnostic.
1151        let op = UpdateOperation::<RecordingEngine, TestModel>::new(engine.clone())
1152            .r#where(Filter::Equals(
1153                "email".into(),
1154                FilterValue::String("a@x.com".into()),
1155            ))
1156            .set("name", "Renamed")
1157            .with(NestedWriteOp::Disconnect {
1158                relation: "posts",
1159                target_table: "posts",
1160                foreign_key: "author_id",
1161                target_pk: "id",
1162                pk: FilterValue::Int(42),
1163            });
1164
1165        let result = op.exec().await;
1166        let err = result.err().expect("non-PK where must error");
1167        let msg = err.to_string();
1168        assert!(
1169            msg.contains("primary-key column") || msg.contains("primary key"),
1170            "expected PK-required diagnostic, got: {msg}"
1171        );
1172        // Nothing should have been emitted to the engine.
1173        assert!(
1174            engine.statements().is_empty(),
1175            "no SQL should run: {:#?}",
1176            engine.statements()
1177        );
1178    }
1179}