prax_query/operations/
upsert.rs

1//! Upsert operation for creating or updating records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10/// An upsert (insert or update) operation.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// let user = client
16///     .user()
17///     .upsert()
18///     .r#where(user::email::equals("test@example.com"))
19///     .create(user::Create { email: "test@example.com".into(), name: Some("Test".into()) })
20///     .update(user::Update { name: Some("Updated".into()), ..Default::default() })
21///     .exec()
22///     .await?;
23/// ```
24pub struct UpsertOperation<E: QueryEngine, M: Model> {
25    engine: E,
26    filter: Filter,
27    create_columns: Vec<String>,
28    create_values: Vec<FilterValue>,
29    update_columns: Vec<String>,
30    update_values: Vec<FilterValue>,
31    conflict_columns: Vec<String>,
32    select: Select,
33    _model: PhantomData<M>,
34}
35
36impl<E: QueryEngine, M: Model> UpsertOperation<E, M> {
37    /// Create a new Upsert operation.
38    pub fn new(engine: E) -> Self {
39        Self {
40            engine,
41            filter: Filter::None,
42            create_columns: Vec::new(),
43            create_values: Vec::new(),
44            update_columns: Vec::new(),
45            update_values: Vec::new(),
46            conflict_columns: Vec::new(),
47            select: Select::All,
48            _model: PhantomData,
49        }
50    }
51
52    /// Add a filter condition (identifies the record to upsert).
53    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
54        self.filter = filter.into();
55        self
56    }
57
58    /// Set the columns to check for conflict.
59    pub fn on_conflict(mut self, columns: impl IntoIterator<Item = impl Into<String>>) -> Self {
60        self.conflict_columns = columns.into_iter().map(Into::into).collect();
61        self
62    }
63
64    /// Set the create data.
65    pub fn create(
66        mut self,
67        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
68    ) -> Self {
69        for (col, val) in values {
70            self.create_columns.push(col.into());
71            self.create_values.push(val.into());
72        }
73        self
74    }
75
76    /// Set a single create column.
77    pub fn create_set(
78        mut self,
79        column: impl Into<String>,
80        value: impl Into<FilterValue>,
81    ) -> Self {
82        self.create_columns.push(column.into());
83        self.create_values.push(value.into());
84        self
85    }
86
87    /// Set the update data.
88    pub fn update(
89        mut self,
90        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
91    ) -> Self {
92        for (col, val) in values {
93            self.update_columns.push(col.into());
94            self.update_values.push(val.into());
95        }
96        self
97    }
98
99    /// Set a single update column.
100    pub fn update_set(
101        mut self,
102        column: impl Into<String>,
103        value: impl Into<FilterValue>,
104    ) -> Self {
105        self.update_columns.push(column.into());
106        self.update_values.push(value.into());
107        self
108    }
109
110    /// Select specific fields to return.
111    pub fn select(mut self, select: impl Into<Select>) -> Self {
112        self.select = select.into();
113        self
114    }
115
116    /// Build the SQL query.
117    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
118        let mut sql = String::new();
119        let mut params = Vec::new();
120        let mut param_idx = 1;
121
122        // INSERT INTO clause
123        sql.push_str("INSERT INTO ");
124        sql.push_str(M::TABLE_NAME);
125
126        // Columns
127        sql.push_str(" (");
128        sql.push_str(&self.create_columns.join(", "));
129        sql.push(')');
130
131        // VALUES
132        sql.push_str(" VALUES (");
133        let placeholders: Vec<_> = self
134            .create_values
135            .iter()
136            .map(|v| {
137                params.push(v.clone());
138                let p = format!("${}", param_idx);
139                param_idx += 1;
140                p
141            })
142            .collect();
143        sql.push_str(&placeholders.join(", "));
144        sql.push(')');
145
146        // ON CONFLICT
147        sql.push_str(" ON CONFLICT ");
148        if !self.conflict_columns.is_empty() {
149            sql.push('(');
150            sql.push_str(&self.conflict_columns.join(", "));
151            sql.push_str(") ");
152        }
153
154        // DO UPDATE SET
155        if self.update_columns.is_empty() {
156            sql.push_str("DO NOTHING");
157        } else {
158            sql.push_str("DO UPDATE SET ");
159            let update_parts: Vec<_> = self
160                .update_columns
161                .iter()
162                .zip(self.update_values.iter())
163                .map(|(col, val)| {
164                    params.push(val.clone());
165                    let part = format!("{} = ${}", col, param_idx);
166                    param_idx += 1;
167                    part
168                })
169                .collect();
170            sql.push_str(&update_parts.join(", "));
171        }
172
173        // RETURNING clause
174        sql.push_str(" RETURNING ");
175        sql.push_str(&self.select.to_sql());
176
177        (sql, params)
178    }
179
180    /// Execute the upsert and return the record.
181    pub async fn exec(self) -> QueryResult<M>
182    where
183        M: Send + 'static,
184    {
185        let (sql, params) = self.build_sql();
186        self.engine.execute_insert::<M>(&sql, params).await
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use crate::error::QueryError;
194
195    struct TestModel;
196
197    impl Model for TestModel {
198        const MODEL_NAME: &'static str = "TestModel";
199        const TABLE_NAME: &'static str = "test_models";
200        const PRIMARY_KEY: &'static [&'static str] = &["id"];
201        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
202    }
203
204    #[derive(Clone)]
205    struct MockEngine;
206
207    impl QueryEngine for MockEngine {
208        fn query_many<T: Model + Send + 'static>(
209            &self,
210            _sql: &str,
211            _params: Vec<FilterValue>,
212        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
213            Box::pin(async { Ok(Vec::new()) })
214        }
215
216        fn query_one<T: Model + Send + 'static>(
217            &self,
218            _sql: &str,
219            _params: Vec<FilterValue>,
220        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
221            Box::pin(async { Err(QueryError::not_found("test")) })
222        }
223
224        fn query_optional<T: Model + Send + 'static>(
225            &self,
226            _sql: &str,
227            _params: Vec<FilterValue>,
228        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
229            Box::pin(async { Ok(None) })
230        }
231
232        fn execute_insert<T: Model + Send + 'static>(
233            &self,
234            _sql: &str,
235            _params: Vec<FilterValue>,
236        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
237            Box::pin(async { Err(QueryError::not_found("test")) })
238        }
239
240        fn execute_update<T: Model + Send + 'static>(
241            &self,
242            _sql: &str,
243            _params: Vec<FilterValue>,
244        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
245            Box::pin(async { Ok(Vec::new()) })
246        }
247
248        fn execute_delete(
249            &self,
250            _sql: &str,
251            _params: Vec<FilterValue>,
252        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
253            Box::pin(async { Ok(0) })
254        }
255
256        fn execute_raw(
257            &self,
258            _sql: &str,
259            _params: Vec<FilterValue>,
260        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
261            Box::pin(async { Ok(0) })
262        }
263
264        fn count(
265            &self,
266            _sql: &str,
267            _params: Vec<FilterValue>,
268        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
269            Box::pin(async { Ok(0) })
270        }
271    }
272
273    // ========== Construction Tests ==========
274
275    #[test]
276    fn test_upsert_new() {
277        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine);
278        let (sql, params) = op.build_sql();
279
280        assert!(sql.contains("INSERT INTO test_models"));
281        assert!(sql.contains("ON CONFLICT"));
282        assert!(sql.contains("RETURNING *"));
283        assert!(params.is_empty());
284    }
285
286    #[test]
287    fn test_upsert_basic() {
288        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
289            .on_conflict(["email"])
290            .create_set("email", "test@example.com")
291            .create_set("name", "Test")
292            .update_set("name", "Updated");
293
294        let (sql, params) = op.build_sql();
295
296        assert!(sql.contains("INSERT INTO test_models"));
297        assert!(sql.contains("ON CONFLICT (email)"));
298        assert!(sql.contains("DO UPDATE SET"));
299        assert!(sql.contains("RETURNING *"));
300        assert_eq!(params.len(), 3); // 2 create + 1 update
301    }
302
303    // ========== Conflict Column Tests ==========
304
305    #[test]
306    fn test_upsert_single_conflict_column() {
307        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
308            .on_conflict(["id"])
309            .create_set("id", FilterValue::Int(1));
310
311        let (sql, _) = op.build_sql();
312
313        assert!(sql.contains("ON CONFLICT (id)"));
314    }
315
316    #[test]
317    fn test_upsert_multiple_conflict_columns() {
318        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
319            .on_conflict(["tenant_id", "email"])
320            .create_set("email", "test@example.com")
321            .create_set("tenant_id", FilterValue::Int(1));
322
323        let (sql, _) = op.build_sql();
324
325        assert!(sql.contains("ON CONFLICT (tenant_id, email)"));
326    }
327
328    #[test]
329    fn test_upsert_without_conflict_columns() {
330        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
331            .create_set("email", "test@example.com");
332
333        let (sql, _) = op.build_sql();
334
335        assert!(sql.contains("ON CONFLICT"));
336        assert!(!sql.contains("ON CONFLICT ("));
337    }
338
339    // ========== Create Tests ==========
340
341    #[test]
342    fn test_upsert_create_with_set() {
343        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
344            .on_conflict(["email"])
345            .create_set("email", "test@example.com")
346            .create_set("name", "Test User");
347
348        let (sql, params) = op.build_sql();
349
350        assert!(sql.contains("(email, name)"));
351        assert!(sql.contains("VALUES ($1, $2)"));
352        assert_eq!(params.len(), 2);
353    }
354
355    #[test]
356    fn test_upsert_create_with_iterator() {
357        let create_data = vec![
358            ("email", FilterValue::String("test@example.com".to_string())),
359            ("name", FilterValue::String("Test User".to_string())),
360            ("age", FilterValue::Int(25)),
361        ];
362        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
363            .on_conflict(["email"])
364            .create(create_data);
365
366        let (sql, params) = op.build_sql();
367
368        assert!(sql.contains("(email, name, age)"));
369        assert!(sql.contains("VALUES ($1, $2, $3)"));
370        assert_eq!(params.len(), 3);
371    }
372
373    // ========== Update Tests ==========
374
375    #[test]
376    fn test_upsert_update_with_set() {
377        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
378            .on_conflict(["email"])
379            .create_set("email", "test@example.com")
380            .update_set("name", "Updated Name")
381            .update_set("updated_at", "2024-01-01");
382
383        let (sql, params) = op.build_sql();
384
385        assert!(sql.contains("DO UPDATE SET"));
386        assert!(sql.contains("name = $"));
387        assert!(sql.contains("updated_at = $"));
388        assert_eq!(params.len(), 3); // 1 create + 2 update
389    }
390
391    #[test]
392    fn test_upsert_update_with_iterator() {
393        let update_data = vec![
394            ("name", FilterValue::String("Updated".to_string())),
395            ("status", FilterValue::String("active".to_string())),
396        ];
397        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
398            .on_conflict(["id"])
399            .create_set("id", FilterValue::Int(1))
400            .update(update_data);
401
402        let (sql, params) = op.build_sql();
403
404        assert!(sql.contains("DO UPDATE SET"));
405        assert_eq!(params.len(), 3); // 1 create + 2 update
406    }
407
408    // ========== Do Nothing Tests ==========
409
410    #[test]
411    fn test_upsert_do_nothing() {
412        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
413            .on_conflict(["email"])
414            .create_set("email", "test@example.com");
415
416        let (sql, _) = op.build_sql();
417
418        assert!(sql.contains("DO NOTHING"));
419        assert!(!sql.contains("DO UPDATE"));
420    }
421
422    #[test]
423    fn test_upsert_do_nothing_multiple_create() {
424        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
425            .on_conflict(["email"])
426            .create_set("email", "test@example.com")
427            .create_set("name", "Test");
428
429        let (sql, params) = op.build_sql();
430
431        assert!(sql.contains("DO NOTHING"));
432        assert_eq!(params.len(), 2);
433    }
434
435    // ========== Select Tests ==========
436
437    #[test]
438    fn test_upsert_with_select() {
439        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
440            .on_conflict(["email"])
441            .create_set("email", "test@example.com")
442            .update_set("name", "Updated")
443            .select(Select::fields(["id", "email"]));
444
445        let (sql, _) = op.build_sql();
446
447        assert!(sql.contains("RETURNING id, email"));
448        assert!(!sql.contains("RETURNING *"));
449    }
450
451    #[test]
452    fn test_upsert_select_all() {
453        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
454            .on_conflict(["email"])
455            .create_set("email", "test@example.com")
456            .select(Select::All);
457
458        let (sql, _) = op.build_sql();
459
460        assert!(sql.contains("RETURNING *"));
461    }
462
463    // ========== Where Filter Tests ==========
464
465    #[test]
466    fn test_upsert_with_where() {
467        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
468            .r#where(Filter::Equals("email".into(), FilterValue::String("test@example.com".to_string())))
469            .on_conflict(["email"])
470            .create_set("email", "test@example.com");
471
472        let (_, _) = op.build_sql();
473        // where_ sets the filter but doesn't affect upsert SQL directly
474    }
475
476    // ========== SQL Structure Tests ==========
477
478    #[test]
479    fn test_upsert_sql_structure() {
480        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
481            .on_conflict(["email"])
482            .create_set("email", "test@example.com")
483            .update_set("name", "Updated")
484            .select(Select::fields(["id"]));
485
486        let (sql, _) = op.build_sql();
487
488        let insert_pos = sql.find("INSERT INTO").unwrap();
489        let values_pos = sql.find("VALUES").unwrap();
490        let conflict_pos = sql.find("ON CONFLICT").unwrap();
491        let update_pos = sql.find("DO UPDATE SET").unwrap();
492        let returning_pos = sql.find("RETURNING").unwrap();
493
494        assert!(insert_pos < values_pos);
495        assert!(values_pos < conflict_pos);
496        assert!(conflict_pos < update_pos);
497        assert!(update_pos < returning_pos);
498    }
499
500    #[test]
501    fn test_upsert_table_name() {
502        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine);
503        let (sql, _) = op.build_sql();
504
505        assert!(sql.contains("test_models"));
506    }
507
508    // ========== Param Ordering Tests ==========
509
510    #[test]
511    fn test_upsert_param_ordering() {
512        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
513            .on_conflict(["email"])
514            .create_set("email", "create@test.com")
515            .create_set("name", "Create Name")
516            .update_set("name", "Update Name");
517
518        let (sql, params) = op.build_sql();
519
520        // Create params first, then update params
521        assert!(sql.contains("VALUES ($1, $2)"));
522        assert!(sql.contains("name = $3"));
523        assert_eq!(params.len(), 3);
524        assert_eq!(params[0], FilterValue::String("create@test.com".to_string()));
525        assert_eq!(params[1], FilterValue::String("Create Name".to_string()));
526        assert_eq!(params[2], FilterValue::String("Update Name".to_string()));
527    }
528
529    // ========== Async Execution Tests ==========
530
531    #[tokio::test]
532    async fn test_upsert_exec() {
533        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
534            .on_conflict(["email"])
535            .create_set("email", "test@example.com");
536
537        let result = op.exec().await;
538
539        // MockEngine returns not_found for execute_insert
540        assert!(result.is_err());
541    }
542
543    // ========== Method Chaining Tests ==========
544
545    #[test]
546    fn test_upsert_full_chain() {
547        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
548            .r#where(Filter::Equals("email".into(), FilterValue::String("test@example.com".to_string())))
549            .on_conflict(["email"])
550            .create_set("email", "test@example.com")
551            .create_set("name", "Test User")
552            .update_set("name", "Updated User")
553            .select(Select::fields(["id", "name", "email"]));
554
555        let (sql, params) = op.build_sql();
556
557        assert!(sql.contains("INSERT INTO test_models"));
558        assert!(sql.contains("ON CONFLICT (email)"));
559        assert!(sql.contains("DO UPDATE SET"));
560        assert!(sql.contains("RETURNING id, name, email"));
561        assert_eq!(params.len(), 3);
562    }
563
564    // ========== Value Type Tests ==========
565
566    #[test]
567    fn test_upsert_with_null_value() {
568        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
569            .on_conflict(["id"])
570            .create_set("id", FilterValue::Int(1))
571            .create_set("nickname", FilterValue::Null);
572
573        let (_, params) = op.build_sql();
574
575        assert_eq!(params[1], FilterValue::Null);
576    }
577
578    #[test]
579    fn test_upsert_with_boolean_value() {
580        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
581            .on_conflict(["id"])
582            .create_set("id", FilterValue::Int(1))
583            .create_set("active", FilterValue::Bool(true))
584            .update_set("active", FilterValue::Bool(false));
585
586        let (_, params) = op.build_sql();
587
588        assert_eq!(params[1], FilterValue::Bool(true));
589        assert_eq!(params[2], FilterValue::Bool(false));
590    }
591
592    #[test]
593    fn test_upsert_with_numeric_values() {
594        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
595            .on_conflict(["id"])
596            .create_set("id", FilterValue::Int(1))
597            .create_set("score", FilterValue::Float(99.5));
598
599        let (_, params) = op.build_sql();
600
601        assert_eq!(params[0], FilterValue::Int(1));
602        assert_eq!(params[1], FilterValue::Float(99.5));
603    }
604
605    #[test]
606    fn test_upsert_with_json_value() {
607        let json = serde_json::json!({"key": "value"});
608        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
609            .on_conflict(["id"])
610            .create_set("id", FilterValue::Int(1))
611            .create_set("metadata", FilterValue::Json(json.clone()));
612
613        let (_, params) = op.build_sql();
614
615        assert_eq!(params[1], FilterValue::Json(json));
616    }
617}
618