1use crate::{
4 builder::SqlFragment,
5 identifier::{escape_ident, from_qi, QualifiedIdentifier},
6 param::SqlParam,
7};
8
9#[derive(Clone, Debug, Default)]
11pub struct InsertBuilder {
12 table: Option<SqlFragment>,
13 columns: Vec<String>,
14 values: Vec<Vec<SqlFragment>>,
15 on_conflict: Option<OnConflict>,
16 returning: Vec<SqlFragment>,
17}
18
19#[derive(Clone, Debug)]
20pub enum OnConflict {
21 DoNothing,
22 DoUpdate {
23 columns: Vec<String>,
24 set: Vec<(String, SqlFragment)>,
25 where_clause: Option<SqlFragment>,
26 },
27}
28
29impl InsertBuilder {
30 pub fn new() -> Self {
32 Self::default()
33 }
34
35 pub fn into_table(mut self, qi: &QualifiedIdentifier) -> Self {
37 self.table = Some(SqlFragment::raw(from_qi(qi)));
38 self
39 }
40
41 pub fn columns(mut self, cols: Vec<String>) -> Self {
43 self.columns = cols;
44 self
45 }
46
47 pub fn values(mut self, vals: Vec<SqlParam>) -> Self {
49 let row: Vec<SqlFragment> = vals
50 .into_iter()
51 .map(|v| {
52 let mut frag = SqlFragment::new();
53 frag.push_param(v);
54 frag
55 })
56 .collect();
57 self.values.push(row);
58 self
59 }
60
61 pub fn values_raw(mut self, vals: Vec<SqlFragment>) -> Self {
63 self.values.push(vals);
64 self
65 }
66
67 pub fn on_conflict_do_nothing(mut self) -> Self {
69 self.on_conflict = Some(OnConflict::DoNothing);
70 self
71 }
72
73 pub fn on_conflict_do_update(
75 mut self,
76 conflict_columns: Vec<String>,
77 set: Vec<(String, SqlFragment)>,
78 ) -> Self {
79 self.on_conflict = Some(OnConflict::DoUpdate {
80 columns: conflict_columns,
81 set,
82 where_clause: None,
83 });
84 self
85 }
86
87 pub fn returning(mut self, column: &str) -> Self {
89 self.returning
90 .push(SqlFragment::raw(escape_ident(column)));
91 self
92 }
93
94 pub fn returning_all(mut self) -> Self {
96 self.returning.push(SqlFragment::raw("*"));
97 self
98 }
99
100 pub fn build(self) -> SqlFragment {
102 let mut result = SqlFragment::new();
103
104 result.push("INSERT INTO ");
105
106 if let Some(table) = self.table {
107 result.append(table);
108 }
109
110 if !self.columns.is_empty() {
112 result.push(" (");
113 for (i, col) in self.columns.iter().enumerate() {
114 if i > 0 {
115 result.push(", ");
116 }
117 result.push(&escape_ident(col));
118 }
119 result.push(")");
120 }
121
122 if !self.values.is_empty() {
124 result.push(" VALUES ");
125 for (i, row) in self.values.into_iter().enumerate() {
126 if i > 0 {
127 result.push(", ");
128 }
129 result.push("(");
130 for (j, val) in row.into_iter().enumerate() {
131 if j > 0 {
132 result.push(", ");
133 }
134 result.append(val);
135 }
136 result.push(")");
137 }
138 } else {
139 result.push(" DEFAULT VALUES");
140 }
141
142 if let Some(conflict) = self.on_conflict {
144 match conflict {
145 OnConflict::DoNothing => {
146 result.push(" ON CONFLICT DO NOTHING");
147 }
148 OnConflict::DoUpdate {
149 columns,
150 set,
151 where_clause,
152 } => {
153 result.push(" ON CONFLICT (");
154 for (i, col) in columns.iter().enumerate() {
155 if i > 0 {
156 result.push(", ");
157 }
158 result.push(&escape_ident(col));
159 }
160 result.push(") DO UPDATE SET ");
161 for (i, (col, val)) in set.into_iter().enumerate() {
162 if i > 0 {
163 result.push(", ");
164 }
165 result.push(&escape_ident(&col));
166 result.push(" = ");
167 result.append(val);
168 }
169 if let Some(where_sql) = where_clause {
170 result.push(" WHERE ");
171 result.append(where_sql);
172 }
173 }
174 }
175 }
176
177 if !self.returning.is_empty() {
179 result.push(" RETURNING ");
180 for (i, ret) in self.returning.into_iter().enumerate() {
181 if i > 0 {
182 result.push(", ");
183 }
184 result.append(ret);
185 }
186 }
187
188 result
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195
196 #[test]
197 fn test_simple_insert() {
198 let qi = QualifiedIdentifier::new("public", "users");
199 let sql = InsertBuilder::new()
200 .into_table(&qi)
201 .columns(vec!["name".into(), "email".into()])
202 .values(vec![SqlParam::text("John"), SqlParam::text("john@example.com")])
203 .build();
204
205 assert!(sql.sql().contains("INSERT INTO"));
206 assert!(sql.sql().contains("VALUES"));
207 assert_eq!(sql.params().len(), 2);
208 }
209
210 #[test]
211 fn test_insert_returning() {
212 let qi = QualifiedIdentifier::unqualified("users");
213 let sql = InsertBuilder::new()
214 .into_table(&qi)
215 .columns(vec!["name".into()])
216 .values(vec![SqlParam::text("John")])
217 .returning("id")
218 .build();
219
220 assert!(sql.sql().contains("RETURNING"));
221 }
222
223 #[test]
224 fn test_insert_on_conflict_nothing() {
225 let qi = QualifiedIdentifier::unqualified("users");
226 let sql = InsertBuilder::new()
227 .into_table(&qi)
228 .columns(vec!["email".into()])
229 .values(vec![SqlParam::text("john@example.com")])
230 .on_conflict_do_nothing()
231 .build();
232
233 assert!(sql.sql().contains("ON CONFLICT DO NOTHING"));
234 }
235
236 #[test]
237 fn test_insert_upsert() {
238 let qi = QualifiedIdentifier::unqualified("users");
239 let mut name_val = SqlFragment::new();
240 name_val.push("EXCLUDED.\"name\"");
241
242 let sql = InsertBuilder::new()
243 .into_table(&qi)
244 .columns(vec!["id".into(), "name".into()])
245 .values(vec![SqlParam::Int(1), SqlParam::text("John")])
246 .on_conflict_do_update(vec!["id".into()], vec![("name".into(), name_val)])
247 .build();
248
249 assert!(sql.sql().contains("ON CONFLICT"));
250 assert!(sql.sql().contains("DO UPDATE SET"));
251 }
252}