prax_query/operations/
upsert.rs

1//! Upsert operation for creating or updating records.
2
3use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10/// An upsert (insert or update) operation.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// let user = client
16///     .user()
17///     .upsert()
18///     .r#where(user::email::equals("test@example.com"))
19///     .create(user::Create { email: "test@example.com".into(), name: Some("Test".into()) })
20///     .update(user::Update { name: Some("Updated".into()), ..Default::default() })
21///     .exec()
22///     .await?;
23/// ```
24pub struct UpsertOperation<E: QueryEngine, M: Model> {
25    engine: E,
26    filter: Filter,
27    create_columns: Vec<String>,
28    create_values: Vec<FilterValue>,
29    update_columns: Vec<String>,
30    update_values: Vec<FilterValue>,
31    conflict_columns: Vec<String>,
32    select: Select,
33    _model: PhantomData<M>,
34}
35
36impl<E: QueryEngine, M: Model> UpsertOperation<E, M> {
37    /// Create a new Upsert operation.
38    pub fn new(engine: E) -> Self {
39        Self {
40            engine,
41            filter: Filter::None,
42            create_columns: Vec::new(),
43            create_values: Vec::new(),
44            update_columns: Vec::new(),
45            update_values: Vec::new(),
46            conflict_columns: Vec::new(),
47            select: Select::All,
48            _model: PhantomData,
49        }
50    }
51
52    /// Add a filter condition (identifies the record to upsert).
53    pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
54        self.filter = filter.into();
55        self
56    }
57
58    /// Set the columns to check for conflict.
59    pub fn on_conflict(mut self, columns: impl IntoIterator<Item = impl Into<String>>) -> Self {
60        self.conflict_columns = columns.into_iter().map(Into::into).collect();
61        self
62    }
63
64    /// Set the create data.
65    pub fn create(
66        mut self,
67        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
68    ) -> Self {
69        for (col, val) in values {
70            self.create_columns.push(col.into());
71            self.create_values.push(val.into());
72        }
73        self
74    }
75
76    /// Set a single create column.
77    pub fn create_set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
78        self.create_columns.push(column.into());
79        self.create_values.push(value.into());
80        self
81    }
82
83    /// Set the update data.
84    pub fn update(
85        mut self,
86        values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
87    ) -> Self {
88        for (col, val) in values {
89            self.update_columns.push(col.into());
90            self.update_values.push(val.into());
91        }
92        self
93    }
94
95    /// Set a single update column.
96    pub fn update_set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
97        self.update_columns.push(column.into());
98        self.update_values.push(value.into());
99        self
100    }
101
102    /// Select specific fields to return.
103    pub fn select(mut self, select: impl Into<Select>) -> Self {
104        self.select = select.into();
105        self
106    }
107
108    /// Build the SQL query.
109    pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
110        let mut sql = String::new();
111        let mut params = Vec::new();
112        let mut param_idx = 1;
113
114        // INSERT INTO clause
115        sql.push_str("INSERT INTO ");
116        sql.push_str(M::TABLE_NAME);
117
118        // Columns
119        sql.push_str(" (");
120        sql.push_str(&self.create_columns.join(", "));
121        sql.push(')');
122
123        // VALUES
124        sql.push_str(" VALUES (");
125        let placeholders: Vec<_> = self
126            .create_values
127            .iter()
128            .map(|v| {
129                params.push(v.clone());
130                let p = format!("${}", param_idx);
131                param_idx += 1;
132                p
133            })
134            .collect();
135        sql.push_str(&placeholders.join(", "));
136        sql.push(')');
137
138        // ON CONFLICT
139        sql.push_str(" ON CONFLICT ");
140        if !self.conflict_columns.is_empty() {
141            sql.push('(');
142            sql.push_str(&self.conflict_columns.join(", "));
143            sql.push_str(") ");
144        }
145
146        // DO UPDATE SET
147        if self.update_columns.is_empty() {
148            sql.push_str("DO NOTHING");
149        } else {
150            sql.push_str("DO UPDATE SET ");
151            let update_parts: Vec<_> = self
152                .update_columns
153                .iter()
154                .zip(self.update_values.iter())
155                .map(|(col, val)| {
156                    params.push(val.clone());
157                    let part = format!("{} = ${}", col, param_idx);
158                    param_idx += 1;
159                    part
160                })
161                .collect();
162            sql.push_str(&update_parts.join(", "));
163        }
164
165        // RETURNING clause
166        sql.push_str(" RETURNING ");
167        sql.push_str(&self.select.to_sql());
168
169        (sql, params)
170    }
171
172    /// Execute the upsert and return the record.
173    pub async fn exec(self) -> QueryResult<M>
174    where
175        M: Send + 'static,
176    {
177        let (sql, params) = self.build_sql();
178        self.engine.execute_insert::<M>(&sql, params).await
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use crate::error::QueryError;
186
187    struct TestModel;
188
189    impl Model for TestModel {
190        const MODEL_NAME: &'static str = "TestModel";
191        const TABLE_NAME: &'static str = "test_models";
192        const PRIMARY_KEY: &'static [&'static str] = &["id"];
193        const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
194    }
195
196    #[derive(Clone)]
197    struct MockEngine;
198
199    impl QueryEngine for MockEngine {
200        fn query_many<T: Model + Send + 'static>(
201            &self,
202            _sql: &str,
203            _params: Vec<FilterValue>,
204        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
205            Box::pin(async { Ok(Vec::new()) })
206        }
207
208        fn query_one<T: Model + Send + 'static>(
209            &self,
210            _sql: &str,
211            _params: Vec<FilterValue>,
212        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
213            Box::pin(async { Err(QueryError::not_found("test")) })
214        }
215
216        fn query_optional<T: Model + Send + 'static>(
217            &self,
218            _sql: &str,
219            _params: Vec<FilterValue>,
220        ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
221            Box::pin(async { Ok(None) })
222        }
223
224        fn execute_insert<T: Model + Send + 'static>(
225            &self,
226            _sql: &str,
227            _params: Vec<FilterValue>,
228        ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
229            Box::pin(async { Err(QueryError::not_found("test")) })
230        }
231
232        fn execute_update<T: Model + Send + 'static>(
233            &self,
234            _sql: &str,
235            _params: Vec<FilterValue>,
236        ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
237            Box::pin(async { Ok(Vec::new()) })
238        }
239
240        fn execute_delete(
241            &self,
242            _sql: &str,
243            _params: Vec<FilterValue>,
244        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
245            Box::pin(async { Ok(0) })
246        }
247
248        fn execute_raw(
249            &self,
250            _sql: &str,
251            _params: Vec<FilterValue>,
252        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
253            Box::pin(async { Ok(0) })
254        }
255
256        fn count(
257            &self,
258            _sql: &str,
259            _params: Vec<FilterValue>,
260        ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
261            Box::pin(async { Ok(0) })
262        }
263    }
264
265    // ========== Construction Tests ==========
266
267    #[test]
268    fn test_upsert_new() {
269        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine);
270        let (sql, params) = op.build_sql();
271
272        assert!(sql.contains("INSERT INTO test_models"));
273        assert!(sql.contains("ON CONFLICT"));
274        assert!(sql.contains("RETURNING *"));
275        assert!(params.is_empty());
276    }
277
278    #[test]
279    fn test_upsert_basic() {
280        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
281            .on_conflict(["email"])
282            .create_set("email", "test@example.com")
283            .create_set("name", "Test")
284            .update_set("name", "Updated");
285
286        let (sql, params) = op.build_sql();
287
288        assert!(sql.contains("INSERT INTO test_models"));
289        assert!(sql.contains("ON CONFLICT (email)"));
290        assert!(sql.contains("DO UPDATE SET"));
291        assert!(sql.contains("RETURNING *"));
292        assert_eq!(params.len(), 3); // 2 create + 1 update
293    }
294
295    // ========== Conflict Column Tests ==========
296
297    #[test]
298    fn test_upsert_single_conflict_column() {
299        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
300            .on_conflict(["id"])
301            .create_set("id", FilterValue::Int(1));
302
303        let (sql, _) = op.build_sql();
304
305        assert!(sql.contains("ON CONFLICT (id)"));
306    }
307
308    #[test]
309    fn test_upsert_multiple_conflict_columns() {
310        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
311            .on_conflict(["tenant_id", "email"])
312            .create_set("email", "test@example.com")
313            .create_set("tenant_id", FilterValue::Int(1));
314
315        let (sql, _) = op.build_sql();
316
317        assert!(sql.contains("ON CONFLICT (tenant_id, email)"));
318    }
319
320    #[test]
321    fn test_upsert_without_conflict_columns() {
322        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
323            .create_set("email", "test@example.com");
324
325        let (sql, _) = op.build_sql();
326
327        assert!(sql.contains("ON CONFLICT"));
328        assert!(!sql.contains("ON CONFLICT ("));
329    }
330
331    // ========== Create Tests ==========
332
333    #[test]
334    fn test_upsert_create_with_set() {
335        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
336            .on_conflict(["email"])
337            .create_set("email", "test@example.com")
338            .create_set("name", "Test User");
339
340        let (sql, params) = op.build_sql();
341
342        assert!(sql.contains("(email, name)"));
343        assert!(sql.contains("VALUES ($1, $2)"));
344        assert_eq!(params.len(), 2);
345    }
346
347    #[test]
348    fn test_upsert_create_with_iterator() {
349        let create_data = vec![
350            ("email", FilterValue::String("test@example.com".to_string())),
351            ("name", FilterValue::String("Test User".to_string())),
352            ("age", FilterValue::Int(25)),
353        ];
354        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
355            .on_conflict(["email"])
356            .create(create_data);
357
358        let (sql, params) = op.build_sql();
359
360        assert!(sql.contains("(email, name, age)"));
361        assert!(sql.contains("VALUES ($1, $2, $3)"));
362        assert_eq!(params.len(), 3);
363    }
364
365    // ========== Update Tests ==========
366
367    #[test]
368    fn test_upsert_update_with_set() {
369        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
370            .on_conflict(["email"])
371            .create_set("email", "test@example.com")
372            .update_set("name", "Updated Name")
373            .update_set("updated_at", "2024-01-01");
374
375        let (sql, params) = op.build_sql();
376
377        assert!(sql.contains("DO UPDATE SET"));
378        assert!(sql.contains("name = $"));
379        assert!(sql.contains("updated_at = $"));
380        assert_eq!(params.len(), 3); // 1 create + 2 update
381    }
382
383    #[test]
384    fn test_upsert_update_with_iterator() {
385        let update_data = vec![
386            ("name", FilterValue::String("Updated".to_string())),
387            ("status", FilterValue::String("active".to_string())),
388        ];
389        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
390            .on_conflict(["id"])
391            .create_set("id", FilterValue::Int(1))
392            .update(update_data);
393
394        let (sql, params) = op.build_sql();
395
396        assert!(sql.contains("DO UPDATE SET"));
397        assert_eq!(params.len(), 3); // 1 create + 2 update
398    }
399
400    // ========== Do Nothing Tests ==========
401
402    #[test]
403    fn test_upsert_do_nothing() {
404        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
405            .on_conflict(["email"])
406            .create_set("email", "test@example.com");
407
408        let (sql, _) = op.build_sql();
409
410        assert!(sql.contains("DO NOTHING"));
411        assert!(!sql.contains("DO UPDATE"));
412    }
413
414    #[test]
415    fn test_upsert_do_nothing_multiple_create() {
416        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
417            .on_conflict(["email"])
418            .create_set("email", "test@example.com")
419            .create_set("name", "Test");
420
421        let (sql, params) = op.build_sql();
422
423        assert!(sql.contains("DO NOTHING"));
424        assert_eq!(params.len(), 2);
425    }
426
427    // ========== Select Tests ==========
428
429    #[test]
430    fn test_upsert_with_select() {
431        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
432            .on_conflict(["email"])
433            .create_set("email", "test@example.com")
434            .update_set("name", "Updated")
435            .select(Select::fields(["id", "email"]));
436
437        let (sql, _) = op.build_sql();
438
439        assert!(sql.contains("RETURNING id, email"));
440        assert!(!sql.contains("RETURNING *"));
441    }
442
443    #[test]
444    fn test_upsert_select_all() {
445        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
446            .on_conflict(["email"])
447            .create_set("email", "test@example.com")
448            .select(Select::All);
449
450        let (sql, _) = op.build_sql();
451
452        assert!(sql.contains("RETURNING *"));
453    }
454
455    // ========== Where Filter Tests ==========
456
457    #[test]
458    fn test_upsert_with_where() {
459        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
460            .r#where(Filter::Equals(
461                "email".into(),
462                FilterValue::String("test@example.com".to_string()),
463            ))
464            .on_conflict(["email"])
465            .create_set("email", "test@example.com");
466
467        let (_, _) = op.build_sql();
468        // where_ sets the filter but doesn't affect upsert SQL directly
469    }
470
471    // ========== SQL Structure Tests ==========
472
473    #[test]
474    fn test_upsert_sql_structure() {
475        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
476            .on_conflict(["email"])
477            .create_set("email", "test@example.com")
478            .update_set("name", "Updated")
479            .select(Select::fields(["id"]));
480
481        let (sql, _) = op.build_sql();
482
483        let insert_pos = sql.find("INSERT INTO").unwrap();
484        let values_pos = sql.find("VALUES").unwrap();
485        let conflict_pos = sql.find("ON CONFLICT").unwrap();
486        let update_pos = sql.find("DO UPDATE SET").unwrap();
487        let returning_pos = sql.find("RETURNING").unwrap();
488
489        assert!(insert_pos < values_pos);
490        assert!(values_pos < conflict_pos);
491        assert!(conflict_pos < update_pos);
492        assert!(update_pos < returning_pos);
493    }
494
495    #[test]
496    fn test_upsert_table_name() {
497        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine);
498        let (sql, _) = op.build_sql();
499
500        assert!(sql.contains("test_models"));
501    }
502
503    // ========== Param Ordering Tests ==========
504
505    #[test]
506    fn test_upsert_param_ordering() {
507        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
508            .on_conflict(["email"])
509            .create_set("email", "create@test.com")
510            .create_set("name", "Create Name")
511            .update_set("name", "Update Name");
512
513        let (sql, params) = op.build_sql();
514
515        // Create params first, then update params
516        assert!(sql.contains("VALUES ($1, $2)"));
517        assert!(sql.contains("name = $3"));
518        assert_eq!(params.len(), 3);
519        assert_eq!(
520            params[0],
521            FilterValue::String("create@test.com".to_string())
522        );
523        assert_eq!(params[1], FilterValue::String("Create Name".to_string()));
524        assert_eq!(params[2], FilterValue::String("Update Name".to_string()));
525    }
526
527    // ========== Async Execution Tests ==========
528
529    #[tokio::test]
530    async fn test_upsert_exec() {
531        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
532            .on_conflict(["email"])
533            .create_set("email", "test@example.com");
534
535        let result = op.exec().await;
536
537        // MockEngine returns not_found for execute_insert
538        assert!(result.is_err());
539    }
540
541    // ========== Method Chaining Tests ==========
542
543    #[test]
544    fn test_upsert_full_chain() {
545        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
546            .r#where(Filter::Equals(
547                "email".into(),
548                FilterValue::String("test@example.com".to_string()),
549            ))
550            .on_conflict(["email"])
551            .create_set("email", "test@example.com")
552            .create_set("name", "Test User")
553            .update_set("name", "Updated User")
554            .select(Select::fields(["id", "name", "email"]));
555
556        let (sql, params) = op.build_sql();
557
558        assert!(sql.contains("INSERT INTO test_models"));
559        assert!(sql.contains("ON CONFLICT (email)"));
560        assert!(sql.contains("DO UPDATE SET"));
561        assert!(sql.contains("RETURNING id, name, email"));
562        assert_eq!(params.len(), 3);
563    }
564
565    // ========== Value Type Tests ==========
566
567    #[test]
568    fn test_upsert_with_null_value() {
569        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
570            .on_conflict(["id"])
571            .create_set("id", FilterValue::Int(1))
572            .create_set("nickname", FilterValue::Null);
573
574        let (_, params) = op.build_sql();
575
576        assert_eq!(params[1], FilterValue::Null);
577    }
578
579    #[test]
580    fn test_upsert_with_boolean_value() {
581        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
582            .on_conflict(["id"])
583            .create_set("id", FilterValue::Int(1))
584            .create_set("active", FilterValue::Bool(true))
585            .update_set("active", FilterValue::Bool(false));
586
587        let (_, params) = op.build_sql();
588
589        assert_eq!(params[1], FilterValue::Bool(true));
590        assert_eq!(params[2], FilterValue::Bool(false));
591    }
592
593    #[test]
594    fn test_upsert_with_numeric_values() {
595        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
596            .on_conflict(["id"])
597            .create_set("id", FilterValue::Int(1))
598            .create_set("score", FilterValue::Float(99.5));
599
600        let (_, params) = op.build_sql();
601
602        assert_eq!(params[0], FilterValue::Int(1));
603        assert_eq!(params[1], FilterValue::Float(99.5));
604    }
605
606    #[test]
607    fn test_upsert_with_json_value() {
608        let json = serde_json::json!({"key": "value"});
609        let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
610            .on_conflict(["id"])
611            .create_set("id", FilterValue::Int(1))
612            .create_set("metadata", FilterValue::Json(json.clone()));
613
614        let (_, params) = op.build_sql();
615
616        assert_eq!(params[1], FilterValue::Json(json));
617    }
618}