dbkit_core/
mutation.rs

1use std::marker::PhantomData;
2
3use crate::compile::{CompiledSql, SqlBuilder, ToSql};
4use crate::expr::{ColumnValue, Expr, Value};
5use crate::schema::{Column, ColumnRef, Table};
6
7#[derive(Debug, Clone)]
8pub struct Insert<Out> {
9    table: Table,
10    columns: Vec<ColumnRef>,
11    values: Vec<Value>,
12    row_count: usize,
13    mode: InsertMode,
14    returning: Option<Vec<ColumnRef>>,
15    returning_all: bool,
16    _marker: PhantomData<Out>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20enum InsertMode {
21    Unset,
22    Values,
23    Rows,
24}
25
26impl<Out> Insert<Out> {
27    pub fn new(table: Table) -> Self {
28        Self {
29            table,
30            columns: Vec::new(),
31            values: Vec::new(),
32            row_count: 0,
33            mode: InsertMode::Unset,
34            returning: None,
35            returning_all: false,
36            _marker: PhantomData,
37        }
38    }
39
40    pub fn value<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
41    where
42        V: ColumnValue<T>,
43    {
44        if self.mode == InsertMode::Rows {
45            panic!("dbkit: cannot use value() after row()");
46        }
47        self.mode = InsertMode::Values;
48        if self.row_count == 0 {
49            self.row_count = 1;
50        }
51        let value = match value.into_value() {
52            Some(value) => value,
53            None => Value::Null,
54        };
55        self.columns.push(column.as_ref());
56        self.values.push(value);
57        self
58    }
59
60    pub fn row<F>(mut self, build: F) -> Self
61    where
62        F: FnOnce(InsertRow) -> InsertRow,
63    {
64        if self.mode == InsertMode::Values {
65            self.mode = InsertMode::Rows;
66        }
67        if self.mode == InsertMode::Unset {
68            self.mode = InsertMode::Rows;
69        }
70
71        let expected = if self.columns.is_empty() {
72            None
73        } else {
74            Some(self.columns.clone())
75        };
76        let row = build(InsertRow::new(expected));
77        if self.columns.is_empty() {
78            self.columns = row.columns.clone();
79        } else if row.columns != self.columns {
80            panic!("dbkit: insert row columns must match");
81        }
82        if row.values.len() != self.columns.len() {
83            panic!("dbkit: insert row value count mismatch");
84        }
85
86        if self.row_count == 0 && !self.values.is_empty() {
87            self.row_count = 1;
88        }
89        self.values.extend(row.values);
90        self.row_count += 1;
91        self
92    }
93
94    pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
95        self.returning = Some(columns);
96        self.returning_all = false;
97        self
98    }
99
100    pub fn returning_all(mut self) -> Self {
101        self.returning = None;
102        self.returning_all = true;
103        self
104    }
105
106    pub fn compile(&self) -> CompiledSql {
107        let mut builder = SqlBuilder::new();
108        builder.push_sql("INSERT INTO ");
109        builder.push_sql(&self.table.qualified_name());
110        builder.push_sql(" (");
111        for (idx, col) in self.columns.iter().enumerate() {
112            if idx > 0 {
113                builder.push_sql(", ");
114            }
115            builder.push_sql(col.name);
116        }
117        builder.push_sql(") VALUES (");
118        let row_len = self.columns.len();
119        let row_count = if self.row_count == 0 && !self.values.is_empty() {
120            1
121        } else {
122            self.row_count
123        };
124        if row_count == 0 {
125            builder.push_sql(")");
126        } else {
127            for row_idx in 0..row_count {
128                if row_idx > 0 {
129                    builder.push_sql(", (");
130                }
131                for col_idx in 0..row_len {
132                    if col_idx > 0 {
133                        builder.push_sql(", ");
134                    }
135                    let value = self.values[row_idx * row_len + col_idx].clone();
136                    builder.push_value(value);
137                }
138                builder.push_sql(")");
139            }
140        }
141        if self.returning_all {
142            builder.push_sql(" RETURNING ");
143            builder.push_sql(self.table.qualifier());
144            builder.push_sql(".*");
145        } else if let Some(columns) = &self.returning {
146            builder.push_sql(" RETURNING ");
147            for (idx, col) in columns.iter().enumerate() {
148                if idx > 0 {
149                    builder.push_sql(", ");
150                }
151                builder.push_column(*col);
152            }
153        }
154        builder.finish()
155    }
156}
157
158#[derive(Debug, Clone)]
159pub struct InsertRow {
160    columns: Vec<ColumnRef>,
161    values: Vec<Value>,
162    expected: Option<Vec<ColumnRef>>,
163}
164
165impl InsertRow {
166    fn new(expected: Option<Vec<ColumnRef>>) -> Self {
167        let columns = expected.clone().unwrap_or_default();
168        Self {
169            columns,
170            values: Vec::new(),
171            expected,
172        }
173    }
174
175    pub fn value<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
176    where
177        V: ColumnValue<T>,
178    {
179        let column_ref = column.as_ref();
180        if let Some(expected) = &self.expected {
181            let idx = self.values.len();
182            if idx >= expected.len() {
183                panic!("dbkit: insert row has too many values");
184            }
185            if expected[idx] != column_ref {
186                panic!("dbkit: insert row column mismatch");
187            }
188        } else {
189            self.columns.push(column_ref);
190        }
191
192        let value = match value.into_value() {
193            Some(value) => value,
194            None => Value::Null,
195        };
196        self.values.push(value);
197        self
198    }
199}
200
201#[derive(Debug, Clone)]
202pub struct Update<Out> {
203    table: Table,
204    sets: Vec<(ColumnRef, Value)>,
205    filters: Vec<Expr<bool>>,
206    returning: Option<Vec<ColumnRef>>,
207    returning_all: bool,
208    _marker: PhantomData<Out>,
209}
210
211impl<Out> Update<Out> {
212    pub fn new(table: Table) -> Self {
213        Self {
214            table,
215            sets: Vec::new(),
216            filters: Vec::new(),
217            returning: None,
218            returning_all: false,
219            _marker: PhantomData,
220        }
221    }
222
223    pub fn set<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
224    where
225        V: ColumnValue<T>,
226    {
227        let value = match value.into_value() {
228            Some(value) => value,
229            None => Value::Null,
230        };
231        self.sets.push((column.as_ref(), value));
232        self
233    }
234
235    pub fn filter(mut self, expr: Expr<bool>) -> Self {
236        self.filters.push(expr);
237        self
238    }
239
240    pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
241        self.returning = Some(columns);
242        self.returning_all = false;
243        self
244    }
245
246    pub fn returning_all(mut self) -> Self {
247        self.returning = None;
248        self.returning_all = true;
249        self
250    }
251
252    pub fn compile(&self) -> CompiledSql {
253        let mut builder = SqlBuilder::new();
254        builder.push_sql("UPDATE ");
255        builder.push_sql(&self.table.qualified_name());
256        builder.push_sql(" SET ");
257        for (idx, (col, value)) in self.sets.iter().enumerate() {
258            if idx > 0 {
259                builder.push_sql(", ");
260            }
261            builder.push_sql(col.name);
262            builder.push_sql(" = ");
263            builder.push_value(value.clone());
264        }
265        if !self.filters.is_empty() {
266            builder.push_sql(" WHERE ");
267            for (idx, expr) in self.filters.iter().enumerate() {
268                if idx > 0 {
269                    builder.push_sql(" AND ");
270                }
271                expr.node.to_sql(&mut builder);
272            }
273        }
274        if self.returning_all {
275            builder.push_sql(" RETURNING ");
276            builder.push_sql(self.table.qualifier());
277            builder.push_sql(".*");
278        } else if let Some(columns) = &self.returning {
279            builder.push_sql(" RETURNING ");
280            for (idx, col) in columns.iter().enumerate() {
281                if idx > 0 {
282                    builder.push_sql(", ");
283                }
284                builder.push_column(*col);
285            }
286        }
287        builder.finish()
288    }
289}
290
291#[derive(Debug, Clone)]
292pub struct Delete {
293    table: Table,
294    filters: Vec<Expr<bool>>,
295    returning: Option<Vec<ColumnRef>>,
296    returning_all: bool,
297}
298
299impl Delete {
300    pub fn new(table: Table) -> Self {
301        Self {
302            table,
303            filters: Vec::new(),
304            returning: None,
305            returning_all: false,
306        }
307    }
308
309    pub fn filter(mut self, expr: Expr<bool>) -> Self {
310        self.filters.push(expr);
311        self
312    }
313
314    pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
315        self.returning = Some(columns);
316        self.returning_all = false;
317        self
318    }
319
320    pub fn returning_all(mut self) -> Self {
321        self.returning = None;
322        self.returning_all = true;
323        self
324    }
325
326    pub fn compile(&self) -> CompiledSql {
327        let mut builder = SqlBuilder::new();
328        builder.push_sql("DELETE FROM ");
329        builder.push_sql(&self.table.qualified_name());
330        if !self.filters.is_empty() {
331            builder.push_sql(" WHERE ");
332            for (idx, expr) in self.filters.iter().enumerate() {
333                if idx > 0 {
334                    builder.push_sql(" AND ");
335                }
336                expr.node.to_sql(&mut builder);
337            }
338        }
339        if self.returning_all {
340            builder.push_sql(" RETURNING ");
341            builder.push_sql(self.table.qualifier());
342            builder.push_sql(".*");
343        } else if let Some(columns) = &self.returning {
344            builder.push_sql(" RETURNING ");
345            for (idx, col) in columns.iter().enumerate() {
346                if idx > 0 {
347                    builder.push_sql(", ");
348                }
349                builder.push_column(*col);
350            }
351        }
352        builder.finish()
353    }
354}