1use std::marker::PhantomData;
2
3use crate::compile::{CompiledSql, SqlBuilder, ToSql};
4use crate::expr::{ColumnValue, Expr, Value};
5use crate::schema::{Column, ColumnRef, Table};
6
7#[derive(Debug, Clone)]
8pub struct Insert<Out> {
9 table: Table,
10 columns: Vec<ColumnRef>,
11 values: Vec<Value>,
12 row_count: usize,
13 mode: InsertMode,
14 conflict: Option<InsertConflict>,
15 returning: Option<Vec<ColumnRef>>,
16 returning_all: bool,
17 _marker: PhantomData<Out>,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21enum InsertMode {
22 Unset,
23 Values,
24 Rows,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
28enum InsertConflict {
29 DoNothing { target: Vec<ColumnRef> },
30 DoUpdate { target: Vec<ColumnRef>, updates: Vec<ColumnRef> },
31}
32
33mod private {
34 pub trait Sealed {}
35}
36
37pub trait ConflictColumns<M>: private::Sealed {
38 fn into_columns(self) -> Vec<ColumnRef>;
39}
40
41impl<M, T> ConflictColumns<M> for Column<M, T> {
42 fn into_columns(self) -> Vec<ColumnRef> {
43 vec![self.as_ref()]
44 }
45}
46impl<M, T> private::Sealed for Column<M, T> {}
47
48macro_rules! impl_conflict_columns_tuple {
49 ($(($($ty:ident:$col:ident),+)),+ $(,)?) => {
50 $(
51 impl<M, $($ty),+> ConflictColumns<M> for ($(Column<M, $ty>,)+) {
52 fn into_columns(self) -> Vec<ColumnRef> {
53 let ($($col,)+) = self;
54 vec![$($col.as_ref()),+]
55 }
56 }
57
58 impl<M, $($ty),+> private::Sealed for ($(Column<M, $ty>,)+) {}
59 )+
60 };
61}
62
63impl_conflict_columns_tuple!(
64 (T1:c1, T2:c2),
65 (T1:c1, T2:c2, T3:c3),
66 (T1:c1, T2:c2, T3:c3, T4:c4),
67 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5),
68 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6),
69 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7),
70 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8),
71 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9),
72 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10),
73 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11),
74 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12),
75 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13),
76 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14),
77 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15),
78 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16),
79 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17),
80 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18),
81 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19),
82 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20),
83 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21),
84 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22),
85 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23),
86 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24),
87 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25),
88 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26),
89 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27),
90 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28),
91 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28, T29:c29),
92 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28, T29:c29, T30:c30),
93 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28, T29:c29, T30:c30, T31:c31),
94 (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28, T29:c29, T30:c30, T31:c31, T32:c32)
95);
96
97impl<Out> Insert<Out> {
98 pub fn new(table: Table) -> Self {
99 Self {
100 table,
101 columns: Vec::new(),
102 values: Vec::new(),
103 row_count: 0,
104 mode: InsertMode::Unset,
105 conflict: None,
106 returning: None,
107 returning_all: false,
108 _marker: PhantomData,
109 }
110 }
111
112 pub fn value<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
113 where
114 V: ColumnValue<T>,
115 {
116 if self.mode == InsertMode::Rows {
117 panic!("dbkit: cannot use value() after row()");
118 }
119 self.mode = InsertMode::Values;
120 if self.row_count == 0 {
121 self.row_count = 1;
122 }
123 let value = match value.into_value() {
124 Some(value) => value,
125 None => Value::Null,
126 };
127 self.columns.push(column.as_ref());
128 self.values.push(value);
129 self
130 }
131
132 pub fn row<F>(mut self, build: F) -> Self
133 where
134 F: FnOnce(InsertRow) -> InsertRow,
135 {
136 if self.mode == InsertMode::Values {
137 self.mode = InsertMode::Rows;
138 }
139 if self.mode == InsertMode::Unset {
140 self.mode = InsertMode::Rows;
141 }
142
143 let expected = if self.columns.is_empty() {
144 None
145 } else {
146 Some(self.columns.clone())
147 };
148 let row = build(InsertRow::new(expected));
149 if self.columns.is_empty() {
150 self.columns = row.columns.clone();
151 } else if row.columns != self.columns {
152 panic!("dbkit: insert row columns must match");
153 }
154 if row.values.len() != self.columns.len() {
155 panic!("dbkit: insert row value count mismatch");
156 }
157
158 if self.row_count == 0 && !self.values.is_empty() {
159 self.row_count = 1;
160 }
161 self.values.extend(row.values);
162 self.row_count += 1;
163 self
164 }
165
166 pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
167 self.returning = Some(columns);
168 self.returning_all = false;
169 self
170 }
171
172 pub fn returning_all(mut self) -> Self {
173 self.returning = None;
174 self.returning_all = true;
175 self
176 }
177
178 pub fn on_conflict_do_nothing<M, C>(mut self, target: C) -> Self
179 where
180 C: ConflictColumns<M>,
181 {
182 self.conflict = Some(InsertConflict::DoNothing {
183 target: target.into_columns(),
184 });
185 self
186 }
187
188 pub fn on_conflict_do_update<M, C, U>(mut self, target: C, updates: U) -> Self
189 where
190 C: ConflictColumns<M>,
191 U: ConflictColumns<M>,
192 {
193 self.conflict = Some(InsertConflict::DoUpdate {
194 target: target.into_columns(),
195 updates: updates.into_columns(),
196 });
197 self
198 }
199
200 pub fn compile(&self) -> CompiledSql {
201 let mut builder = SqlBuilder::new();
202 builder.push_sql("INSERT INTO ");
203 builder.push_sql(&self.table.qualified_name());
204 builder.push_sql(" (");
205 for (idx, col) in self.columns.iter().enumerate() {
206 if idx > 0 {
207 builder.push_sql(", ");
208 }
209 builder.push_sql(col.name);
210 }
211 builder.push_sql(") VALUES (");
212 let row_len = self.columns.len();
213 let row_count = if self.row_count == 0 && !self.values.is_empty() {
214 1
215 } else {
216 self.row_count
217 };
218 if row_count == 0 {
219 builder.push_sql(")");
220 } else {
221 for row_idx in 0..row_count {
222 if row_idx > 0 {
223 builder.push_sql(", (");
224 }
225 for col_idx in 0..row_len {
226 if col_idx > 0 {
227 builder.push_sql(", ");
228 }
229 let value = self.values[row_idx * row_len + col_idx].clone();
230 builder.push_value(value);
231 }
232 builder.push_sql(")");
233 }
234 }
235 if let Some(conflict) = &self.conflict {
236 let target = match conflict {
237 InsertConflict::DoNothing { target } => target,
238 InsertConflict::DoUpdate { target, .. } => target,
239 };
240 builder.push_sql(" ON CONFLICT (");
241 for (idx, col) in target.iter().enumerate() {
242 if idx > 0 {
243 builder.push_sql(", ");
244 }
245 builder.push_sql(col.name);
246 }
247 builder.push_sql(")");
248
249 match conflict {
250 InsertConflict::DoNothing { .. } => {
251 builder.push_sql(" DO NOTHING");
252 }
253 InsertConflict::DoUpdate { updates, .. } => {
254 builder.push_sql(" DO UPDATE SET ");
255 for (idx, col) in updates.iter().enumerate() {
256 if idx > 0 {
257 builder.push_sql(", ");
258 }
259 builder.push_sql(col.name);
260 builder.push_sql(" = EXCLUDED.");
261 builder.push_sql(col.name);
262 }
263 }
264 }
265 }
266 if self.returning_all {
267 builder.push_sql(" RETURNING ");
268 builder.push_sql(self.table.qualifier());
269 builder.push_sql(".*");
270 } else if let Some(columns) = &self.returning {
271 builder.push_sql(" RETURNING ");
272 for (idx, col) in columns.iter().enumerate() {
273 if idx > 0 {
274 builder.push_sql(", ");
275 }
276 builder.push_column(*col);
277 }
278 }
279 builder.finish()
280 }
281}
282
283#[derive(Debug, Clone)]
284pub struct InsertRow {
285 columns: Vec<ColumnRef>,
286 values: Vec<Value>,
287 expected: Option<Vec<ColumnRef>>,
288}
289
290impl InsertRow {
291 fn new(expected: Option<Vec<ColumnRef>>) -> Self {
292 let columns = expected.clone().unwrap_or_default();
293 Self {
294 columns,
295 values: Vec::new(),
296 expected,
297 }
298 }
299
300 pub fn value<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
301 where
302 V: ColumnValue<T>,
303 {
304 let column_ref = column.as_ref();
305 if let Some(expected) = &self.expected {
306 let idx = self.values.len();
307 if idx >= expected.len() {
308 panic!("dbkit: insert row has too many values");
309 }
310 if expected[idx] != column_ref {
311 panic!("dbkit: insert row column mismatch");
312 }
313 } else {
314 self.columns.push(column_ref);
315 }
316
317 let value = match value.into_value() {
318 Some(value) => value,
319 None => Value::Null,
320 };
321 self.values.push(value);
322 self
323 }
324}
325
326#[derive(Debug, Clone)]
327pub struct Update<Out> {
328 table: Table,
329 sets: Vec<(ColumnRef, Value)>,
330 filters: Vec<Expr<bool>>,
331 returning: Option<Vec<ColumnRef>>,
332 returning_all: bool,
333 _marker: PhantomData<Out>,
334}
335
336impl<Out> Update<Out> {
337 pub fn new(table: Table) -> Self {
338 Self {
339 table,
340 sets: Vec::new(),
341 filters: Vec::new(),
342 returning: None,
343 returning_all: false,
344 _marker: PhantomData,
345 }
346 }
347
348 pub fn set<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
349 where
350 V: ColumnValue<T>,
351 {
352 let value = match value.into_value() {
353 Some(value) => value,
354 None => Value::Null,
355 };
356 self.sets.push((column.as_ref(), value));
357 self
358 }
359
360 pub fn filter(mut self, expr: Expr<bool>) -> Self {
361 self.filters.push(expr);
362 self
363 }
364
365 pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
366 self.returning = Some(columns);
367 self.returning_all = false;
368 self
369 }
370
371 pub fn returning_all(mut self) -> Self {
372 self.returning = None;
373 self.returning_all = true;
374 self
375 }
376
377 pub fn compile(&self) -> CompiledSql {
378 let mut builder = SqlBuilder::new();
379 builder.push_sql("UPDATE ");
380 builder.push_sql(&self.table.qualified_name());
381 builder.push_sql(" SET ");
382 for (idx, (col, value)) in self.sets.iter().enumerate() {
383 if idx > 0 {
384 builder.push_sql(", ");
385 }
386 builder.push_sql(col.name);
387 builder.push_sql(" = ");
388 builder.push_value(value.clone());
389 }
390 if !self.filters.is_empty() {
391 builder.push_sql(" WHERE ");
392 for (idx, expr) in self.filters.iter().enumerate() {
393 if idx > 0 {
394 builder.push_sql(" AND ");
395 }
396 expr.node.to_sql(&mut builder);
397 }
398 }
399 if self.returning_all {
400 builder.push_sql(" RETURNING ");
401 builder.push_sql(self.table.qualifier());
402 builder.push_sql(".*");
403 } else if let Some(columns) = &self.returning {
404 builder.push_sql(" RETURNING ");
405 for (idx, col) in columns.iter().enumerate() {
406 if idx > 0 {
407 builder.push_sql(", ");
408 }
409 builder.push_column(*col);
410 }
411 }
412 builder.finish()
413 }
414}
415
416#[derive(Debug, Clone)]
417pub struct Delete {
418 table: Table,
419 filters: Vec<Expr<bool>>,
420 returning: Option<Vec<ColumnRef>>,
421 returning_all: bool,
422}
423
424impl Delete {
425 pub fn new(table: Table) -> Self {
426 Self {
427 table,
428 filters: Vec::new(),
429 returning: None,
430 returning_all: false,
431 }
432 }
433
434 pub fn filter(mut self, expr: Expr<bool>) -> Self {
435 self.filters.push(expr);
436 self
437 }
438
439 pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
440 self.returning = Some(columns);
441 self.returning_all = false;
442 self
443 }
444
445 pub fn returning_all(mut self) -> Self {
446 self.returning = None;
447 self.returning_all = true;
448 self
449 }
450
451 pub fn compile(&self) -> CompiledSql {
452 let mut builder = SqlBuilder::new();
453 builder.push_sql("DELETE FROM ");
454 builder.push_sql(&self.table.qualified_name());
455 if !self.filters.is_empty() {
456 builder.push_sql(" WHERE ");
457 for (idx, expr) in self.filters.iter().enumerate() {
458 if idx > 0 {
459 builder.push_sql(" AND ");
460 }
461 expr.node.to_sql(&mut builder);
462 }
463 }
464 if self.returning_all {
465 builder.push_sql(" RETURNING ");
466 builder.push_sql(self.table.qualifier());
467 builder.push_sql(".*");
468 } else if let Some(columns) = &self.returning {
469 builder.push_sql(" RETURNING ");
470 for (idx, col) in columns.iter().enumerate() {
471 if idx > 0 {
472 builder.push_sql(", ");
473 }
474 builder.push_column(*col);
475 }
476 }
477 builder.finish()
478 }
479}