postrust_sql/
insert.rs

1//! INSERT statement builder.
2
3use crate::{
4    builder::SqlFragment,
5    identifier::{escape_ident, from_qi, QualifiedIdentifier},
6    param::SqlParam,
7};
8
9/// Builder for INSERT statements.
10#[derive(Clone, Debug, Default)]
11pub struct InsertBuilder {
12    table: Option<SqlFragment>,
13    columns: Vec<String>,
14    values: Vec<Vec<SqlFragment>>,
15    on_conflict: Option<OnConflict>,
16    returning: Vec<SqlFragment>,
17}
18
19#[derive(Clone, Debug)]
20pub enum OnConflict {
21    DoNothing,
22    DoUpdate {
23        columns: Vec<String>,
24        set: Vec<(String, SqlFragment)>,
25        where_clause: Option<SqlFragment>,
26    },
27}
28
29impl InsertBuilder {
30    /// Create a new INSERT builder.
31    pub fn new() -> Self {
32        Self::default()
33    }
34
35    /// Set the target table.
36    pub fn into_table(mut self, qi: &QualifiedIdentifier) -> Self {
37        self.table = Some(SqlFragment::raw(from_qi(qi)));
38        self
39    }
40
41    /// Set the columns to insert.
42    pub fn columns(mut self, cols: Vec<String>) -> Self {
43        self.columns = cols;
44        self
45    }
46
47    /// Add a row of values.
48    pub fn values(mut self, vals: Vec<SqlParam>) -> Self {
49        let row: Vec<SqlFragment> = vals
50            .into_iter()
51            .map(|v| {
52                let mut frag = SqlFragment::new();
53                frag.push_param(v);
54                frag
55            })
56            .collect();
57        self.values.push(row);
58        self
59    }
60
61    /// Add a row of raw SQL values.
62    pub fn values_raw(mut self, vals: Vec<SqlFragment>) -> Self {
63        self.values.push(vals);
64        self
65    }
66
67    /// Set ON CONFLICT DO NOTHING.
68    pub fn on_conflict_do_nothing(mut self) -> Self {
69        self.on_conflict = Some(OnConflict::DoNothing);
70        self
71    }
72
73    /// Set ON CONFLICT DO UPDATE.
74    pub fn on_conflict_do_update(
75        mut self,
76        conflict_columns: Vec<String>,
77        set: Vec<(String, SqlFragment)>,
78    ) -> Self {
79        self.on_conflict = Some(OnConflict::DoUpdate {
80            columns: conflict_columns,
81            set,
82            where_clause: None,
83        });
84        self
85    }
86
87    /// Add RETURNING clause.
88    pub fn returning(mut self, column: &str) -> Self {
89        self.returning
90            .push(SqlFragment::raw(escape_ident(column)));
91        self
92    }
93
94    /// Add RETURNING * clause.
95    pub fn returning_all(mut self) -> Self {
96        self.returning.push(SqlFragment::raw("*"));
97        self
98    }
99
100    /// Build the INSERT statement.
101    pub fn build(self) -> SqlFragment {
102        let mut result = SqlFragment::new();
103
104        result.push("INSERT INTO ");
105
106        if let Some(table) = self.table {
107            result.append(table);
108        }
109
110        // Columns
111        if !self.columns.is_empty() {
112            result.push(" (");
113            for (i, col) in self.columns.iter().enumerate() {
114                if i > 0 {
115                    result.push(", ");
116                }
117                result.push(&escape_ident(col));
118            }
119            result.push(")");
120        }
121
122        // VALUES
123        if !self.values.is_empty() {
124            result.push(" VALUES ");
125            for (i, row) in self.values.into_iter().enumerate() {
126                if i > 0 {
127                    result.push(", ");
128                }
129                result.push("(");
130                for (j, val) in row.into_iter().enumerate() {
131                    if j > 0 {
132                        result.push(", ");
133                    }
134                    result.append(val);
135                }
136                result.push(")");
137            }
138        } else {
139            result.push(" DEFAULT VALUES");
140        }
141
142        // ON CONFLICT
143        if let Some(conflict) = self.on_conflict {
144            match conflict {
145                OnConflict::DoNothing => {
146                    result.push(" ON CONFLICT DO NOTHING");
147                }
148                OnConflict::DoUpdate {
149                    columns,
150                    set,
151                    where_clause,
152                } => {
153                    result.push(" ON CONFLICT (");
154                    for (i, col) in columns.iter().enumerate() {
155                        if i > 0 {
156                            result.push(", ");
157                        }
158                        result.push(&escape_ident(col));
159                    }
160                    result.push(") DO UPDATE SET ");
161                    for (i, (col, val)) in set.into_iter().enumerate() {
162                        if i > 0 {
163                            result.push(", ");
164                        }
165                        result.push(&escape_ident(&col));
166                        result.push(" = ");
167                        result.append(val);
168                    }
169                    if let Some(where_sql) = where_clause {
170                        result.push(" WHERE ");
171                        result.append(where_sql);
172                    }
173                }
174            }
175        }
176
177        // RETURNING
178        if !self.returning.is_empty() {
179            result.push(" RETURNING ");
180            for (i, ret) in self.returning.into_iter().enumerate() {
181                if i > 0 {
182                    result.push(", ");
183                }
184                result.append(ret);
185            }
186        }
187
188        result
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn test_simple_insert() {
198        let qi = QualifiedIdentifier::new("public", "users");
199        let sql = InsertBuilder::new()
200            .into_table(&qi)
201            .columns(vec!["name".into(), "email".into()])
202            .values(vec![SqlParam::text("John"), SqlParam::text("john@example.com")])
203            .build();
204
205        assert!(sql.sql().contains("INSERT INTO"));
206        assert!(sql.sql().contains("VALUES"));
207        assert_eq!(sql.params().len(), 2);
208    }
209
210    #[test]
211    fn test_insert_returning() {
212        let qi = QualifiedIdentifier::unqualified("users");
213        let sql = InsertBuilder::new()
214            .into_table(&qi)
215            .columns(vec!["name".into()])
216            .values(vec![SqlParam::text("John")])
217            .returning("id")
218            .build();
219
220        assert!(sql.sql().contains("RETURNING"));
221    }
222
223    #[test]
224    fn test_insert_on_conflict_nothing() {
225        let qi = QualifiedIdentifier::unqualified("users");
226        let sql = InsertBuilder::new()
227            .into_table(&qi)
228            .columns(vec!["email".into()])
229            .values(vec![SqlParam::text("john@example.com")])
230            .on_conflict_do_nothing()
231            .build();
232
233        assert!(sql.sql().contains("ON CONFLICT DO NOTHING"));
234    }
235
236    #[test]
237    fn test_insert_upsert() {
238        let qi = QualifiedIdentifier::unqualified("users");
239        let mut name_val = SqlFragment::new();
240        name_val.push("EXCLUDED.\"name\"");
241
242        let sql = InsertBuilder::new()
243            .into_table(&qi)
244            .columns(vec!["id".into(), "name".into()])
245            .values(vec![SqlParam::Int(1), SqlParam::text("John")])
246            .on_conflict_do_update(vec!["id".into()], vec![("name".into(), name_val)])
247            .build();
248
249        assert!(sql.sql().contains("ON CONFLICT"));
250        assert!(sql.sql().contains("DO UPDATE SET"));
251    }
252}