prax_query/operations/
update.rs

1//! Update operation for modifying existing records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10/// An update operation for modifying existing records.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// let users = client
16///     .user()
17///     .update()
18///     .r#where(user::id::equals(1))
19///     .set("name", "Updated Name")
20///     .exec()
21///     .await?;
22/// ```
23pub struct UpdateOperation<E: QueryEngine, M: Model> {
24    engine: E,
25    filter: Filter,
26    updates: Vec<(String, FilterValue)>,
27    select: Select,
28    _model: PhantomData<M>,
29}
30
31impl<E: QueryEngine, M: Model> UpdateOperation<E, M> {
32    /// Create a new Update operation.
33    pub fn new(engine: E) -> Self {
34        Self {
35            engine,
36            filter: Filter::None,
37            updates: Vec::new(),
38            select: Select::All,
39            _model: PhantomData,
40        }
41    }
42
43    /// Add a filter condition.
44    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
45        let new_filter = filter.into();
46        self.filter = self.filter.and_then(new_filter);
47        self
48    }
49
50    /// Set a column to a new value.
51    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
52        self.updates.push((column.into(), value.into()));
53        self
54    }
55
56    /// Set multiple columns from an iterator.
57    pub fn set_many(
58        mut self,
59        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
60    ) -> Self {
61        for (col, val) in values {
62            self.updates.push((col.into(), val.into()));
63        }
64        self
65    }
66
67    /// Increment a numeric column.
68    pub fn increment(self, column: impl Into<String>, amount: i64) -> Self {
69        // This would need special handling in SQL generation
70        // For now, we'll implement a basic version
71        self.set(column, FilterValue::Int(amount))
72    }
73
74    /// Select specific fields to return.
75    pub fn select(mut self, select: impl Into<Select>) -> Self {
76        self.select = select.into();
77        self
78    }
79
80    /// Build the SQL query.
81    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
82        let mut sql = String::new();
83        let mut params = Vec::new();
84        let mut param_idx = 1;
85
86        // UPDATE clause
87        sql.push_str("UPDATE ");
88        sql.push_str(M::TABLE_NAME);
89
90        // SET clause
91        sql.push_str(" SET ");
92        let set_parts: Vec<_> = self
93            .updates
94            .iter()
95            .map(|(col, val)| {
96                params.push(val.clone());
97                let part = format!("{} = ${}", col, param_idx);
98                param_idx += 1;
99                part
100            })
101            .collect();
102        sql.push_str(&set_parts.join(", "));
103
104        // WHERE clause
105        if !self.filter.is_none() {
106            let (where_sql, where_params) = self.filter.to_sql(param_idx - 1);
107            sql.push_str(" WHERE ");
108            sql.push_str(&where_sql);
109            params.extend(where_params);
110        }
111
112        // RETURNING clause
113        sql.push_str(" RETURNING ");
114        sql.push_str(&self.select.to_sql());
115
116        (sql, params)
117    }
118
119    /// Execute the update and return modified records.
120    pub async fn exec(self) -> QueryResult<Vec<M>>
121    where
122        M: Send + 'static,
123    {
124        let (sql, params) = self.build_sql();
125        self.engine.execute_update::<M>(&sql, params).await
126    }
127
128    /// Execute the update and return the first modified record.
129    pub async fn exec_one(self) -> QueryResult<M>
130    where
131        M: Send + 'static,
132    {
133        let (sql, params) = self.build_sql();
134        self.engine.query_one::<M>(&sql, params).await
135    }
136}
137
138/// Update many records at once.
139pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
140    engine: E,
141    filter: Filter,
142    updates: Vec<(String, FilterValue)>,
143    _model: PhantomData<M>,
144}
145
146impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
147    /// Create a new UpdateMany operation.
148    pub fn new(engine: E) -> Self {
149        Self {
150            engine,
151            filter: Filter::None,
152            updates: Vec::new(),
153            _model: PhantomData,
154        }
155    }
156
157    /// Add a filter condition.
158    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
159        let new_filter = filter.into();
160        self.filter = self.filter.and_then(new_filter);
161        self
162    }
163
164    /// Set a column to a new value.
165    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
166        self.updates.push((column.into(), value.into()));
167        self
168    }
169
170    /// Build the SQL query.
171    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
172        let mut sql = String::new();
173        let mut params = Vec::new();
174        let mut param_idx = 1;
175
176        // UPDATE clause
177        sql.push_str("UPDATE ");
178        sql.push_str(M::TABLE_NAME);
179
180        // SET clause
181        sql.push_str(" SET ");
182        let set_parts: Vec<_> = self
183            .updates
184            .iter()
185            .map(|(col, val)| {
186                params.push(val.clone());
187                let part = format!("{} = ${}", col, param_idx);
188                param_idx += 1;
189                part
190            })
191            .collect();
192        sql.push_str(&set_parts.join(", "));
193
194        // WHERE clause
195        if !self.filter.is_none() {
196            let (where_sql, where_params) = self.filter.to_sql(param_idx - 1);
197            sql.push_str(" WHERE ");
198            sql.push_str(&where_sql);
199            params.extend(where_params);
200        }
201
202        (sql, params)
203    }
204
205    /// Execute the update and return the count of modified records.
206    pub async fn exec(self) -> QueryResult<u64> {
207        let (sql, params) = self.build_sql();
208        self.engine.execute_raw(&sql, params).await
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::error::QueryError;
216    use crate::types::Select;
217
218    struct TestModel;
219
220    impl Model for TestModel {
221        const MODEL_NAME: &'static str = "TestModel";
222        const TABLE_NAME: &'static str = "test_models";
223        const PRIMARY_KEY: &'static [&'static str] = &["id"];
224        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
225    }
226
227    #[derive(Clone)]
228    struct MockEngine {
229        return_count: u64,
230    }
231
232    impl MockEngine {
233        fn new() -> Self {
234            Self { return_count: 0 }
235        }
236
237        fn with_count(count: u64) -> Self {
238            Self { return_count: count }
239        }
240    }
241
242    impl QueryEngine for MockEngine {
243        fn query_many<T: Model + Send + 'static>(
244            &self,
245            _sql: &str,
246            _params: Vec<FilterValue>,
247        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
248            Box::pin(async { Ok(Vec::new()) })
249        }
250
251        fn query_one<T: Model + Send + 'static>(
252            &self,
253            _sql: &str,
254            _params: Vec<FilterValue>,
255        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
256            Box::pin(async { Err(QueryError::not_found("test")) })
257        }
258
259        fn query_optional<T: Model + Send + 'static>(
260            &self,
261            _sql: &str,
262            _params: Vec<FilterValue>,
263        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
264            Box::pin(async { Ok(None) })
265        }
266
267        fn execute_insert<T: Model + Send + 'static>(
268            &self,
269            _sql: &str,
270            _params: Vec<FilterValue>,
271        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
272            Box::pin(async { Err(QueryError::not_found("test")) })
273        }
274
275        fn execute_update<T: Model + Send + 'static>(
276            &self,
277            _sql: &str,
278            _params: Vec<FilterValue>,
279        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
280            Box::pin(async { Ok(Vec::new()) })
281        }
282
283        fn execute_delete(
284            &self,
285            _sql: &str,
286            _params: Vec<FilterValue>,
287        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
288            Box::pin(async { Ok(0) })
289        }
290
291        fn execute_raw(
292            &self,
293            _sql: &str,
294            _params: Vec<FilterValue>,
295        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
296            let count = self.return_count;
297            Box::pin(async move { Ok(count) })
298        }
299
300        fn count(
301            &self,
302            _sql: &str,
303            _params: Vec<FilterValue>,
304        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
305            Box::pin(async { Ok(0) })
306        }
307    }
308
309    // ========== UpdateOperation Tests ==========
310
311    #[test]
312    fn test_update_new() {
313        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
314        let (sql, params) = op.build_sql();
315
316        assert!(sql.contains("UPDATE test_models SET"));
317        assert!(sql.contains("RETURNING *"));
318        assert!(params.is_empty());
319    }
320
321    #[test]
322    fn test_update_basic() {
323        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
324            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
325            .set("name", "Updated");
326
327        let (sql, params) = op.build_sql();
328
329        assert!(sql.contains("UPDATE test_models SET"));
330        assert!(sql.contains("name = $1"));
331        assert!(sql.contains("WHERE"));
332        assert!(sql.contains("RETURNING *"));
333        assert_eq!(params.len(), 2);
334    }
335
336    #[test]
337    fn test_update_many_fields() {
338        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
339            .set("name", "Updated")
340            .set("email", "updated@example.com");
341
342        let (sql, params) = op.build_sql();
343
344        assert!(sql.contains("name = $1"));
345        assert!(sql.contains("email = $2"));
346        assert_eq!(params.len(), 2);
347    }
348
349    #[test]
350    fn test_update_with_set_many() {
351        let updates = vec![
352            ("name", FilterValue::String("Alice".to_string())),
353            ("email", FilterValue::String("alice@test.com".to_string())),
354            ("age", FilterValue::Int(30)),
355        ];
356        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
357            .set_many(updates);
358
359        let (sql, params) = op.build_sql();
360
361        assert!(sql.contains("name = $1"));
362        assert!(sql.contains("email = $2"));
363        assert!(sql.contains("age = $3"));
364        assert_eq!(params.len(), 3);
365    }
366
367    #[test]
368    fn test_update_increment() {
369        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
370            .increment("counter", 5);
371
372        let (sql, params) = op.build_sql();
373
374        assert!(sql.contains("counter = $1"));
375        assert_eq!(params.len(), 1);
376        assert_eq!(params[0], FilterValue::Int(5));
377    }
378
379    #[test]
380    fn test_update_with_select() {
381        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
382            .set("name", "Updated")
383            .select(Select::fields(["id", "name"]));
384
385        let (sql, _) = op.build_sql();
386
387        assert!(sql.contains("RETURNING id, name"));
388    }
389
390    #[test]
391    fn test_update_with_complex_filter() {
392        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
393            .r#where(Filter::Equals("status".into(), FilterValue::String("active".to_string())))
394            .r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
395            .set("verified", FilterValue::Bool(true));
396
397        let (sql, params) = op.build_sql();
398
399        assert!(sql.contains("WHERE"));
400        assert!(sql.contains("AND"));
401        assert_eq!(params.len(), 3); // 1 set + 2 where
402    }
403
404    #[test]
405    fn test_update_without_filter() {
406        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
407            .set("status", "updated");
408
409        let (sql, _) = op.build_sql();
410
411        // Should not have WHERE clause
412        assert!(!sql.contains("WHERE"));
413        assert!(sql.contains("UPDATE test_models SET"));
414    }
415
416    #[test]
417    fn test_update_with_null_value() {
418        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
419            .set("deleted_at", FilterValue::Null);
420
421        let (sql, params) = op.build_sql();
422
423        assert!(sql.contains("deleted_at = $1"));
424        assert_eq!(params.len(), 1);
425        assert_eq!(params[0], FilterValue::Null);
426    }
427
428    #[test]
429    fn test_update_with_boolean() {
430        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
431            .set("active", FilterValue::Bool(true))
432            .set("verified", FilterValue::Bool(false));
433
434        let (sql, params) = op.build_sql();
435
436        assert_eq!(params.len(), 2);
437        assert_eq!(params[0], FilterValue::Bool(true));
438        assert_eq!(params[1], FilterValue::Bool(false));
439    }
440
441    #[tokio::test]
442    async fn test_update_exec() {
443        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
444            .set("name", "Updated");
445
446        let result = op.exec().await;
447        assert!(result.is_ok());
448        assert!(result.unwrap().is_empty());
449    }
450
451    #[tokio::test]
452    async fn test_update_exec_one() {
453        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
454            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
455            .set("name", "Updated");
456
457        let result = op.exec_one().await;
458        assert!(result.is_err()); // MockEngine returns not_found
459    }
460
461    // ========== UpdateManyOperation Tests ==========
462
463    #[test]
464    fn test_update_many_new() {
465        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
466        let (sql, params) = op.build_sql();
467
468        assert!(sql.contains("UPDATE test_models SET"));
469        assert!(!sql.contains("RETURNING")); // UpdateMany doesn't return records
470        assert!(params.is_empty());
471    }
472
473    #[test]
474    fn test_update_many_basic() {
475        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
476            .r#where(Filter::In(
477                "id".into(),
478                vec![FilterValue::Int(1), FilterValue::Int(2), FilterValue::Int(3)],
479            ))
480            .set("status", "processed");
481
482        let (sql, params) = op.build_sql();
483
484        assert!(sql.contains("UPDATE test_models SET"));
485        assert!(sql.contains("status = $1"));
486        assert!(sql.contains("WHERE"));
487        assert!(sql.contains("IN"));
488        assert_eq!(params.len(), 4); // 1 set + 3 IN values
489    }
490
491    #[test]
492    fn test_update_many_with_multiple_conditions() {
493        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
494            .r#where(Filter::Equals("department".into(), FilterValue::String("engineering".to_string())))
495            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
496            .set("reviewed", FilterValue::Bool(true));
497
498        let (sql, params) = op.build_sql();
499
500        assert!(sql.contains("AND"));
501        assert_eq!(params.len(), 3);
502    }
503
504    #[test]
505    fn test_update_many_without_where() {
506        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
507            .set("reset_password", FilterValue::Bool(true));
508
509        let (sql, _) = op.build_sql();
510
511        assert!(!sql.contains("WHERE"));
512    }
513
514    #[tokio::test]
515    async fn test_update_many_exec() {
516        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
517            .set("status", "updated");
518
519        let result = op.exec().await;
520        assert!(result.is_ok());
521        assert_eq!(result.unwrap(), 5);
522    }
523
524    // ========== SQL Generation Edge Cases ==========
525
526    #[test]
527    fn test_update_param_ordering() {
528        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
529            .set("field1", "value1")
530            .set("field2", "value2")
531            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
532
533        let (sql, params) = op.build_sql();
534
535        // SET params come first, then WHERE params
536        assert!(sql.contains("field1 = $1"));
537        assert!(sql.contains("field2 = $2"));
538        assert!(sql.contains("id = $3"));
539        assert_eq!(params.len(), 3);
540    }
541
542    #[test]
543    fn test_update_many_param_ordering() {
544        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
545            .set("field1", "value1")
546            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
547
548        let (sql, params) = op.build_sql();
549
550        assert!(sql.contains("field1 = $1"));
551        assert!(sql.contains("id = $2"));
552        assert_eq!(params.len(), 2);
553    }
554
555    #[test]
556    fn test_update_with_float_value() {
557        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
558            .set("price", FilterValue::Float(99.99));
559
560        let (sql, params) = op.build_sql();
561
562        assert!(sql.contains("price = $1"));
563        assert_eq!(params.len(), 1);
564    }
565
566    #[test]
567    fn test_update_with_json_value() {
568        let json_value = serde_json::json!({"key": "value"});
569        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
570            .set("metadata", FilterValue::Json(json_value.clone()));
571
572        let (sql, params) = op.build_sql();
573
574        assert!(sql.contains("metadata = $1"));
575        assert_eq!(params[0], FilterValue::Json(json_value));
576    }
577}
578