prax_query/
upsert.rs

1//! Upsert and conflict resolution support.
2//!
3//! This module provides types for building upsert operations with conflict
4//! resolution across different database backends.
5//!
6//! # Database Support
7//!
8//! | Feature          | PostgreSQL     | MySQL              | SQLite         | MSSQL   | MongoDB      |
9//! |------------------|----------------|--------------------|----------------|---------|--------------|
10//! | ON CONFLICT      | ✅             | ❌                 | ✅             | ❌      | ❌           |
11//! | ON DUPLICATE KEY | ❌             | ✅                 | ❌             | ❌      | ❌           |
12//! | MERGE statement  | ❌             | ❌                 | ❌             | ✅      | ❌           |
13//! | Native upsert    | ❌             | ❌                 | ❌             | ❌      | ✅ upsert:true|
14//! | Conflict targets | ✅             | ❌ (implicit PK/UK)| ✅             | ✅      | ✅ filter    |
15//!
16//! # Example Usage
17//!
18//! ```rust,ignore
19//! use prax_query::upsert::{Upsert, ConflictTarget, ConflictAction};
20//!
21//! // PostgreSQL: INSERT ... ON CONFLICT (email) DO UPDATE SET ...
22//! let upsert = Upsert::new("users")
23//!     .columns(["email", "name", "updated_at"])
24//!     .values(["$1", "$2", "NOW()"])
25//!     .on_conflict(ConflictTarget::columns(["email"]))
26//!     .do_update(["name", "updated_at"]);
27//! ```
28
29use serde::{Deserialize, Serialize};
30
31use crate::error::{QueryError, QueryResult};
32use crate::sql::DatabaseType;
33
34/// An upsert operation (INSERT with conflict handling).
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub struct Upsert {
37    /// Table name.
38    pub table: String,
39    /// Columns to insert.
40    pub columns: Vec<String>,
41    /// Values to insert (expressions or placeholders).
42    pub values: Vec<String>,
43    /// Conflict target specification.
44    pub conflict_target: Option<ConflictTarget>,
45    /// Action to take on conflict.
46    pub conflict_action: ConflictAction,
47    /// WHERE clause for conflict update (PostgreSQL).
48    pub where_clause: Option<String>,
49    /// RETURNING clause (PostgreSQL).
50    pub returning: Option<Vec<String>>,
51}
52
53/// What to match on for conflict detection.
54#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55pub enum ConflictTarget {
56    /// Match on specific columns (unique constraint).
57    Columns(Vec<String>),
58    /// Match on a named constraint.
59    Constraint(String),
60    /// Match on index expression (PostgreSQL).
61    IndexExpression(String),
62    /// No specific target (MySQL ON DUPLICATE KEY).
63    Implicit,
64}
65
66impl ConflictTarget {
67    /// Create a column-based conflict target.
68    pub fn columns<I, S>(cols: I) -> Self
69    where
70        I: IntoIterator<Item = S>,
71        S: Into<String>,
72    {
73        Self::Columns(cols.into_iter().map(Into::into).collect())
74    }
75
76    /// Create a constraint-based conflict target.
77    pub fn constraint(name: impl Into<String>) -> Self {
78        Self::Constraint(name.into())
79    }
80
81    /// Create an index expression conflict target.
82    pub fn index_expression(expr: impl Into<String>) -> Self {
83        Self::IndexExpression(expr.into())
84    }
85
86    /// Generate PostgreSQL ON CONFLICT target.
87    pub fn to_postgres_sql(&self) -> String {
88        match self {
89            Self::Columns(cols) => format!("({})", cols.join(", ")),
90            Self::Constraint(name) => format!("ON CONSTRAINT {}", name),
91            Self::IndexExpression(expr) => format!("({})", expr),
92            Self::Implicit => String::new(),
93        }
94    }
95
96    /// Generate SQLite ON CONFLICT target.
97    pub fn to_sqlite_sql(&self) -> String {
98        match self {
99            Self::Columns(cols) => format!("({})", cols.join(", ")),
100            Self::Constraint(_) | Self::IndexExpression(_) => {
101                // SQLite doesn't support these directly
102                String::new()
103            }
104            Self::Implicit => String::new(),
105        }
106    }
107}
108
109/// Action to take when a conflict is detected.
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
111pub enum ConflictAction {
112    /// Do nothing (ignore the insert).
113    DoNothing,
114    /// Update specified columns.
115    DoUpdate(UpdateSpec),
116}
117
118/// Specification for what to update on conflict.
119#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
120pub struct UpdateSpec {
121    /// Columns to update with their values.
122    pub assignments: Vec<Assignment>,
123    /// Use EXCLUDED values for columns (PostgreSQL/SQLite).
124    pub excluded_columns: Vec<String>,
125}
126
127/// A single column assignment.
128#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
129pub struct Assignment {
130    /// Column name.
131    pub column: String,
132    /// Value expression.
133    pub value: AssignmentValue,
134}
135
136/// Value for an assignment.
137#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
138pub enum AssignmentValue {
139    /// Use the EXCLUDED/VALUES value.
140    Excluded,
141    /// Use a literal expression.
142    Expression(String),
143    /// Use a parameter placeholder.
144    Param(usize),
145}
146
147impl Upsert {
148    /// Create a new upsert for the given table.
149    pub fn new(table: impl Into<String>) -> Self {
150        Self {
151            table: table.into(),
152            columns: Vec::new(),
153            values: Vec::new(),
154            conflict_target: None,
155            conflict_action: ConflictAction::DoNothing,
156            where_clause: None,
157            returning: None,
158        }
159    }
160
161    /// Create an upsert builder.
162    pub fn builder(table: impl Into<String>) -> UpsertBuilder {
163        UpsertBuilder::new(table)
164    }
165
166    /// Set the columns to insert.
167    pub fn columns<I, S>(mut self, cols: I) -> Self
168    where
169        I: IntoIterator<Item = S>,
170        S: Into<String>,
171    {
172        self.columns = cols.into_iter().map(Into::into).collect();
173        self
174    }
175
176    /// Set the values to insert.
177    pub fn values<I, S>(mut self, vals: I) -> Self
178    where
179        I: IntoIterator<Item = S>,
180        S: Into<String>,
181    {
182        self.values = vals.into_iter().map(Into::into).collect();
183        self
184    }
185
186    /// Set the conflict target.
187    pub fn on_conflict(mut self, target: ConflictTarget) -> Self {
188        self.conflict_target = Some(target);
189        self
190    }
191
192    /// Set conflict action to DO NOTHING.
193    pub fn do_nothing(mut self) -> Self {
194        self.conflict_action = ConflictAction::DoNothing;
195        self
196    }
197
198    /// Set conflict action to DO UPDATE for specified columns (using EXCLUDED).
199    pub fn do_update<I, S>(mut self, cols: I) -> Self
200    where
201        I: IntoIterator<Item = S>,
202        S: Into<String>,
203    {
204        self.conflict_action = ConflictAction::DoUpdate(UpdateSpec {
205            assignments: Vec::new(),
206            excluded_columns: cols.into_iter().map(Into::into).collect(),
207        });
208        self
209    }
210
211    /// Set conflict action to DO UPDATE with specific assignments.
212    pub fn do_update_set(mut self, assignments: Vec<Assignment>) -> Self {
213        self.conflict_action = ConflictAction::DoUpdate(UpdateSpec {
214            assignments,
215            excluded_columns: Vec::new(),
216        });
217        self
218    }
219
220    /// Add a WHERE clause for the update (PostgreSQL).
221    pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
222        self.where_clause = Some(condition.into());
223        self
224    }
225
226    /// Add RETURNING clause (PostgreSQL).
227    pub fn returning<I, S>(mut self, cols: I) -> Self
228    where
229        I: IntoIterator<Item = S>,
230        S: Into<String>,
231    {
232        self.returning = Some(cols.into_iter().map(Into::into).collect());
233        self
234    }
235
236    /// Generate PostgreSQL INSERT ... ON CONFLICT SQL.
237    pub fn to_postgres_sql(&self) -> String {
238        let mut sql = format!(
239            "INSERT INTO {} ({}) VALUES ({})",
240            self.table,
241            self.columns.join(", "),
242            self.values.join(", ")
243        );
244
245        sql.push_str(" ON CONFLICT ");
246
247        if let Some(ref target) = self.conflict_target {
248            sql.push_str(&target.to_postgres_sql());
249            sql.push(' ');
250        }
251
252        match &self.conflict_action {
253            ConflictAction::DoNothing => {
254                sql.push_str("DO NOTHING");
255            }
256            ConflictAction::DoUpdate(spec) => {
257                sql.push_str("DO UPDATE SET ");
258                let assignments: Vec<String> = if !spec.excluded_columns.is_empty() {
259                    spec.excluded_columns
260                        .iter()
261                        .map(|c| format!("{} = EXCLUDED.{}", c, c))
262                        .collect()
263                } else {
264                    spec.assignments
265                        .iter()
266                        .map(|a| {
267                            let value = match &a.value {
268                                AssignmentValue::Excluded => format!("EXCLUDED.{}", a.column),
269                                AssignmentValue::Expression(expr) => expr.clone(),
270                                AssignmentValue::Param(n) => format!("${}", n),
271                            };
272                            format!("{} = {}", a.column, value)
273                        })
274                        .collect()
275                };
276                sql.push_str(&assignments.join(", "));
277
278                if let Some(ref where_clause) = self.where_clause {
279                    sql.push_str(" WHERE ");
280                    sql.push_str(where_clause);
281                }
282            }
283        }
284
285        if let Some(ref returning) = self.returning {
286            sql.push_str(" RETURNING ");
287            sql.push_str(&returning.join(", "));
288        }
289
290        sql
291    }
292
293    /// Generate MySQL INSERT ... ON DUPLICATE KEY UPDATE SQL.
294    pub fn to_mysql_sql(&self) -> String {
295        let mut sql = format!(
296            "INSERT INTO {} ({}) VALUES ({})",
297            self.table,
298            self.columns.join(", "),
299            self.values.join(", ")
300        );
301
302        match &self.conflict_action {
303            ConflictAction::DoNothing => {
304                // MySQL doesn't have DO NOTHING, use INSERT IGNORE
305                sql = format!(
306                    "INSERT IGNORE INTO {} ({}) VALUES ({})",
307                    self.table,
308                    self.columns.join(", "),
309                    self.values.join(", ")
310                );
311            }
312            ConflictAction::DoUpdate(spec) => {
313                sql.push_str(" ON DUPLICATE KEY UPDATE ");
314                let assignments: Vec<String> = if !spec.excluded_columns.is_empty() {
315                    spec.excluded_columns
316                        .iter()
317                        .map(|c| format!("{} = VALUES({})", c, c))
318                        .collect()
319                } else {
320                    spec.assignments
321                        .iter()
322                        .map(|a| {
323                            let value = match &a.value {
324                                AssignmentValue::Excluded => format!("VALUES({})", a.column),
325                                AssignmentValue::Expression(expr) => expr.clone(),
326                                AssignmentValue::Param(_n) => "?".to_string(),
327                            };
328                            format!("{} = {}", a.column, value)
329                        })
330                        .collect()
331                };
332                sql.push_str(&assignments.join(", "));
333            }
334        }
335
336        sql
337    }
338
339    /// Generate SQLite INSERT ... ON CONFLICT SQL.
340    pub fn to_sqlite_sql(&self) -> String {
341        let mut sql = format!(
342            "INSERT INTO {} ({}) VALUES ({})",
343            self.table,
344            self.columns.join(", "),
345            self.values.join(", ")
346        );
347
348        sql.push_str(" ON CONFLICT");
349
350        if let Some(ref target) = self.conflict_target {
351            let target_sql = target.to_sqlite_sql();
352            if !target_sql.is_empty() {
353                sql.push(' ');
354                sql.push_str(&target_sql);
355            }
356        }
357
358        match &self.conflict_action {
359            ConflictAction::DoNothing => {
360                sql.push_str(" DO NOTHING");
361            }
362            ConflictAction::DoUpdate(spec) => {
363                sql.push_str(" DO UPDATE SET ");
364                let assignments: Vec<String> = if !spec.excluded_columns.is_empty() {
365                    spec.excluded_columns
366                        .iter()
367                        .map(|c| format!("{} = excluded.{}", c, c))
368                        .collect()
369                } else {
370                    spec.assignments
371                        .iter()
372                        .map(|a| {
373                            let value = match &a.value {
374                                AssignmentValue::Excluded => format!("excluded.{}", a.column),
375                                AssignmentValue::Expression(expr) => expr.clone(),
376                                AssignmentValue::Param(_n) => "?".to_string(),
377                            };
378                            format!("{} = {}", a.column, value)
379                        })
380                        .collect()
381                };
382                sql.push_str(&assignments.join(", "));
383
384                if let Some(ref where_clause) = self.where_clause {
385                    sql.push_str(" WHERE ");
386                    sql.push_str(where_clause);
387                }
388            }
389        }
390
391        if let Some(ref returning) = self.returning {
392            sql.push_str(" RETURNING ");
393            sql.push_str(&returning.join(", "));
394        }
395
396        sql
397    }
398
399    /// Generate MSSQL MERGE statement.
400    pub fn to_mssql_sql(&self) -> String {
401        let target = self
402            .conflict_target
403            .as_ref()
404            .and_then(|t| match t {
405                ConflictTarget::Columns(cols) => Some(cols.clone()),
406                _ => None,
407            })
408            .unwrap_or_else(|| vec![self.columns.first().cloned().unwrap_or_default()]);
409
410        let source_cols: Vec<String> = self
411            .columns
412            .iter()
413            .zip(&self.values)
414            .map(|(c, v)| format!("{} AS {}", v, c))
415            .collect();
416
417        let match_conditions: Vec<String> = target
418            .iter()
419            .map(|c| format!("target.{} = source.{}", c, c))
420            .collect();
421
422        let mut sql = format!(
423            "MERGE INTO {} AS target USING (SELECT {}) AS source ON {}",
424            self.table,
425            source_cols.join(", "),
426            match_conditions.join(" AND ")
427        );
428
429        match &self.conflict_action {
430            ConflictAction::DoNothing => {
431                // MSSQL MERGE requires at least one action
432                sql.push_str(" WHEN NOT MATCHED THEN INSERT (");
433                sql.push_str(&self.columns.join(", "));
434                sql.push_str(") VALUES (");
435                let source_refs: Vec<String> = self.columns.iter().map(|c| format!("source.{}", c)).collect();
436                sql.push_str(&source_refs.join(", "));
437                sql.push(')');
438            }
439            ConflictAction::DoUpdate(spec) => {
440                sql.push_str(" WHEN MATCHED THEN UPDATE SET ");
441
442                let update_cols = if !spec.excluded_columns.is_empty() {
443                    &spec.excluded_columns
444                } else {
445                    &self.columns
446                };
447
448                let assignments: Vec<String> = update_cols
449                    .iter()
450                    .filter(|c| !target.contains(c))
451                    .map(|c| format!("target.{} = source.{}", c, c))
452                    .collect();
453
454                if assignments.is_empty() {
455                    // Need at least one assignment, use first non-key column
456                    let first_non_key = self.columns.iter().find(|c| !target.contains(*c));
457                    if let Some(col) = first_non_key {
458                        sql.push_str(&format!("target.{} = source.{}", col, col));
459                    } else {
460                        sql.push_str(&format!("target.{} = source.{}", self.columns[0], self.columns[0]));
461                    }
462                } else {
463                    sql.push_str(&assignments.join(", "));
464                }
465
466                sql.push_str(" WHEN NOT MATCHED THEN INSERT (");
467                sql.push_str(&self.columns.join(", "));
468                sql.push_str(") VALUES (");
469                let source_refs: Vec<String> = self.columns.iter().map(|c| format!("source.{}", c)).collect();
470                sql.push_str(&source_refs.join(", "));
471                sql.push(')');
472            }
473        }
474
475        sql.push(';');
476        sql
477    }
478
479    /// Generate SQL for the specified database type.
480    pub fn to_sql(&self, db_type: DatabaseType) -> String {
481        match db_type {
482            DatabaseType::PostgreSQL => self.to_postgres_sql(),
483            DatabaseType::MySQL => self.to_mysql_sql(),
484            DatabaseType::SQLite => self.to_sqlite_sql(),
485            DatabaseType::MSSQL => self.to_mssql_sql(),
486        }
487    }
488}
489
490/// Builder for upsert operations.
491#[derive(Debug, Clone, Default)]
492pub struct UpsertBuilder {
493    table: String,
494    columns: Vec<String>,
495    values: Vec<String>,
496    conflict_target: Option<ConflictTarget>,
497    conflict_action: Option<ConflictAction>,
498    where_clause: Option<String>,
499    returning: Option<Vec<String>>,
500}
501
502impl UpsertBuilder {
503    /// Create a new builder.
504    pub fn new(table: impl Into<String>) -> Self {
505        Self {
506            table: table.into(),
507            ..Default::default()
508        }
509    }
510
511    /// Add columns to insert.
512    pub fn columns<I, S>(mut self, cols: I) -> Self
513    where
514        I: IntoIterator<Item = S>,
515        S: Into<String>,
516    {
517        self.columns = cols.into_iter().map(Into::into).collect();
518        self
519    }
520
521    /// Add values to insert.
522    pub fn values<I, S>(mut self, vals: I) -> Self
523    where
524        I: IntoIterator<Item = S>,
525        S: Into<String>,
526    {
527        self.values = vals.into_iter().map(Into::into).collect();
528        self
529    }
530
531    /// Set conflict target columns.
532    pub fn on_conflict_columns<I, S>(mut self, cols: I) -> Self
533    where
534        I: IntoIterator<Item = S>,
535        S: Into<String>,
536    {
537        self.conflict_target = Some(ConflictTarget::columns(cols));
538        self
539    }
540
541    /// Set conflict target constraint.
542    pub fn on_conflict_constraint(mut self, name: impl Into<String>) -> Self {
543        self.conflict_target = Some(ConflictTarget::Constraint(name.into()));
544        self
545    }
546
547    /// Set action to DO NOTHING.
548    pub fn do_nothing(mut self) -> Self {
549        self.conflict_action = Some(ConflictAction::DoNothing);
550        self
551    }
552
553    /// Set action to DO UPDATE with excluded columns.
554    pub fn do_update<I, S>(mut self, cols: I) -> Self
555    where
556        I: IntoIterator<Item = S>,
557        S: Into<String>,
558    {
559        self.conflict_action = Some(ConflictAction::DoUpdate(UpdateSpec {
560            assignments: Vec::new(),
561            excluded_columns: cols.into_iter().map(Into::into).collect(),
562        }));
563        self
564    }
565
566    /// Set action to DO UPDATE with assignments.
567    pub fn do_update_assignments(mut self, assignments: Vec<Assignment>) -> Self {
568        self.conflict_action = Some(ConflictAction::DoUpdate(UpdateSpec {
569            assignments,
570            excluded_columns: Vec::new(),
571        }));
572        self
573    }
574
575    /// Add WHERE clause for update.
576    pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
577        self.where_clause = Some(condition.into());
578        self
579    }
580
581    /// Add RETURNING clause.
582    pub fn returning<I, S>(mut self, cols: I) -> Self
583    where
584        I: IntoIterator<Item = S>,
585        S: Into<String>,
586    {
587        self.returning = Some(cols.into_iter().map(Into::into).collect());
588        self
589    }
590
591    /// Build the upsert.
592    pub fn build(self) -> QueryResult<Upsert> {
593        if self.columns.is_empty() {
594            return Err(QueryError::invalid_input("columns", "Upsert requires at least one column"));
595        }
596        if self.values.is_empty() {
597            return Err(QueryError::invalid_input("values", "Upsert requires at least one value"));
598        }
599
600        Ok(Upsert {
601            table: self.table,
602            columns: self.columns,
603            values: self.values,
604            conflict_target: self.conflict_target,
605            conflict_action: self.conflict_action.unwrap_or(ConflictAction::DoNothing),
606            where_clause: self.where_clause,
607            returning: self.returning,
608        })
609    }
610}
611
612/// MongoDB upsert operations.
613pub mod mongodb {
614    use serde::{Deserialize, Serialize};
615    use serde_json::Value as JsonValue;
616
617    /// MongoDB upsert operation builder.
618    #[derive(Debug, Clone, Default)]
619    pub struct MongoUpsert {
620        /// Filter to find existing document.
621        pub filter: serde_json::Map<String, JsonValue>,
622        /// Update operations or replacement document.
623        pub update: JsonValue,
624        /// Insert-only fields ($setOnInsert).
625        pub set_on_insert: Option<serde_json::Map<String, JsonValue>>,
626        /// Array filters for updates.
627        pub array_filters: Option<Vec<JsonValue>>,
628    }
629
630    impl MongoUpsert {
631        /// Create a new upsert with filter.
632        pub fn new() -> MongoUpsertBuilder {
633            MongoUpsertBuilder::default()
634        }
635
636        /// Convert to updateOne options.
637        pub fn to_update_one(&self) -> JsonValue {
638            let mut options = serde_json::Map::new();
639            options.insert("upsert".to_string(), JsonValue::Bool(true));
640
641            if let Some(ref filters) = self.array_filters {
642                options.insert("arrayFilters".to_string(), JsonValue::Array(filters.clone()));
643            }
644
645            serde_json::json!({
646                "filter": self.filter,
647                "update": self.update,
648                "options": options
649            })
650        }
651
652        /// Convert to findOneAndUpdate options.
653        pub fn to_find_one_and_update(&self, return_new: bool) -> JsonValue {
654            let mut options = serde_json::Map::new();
655            options.insert("upsert".to_string(), JsonValue::Bool(true));
656            options.insert(
657                "returnDocument".to_string(),
658                JsonValue::String(if return_new { "after" } else { "before" }.to_string()),
659            );
660
661            if let Some(ref filters) = self.array_filters {
662                options.insert("arrayFilters".to_string(), JsonValue::Array(filters.clone()));
663            }
664
665            serde_json::json!({
666                "filter": self.filter,
667                "update": self.update,
668                "options": options
669            })
670        }
671
672        /// Convert to replaceOne options.
673        pub fn to_replace_one(&self, replacement: JsonValue) -> JsonValue {
674            serde_json::json!({
675                "filter": self.filter,
676                "replacement": replacement,
677                "options": { "upsert": true }
678            })
679        }
680    }
681
682    /// Builder for MongoDB upsert.
683    #[derive(Debug, Clone, Default)]
684    #[allow(dead_code)]
685    pub struct MongoUpsertBuilder {
686        filter: serde_json::Map<String, JsonValue>,
687        set: serde_json::Map<String, JsonValue>,
688        set_on_insert: serde_json::Map<String, JsonValue>,
689        inc: serde_json::Map<String, JsonValue>,
690        unset: Vec<String>,
691        push: serde_json::Map<String, JsonValue>,
692        pull: serde_json::Map<String, JsonValue>,
693        add_to_set: serde_json::Map<String, JsonValue>,
694        array_filters: Option<Vec<JsonValue>>,
695    }
696
697    impl MongoUpsertBuilder {
698        /// Set filter field equality.
699        pub fn filter_eq(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
700            self.filter.insert(field.into(), value.into());
701            self
702        }
703
704        /// Set filter with raw document.
705        pub fn filter(mut self, filter: serde_json::Map<String, JsonValue>) -> Self {
706            self.filter = filter;
707            self
708        }
709
710        /// Add $set field.
711        pub fn set(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
712            self.set.insert(field.into(), value.into());
713            self
714        }
715
716        /// Add $setOnInsert field (only on insert).
717        pub fn set_on_insert(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
718            self.set_on_insert.insert(field.into(), value.into());
719            self
720        }
721
722        /// Add $inc field.
723        pub fn inc(mut self, field: impl Into<String>, amount: impl Into<JsonValue>) -> Self {
724            self.inc.insert(field.into(), amount.into());
725            self
726        }
727
728        /// Add $unset field.
729        pub fn unset(mut self, field: impl Into<String>) -> Self {
730            self.unset.push(field.into());
731            self
732        }
733
734        /// Add $push field.
735        pub fn push(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
736            self.push.insert(field.into(), value.into());
737            self
738        }
739
740        /// Add $addToSet field.
741        pub fn add_to_set(mut self, field: impl Into<String>, value: impl Into<JsonValue>) -> Self {
742            self.add_to_set.insert(field.into(), value.into());
743            self
744        }
745
746        /// Add array filters.
747        pub fn array_filter(mut self, filter: JsonValue) -> Self {
748            self.array_filters
749                .get_or_insert_with(Vec::new)
750                .push(filter);
751            self
752        }
753
754        /// Build the upsert.
755        pub fn build(self) -> MongoUpsert {
756            let mut update = serde_json::Map::new();
757
758            if !self.set.is_empty() {
759                update.insert("$set".to_string(), JsonValue::Object(self.set));
760            }
761
762            if !self.set_on_insert.is_empty() {
763                update.insert("$setOnInsert".to_string(), JsonValue::Object(self.set_on_insert.clone()));
764            }
765
766            if !self.inc.is_empty() {
767                update.insert("$inc".to_string(), JsonValue::Object(self.inc));
768            }
769
770            if !self.unset.is_empty() {
771                let unset_obj: serde_json::Map<String, JsonValue> = self
772                    .unset
773                    .into_iter()
774                    .map(|f| (f, JsonValue::String(String::new())))
775                    .collect();
776                update.insert("$unset".to_string(), JsonValue::Object(unset_obj));
777            }
778
779            if !self.push.is_empty() {
780                update.insert("$push".to_string(), JsonValue::Object(self.push));
781            }
782
783            if !self.add_to_set.is_empty() {
784                update.insert("$addToSet".to_string(), JsonValue::Object(self.add_to_set));
785            }
786
787            MongoUpsert {
788                filter: self.filter,
789                update: JsonValue::Object(update),
790                set_on_insert: if self.set_on_insert.is_empty() {
791                    None
792                } else {
793                    Some(self.set_on_insert)
794                },
795                array_filters: self.array_filters,
796            }
797        }
798    }
799
800    /// Bulk upsert operation.
801    #[derive(Debug, Clone, Default)]
802    pub struct BulkUpsert {
803        /// Operations to perform.
804        pub operations: Vec<BulkUpsertOp>,
805        /// Whether operations are ordered.
806        pub ordered: bool,
807    }
808
809    /// A single bulk upsert operation.
810    #[derive(Debug, Clone, Serialize, Deserialize)]
811    pub struct BulkUpsertOp {
812        /// Filter to match document.
813        pub filter: serde_json::Map<String, JsonValue>,
814        /// Update document.
815        pub update: JsonValue,
816    }
817
818    impl BulkUpsert {
819        /// Create a new bulk upsert.
820        pub fn new() -> Self {
821            Self::default()
822        }
823
824        /// Set ordered mode.
825        pub fn ordered(mut self, ordered: bool) -> Self {
826            self.ordered = ordered;
827            self
828        }
829
830        /// Add an upsert operation.
831        pub fn add(mut self, filter: serde_json::Map<String, JsonValue>, update: JsonValue) -> Self {
832            self.operations.push(BulkUpsertOp { filter, update });
833            self
834        }
835
836        /// Convert to bulkWrite operations.
837        pub fn to_bulk_write(&self) -> JsonValue {
838            let ops: Vec<JsonValue> = self
839                .operations
840                .iter()
841                .map(|op| {
842                    serde_json::json!({
843                        "updateOne": {
844                            "filter": op.filter,
845                            "update": op.update,
846                            "upsert": true
847                        }
848                    })
849                })
850                .collect();
851
852            serde_json::json!({
853                "operations": ops,
854                "options": { "ordered": self.ordered }
855            })
856        }
857    }
858
859    /// Helper to create a MongoDB upsert.
860    pub fn upsert() -> MongoUpsertBuilder {
861        MongoUpsertBuilder::default()
862    }
863
864    /// Helper to create a bulk upsert.
865    pub fn bulk_upsert() -> BulkUpsert {
866        BulkUpsert::new()
867    }
868}
869
870#[cfg(test)]
871mod tests {
872    use super::*;
873
874    #[test]
875    fn test_postgres_on_conflict_do_nothing() {
876        let upsert = Upsert::new("users")
877            .columns(["email", "name"])
878            .values(["$1", "$2"])
879            .on_conflict(ConflictTarget::columns(["email"]))
880            .do_nothing();
881
882        let sql = upsert.to_postgres_sql();
883        assert!(sql.contains("INSERT INTO users"));
884        assert!(sql.contains("ON CONFLICT (email) DO NOTHING"));
885    }
886
887    #[test]
888    fn test_postgres_on_conflict_do_update() {
889        let upsert = Upsert::new("users")
890            .columns(["email", "name", "updated_at"])
891            .values(["$1", "$2", "NOW()"])
892            .on_conflict(ConflictTarget::columns(["email"]))
893            .do_update(["name", "updated_at"]);
894
895        let sql = upsert.to_postgres_sql();
896        assert!(sql.contains("ON CONFLICT (email) DO UPDATE SET"));
897        assert!(sql.contains("name = EXCLUDED.name"));
898        assert!(sql.contains("updated_at = EXCLUDED.updated_at"));
899    }
900
901    #[test]
902    fn test_postgres_with_where() {
903        let upsert = Upsert::new("users")
904            .columns(["email", "name"])
905            .values(["$1", "$2"])
906            .on_conflict(ConflictTarget::columns(["email"]))
907            .do_update(["name"])
908            .where_clause("users.active = true");
909
910        let sql = upsert.to_postgres_sql();
911        assert!(sql.contains("WHERE users.active = true"));
912    }
913
914    #[test]
915    fn test_postgres_with_returning() {
916        let upsert = Upsert::new("users")
917            .columns(["email", "name"])
918            .values(["$1", "$2"])
919            .on_conflict(ConflictTarget::columns(["email"]))
920            .do_update(["name"])
921            .returning(["id", "email"]);
922
923        let sql = upsert.to_postgres_sql();
924        assert!(sql.contains("RETURNING id, email"));
925    }
926
927    #[test]
928    fn test_mysql_on_duplicate_key() {
929        let upsert = Upsert::new("users")
930            .columns(["email", "name"])
931            .values(["?", "?"])
932            .do_update(["name"]);
933
934        let sql = upsert.to_mysql_sql();
935        assert!(sql.contains("INSERT INTO users"));
936        assert!(sql.contains("ON DUPLICATE KEY UPDATE"));
937        assert!(sql.contains("name = VALUES(name)"));
938    }
939
940    #[test]
941    fn test_mysql_insert_ignore() {
942        let upsert = Upsert::new("users")
943            .columns(["email", "name"])
944            .values(["?", "?"])
945            .do_nothing();
946
947        let sql = upsert.to_mysql_sql();
948        assert!(sql.contains("INSERT IGNORE INTO users"));
949    }
950
951    #[test]
952    fn test_sqlite_on_conflict() {
953        let upsert = Upsert::new("users")
954            .columns(["email", "name"])
955            .values(["?", "?"])
956            .on_conflict(ConflictTarget::columns(["email"]))
957            .do_update(["name"]);
958
959        let sql = upsert.to_sqlite_sql();
960        assert!(sql.contains("ON CONFLICT (email) DO UPDATE SET"));
961        assert!(sql.contains("name = excluded.name"));
962    }
963
964    #[test]
965    fn test_mssql_merge() {
966        let upsert = Upsert::new("users")
967            .columns(["email", "name"])
968            .values(["@P1", "@P2"])
969            .on_conflict(ConflictTarget::columns(["email"]))
970            .do_update(["name"]);
971
972        let sql = upsert.to_mssql_sql();
973        assert!(sql.contains("MERGE INTO users AS target"));
974        assert!(sql.contains("USING (SELECT"));
975        assert!(sql.contains("WHEN MATCHED THEN UPDATE SET"));
976        assert!(sql.contains("WHEN NOT MATCHED THEN INSERT"));
977    }
978
979    #[test]
980    fn test_upsert_builder() {
981        let upsert = UpsertBuilder::new("users")
982            .columns(["email", "name"])
983            .values(["$1", "$2"])
984            .on_conflict_columns(["email"])
985            .do_update(["name"])
986            .returning(["id"])
987            .build()
988            .unwrap();
989
990        assert_eq!(upsert.table, "users");
991        assert_eq!(upsert.columns, vec!["email", "name"]);
992    }
993
994    #[test]
995    fn test_conflict_target_constraint() {
996        let target = ConflictTarget::constraint("users_email_key");
997        assert_eq!(target.to_postgres_sql(), "ON CONSTRAINT users_email_key");
998    }
999
1000    mod mongodb_tests {
1001        use super::super::mongodb::*;
1002
1003        #[test]
1004        fn test_simple_upsert() {
1005            let upsert = upsert()
1006                .filter_eq("email", "test@example.com")
1007                .set("name", "John")
1008                .set("updated_at", serde_json::json!({"$date": "2024-01-01"}))
1009                .set_on_insert("created_at", serde_json::json!({"$date": "2024-01-01"}))
1010                .build();
1011
1012            let doc = upsert.to_update_one();
1013            assert!(doc["options"]["upsert"].as_bool().unwrap());
1014            assert!(doc["update"]["$set"]["name"].is_string());
1015            assert!(doc["update"]["$setOnInsert"].is_object());
1016        }
1017
1018        #[test]
1019        fn test_upsert_with_inc() {
1020            let upsert = upsert()
1021                .filter_eq("_id", "doc1")
1022                .inc("visits", 1)
1023                .set("last_visit", "2024-01-01")
1024                .build();
1025
1026            let doc = upsert.to_update_one();
1027            assert_eq!(doc["update"]["$inc"]["visits"], 1);
1028        }
1029
1030        #[test]
1031        fn test_find_one_and_update() {
1032            let upsert = upsert()
1033                .filter_eq("email", "test@example.com")
1034                .set("name", "Updated")
1035                .build();
1036
1037            let doc = upsert.to_find_one_and_update(true);
1038            assert_eq!(doc["options"]["returnDocument"], "after");
1039            assert!(doc["options"]["upsert"].as_bool().unwrap());
1040        }
1041
1042        #[test]
1043        fn test_bulk_upsert() {
1044            let mut filter1 = serde_json::Map::new();
1045            filter1.insert("email".to_string(), serde_json::json!("a@b.com"));
1046
1047            let mut filter2 = serde_json::Map::new();
1048            filter2.insert("email".to_string(), serde_json::json!("c@d.com"));
1049
1050            let bulk = bulk_upsert()
1051                .ordered(false)
1052                .add(filter1, serde_json::json!({"$set": {"name": "A"}}))
1053                .add(filter2, serde_json::json!({"$set": {"name": "B"}}));
1054
1055            let doc = bulk.to_bulk_write();
1056            assert!(!doc["options"]["ordered"].as_bool().unwrap());
1057            assert_eq!(doc["operations"].as_array().unwrap().len(), 2);
1058        }
1059    }
1060}
1061
1062