prax_query/operations/
update.rs

1//! Update operation for modifying existing records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10/// An update operation for modifying existing records.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// let users = client
16///     .user()
17///     .update()
18///     .r#where(user::id::equals(1))
19///     .set("name", "Updated Name")
20///     .exec()
21///     .await?;
22/// ```
23pub struct UpdateOperation<E: QueryEngine, M: Model> {
24    engine: E,
25    filter: Filter,
26    updates: Vec<(String, FilterValue)>,
27    select: Select,
28    _model: PhantomData<M>,
29}
30
31impl<E: QueryEngine, M: Model> UpdateOperation<E, M> {
32    /// Create a new Update operation.
33    pub fn new(engine: E) -> Self {
34        Self {
35            engine,
36            filter: Filter::None,
37            updates: Vec::new(),
38            select: Select::All,
39            _model: PhantomData,
40        }
41    }
42
43    /// Add a filter condition.
44    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
45        let new_filter = filter.into();
46        self.filter = self.filter.and_then(new_filter);
47        self
48    }
49
50    /// Set a column to a new value.
51    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
52        self.updates.push((column.into(), value.into()));
53        self
54    }
55
56    /// Set multiple columns from an iterator.
57    pub fn set_many(
58        mut self,
59        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
60    ) -> Self {
61        for (col, val) in values {
62            self.updates.push((col.into(), val.into()));
63        }
64        self
65    }
66
67    /// Increment a numeric column.
68    pub fn increment(self, column: impl Into<String>, amount: i64) -> Self {
69        // This would need special handling in SQL generation
70        // For now, we'll implement a basic version
71        self.set(column, FilterValue::Int(amount))
72    }
73
74    /// Select specific fields to return.
75    pub fn select(mut self, select: impl Into<Select>) -> Self {
76        self.select = select.into();
77        self
78    }
79
80    /// Build the SQL query.
81    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
82        let mut sql = String::new();
83        let mut params = Vec::new();
84        let mut param_idx = 1;
85
86        // UPDATE clause
87        sql.push_str("UPDATE ");
88        sql.push_str(M::TABLE_NAME);
89
90        // SET clause
91        sql.push_str(" SET ");
92        let set_parts: Vec<_> = self
93            .updates
94            .iter()
95            .map(|(col, val)| {
96                params.push(val.clone());
97                let part = format!("{} = ${}", col, param_idx);
98                param_idx += 1;
99                part
100            })
101            .collect();
102        sql.push_str(&set_parts.join(", "));
103
104        // WHERE clause
105        if !self.filter.is_none() {
106            let (where_sql, where_params) = self.filter.to_sql(param_idx - 1);
107            sql.push_str(" WHERE ");
108            sql.push_str(&where_sql);
109            params.extend(where_params);
110        }
111
112        // RETURNING clause
113        sql.push_str(" RETURNING ");
114        sql.push_str(&self.select.to_sql());
115
116        (sql, params)
117    }
118
119    /// Execute the update and return modified records.
120    pub async fn exec(self) -> QueryResult<Vec<M>>
121    where
122        M: Send + 'static,
123    {
124        let (sql, params) = self.build_sql();
125        self.engine.execute_update::<M>(&sql, params).await
126    }
127
128    /// Execute the update and return the first modified record.
129    pub async fn exec_one(self) -> QueryResult<M>
130    where
131        M: Send + 'static,
132    {
133        let (sql, params) = self.build_sql();
134        self.engine.query_one::<M>(&sql, params).await
135    }
136}
137
138/// Update many records at once.
139pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
140    engine: E,
141    filter: Filter,
142    updates: Vec<(String, FilterValue)>,
143    _model: PhantomData<M>,
144}
145
146impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
147    /// Create a new UpdateMany operation.
148    pub fn new(engine: E) -> Self {
149        Self {
150            engine,
151            filter: Filter::None,
152            updates: Vec::new(),
153            _model: PhantomData,
154        }
155    }
156
157    /// Add a filter condition.
158    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
159        let new_filter = filter.into();
160        self.filter = self.filter.and_then(new_filter);
161        self
162    }
163
164    /// Set a column to a new value.
165    pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
166        self.updates.push((column.into(), value.into()));
167        self
168    }
169
170    /// Build the SQL query.
171    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
172        let mut sql = String::new();
173        let mut params = Vec::new();
174        let mut param_idx = 1;
175
176        // UPDATE clause
177        sql.push_str("UPDATE ");
178        sql.push_str(M::TABLE_NAME);
179
180        // SET clause
181        sql.push_str(" SET ");
182        let set_parts: Vec<_> = self
183            .updates
184            .iter()
185            .map(|(col, val)| {
186                params.push(val.clone());
187                let part = format!("{} = ${}", col, param_idx);
188                param_idx += 1;
189                part
190            })
191            .collect();
192        sql.push_str(&set_parts.join(", "));
193
194        // WHERE clause
195        if !self.filter.is_none() {
196            let (where_sql, where_params) = self.filter.to_sql(param_idx - 1);
197            sql.push_str(" WHERE ");
198            sql.push_str(&where_sql);
199            params.extend(where_params);
200        }
201
202        (sql, params)
203    }
204
205    /// Execute the update and return the count of modified records.
206    pub async fn exec(self) -> QueryResult<u64> {
207        let (sql, params) = self.build_sql();
208        self.engine.execute_raw(&sql, params).await
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::error::QueryError;
216    use crate::types::Select;
217
218    struct TestModel;
219
220    impl Model for TestModel {
221        const MODEL_NAME: &'static str = "TestModel";
222        const TABLE_NAME: &'static str = "test_models";
223        const PRIMARY_KEY: &'static [&'static str] = &["id"];
224        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
225    }
226
227    #[derive(Clone)]
228    struct MockEngine {
229        return_count: u64,
230    }
231
232    impl MockEngine {
233        fn new() -> Self {
234            Self { return_count: 0 }
235        }
236
237        fn with_count(count: u64) -> Self {
238            Self {
239                return_count: count,
240            }
241        }
242    }
243
244    impl QueryEngine for MockEngine {
245        fn query_many<T: Model + Send + 'static>(
246            &self,
247            _sql: &str,
248            _params: Vec<FilterValue>,
249        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
250            Box::pin(async { Ok(Vec::new()) })
251        }
252
253        fn query_one<T: Model + Send + 'static>(
254            &self,
255            _sql: &str,
256            _params: Vec<FilterValue>,
257        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
258            Box::pin(async { Err(QueryError::not_found("test")) })
259        }
260
261        fn query_optional<T: Model + Send + 'static>(
262            &self,
263            _sql: &str,
264            _params: Vec<FilterValue>,
265        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
266            Box::pin(async { Ok(None) })
267        }
268
269        fn execute_insert<T: Model + Send + 'static>(
270            &self,
271            _sql: &str,
272            _params: Vec<FilterValue>,
273        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
274            Box::pin(async { Err(QueryError::not_found("test")) })
275        }
276
277        fn execute_update<T: Model + Send + 'static>(
278            &self,
279            _sql: &str,
280            _params: Vec<FilterValue>,
281        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
282            Box::pin(async { Ok(Vec::new()) })
283        }
284
285        fn execute_delete(
286            &self,
287            _sql: &str,
288            _params: Vec<FilterValue>,
289        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
290            Box::pin(async { Ok(0) })
291        }
292
293        fn execute_raw(
294            &self,
295            _sql: &str,
296            _params: Vec<FilterValue>,
297        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
298            let count = self.return_count;
299            Box::pin(async move { Ok(count) })
300        }
301
302        fn count(
303            &self,
304            _sql: &str,
305            _params: Vec<FilterValue>,
306        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
307            Box::pin(async { Ok(0) })
308        }
309    }
310
311    // ========== UpdateOperation Tests ==========
312
313    #[test]
314    fn test_update_new() {
315        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
316        let (sql, params) = op.build_sql();
317
318        assert!(sql.contains("UPDATE test_models SET"));
319        assert!(sql.contains("RETURNING *"));
320        assert!(params.is_empty());
321    }
322
323    #[test]
324    fn test_update_basic() {
325        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
326            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
327            .set("name", "Updated");
328
329        let (sql, params) = op.build_sql();
330
331        assert!(sql.contains("UPDATE test_models SET"));
332        assert!(sql.contains("name = $1"));
333        assert!(sql.contains("WHERE"));
334        assert!(sql.contains("RETURNING *"));
335        assert_eq!(params.len(), 2);
336    }
337
338    #[test]
339    fn test_update_many_fields() {
340        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
341            .set("name", "Updated")
342            .set("email", "updated@example.com");
343
344        let (sql, params) = op.build_sql();
345
346        assert!(sql.contains("name = $1"));
347        assert!(sql.contains("email = $2"));
348        assert_eq!(params.len(), 2);
349    }
350
351    #[test]
352    fn test_update_with_set_many() {
353        let updates = vec![
354            ("name", FilterValue::String("Alice".to_string())),
355            ("email", FilterValue::String("alice@test.com".to_string())),
356            ("age", FilterValue::Int(30)),
357        ];
358        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(updates);
359
360        let (sql, params) = op.build_sql();
361
362        assert!(sql.contains("name = $1"));
363        assert!(sql.contains("email = $2"));
364        assert!(sql.contains("age = $3"));
365        assert_eq!(params.len(), 3);
366    }
367
368    #[test]
369    fn test_update_increment() {
370        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
371            .increment("counter", 5);
372
373        let (sql, params) = op.build_sql();
374
375        assert!(sql.contains("counter = $1"));
376        assert_eq!(params.len(), 1);
377        assert_eq!(params[0], FilterValue::Int(5));
378    }
379
380    #[test]
381    fn test_update_with_select() {
382        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
383            .set("name", "Updated")
384            .select(Select::fields(["id", "name"]));
385
386        let (sql, _) = op.build_sql();
387
388        assert!(sql.contains("RETURNING id, name"));
389    }
390
391    #[test]
392    fn test_update_with_complex_filter() {
393        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
394            .r#where(Filter::Equals(
395                "status".into(),
396                FilterValue::String("active".to_string()),
397            ))
398            .r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
399            .set("verified", FilterValue::Bool(true));
400
401        let (sql, params) = op.build_sql();
402
403        assert!(sql.contains("WHERE"));
404        assert!(sql.contains("AND"));
405        assert_eq!(params.len(), 3); // 1 set + 2 where
406    }
407
408    #[test]
409    fn test_update_without_filter() {
410        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
411            .set("status", "updated");
412
413        let (sql, _) = op.build_sql();
414
415        // Should not have WHERE clause
416        assert!(!sql.contains("WHERE"));
417        assert!(sql.contains("UPDATE test_models SET"));
418    }
419
420    #[test]
421    fn test_update_with_null_value() {
422        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
423            .set("deleted_at", FilterValue::Null);
424
425        let (sql, params) = op.build_sql();
426
427        assert!(sql.contains("deleted_at = $1"));
428        assert_eq!(params.len(), 1);
429        assert_eq!(params[0], FilterValue::Null);
430    }
431
432    #[test]
433    fn test_update_with_boolean() {
434        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
435            .set("active", FilterValue::Bool(true))
436            .set("verified", FilterValue::Bool(false));
437
438        let (sql, params) = op.build_sql();
439
440        assert_eq!(params.len(), 2);
441        assert_eq!(params[0], FilterValue::Bool(true));
442        assert_eq!(params[1], FilterValue::Bool(false));
443    }
444
445    #[tokio::test]
446    async fn test_update_exec() {
447        let op =
448            UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Updated");
449
450        let result = op.exec().await;
451        assert!(result.is_ok());
452        assert!(result.unwrap().is_empty());
453    }
454
455    #[tokio::test]
456    async fn test_update_exec_one() {
457        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
458            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
459            .set("name", "Updated");
460
461        let result = op.exec_one().await;
462        assert!(result.is_err()); // MockEngine returns not_found
463    }
464
465    // ========== UpdateManyOperation Tests ==========
466
467    #[test]
468    fn test_update_many_new() {
469        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
470        let (sql, params) = op.build_sql();
471
472        assert!(sql.contains("UPDATE test_models SET"));
473        assert!(!sql.contains("RETURNING")); // UpdateMany doesn't return records
474        assert!(params.is_empty());
475    }
476
477    #[test]
478    fn test_update_many_basic() {
479        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
480            .r#where(Filter::In(
481                "id".into(),
482                vec![
483                    FilterValue::Int(1),
484                    FilterValue::Int(2),
485                    FilterValue::Int(3),
486                ],
487            ))
488            .set("status", "processed");
489
490        let (sql, params) = op.build_sql();
491
492        assert!(sql.contains("UPDATE test_models SET"));
493        assert!(sql.contains("status = $1"));
494        assert!(sql.contains("WHERE"));
495        assert!(sql.contains("IN"));
496        assert_eq!(params.len(), 4); // 1 set + 3 IN values
497    }
498
499    #[test]
500    fn test_update_many_with_multiple_conditions() {
501        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
502            .r#where(Filter::Equals(
503                "department".into(),
504                FilterValue::String("engineering".to_string()),
505            ))
506            .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
507            .set("reviewed", FilterValue::Bool(true));
508
509        let (sql, params) = op.build_sql();
510
511        assert!(sql.contains("AND"));
512        assert_eq!(params.len(), 3);
513    }
514
515    #[test]
516    fn test_update_many_without_where() {
517        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
518            .set("reset_password", FilterValue::Bool(true));
519
520        let (sql, _) = op.build_sql();
521
522        assert!(!sql.contains("WHERE"));
523    }
524
525    #[tokio::test]
526    async fn test_update_many_exec() {
527        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
528            .set("status", "updated");
529
530        let result = op.exec().await;
531        assert!(result.is_ok());
532        assert_eq!(result.unwrap(), 5);
533    }
534
535    // ========== SQL Generation Edge Cases ==========
536
537    #[test]
538    fn test_update_param_ordering() {
539        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
540            .set("field1", "value1")
541            .set("field2", "value2")
542            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
543
544        let (sql, params) = op.build_sql();
545
546        // SET params come first, then WHERE params
547        assert!(sql.contains("field1 = $1"));
548        assert!(sql.contains("field2 = $2"));
549        assert!(sql.contains("id = $3"));
550        assert_eq!(params.len(), 3);
551    }
552
553    #[test]
554    fn test_update_many_param_ordering() {
555        let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
556            .set("field1", "value1")
557            .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
558
559        let (sql, params) = op.build_sql();
560
561        assert!(sql.contains("field1 = $1"));
562        assert!(sql.contains("id = $2"));
563        assert_eq!(params.len(), 2);
564    }
565
566    #[test]
567    fn test_update_with_float_value() {
568        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
569            .set("price", FilterValue::Float(99.99));
570
571        let (sql, params) = op.build_sql();
572
573        assert!(sql.contains("price = $1"));
574        assert_eq!(params.len(), 1);
575    }
576
577    #[test]
578    fn test_update_with_json_value() {
579        let json_value = serde_json::json!({"key": "value"});
580        let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
581            .set("metadata", FilterValue::Json(json_value.clone()));
582
583        let (sql, params) = op.build_sql();
584
585        assert!(sql.contains("metadata = $1"));
586        assert_eq!(params[0], FilterValue::Json(json_value));
587    }
588}