Skip to main content

prax_query/
upsert.rs

1//! Upsert and conflict resolution support.
2//!
3//! This module provides types for building upsert operations with conflict
4//! resolution across different database backends.
5//!
6//! # Database Support
7//!
8//! | Feature          | PostgreSQL     | MySQL              | SQLite         | MSSQL   | MongoDB      |
9//! |------------------|----------------|--------------------|----------------|---------|--------------|
10//! | ON CONFLICT      | ✅             | ❌                 | ✅             | ❌      | ❌           |
11//! | ON DUPLICATE KEY | ❌             | ✅                 | ❌             | ❌      | ❌           |
12//! | MERGE statement  | ❌             | ❌                 | ❌             | ✅      | ❌           |
13//! | Native upsert    | ❌             | ❌                 | ❌             | ❌      | ✅ upsert:true|
14//! | Conflict targets | ✅             | ❌ (implicit PK/UK)| ✅             | ✅      | ✅ filter    |
15//!
16//! # Example Usage
17//!
18//! ```rust,ignore
19//! use prax_query::upsert::{Upsert, ConflictTarget, ConflictAction};
20//!
21//! // PostgreSQL: INSERT ... ON CONFLICT (email) DO UPDATE SET ...
22//! let upsert = Upsert::new("users")
23//!     .columns(["email", "name", "updated_at"])
24//!     .values(["$1", "$2", "NOW()"])
25//!     .on_conflict(ConflictTarget::columns(["email"]))
26//!     .do_update(["name", "updated_at"]);
27//! ```
28
29use serde::{Deserialize, Serialize};
30
31use crate::error::{QueryError, QueryResult};
32use crate::sql::DatabaseType;
33
34/// An upsert operation (INSERT with conflict handling).
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub struct Upsert {
37    /// Table name.
38    pub table: String,
39    /// Columns to insert.
40    pub columns: Vec<String>,
41    /// Values to insert (expressions or placeholders).
42    pub values: Vec<String>,
43    /// Conflict target specification.
44    pub conflict_target: Option<ConflictTarget>,
45    /// Action to take on conflict.
46    pub conflict_action: ConflictAction,
47    /// WHERE clause for conflict update (PostgreSQL).
48    pub where_clause: Option<String>,
49    /// RETURNING clause (PostgreSQL).
50    pub returning: Option<Vec<String>>,
51}
52
53/// What to match on for conflict detection.
54#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55pub enum ConflictTarget {
56    /// Match on specific columns (unique constraint).
57    Columns(Vec<String>),
58    /// Match on a named constraint.
59    Constraint(String),
60    /// Match on index expression (PostgreSQL).
61    IndexExpression(String),
62    /// No specific target (MySQL ON DUPLICATE KEY).
63    Implicit,
64}
65
66impl ConflictTarget {
67    /// Create a column-based conflict target.
68    pub fn columns<I, S>(cols: I) -> Self
69    where
70        I: IntoIterator<Item = S>,
71        S: Into<String>,
72    {
73        Self::Columns(cols.into_iter().map(Into::into).collect())
74    }
75
76    /// Create a constraint-based conflict target.
77    pub fn constraint(name: impl Into<String>) -> Self {
78        Self::Constraint(name.into())
79    }
80
81    /// Create an index expression conflict target.
82    pub fn index_expression(expr: impl Into<String>) -> Self {
83        Self::IndexExpression(expr.into())
84    }
85
86    /// Generate PostgreSQL ON CONFLICT target.
87    pub fn to_postgres_sql(&self) -> String {
88        match self {
89            Self::Columns(cols) => format!("({})", cols.join(", ")),
90            Self::Constraint(name) => format!("ON CONSTRAINT {}", name),
91            Self::IndexExpression(expr) => format!("({})", expr),
92            Self::Implicit => String::new(),
93        }
94    }
95
96    /// Generate SQLite ON CONFLICT target.
97    pub fn to_sqlite_sql(&self) -> String {
98        match self {
99            Self::Columns(cols) => format!("({})", cols.join(", ")),
100            Self::Constraint(_) | Self::IndexExpression(_) => {
101                // SQLite doesn't support these directly
102                String::new()
103            }
104            Self::Implicit => String::new(),
105        }
106    }
107}
108
109/// Action to take when a conflict is detected.
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
111pub enum ConflictAction {
112    /// Do nothing (ignore the insert).
113    DoNothing,
114    /// Update specified columns.
115    DoUpdate(UpdateSpec),
116}
117
118/// Specification for what to update on conflict.
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub struct UpdateSpec {
121    /// Columns to update with their values.
122    pub assignments: Vec<Assignment>,
123    /// Use EXCLUDED values for columns (PostgreSQL/SQLite).
124    pub excluded_columns: Vec<String>,
125}
126
127/// A single column assignment.
128#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
129pub struct Assignment {
130    /// Column name.
131    pub column: String,
132    /// Value expression.
133    pub value: AssignmentValue,
134}
135
136/// Value for an assignment.
137#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
138pub enum AssignmentValue {
139    /// Use the EXCLUDED/VALUES value.
140    Excluded,
141    /// Use a literal expression.
142    Expression(String),
143    /// Use a parameter placeholder.
144    Param(usize),
145}
146
147impl Upsert {
148    /// Create a new upsert for the given table.
149    pub fn new(table: impl Into<String>) -> Self {
150        Self {
151            table: table.into(),
152            columns: Vec::new(),
153            values: Vec::new(),
154            conflict_target: None,
155            conflict_action: ConflictAction::DoNothing,
156            where_clause: None,
157            returning: None,
158        }
159    }
160
161    /// Create an upsert builder.
162    pub fn builder(table: impl Into<String>) -> UpsertBuilder {
163        UpsertBuilder::new(table)
164    }
165
166    /// Set the columns to insert.
167    pub fn columns<I, S>(mut self, cols: I) -> Self
168    where
169        I: IntoIterator<Item = S>,
170        S: Into<String>,
171    {
172        self.columns = cols.into_iter().map(Into::into).collect();
173        self
174    }
175
176    /// Set the values to insert.
177    pub fn values<I, S>(mut self, vals: I) -> Self
178    where
179        I: IntoIterator<Item = S>,
180        S: Into<String>,
181    {
182        self.values = vals.into_iter().map(Into::into).collect();
183        self
184    }
185
186    /// Set the conflict target.
187    pub fn on_conflict(mut self, target: ConflictTarget) -> Self {
188        self.conflict_target = Some(target);
189        self
190    }
191
192    /// Set conflict action to DO NOTHING.
193    pub fn do_nothing(mut self) -> Self {
194        self.conflict_action = ConflictAction::DoNothing;
195        self
196    }
197
198    /// Set conflict action to DO UPDATE for specified columns (using EXCLUDED).
199    pub fn do_update<I, S>(mut self, cols: I) -> Self
200    where
201        I: IntoIterator<Item = S>,
202        S: Into<String>,
203    {
204        self.conflict_action = ConflictAction::DoUpdate(UpdateSpec {
205            assignments: Vec::new(),
206            excluded_columns: cols.into_iter().map(Into::into).collect(),
207        });
208        self
209    }
210
211    /// Set conflict action to DO UPDATE with specific assignments.
212    pub fn do_update_set(mut self, assignments: Vec<Assignment>) -> Self {
213        self.conflict_action = ConflictAction::DoUpdate(UpdateSpec {
214            assignments,
215            excluded_columns: Vec::new(),
216        });
217        self
218    }
219
220    /// Add a WHERE clause for the update (PostgreSQL).
221    pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
222        self.where_clause = Some(condition.into());
223        self
224    }
225
226    /// Add RETURNING clause (PostgreSQL).
227    pub fn returning<I, S>(mut self, cols: I) -> Self
228    where
229        I: IntoIterator<Item = S>,
230        S: Into<String>,
231    {
232        self.returning = Some(cols.into_iter().map(Into::into).collect());
233        self
234    }
235
236    /// Generate PostgreSQL INSERT ... ON CONFLICT SQL.
237    pub fn to_postgres_sql(&self) -> String {
238        let mut sql = format!(
239            "INSERT INTO {} ({}) VALUES ({})",
240            self.table,
241            self.columns.join(", "),
242            self.values.join(", ")
243        );
244
245        sql.push_str(" ON CONFLICT ");
246
247        if let Some(ref target) = self.conflict_target {
248            sql.push_str(&target.to_postgres_sql());
249            sql.push(' ');
250        }
251
252        match &self.conflict_action {
253            ConflictAction::DoNothing => {
254                sql.push_str("DO NOTHING");
255            }
256            ConflictAction::DoUpdate(spec) => {
257                sql.push_str("DO UPDATE SET ");
258                let assignments: Vec<String> = if !spec.excluded_columns.is_empty() {
259                    spec.excluded_columns
260                        .iter()
261                        .map(|c| format!("{} = EXCLUDED.{}", c, c))
262                        .collect()
263                } else {
264                    spec.assignments
265                        .iter()
266                        .map(|a| {
267                            let value = match &a.value {
268                                AssignmentValue::Excluded => format!("EXCLUDED.{}", a.column),
269                                AssignmentValue::Expression(expr) => expr.clone(),
270                                AssignmentValue::Param(n) => format!("${}", n),
271                            };
272                            format!("{} = {}", a.column, value)
273                        })
274                        .collect()
275                };
276                sql.push_str(&assignments.join(", "));
277
278                if let Some(ref where_clause) = self.where_clause {
279                    sql.push_str(" WHERE ");
280                    sql.push_str(where_clause);
281                }
282            }
283        }
284
285        if let Some(ref returning) = self.returning {
286            sql.push_str(" RETURNING ");
287            sql.push_str(&returning.join(", "));
288        }
289
290        sql
291    }
292
293    /// Generate MySQL INSERT ... ON DUPLICATE KEY UPDATE SQL.
294    pub fn to_mysql_sql(&self) -> String {
295        let mut sql = format!(
296            "INSERT INTO {} ({}) VALUES ({})",
297            self.table,
298            self.columns.join(", "),
299            self.values.join(", ")
300        );
301
302        match &self.conflict_action {
303            ConflictAction::DoNothing => {
304                // MySQL doesn't have DO NOTHING, use INSERT IGNORE
305                sql = format!(
306                    "INSERT IGNORE INTO {} ({}) VALUES ({})",
307                    self.table,
308                    self.columns.join(", "),
309                    self.values.join(", ")
310                );
311            }
312            ConflictAction::DoUpdate(spec) => {
313                sql.push_str(" ON DUPLICATE KEY UPDATE ");
314                let assignments: Vec<String> = if !spec.excluded_columns.is_empty() {
315                    spec.excluded_columns
316                        .iter()
317                        .map(|c| format!("{} = VALUES({})", c, c))
318                        .collect()
319                } else {
320                    spec.assignments
321                        .iter()
322                        .map(|a| {
323                            let value = match &a.value {
324                                AssignmentValue::Excluded => format!("VALUES({})", a.column),
325                                AssignmentValue::Expression(expr) => expr.clone(),
326                                AssignmentValue::Param(_n) => "?".to_string(),
327                            };
328                            format!("{} = {}", a.column, value)
329                        })
330                        .collect()
331                };
332                sql.push_str(&assignments.join(", "));
333            }
334        }
335
336        sql
337    }
338
339    /// Generate SQLite INSERT ... ON CONFLICT SQL.
340    pub fn to_sqlite_sql(&self) -> String {
341        let mut sql = format!(
342            "INSERT INTO {} ({}) VALUES ({})",
343            self.table,
344            self.columns.join(", "),
345            self.values.join(", ")
346        );
347
348        sql.push_str(" ON CONFLICT");
349
350        if let Some(ref target) = self.conflict_target {
351            let target_sql = target.to_sqlite_sql();
352            if !target_sql.is_empty() {
353                sql.push(' ');
354                sql.push_str(&target_sql);
355            }
356        }
357
358        match &self.conflict_action {
359            ConflictAction::DoNothing => {
360                sql.push_str(" DO NOTHING");
361            }
362            ConflictAction::DoUpdate(spec) => {
363                sql.push_str(" DO UPDATE SET ");
364                let assignments: Vec<String> = if !spec.excluded_columns.is_empty() {
365                    spec.excluded_columns
366                        .iter()
367                        .map(|c| format!("{} = excluded.{}", c, c))
368                        .collect()
369                } else {
370                    spec.assignments
371                        .iter()
372                        .map(|a| {
373                            let value = match &a.value {
374                                AssignmentValue::Excluded => format!("excluded.{}", a.column),
375                                AssignmentValue::Expression(expr) => expr.clone(),
376                                AssignmentValue::Param(_n) => "?".to_string(),
377                            };
378                            format!("{} = {}", a.column, value)
379                        })
380                        .collect()
381                };
382                sql.push_str(&assignments.join(", "));
383
384                if let Some(ref where_clause) = self.where_clause {
385                    sql.push_str(" WHERE ");
386                    sql.push_str(where_clause);
387                }
388            }
389        }
390
391        if let Some(ref returning) = self.returning {
392            sql.push_str(" RETURNING ");
393            sql.push_str(&returning.join(", "));
394        }
395
396        sql
397    }
398
399    /// Generate MSSQL MERGE statement.
400    pub fn to_mssql_sql(&self) -> String {
401        let target = self
402            .conflict_target
403            .as_ref()
404            .and_then(|t| match t {
405                ConflictTarget::Columns(cols) => Some(cols.clone()),
406                _ => None,
407            })
408            .unwrap_or_else(|| vec![self.columns.first().cloned().unwrap_or_default()]);
409
410        let source_cols: Vec<String> = self
411            .columns
412            .iter()
413            .zip(&self.values)
414            .map(|(c, v)| format!("{} AS {}", v, c))
415            .collect();
416
417        let match_conditions: Vec<String> = target
418            .iter()
419            .map(|c| format!("target.{} = source.{}", c, c))
420            .collect();
421
422        let mut sql = format!(
423            "MERGE INTO {} AS target USING (SELECT {}) AS source ON {}",
424            self.table,
425            source_cols.join(", "),
426            match_conditions.join(" AND ")
427        );
428
429        match &self.conflict_action {
430            ConflictAction::DoNothing => {
431                // MSSQL MERGE requires at least one action
432                sql.push_str(" WHEN NOT MATCHED THEN INSERT (");
433                sql.push_str(&self.columns.join(", "));
434                sql.push_str(") VALUES (");
435                let source_refs: Vec<String> = self
436                    .columns
437                    .iter()
438                    .map(|c| format!("source.{}", c))
439                    .collect();
440                sql.push_str(&source_refs.join(", "));
441                sql.push(')');
442            }
443            ConflictAction::DoUpdate(spec) => {
444                sql.push_str(" WHEN MATCHED THEN UPDATE SET ");
445
446                let update_cols = if !spec.excluded_columns.is_empty() {
447                    &spec.excluded_columns
448                } else {
449                    &self.columns
450                };
451
452                let assignments: Vec<String> = update_cols
453                    .iter()
454                    .filter(|c| !target.contains(c))
455                    .map(|c| format!("target.{} = source.{}", c, c))
456                    .collect();
457
458                if assignments.is_empty() {
459                    // Need at least one assignment, use first non-key column
460                    let first_non_key = self.columns.iter().find(|c| !target.contains(*c));
461                    if let Some(col) = first_non_key {
462                        sql.push_str(&format!("target.{} = source.{}", col, col));
463                    } else {
464                        sql.push_str(&format!(
465                            "target.{} = source.{}",
466                            self.columns[0], self.columns[0]
467                        ));
468                    }
469                } else {
470                    sql.push_str(&assignments.join(", "));
471                }
472
473                sql.push_str(" WHEN NOT MATCHED THEN INSERT (");
474                sql.push_str(&self.columns.join(", "));
475                sql.push_str(") VALUES (");
476                let source_refs: Vec<String> = self
477                    .columns
478                    .iter()
479                    .map(|c| format!("source.{}", c))
480                    .collect();
481                sql.push_str(&source_refs.join(", "));
482                sql.push(')');
483            }
484        }
485
486        sql.push(';');
487        sql
488    }
489
490    /// Generate SQL for the specified database type.
491    pub fn to_sql(&self, db_type: DatabaseType) -> String {
492        match db_type {
493            DatabaseType::PostgreSQL => self.to_postgres_sql(),
494            DatabaseType::MySQL => self.to_mysql_sql(),
495            DatabaseType::SQLite => self.to_sqlite_sql(),
496            DatabaseType::MSSQL => self.to_mssql_sql(),
497        }
498    }
499}
500
501/// Builder for upsert operations.
502#[derive(Debug, Clone, Default)]
503pub struct UpsertBuilder {
504    table: String,
505    columns: Vec<String>,
506    values: Vec<String>,
507    conflict_target: Option<ConflictTarget>,
508    conflict_action: Option<ConflictAction>,
509    where_clause: Option<String>,
510    returning: Option<Vec<String>>,
511}
512
513impl UpsertBuilder {
514    /// Create a new builder.
515    pub fn new(table: impl Into<String>) -> Self {
516        Self {
517            table: table.into(),
518            ..Default::default()
519        }
520    }
521
522    /// Add columns to insert.
523    pub fn columns<I, S>(mut self, cols: I) -> Self
524    where
525        I: IntoIterator<Item = S>,
526        S: Into<String>,
527    {
528        self.columns = cols.into_iter().map(Into::into).collect();
529        self
530    }
531
532    /// Add values to insert.
533    pub fn values<I, S>(mut self, vals: I) -> Self
534    where
535        I: IntoIterator<Item = S>,
536        S: Into<String>,
537    {
538        self.values = vals.into_iter().map(Into::into).collect();
539        self
540    }
541
542    /// Set conflict target columns.
543    pub fn on_conflict_columns<I, S>(mut self, cols: I) -> Self
544    where
545        I: IntoIterator<Item = S>,
546        S: Into<String>,
547    {
548        self.conflict_target = Some(ConflictTarget::columns(cols));
549        self
550    }
551
552    /// Set conflict target constraint.
553    pub fn on_conflict_constraint(mut self, name: impl Into<String>) -> Self {
554        self.conflict_target = Some(ConflictTarget::Constraint(name.into()));
555        self
556    }
557
558    /// Set action to DO NOTHING.
559    pub fn do_nothing(mut self) -> Self {
560        self.conflict_action = Some(ConflictAction::DoNothing);
561        self
562    }
563
564    /// Set action to DO UPDATE with excluded columns.
565    pub fn do_update<I, S>(mut self, cols: I) -> Self
566    where
567        I: IntoIterator<Item = S>,
568        S: Into<String>,
569    {
570        self.conflict_action = Some(ConflictAction::DoUpdate(UpdateSpec {
571            assignments: Vec::new(),
572            excluded_columns: cols.into_iter().map(Into::into).collect(),
573        }));
574        self
575    }
576
577    /// Set action to DO UPDATE with assignments.
578    pub fn do_update_assignments(mut self, assignments: Vec<Assignment>) -> Self {
579        self.conflict_action = Some(ConflictAction::DoUpdate(UpdateSpec {
580            assignments,
581            excluded_columns: Vec::new(),
582        }));
583        self
584    }
585
586    /// Add WHERE clause for update.
587    pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
588        self.where_clause = Some(condition.into());
589        self
590    }
591
592    /// Add RETURNING clause.
593    pub fn returning<I, S>(mut self, cols: I) -> Self
594    where
595        I: IntoIterator<Item = S>,
596        S: Into<String>,
597    {
598        self.returning = Some(cols.into_iter().map(Into::into).collect());
599        self
600    }
601
602    /// Build the upsert.
603    pub fn build(self) -> QueryResult<Upsert> {
604        if self.columns.is_empty() {
605            return Err(QueryError::invalid_input(
606                "columns",
607                "Upsert requires at least one column",
608            ));
609        }
610        if self.values.is_empty() {
611            return Err(QueryError::invalid_input(
612                "values",
613                "Upsert requires at least one value",
614            ));
615        }
616
617        Ok(Upsert {
618            table: self.table,
619            columns: self.columns,
620            values: self.values,
621            conflict_target: self.conflict_target,
622            conflict_action: self.conflict_action.unwrap_or(ConflictAction::DoNothing),
623            where_clause: self.where_clause,
624            returning: self.returning,
625        })
626    }
627}
628
629/// MongoDB upsert operations.
630pub mod mongodb {
631    use serde::{Deserialize, Serialize};
632    use serde_json::Value as JsonValue;
633
634    /// MongoDB upsert operation builder.
635    #[derive(Debug, Clone, Default)]
636    pub struct MongoUpsert {
637        /// Filter to find existing document.
638        pub filter: serde_json::Map<String, JsonValue>,
639        /// Update operations or replacement document.
640        pub update: JsonValue,
641        /// Insert-only fields ($setOnInsert).
642        pub set_on_insert: Option<serde_json::Map<String, JsonValue>>,
643        /// Array filters for updates.
644        pub array_filters: Option<Vec<JsonValue>>,
645    }
646
647    impl MongoUpsert {
648        /// Create a new upsert with filter.
649        pub fn new() -> MongoUpsertBuilder {
650            MongoUpsertBuilder::default()
651        }
652
653        /// Convert to updateOne options.
654        pub fn to_update_one(&self) -> JsonValue {
655            let mut options = serde_json::Map::new();
656            options.insert("upsert".to_string(), JsonValue::Bool(true));
657
658            if let Some(ref filters) = self.array_filters {
659                options.insert(
660                    "arrayFilters".to_string(),
661                    JsonValue::Array(filters.clone()),
662                );
663            }
664
665            serde_json::json!({
666                "filter": self.filter,
667                "update": self.update,
668                "options": options
669            })
670        }
671
672        /// Convert to findOneAndUpdate options.
673        pub fn to_find_one_and_update(&self, return_new: bool) -> JsonValue {
674            let mut options = serde_json::Map::new();
675            options.insert("upsert".to_string(), JsonValue::Bool(true));
676            options.insert(
677                "returnDocument".to_string(),
678                JsonValue::String(if return_new { "after" } else { "before" }.to_string()),
679            );
680
681            if let Some(ref filters) = self.array_filters {
682                options.insert(
683                    "arrayFilters".to_string(),
684                    JsonValue::Array(filters.clone()),
685                );
686            }
687
688            serde_json::json!({
689                "filter": self.filter,
690                "update": self.update,
691                "options": options
692            })
693        }
694
695        /// Convert to replaceOne options.
696        pub fn to_replace_one(&self, replacement: JsonValue) -> JsonValue {
697            serde_json::json!({
698                "filter": self.filter,
699                "replacement": replacement,
700                "options": { "upsert": true }
701            })
702        }
703    }
704
705    /// Builder for MongoDB upsert.
706    #[derive(Debug, Clone, Default)]
707    #[allow(dead_code)]
708    pub struct MongoUpsertBuilder {
709        filter: serde_json::Map<String, JsonValue>,
710        set: serde_json::Map<String, JsonValue>,
711        set_on_insert: serde_json::Map<String, JsonValue>,
712        inc: serde_json::Map<String, JsonValue>,
713        unset: Vec<String>,
714        push: serde_json::Map<String, JsonValue>,
715        pull: serde_json::Map<String, JsonValue>,
716        add_to_set: serde_json::Map<String, JsonValue>,
717        array_filters: Option<Vec<JsonValue>>,
718    }
719
720    impl MongoUpsertBuilder {
721        /// Set filter field equality.
722        pub fn filter_eq(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
723            self.filter.insert(field.into(), value.into());
724            self
725        }
726
727        /// Set filter with raw document.
728        pub fn filter(mut self, filter: serde_json::Map<String, JsonValue>) -> Self {
729            self.filter = filter;
730            self
731        }
732
733        /// Add $set field.
734        pub fn set(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
735            self.set.insert(field.into(), value.into());
736            self
737        }
738
739        /// Add $setOnInsert field (only on insert).
740        pub fn set_on_insert(
741            mut self,
742            field: impl Into<String>,
743            value: impl Into<JsonValue>,
744        ) -> Self {
745            self.set_on_insert.insert(field.into(), value.into());
746            self
747        }
748
749        /// Add $inc field.
750        pub fn inc(mut self, field: impl Into<String>, amount: impl Into<JsonValue>) -> Self {
751            self.inc.insert(field.into(), amount.into());
752            self
753        }
754
755        /// Add $unset field.
756        pub fn unset(mut self, field: impl Into<String>) -> Self {
757            self.unset.push(field.into());
758            self
759        }
760
761        /// Add $push field.
762        pub fn push(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
763            self.push.insert(field.into(), value.into());
764            self
765        }
766
767        /// Add $addToSet field.
768        pub fn add_to_set(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
769            self.add_to_set.insert(field.into(), value.into());
770            self
771        }
772
773        /// Add array filters.
774        pub fn array_filter(mut self, filter: JsonValue) -> Self {
775            self.array_filters.get_or_insert_with(Vec::new).push(filter);
776            self
777        }
778
779        /// Build the upsert.
780        pub fn build(self) -> MongoUpsert {
781            let mut update = serde_json::Map::new();
782
783            if !self.set.is_empty() {
784                update.insert("$set".to_string(), JsonValue::Object(self.set));
785            }
786
787            if !self.set_on_insert.is_empty() {
788                update.insert(
789                    "$setOnInsert".to_string(),
790                    JsonValue::Object(self.set_on_insert.clone()),
791                );
792            }
793
794            if !self.inc.is_empty() {
795                update.insert("$inc".to_string(), JsonValue::Object(self.inc));
796            }
797
798            if !self.unset.is_empty() {
799                let unset_obj: serde_json::Map<String, JsonValue> = self
800                    .unset
801                    .into_iter()
802                    .map(|f| (f, JsonValue::String(String::new())))
803                    .collect();
804                update.insert("$unset".to_string(), JsonValue::Object(unset_obj));
805            }
806
807            if !self.push.is_empty() {
808                update.insert("$push".to_string(), JsonValue::Object(self.push));
809            }
810
811            if !self.add_to_set.is_empty() {
812                update.insert("$addToSet".to_string(), JsonValue::Object(self.add_to_set));
813            }
814
815            MongoUpsert {
816                filter: self.filter,
817                update: JsonValue::Object(update),
818                set_on_insert: if self.set_on_insert.is_empty() {
819                    None
820                } else {
821                    Some(self.set_on_insert)
822                },
823                array_filters: self.array_filters,
824            }
825        }
826    }
827
828    /// Bulk upsert operation.
829    #[derive(Debug, Clone, Default)]
830    pub struct BulkUpsert {
831        /// Operations to perform.
832        pub operations: Vec<BulkUpsertOp>,
833        /// Whether operations are ordered.
834        pub ordered: bool,
835    }
836
837    /// A single bulk upsert operation.
838    #[derive(Debug, Clone, Serialize, Deserialize)]
839    pub struct BulkUpsertOp {
840        /// Filter to match document.
841        pub filter: serde_json::Map<String, JsonValue>,
842        /// Update document.
843        pub update: JsonValue,
844    }
845
846    impl BulkUpsert {
847        /// Create a new bulk upsert.
848        pub fn new() -> Self {
849            Self::default()
850        }
851
852        /// Set ordered mode.
853        pub fn ordered(mut self, ordered: bool) -> Self {
854            self.ordered = ordered;
855            self
856        }
857
858        /// Add an upsert operation.
859        pub fn add(
860            mut self,
861            filter: serde_json::Map<String, JsonValue>,
862            update: JsonValue,
863        ) -> Self {
864            self.operations.push(BulkUpsertOp { filter, update });
865            self
866        }
867
868        /// Convert to bulkWrite operations.
869        pub fn to_bulk_write(&self) -> JsonValue {
870            let ops: Vec<JsonValue> = self
871                .operations
872                .iter()
873                .map(|op| {
874                    serde_json::json!({
875                        "updateOne": {
876                            "filter": op.filter,
877                            "update": op.update,
878                            "upsert": true
879                        }
880                    })
881                })
882                .collect();
883
884            serde_json::json!({
885                "operations": ops,
886                "options": { "ordered": self.ordered }
887            })
888        }
889    }
890
891    /// Helper to create a MongoDB upsert.
892    pub fn upsert() -> MongoUpsertBuilder {
893        MongoUpsertBuilder::default()
894    }
895
896    /// Helper to create a bulk upsert.
897    pub fn bulk_upsert() -> BulkUpsert {
898        BulkUpsert::new()
899    }
900}
901
902#[cfg(test)]
903mod tests {
904    use super::*;
905
906    #[test]
907    fn test_postgres_on_conflict_do_nothing() {
908        let upsert = Upsert::new("users")
909            .columns(["email", "name"])
910            .values(["$1", "$2"])
911            .on_conflict(ConflictTarget::columns(["email"]))
912            .do_nothing();
913
914        let sql = upsert.to_postgres_sql();
915        assert!(sql.contains("INSERT INTO users"));
916        assert!(sql.contains("ON CONFLICT (email) DO NOTHING"));
917    }
918
919    #[test]
920    fn test_postgres_on_conflict_do_update() {
921        let upsert = Upsert::new("users")
922            .columns(["email", "name", "updated_at"])
923            .values(["$1", "$2", "NOW()"])
924            .on_conflict(ConflictTarget::columns(["email"]))
925            .do_update(["name", "updated_at"]);
926
927        let sql = upsert.to_postgres_sql();
928        assert!(sql.contains("ON CONFLICT (email) DO UPDATE SET"));
929        assert!(sql.contains("name = EXCLUDED.name"));
930        assert!(sql.contains("updated_at = EXCLUDED.updated_at"));
931    }
932
933    #[test]
934    fn test_postgres_with_where() {
935        let upsert = Upsert::new("users")
936            .columns(["email", "name"])
937            .values(["$1", "$2"])
938            .on_conflict(ConflictTarget::columns(["email"]))
939            .do_update(["name"])
940            .where_clause("users.active = true");
941
942        let sql = upsert.to_postgres_sql();
943        assert!(sql.contains("WHERE users.active = true"));
944    }
945
946    #[test]
947    fn test_postgres_with_returning() {
948        let upsert = Upsert::new("users")
949            .columns(["email", "name"])
950            .values(["$1", "$2"])
951            .on_conflict(ConflictTarget::columns(["email"]))
952            .do_update(["name"])
953            .returning(["id", "email"]);
954
955        let sql = upsert.to_postgres_sql();
956        assert!(sql.contains("RETURNING id, email"));
957    }
958
959    #[test]
960    fn test_mysql_on_duplicate_key() {
961        let upsert = Upsert::new("users")
962            .columns(["email", "name"])
963            .values(["?", "?"])
964            .do_update(["name"]);
965
966        let sql = upsert.to_mysql_sql();
967        assert!(sql.contains("INSERT INTO users"));
968        assert!(sql.contains("ON DUPLICATE KEY UPDATE"));
969        assert!(sql.contains("name = VALUES(name)"));
970    }
971
972    #[test]
973    fn test_mysql_insert_ignore() {
974        let upsert = Upsert::new("users")
975            .columns(["email", "name"])
976            .values(["?", "?"])
977            .do_nothing();
978
979        let sql = upsert.to_mysql_sql();
980        assert!(sql.contains("INSERT IGNORE INTO users"));
981    }
982
983    #[test]
984    fn test_sqlite_on_conflict() {
985        let upsert = Upsert::new("users")
986            .columns(["email", "name"])
987            .values(["?", "?"])
988            .on_conflict(ConflictTarget::columns(["email"]))
989            .do_update(["name"]);
990
991        let sql = upsert.to_sqlite_sql();
992        assert!(sql.contains("ON CONFLICT (email) DO UPDATE SET"));
993        assert!(sql.contains("name = excluded.name"));
994    }
995
996    #[test]
997    fn test_mssql_merge() {
998        let upsert = Upsert::new("users")
999            .columns(["email", "name"])
1000            .values(["@P1", "@P2"])
1001            .on_conflict(ConflictTarget::columns(["email"]))
1002            .do_update(["name"]);
1003
1004        let sql = upsert.to_mssql_sql();
1005        assert!(sql.contains("MERGE INTO users AS target"));
1006        assert!(sql.contains("USING (SELECT"));
1007        assert!(sql.contains("WHEN MATCHED THEN UPDATE SET"));
1008        assert!(sql.contains("WHEN NOT MATCHED THEN INSERT"));
1009    }
1010
1011    #[test]
1012    fn test_upsert_builder() {
1013        let upsert = UpsertBuilder::new("users")
1014            .columns(["email", "name"])
1015            .values(["$1", "$2"])
1016            .on_conflict_columns(["email"])
1017            .do_update(["name"])
1018            .returning(["id"])
1019            .build()
1020            .unwrap();
1021
1022        assert_eq!(upsert.table, "users");
1023        assert_eq!(upsert.columns, vec!["email", "name"]);
1024    }
1025
1026    #[test]
1027    fn test_conflict_target_constraint() {
1028        let target = ConflictTarget::constraint("users_email_key");
1029        assert_eq!(target.to_postgres_sql(), "ON CONSTRAINT users_email_key");
1030    }
1031
1032    mod mongodb_tests {
1033        use super::super::mongodb::*;
1034
1035        #[test]
1036        fn test_simple_upsert() {
1037            let upsert = upsert()
1038                .filter_eq("email", "test@example.com")
1039                .set("name", "John")
1040                .set("updated_at", serde_json::json!({"$date": "2024-01-01"}))
1041                .set_on_insert("created_at", serde_json::json!({"$date": "2024-01-01"}))
1042                .build();
1043
1044            let doc = upsert.to_update_one();
1045            assert!(doc["options"]["upsert"].as_bool().unwrap());
1046            assert!(doc["update"]["$set"]["name"].is_string());
1047            assert!(doc["update"]["$setOnInsert"].is_object());
1048        }
1049
1050        #[test]
1051        fn test_upsert_with_inc() {
1052            let upsert = upsert()
1053                .filter_eq("_id", "doc1")
1054                .inc("visits", 1)
1055                .set("last_visit", "2024-01-01")
1056                .build();
1057
1058            let doc = upsert.to_update_one();
1059            assert_eq!(doc["update"]["$inc"]["visits"], 1);
1060        }
1061
1062        #[test]
1063        fn test_find_one_and_update() {
1064            let upsert = upsert()
1065                .filter_eq("email", "test@example.com")
1066                .set("name", "Updated")
1067                .build();
1068
1069            let doc = upsert.to_find_one_and_update(true);
1070            assert_eq!(doc["options"]["returnDocument"], "after");
1071            assert!(doc["options"]["upsert"].as_bool().unwrap());
1072        }
1073
1074        #[test]
1075        fn test_bulk_upsert() {
1076            let mut filter1 = serde_json::Map::new();
1077            filter1.insert("email".to_string(), serde_json::json!("a@b.com"));
1078
1079            let mut filter2 = serde_json::Map::new();
1080            filter2.insert("email".to_string(), serde_json::json!("c@d.com"));
1081
1082            let bulk = bulk_upsert()
1083                .ordered(false)
1084                .add(filter1, serde_json::json!({"$set": {"name": "A"}}))
1085                .add(filter2, serde_json::json!({"$set": {"name": "B"}}));
1086
1087            let doc = bulk.to_bulk_write();
1088            assert!(!doc["options"]["ordered"].as_bool().unwrap());
1089            assert_eq!(doc["operations"].as_array().unwrap().len(), 2);
1090        }
1091    }
1092}