Skip to main content

dbkit_core/
mutation.rs

1use std::marker::PhantomData;
2
3use crate::compile::{CompiledSql, SqlBuilder, ToSql};
4use crate::expr::{ColumnValue, Expr, Value};
5use crate::schema::{Column, ColumnRef, Table};
6
7#[derive(Debug, Clone)]
8pub struct Insert<Out> {
9    table: Table,
10    columns: Vec<ColumnRef>,
11    values: Vec<Value>,
12    row_count: usize,
13    mode: InsertMode,
14    conflict: Option<InsertConflict>,
15    returning: Option<Vec<ColumnRef>>,
16    returning_all: bool,
17    _marker: PhantomData<Out>,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21enum InsertMode {
22    Unset,
23    Values,
24    Rows,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq)]
28enum InsertConflict {
29    DoNothing { target: Vec<ColumnRef> },
30    DoUpdate { target: Vec<ColumnRef>, updates: Vec<ColumnRef> },
31}
32
33mod private {
34    pub trait Sealed {}
35}
36
37pub trait ConflictColumns<M>: private::Sealed {
38    fn into_columns(self) -> Vec<ColumnRef>;
39}
40
41impl<M, T> ConflictColumns<M> for Column<M, T> {
42    fn into_columns(self) -> Vec<ColumnRef> {
43        vec![self.as_ref()]
44    }
45}
46impl<M, T> private::Sealed for Column<M, T> {}
47
48macro_rules! impl_conflict_columns_tuple {
49    ($(($($ty:ident:$col:ident),+)),+ $(,)?) => {
50        $(
51            impl<M, $($ty),+> ConflictColumns<M> for ($(Column<M, $ty>,)+) {
52                fn into_columns(self) -> Vec<ColumnRef> {
53                    let ($($col,)+) = self;
54                    vec![$($col.as_ref()),+]
55                }
56            }
57
58            impl<M, $($ty),+> private::Sealed for ($(Column<M, $ty>,)+) {}
59        )+
60    };
61}
62
63impl_conflict_columns_tuple!(
64    (T1:c1, T2:c2),
65    (T1:c1, T2:c2, T3:c3),
66    (T1:c1, T2:c2, T3:c3, T4:c4),
67    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5),
68    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6),
69    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7),
70    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8),
71    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9),
72    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10),
73    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11),
74    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12),
75    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13),
76    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14),
77    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15),
78    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16),
79    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17),
80    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18),
81    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19),
82    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20),
83    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21),
84    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22),
85    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23),
86    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24),
87    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25),
88    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26),
89    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27),
90    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28),
91    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28, T29:c29),
92    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28, T29:c29, T30:c30),
93    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28, T29:c29, T30:c30, T31:c31),
94    (T1:c1, T2:c2, T3:c3, T4:c4, T5:c5, T6:c6, T7:c7, T8:c8, T9:c9, T10:c10, T11:c11, T12:c12, T13:c13, T14:c14, T15:c15, T16:c16, T17:c17, T18:c18, T19:c19, T20:c20, T21:c21, T22:c22, T23:c23, T24:c24, T25:c25, T26:c26, T27:c27, T28:c28, T29:c29, T30:c30, T31:c31, T32:c32)
95);
96
97impl<Out> Insert<Out> {
98    pub fn new(table: Table) -> Self {
99        Self {
100            table,
101            columns: Vec::new(),
102            values: Vec::new(),
103            row_count: 0,
104            mode: InsertMode::Unset,
105            conflict: None,
106            returning: None,
107            returning_all: false,
108            _marker: PhantomData,
109        }
110    }
111
112    pub fn value<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
113    where
114        V: ColumnValue<T>,
115    {
116        if self.mode == InsertMode::Rows {
117            panic!("dbkit: cannot use value() after row()");
118        }
119        self.mode = InsertMode::Values;
120        if self.row_count == 0 {
121            self.row_count = 1;
122        }
123        let value = match value.into_value() {
124            Some(value) => value,
125            None => Value::Null,
126        };
127        self.columns.push(column.as_ref());
128        self.values.push(value);
129        self
130    }
131
132    pub fn row<F>(mut self, build: F) -> Self
133    where
134        F: FnOnce(InsertRow) -> InsertRow,
135    {
136        if self.mode == InsertMode::Values {
137            self.mode = InsertMode::Rows;
138        }
139        if self.mode == InsertMode::Unset {
140            self.mode = InsertMode::Rows;
141        }
142
143        let expected = if self.columns.is_empty() {
144            None
145        } else {
146            Some(self.columns.clone())
147        };
148        let row = build(InsertRow::new(expected));
149        if self.columns.is_empty() {
150            self.columns = row.columns.clone();
151        } else if row.columns != self.columns {
152            panic!("dbkit: insert row columns must match");
153        }
154        if row.values.len() != self.columns.len() {
155            panic!("dbkit: insert row value count mismatch");
156        }
157
158        if self.row_count == 0 && !self.values.is_empty() {
159            self.row_count = 1;
160        }
161        self.values.extend(row.values);
162        self.row_count += 1;
163        self
164    }
165
166    pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
167        self.returning = Some(columns);
168        self.returning_all = false;
169        self
170    }
171
172    pub fn returning_all(mut self) -> Self {
173        self.returning = None;
174        self.returning_all = true;
175        self
176    }
177
178    pub fn on_conflict_do_nothing<M, C>(mut self, target: C) -> Self
179    where
180        C: ConflictColumns<M>,
181    {
182        self.conflict = Some(InsertConflict::DoNothing {
183            target: target.into_columns(),
184        });
185        self
186    }
187
188    pub fn on_conflict_do_update<M, C, U>(mut self, target: C, updates: U) -> Self
189    where
190        C: ConflictColumns<M>,
191        U: ConflictColumns<M>,
192    {
193        self.conflict = Some(InsertConflict::DoUpdate {
194            target: target.into_columns(),
195            updates: updates.into_columns(),
196        });
197        self
198    }
199
200    pub fn compile(&self) -> CompiledSql {
201        let mut builder = SqlBuilder::new();
202        builder.push_sql("INSERT INTO ");
203        builder.push_sql(&self.table.qualified_name());
204        builder.push_sql(" (");
205        for (idx, col) in self.columns.iter().enumerate() {
206            if idx > 0 {
207                builder.push_sql(", ");
208            }
209            builder.push_sql(col.name);
210        }
211        builder.push_sql(") VALUES (");
212        let row_len = self.columns.len();
213        let row_count = if self.row_count == 0 && !self.values.is_empty() {
214            1
215        } else {
216            self.row_count
217        };
218        if row_count == 0 {
219            builder.push_sql(")");
220        } else {
221            for row_idx in 0..row_count {
222                if row_idx > 0 {
223                    builder.push_sql(", (");
224                }
225                for col_idx in 0..row_len {
226                    if col_idx > 0 {
227                        builder.push_sql(", ");
228                    }
229                    let value = self.values[row_idx * row_len + col_idx].clone();
230                    builder.push_value(value);
231                }
232                builder.push_sql(")");
233            }
234        }
235        if let Some(conflict) = &self.conflict {
236            let target = match conflict {
237                InsertConflict::DoNothing { target } => target,
238                InsertConflict::DoUpdate { target, .. } => target,
239            };
240            builder.push_sql(" ON CONFLICT (");
241            for (idx, col) in target.iter().enumerate() {
242                if idx > 0 {
243                    builder.push_sql(", ");
244                }
245                builder.push_sql(col.name);
246            }
247            builder.push_sql(")");
248
249            match conflict {
250                InsertConflict::DoNothing { .. } => {
251                    builder.push_sql(" DO NOTHING");
252                }
253                InsertConflict::DoUpdate { updates, .. } => {
254                    builder.push_sql(" DO UPDATE SET ");
255                    for (idx, col) in updates.iter().enumerate() {
256                        if idx > 0 {
257                            builder.push_sql(", ");
258                        }
259                        builder.push_sql(col.name);
260                        builder.push_sql(" = EXCLUDED.");
261                        builder.push_sql(col.name);
262                    }
263                }
264            }
265        }
266        if self.returning_all {
267            builder.push_sql(" RETURNING ");
268            builder.push_sql(self.table.qualifier());
269            builder.push_sql(".*");
270        } else if let Some(columns) = &self.returning {
271            builder.push_sql(" RETURNING ");
272            for (idx, col) in columns.iter().enumerate() {
273                if idx > 0 {
274                    builder.push_sql(", ");
275                }
276                builder.push_column(*col);
277            }
278        }
279        builder.finish()
280    }
281}
282
283#[derive(Debug, Clone)]
284pub struct InsertRow {
285    columns: Vec<ColumnRef>,
286    values: Vec<Value>,
287    expected: Option<Vec<ColumnRef>>,
288}
289
290impl InsertRow {
291    fn new(expected: Option<Vec<ColumnRef>>) -> Self {
292        let columns = expected.clone().unwrap_or_default();
293        Self {
294            columns,
295            values: Vec::new(),
296            expected,
297        }
298    }
299
300    pub fn value<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
301    where
302        V: ColumnValue<T>,
303    {
304        let column_ref = column.as_ref();
305        if let Some(expected) = &self.expected {
306            let idx = self.values.len();
307            if idx >= expected.len() {
308                panic!("dbkit: insert row has too many values");
309            }
310            if expected[idx] != column_ref {
311                panic!("dbkit: insert row column mismatch");
312            }
313        } else {
314            self.columns.push(column_ref);
315        }
316
317        let value = match value.into_value() {
318            Some(value) => value,
319            None => Value::Null,
320        };
321        self.values.push(value);
322        self
323    }
324}
325
326#[derive(Debug, Clone)]
327pub struct Update<Out> {
328    table: Table,
329    sets: Vec<(ColumnRef, Value)>,
330    filters: Vec<Expr<bool>>,
331    returning: Option<Vec<ColumnRef>>,
332    returning_all: bool,
333    _marker: PhantomData<Out>,
334}
335
336impl<Out> Update<Out> {
337    pub fn new(table: Table) -> Self {
338        Self {
339            table,
340            sets: Vec::new(),
341            filters: Vec::new(),
342            returning: None,
343            returning_all: false,
344            _marker: PhantomData,
345        }
346    }
347
348    pub fn set<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
349    where
350        V: ColumnValue<T>,
351    {
352        let value = match value.into_value() {
353            Some(value) => value,
354            None => Value::Null,
355        };
356        self.sets.push((column.as_ref(), value));
357        self
358    }
359
360    pub fn filter(mut self, expr: Expr<bool>) -> Self {
361        self.filters.push(expr);
362        self
363    }
364
365    pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
366        self.returning = Some(columns);
367        self.returning_all = false;
368        self
369    }
370
371    pub fn returning_all(mut self) -> Self {
372        self.returning = None;
373        self.returning_all = true;
374        self
375    }
376
377    pub fn compile(&self) -> CompiledSql {
378        let mut builder = SqlBuilder::new();
379        builder.push_sql("UPDATE ");
380        builder.push_sql(&self.table.qualified_name());
381        builder.push_sql(" SET ");
382        for (idx, (col, value)) in self.sets.iter().enumerate() {
383            if idx > 0 {
384                builder.push_sql(", ");
385            }
386            builder.push_sql(col.name);
387            builder.push_sql(" = ");
388            builder.push_value(value.clone());
389        }
390        if !self.filters.is_empty() {
391            builder.push_sql(" WHERE ");
392            for (idx, expr) in self.filters.iter().enumerate() {
393                if idx > 0 {
394                    builder.push_sql(" AND ");
395                }
396                expr.node.to_sql(&mut builder);
397            }
398        }
399        if self.returning_all {
400            builder.push_sql(" RETURNING ");
401            builder.push_sql(self.table.qualifier());
402            builder.push_sql(".*");
403        } else if let Some(columns) = &self.returning {
404            builder.push_sql(" RETURNING ");
405            for (idx, col) in columns.iter().enumerate() {
406                if idx > 0 {
407                    builder.push_sql(", ");
408                }
409                builder.push_column(*col);
410            }
411        }
412        builder.finish()
413    }
414}
415
416#[derive(Debug, Clone)]
417pub struct Delete {
418    table: Table,
419    filters: Vec<Expr<bool>>,
420    returning: Option<Vec<ColumnRef>>,
421    returning_all: bool,
422}
423
424impl Delete {
425    pub fn new(table: Table) -> Self {
426        Self {
427            table,
428            filters: Vec::new(),
429            returning: None,
430            returning_all: false,
431        }
432    }
433
434    pub fn filter(mut self, expr: Expr<bool>) -> Self {
435        self.filters.push(expr);
436        self
437    }
438
439    pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
440        self.returning = Some(columns);
441        self.returning_all = false;
442        self
443    }
444
445    pub fn returning_all(mut self) -> Self {
446        self.returning = None;
447        self.returning_all = true;
448        self
449    }
450
451    pub fn compile(&self) -> CompiledSql {
452        let mut builder = SqlBuilder::new();
453        builder.push_sql("DELETE FROM ");
454        builder.push_sql(&self.table.qualified_name());
455        if !self.filters.is_empty() {
456            builder.push_sql(" WHERE ");
457            for (idx, expr) in self.filters.iter().enumerate() {
458                if idx > 0 {
459                    builder.push_sql(" AND ");
460                }
461                expr.node.to_sql(&mut builder);
462            }
463        }
464        if self.returning_all {
465            builder.push_sql(" RETURNING ");
466            builder.push_sql(self.table.qualifier());
467            builder.push_sql(".*");
468        } else if let Some(columns) = &self.returning {
469            builder.push_sql(" RETURNING ");
470            for (idx, col) in columns.iter().enumerate() {
471                if idx > 0 {
472                    builder.push_sql(", ");
473                }
474                builder.push_column(*col);
475            }
476        }
477        builder.finish()
478    }
479}