Skip to main content

prax_query/operations/
update.rs

1//! Update operation for modifying existing records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10/// An update operation for modifying existing records.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// let users = client
16///     .user()
17///     .update()
18///     .r#where(user::id::equals(1))
19///     .set("name", "Updated Name")
20///     .exec()
21///     .await?;
22/// ```
23pub struct UpdateOperation<E: QueryEngine, M: Model> {
24    engine: E,
25    filter: Filter,
26    updates: Vec<(String, FilterValue)>,
27    select: Select,
28    _model: PhantomData<M>,
29}
30
31impl<E: QueryEngine, M: Model + crate::row::FromRow> UpdateOperation<E, M> {
32    /// Create a new Update operation.
33    pub fn new(engine: E) -> Self {
34        Self {
35            engine,
36            filter: Filter::None,
37            updates: Vec::new(),
38            select: Select::All,
39            _model: PhantomData,
40        }
41    }
42
43    /// Add a filter condition.
44    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
45        let new_filter = filter.into();
46        self.filter = self.filter.and_then(new_filter);
47        self
48    }
49
50    /// Set a column to a new value.
51    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
52        self.updates.push((column.into(), value.into()));
53        self
54    }
55
56    /// Set multiple columns from an iterator.
57    pub fn set_many(
58        mut self,
59        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
60    ) -> Self {
61        for (col, val) in values {
62            self.updates.push((col.into(), val.into()));
63        }
64        self
65    }
66
67    /// Increment a numeric column.
68    pub fn increment(self, column: impl Into<String>, amount: i64) -> Self {
69        // This would need special handling in SQL generation
70        // For now, we'll implement a basic version
71        self.set(column, FilterValue::Int(amount))
72    }
73
74    /// Select specific fields to return.
75    pub fn select(mut self, select: impl Into<Select>) -> Self {
76        self.select = select.into();
77        self
78    }
79
80    /// Build the SQL query.
81    pub fn build_sql(
82        &self,
83        dialect: &dyn crate::dialect::SqlDialect,
84    ) -> (String, Vec<FilterValue>) {
85        let mut sql = String::new();
86        let mut params = Vec::new();
87        let mut param_idx = 1;
88
89        // UPDATE clause
90        sql.push_str("UPDATE ");
91        sql.push_str(M::TABLE_NAME);
92
93        // SET clause
94        sql.push_str(" SET ");
95        let set_parts: Vec<_> = self
96            .updates
97            .iter()
98            .map(|(col, val)| {
99                params.push(val.clone());
100                let part = format!("{} = {}", col, dialect.placeholder(param_idx));
101                param_idx += 1;
102                part
103            })
104            .collect();
105        sql.push_str(&set_parts.join(", "));
106
107        // WHERE clause
108        if !self.filter.is_none() {
109            let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
110            sql.push_str(" WHERE ");
111            sql.push_str(&where_sql);
112            params.extend(where_params);
113        }
114
115        // RETURNING clause
116        sql.push_str(&dialect.returning_clause(&self.select.to_sql()));
117
118        (sql, params)
119    }
120
121    /// Execute the update and return modified records.
122    pub async fn exec(self) -> QueryResult<Vec<M>>
123    where
124        M: Send + 'static,
125    {
126        let dialect = self.engine.dialect();
127        let (sql, params) = self.build_sql(dialect);
128        self.engine.execute_update::<M>(&sql, params).await
129    }
130
131    /// Execute the update and return the first modified record.
132    pub async fn exec_one(self) -> QueryResult<M>
133    where
134        M: Send + 'static,
135    {
136        let dialect = self.engine.dialect();
137        let (sql, params) = self.build_sql(dialect);
138        self.engine.query_one::<M>(&sql, params).await
139    }
140}
141
142/// Update many records at once.
143pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
144    engine: E,
145    filter: Filter,
146    updates: Vec<(String, FilterValue)>,
147    _model: PhantomData<M>,
148}
149
150impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
151    /// Create a new UpdateMany operation.
152    pub fn new(engine: E) -> Self {
153        Self {
154            engine,
155            filter: Filter::None,
156            updates: Vec::new(),
157            _model: PhantomData,
158        }
159    }
160
161    /// Add a filter condition.
162    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
163        let new_filter = filter.into();
164        self.filter = self.filter.and_then(new_filter);
165        self
166    }
167
168    /// Set a column to a new value.
169    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
170        self.updates.push((column.into(), value.into()));
171        self
172    }
173
174    /// Build the SQL query.
175    pub fn build_sql(
176        &self,
177        dialect: &dyn crate::dialect::SqlDialect,
178    ) -> (String, Vec<FilterValue>) {
179        let mut sql = String::new();
180        let mut params = Vec::new();
181        let mut param_idx = 1;
182
183        // UPDATE clause
184        sql.push_str("UPDATE ");
185        sql.push_str(M::TABLE_NAME);
186
187        // SET clause
188        sql.push_str(" SET ");
189        let set_parts: Vec<_> = self
190            .updates
191            .iter()
192            .map(|(col, val)| {
193                params.push(val.clone());
194                let part = format!("{} = {}", col, dialect.placeholder(param_idx));
195                param_idx += 1;
196                part
197            })
198            .collect();
199        sql.push_str(&set_parts.join(", "));
200
201        // WHERE clause
202        if !self.filter.is_none() {
203            let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
204            sql.push_str(" WHERE ");
205            sql.push_str(&where_sql);
206            params.extend(where_params);
207        }
208
209        (sql, params)
210    }
211
212    /// Execute the update and return the count of modified records.
213    pub async fn exec(self) -> QueryResult<u64> {
214        let dialect = self.engine.dialect();
215        let (sql, params) = self.build_sql(dialect);
216        self.engine.execute_raw(&sql, params).await
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    use crate::error::QueryError;
224    use crate::types::Select;
225
226    struct TestModel;
227
228    impl Model for TestModel {
229        const MODEL_NAME: &'static str = "TestModel";
230        const TABLE_NAME: &'static str = "test_models";
231        const PRIMARY_KEY: &'static [&'static str] = &["id"];
232        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
233    }
234
235    impl crate::row::FromRow for TestModel {
236        fn from_row(_row: &impl crate::row::RowRef) -> Result<Self, crate::row::RowError> {
237            Ok(TestModel)
238        }
239    }
240
241    #[derive(Clone)]
242    struct MockEngine {
243        return_count: u64,
244    }
245
246    impl MockEngine {
247        fn new() -> Self {
248            Self { return_count: 0 }
249        }
250
251        fn with_count(count: u64) -> Self {
252            Self {
253                return_count: count,
254            }
255        }
256    }
257
258    impl QueryEngine for MockEngine {
259        fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
260            &crate::dialect::Postgres
261        }
262
263        fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
264            &self,
265            _sql: &str,
266            _params: Vec<FilterValue>,
267        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
268            Box::pin(async { Ok(Vec::new()) })
269        }
270
271        fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
272            &self,
273            _sql: &str,
274            _params: Vec<FilterValue>,
275        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
276            Box::pin(async { Err(QueryError::not_found("test")) })
277        }
278
279        fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
280            &self,
281            _sql: &str,
282            _params: Vec<FilterValue>,
283        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
284            Box::pin(async { Ok(None) })
285        }
286
287        fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
288            &self,
289            _sql: &str,
290            _params: Vec<FilterValue>,
291        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
292            Box::pin(async { Err(QueryError::not_found("test")) })
293        }
294
295        fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
296            &self,
297            _sql: &str,
298            _params: Vec<FilterValue>,
299        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
300            Box::pin(async { Ok(Vec::new()) })
301        }
302
303        fn execute_delete(
304            &self,
305            _sql: &str,
306            _params: Vec<FilterValue>,
307        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
308            Box::pin(async { Ok(0) })
309        }
310
311        fn execute_raw(
312            &self,
313            _sql: &str,
314            _params: Vec<FilterValue>,
315        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
316            let count = self.return_count;
317            Box::pin(async move { Ok(count) })
318        }
319
320        fn count(
321            &self,
322            _sql: &str,
323            _params: Vec<FilterValue>,
324        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
325            Box::pin(async { Ok(0) })
326        }
327    }
328
329    // ========== UpdateOperation Tests ==========
330
331    #[test]
332    fn test_update_new() {
333        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
334        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
335
336        assert!(sql.contains("UPDATE test_models SET"));
337        assert!(sql.contains("RETURNING *"));
338        assert!(params.is_empty());
339    }
340
341    #[test]
342    fn test_update_basic() {
343        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
344            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
345            .set("name", "Updated");
346
347        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
348
349        assert!(sql.contains("UPDATE test_models SET"));
350        assert!(sql.contains("name = $1"));
351        assert!(sql.contains("WHERE"));
352        assert!(sql.contains("RETURNING *"));
353        assert_eq!(params.len(), 2);
354    }
355
356    #[test]
357    fn test_update_many_fields() {
358        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
359            .set("name", "Updated")
360            .set("email", "updated@example.com");
361
362        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
363
364        assert!(sql.contains("name = $1"));
365        assert!(sql.contains("email = $2"));
366        assert_eq!(params.len(), 2);
367    }
368
369    #[test]
370    fn test_update_with_set_many() {
371        let updates = vec![
372            ("name", FilterValue::String("Alice".to_string())),
373            ("email", FilterValue::String("alice@test.com".to_string())),
374            ("age", FilterValue::Int(30)),
375        ];
376        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(updates);
377
378        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
379
380        assert!(sql.contains("name = $1"));
381        assert!(sql.contains("email = $2"));
382        assert!(sql.contains("age = $3"));
383        assert_eq!(params.len(), 3);
384    }
385
386    #[test]
387    fn test_update_increment() {
388        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
389            .increment("counter", 5);
390
391        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
392
393        assert!(sql.contains("counter = $1"));
394        assert_eq!(params.len(), 1);
395        assert_eq!(params[0], FilterValue::Int(5));
396    }
397
398    #[test]
399    fn test_update_with_select() {
400        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
401            .set("name", "Updated")
402            .select(Select::fields(["id", "name"]));
403
404        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
405
406        assert!(sql.contains("RETURNING id, name"));
407    }
408
409    #[test]
410    fn test_update_with_complex_filter() {
411        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
412            .r#where(Filter::Equals(
413                "status".into(),
414                FilterValue::String("active".to_string()),
415            ))
416            .r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
417            .set("verified", FilterValue::Bool(true));
418
419        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
420
421        assert!(sql.contains("WHERE"));
422        assert!(sql.contains("AND"));
423        assert_eq!(params.len(), 3); // 1 set + 2 where
424    }
425
426    #[test]
427    fn test_update_without_filter() {
428        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
429            .set("status", "updated");
430
431        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
432
433        // Should not have WHERE clause
434        assert!(!sql.contains("WHERE"));
435        assert!(sql.contains("UPDATE test_models SET"));
436    }
437
438    #[test]
439    fn test_update_with_null_value() {
440        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
441            .set("deleted_at", FilterValue::Null);
442
443        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
444
445        assert!(sql.contains("deleted_at = $1"));
446        assert_eq!(params.len(), 1);
447        assert_eq!(params[0], FilterValue::Null);
448    }
449
450    #[test]
451    fn test_update_with_boolean() {
452        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
453            .set("active", FilterValue::Bool(true))
454            .set("verified", FilterValue::Bool(false));
455
456        let (_sql, params) = op.build_sql(&crate::dialect::Postgres);
457
458        assert_eq!(params.len(), 2);
459        assert_eq!(params[0], FilterValue::Bool(true));
460        assert_eq!(params[1], FilterValue::Bool(false));
461    }
462
463    #[tokio::test]
464    async fn test_update_exec() {
465        let op =
466            UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Updated");
467
468        let result = op.exec().await;
469        assert!(result.is_ok());
470        assert!(result.unwrap().is_empty());
471    }
472
473    #[tokio::test]
474    async fn test_update_exec_one() {
475        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
476            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
477            .set("name", "Updated");
478
479        let result = op.exec_one().await;
480        assert!(result.is_err()); // MockEngine returns not_found
481    }
482
483    // ========== UpdateManyOperation Tests ==========
484
485    #[test]
486    fn test_update_many_new() {
487        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
488        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
489
490        assert!(sql.contains("UPDATE test_models SET"));
491        assert!(!sql.contains("RETURNING")); // UpdateMany doesn't return records
492        assert!(params.is_empty());
493    }
494
495    #[test]
496    fn test_update_many_basic() {
497        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
498            .r#where(Filter::In(
499                "id".into(),
500                vec![
501                    FilterValue::Int(1),
502                    FilterValue::Int(2),
503                    FilterValue::Int(3),
504                ],
505            ))
506            .set("status", "processed");
507
508        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
509
510        assert!(sql.contains("UPDATE test_models SET"));
511        assert!(sql.contains("status = $1"));
512        assert!(sql.contains("WHERE"));
513        assert!(sql.contains("IN"));
514        assert_eq!(params.len(), 4); // 1 set + 3 IN values
515    }
516
517    #[test]
518    fn test_update_many_with_multiple_conditions() {
519        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
520            .r#where(Filter::Equals(
521                "department".into(),
522                FilterValue::String("engineering".to_string()),
523            ))
524            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
525            .set("reviewed", FilterValue::Bool(true));
526
527        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
528
529        assert!(sql.contains("AND"));
530        assert_eq!(params.len(), 3);
531    }
532
533    #[test]
534    fn test_update_many_without_where() {
535        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
536            .set("reset_password", FilterValue::Bool(true));
537
538        let (sql, _) = op.build_sql(&crate::dialect::Postgres);
539
540        assert!(!sql.contains("WHERE"));
541    }
542
543    #[tokio::test]
544    async fn test_update_many_exec() {
545        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
546            .set("status", "updated");
547
548        let result = op.exec().await;
549        assert!(result.is_ok());
550        assert_eq!(result.unwrap(), 5);
551    }
552
553    // ========== SQL Generation Edge Cases ==========
554
555    #[test]
556    fn test_update_param_ordering() {
557        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
558            .set("field1", "value1")
559            .set("field2", "value2")
560            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
561
562        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
563
564        // SET params come first, then WHERE params
565        assert!(sql.contains("field1 = $1"));
566        assert!(sql.contains("field2 = $2"));
567        assert!(sql.contains(r#""id" = $3"#));
568        assert_eq!(params.len(), 3);
569    }
570
571    #[test]
572    fn test_update_many_param_ordering() {
573        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
574            .set("field1", "value1")
575            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
576
577        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
578
579        assert!(sql.contains("field1 = $1"));
580        assert!(sql.contains(r#""id" = $2"#));
581        assert_eq!(params.len(), 2);
582    }
583
584    #[test]
585    fn test_update_with_float_value() {
586        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
587            .set("price", FilterValue::Float(99.99));
588
589        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
590
591        assert!(sql.contains("price = $1"));
592        assert_eq!(params.len(), 1);
593    }
594
595    #[test]
596    fn test_update_with_json_value() {
597        let json_value = serde_json::json!({"key": "value"});
598        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
599            .set("metadata", FilterValue::Json(json_value.clone()));
600
601        let (sql, params) = op.build_sql(&crate::dialect::Postgres);
602
603        assert!(sql.contains("metadata = $1"));
604        assert_eq!(params[0], FilterValue::Json(json_value));
605    }
606}