1use std::marker::PhantomData;
2
3use crate::compile::{CompiledSql, SqlBuilder, ToSql};
4use crate::expr::{ColumnValue, Expr, Value};
5use crate::schema::{Column, ColumnRef, Table};
6
7#[derive(Debug, Clone)]
8pub struct Insert<Out> {
9 table: Table,
10 columns: Vec<ColumnRef>,
11 values: Vec<Value>,
12 row_count: usize,
13 mode: InsertMode,
14 returning: Option<Vec<ColumnRef>>,
15 returning_all: bool,
16 _marker: PhantomData<Out>,
17}
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20enum InsertMode {
21 Unset,
22 Values,
23 Rows,
24}
25
26impl<Out> Insert<Out> {
27 pub fn new(table: Table) -> Self {
28 Self {
29 table,
30 columns: Vec::new(),
31 values: Vec::new(),
32 row_count: 0,
33 mode: InsertMode::Unset,
34 returning: None,
35 returning_all: false,
36 _marker: PhantomData,
37 }
38 }
39
40 pub fn value<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
41 where
42 V: ColumnValue<T>,
43 {
44 if self.mode == InsertMode::Rows {
45 panic!("dbkit: cannot use value() after row()");
46 }
47 self.mode = InsertMode::Values;
48 if self.row_count == 0 {
49 self.row_count = 1;
50 }
51 let value = match value.into_value() {
52 Some(value) => value,
53 None => Value::Null,
54 };
55 self.columns.push(column.as_ref());
56 self.values.push(value);
57 self
58 }
59
60 pub fn row<F>(mut self, build: F) -> Self
61 where
62 F: FnOnce(InsertRow) -> InsertRow,
63 {
64 if self.mode == InsertMode::Values {
65 self.mode = InsertMode::Rows;
66 }
67 if self.mode == InsertMode::Unset {
68 self.mode = InsertMode::Rows;
69 }
70
71 let expected = if self.columns.is_empty() {
72 None
73 } else {
74 Some(self.columns.clone())
75 };
76 let row = build(InsertRow::new(expected));
77 if self.columns.is_empty() {
78 self.columns = row.columns.clone();
79 } else if row.columns != self.columns {
80 panic!("dbkit: insert row columns must match");
81 }
82 if row.values.len() != self.columns.len() {
83 panic!("dbkit: insert row value count mismatch");
84 }
85
86 if self.row_count == 0 && !self.values.is_empty() {
87 self.row_count = 1;
88 }
89 self.values.extend(row.values);
90 self.row_count += 1;
91 self
92 }
93
94 pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
95 self.returning = Some(columns);
96 self.returning_all = false;
97 self
98 }
99
100 pub fn returning_all(mut self) -> Self {
101 self.returning = None;
102 self.returning_all = true;
103 self
104 }
105
106 pub fn compile(&self) -> CompiledSql {
107 let mut builder = SqlBuilder::new();
108 builder.push_sql("INSERT INTO ");
109 builder.push_sql(&self.table.qualified_name());
110 builder.push_sql(" (");
111 for (idx, col) in self.columns.iter().enumerate() {
112 if idx > 0 {
113 builder.push_sql(", ");
114 }
115 builder.push_sql(col.name);
116 }
117 builder.push_sql(") VALUES (");
118 let row_len = self.columns.len();
119 let row_count = if self.row_count == 0 && !self.values.is_empty() {
120 1
121 } else {
122 self.row_count
123 };
124 if row_count == 0 {
125 builder.push_sql(")");
126 } else {
127 for row_idx in 0..row_count {
128 if row_idx > 0 {
129 builder.push_sql(", (");
130 }
131 for col_idx in 0..row_len {
132 if col_idx > 0 {
133 builder.push_sql(", ");
134 }
135 let value = self.values[row_idx * row_len + col_idx].clone();
136 builder.push_value(value);
137 }
138 builder.push_sql(")");
139 }
140 }
141 if self.returning_all {
142 builder.push_sql(" RETURNING ");
143 builder.push_sql(self.table.qualifier());
144 builder.push_sql(".*");
145 } else if let Some(columns) = &self.returning {
146 builder.push_sql(" RETURNING ");
147 for (idx, col) in columns.iter().enumerate() {
148 if idx > 0 {
149 builder.push_sql(", ");
150 }
151 builder.push_column(*col);
152 }
153 }
154 builder.finish()
155 }
156}
157
158#[derive(Debug, Clone)]
159pub struct InsertRow {
160 columns: Vec<ColumnRef>,
161 values: Vec<Value>,
162 expected: Option<Vec<ColumnRef>>,
163}
164
165impl InsertRow {
166 fn new(expected: Option<Vec<ColumnRef>>) -> Self {
167 let columns = expected.clone().unwrap_or_default();
168 Self {
169 columns,
170 values: Vec::new(),
171 expected,
172 }
173 }
174
175 pub fn value<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
176 where
177 V: ColumnValue<T>,
178 {
179 let column_ref = column.as_ref();
180 if let Some(expected) = &self.expected {
181 let idx = self.values.len();
182 if idx >= expected.len() {
183 panic!("dbkit: insert row has too many values");
184 }
185 if expected[idx] != column_ref {
186 panic!("dbkit: insert row column mismatch");
187 }
188 } else {
189 self.columns.push(column_ref);
190 }
191
192 let value = match value.into_value() {
193 Some(value) => value,
194 None => Value::Null,
195 };
196 self.values.push(value);
197 self
198 }
199}
200
201#[derive(Debug, Clone)]
202pub struct Update<Out> {
203 table: Table,
204 sets: Vec<(ColumnRef, Value)>,
205 filters: Vec<Expr<bool>>,
206 returning: Option<Vec<ColumnRef>>,
207 returning_all: bool,
208 _marker: PhantomData<Out>,
209}
210
211impl<Out> Update<Out> {
212 pub fn new(table: Table) -> Self {
213 Self {
214 table,
215 sets: Vec::new(),
216 filters: Vec::new(),
217 returning: None,
218 returning_all: false,
219 _marker: PhantomData,
220 }
221 }
222
223 pub fn set<M, T, V>(mut self, column: Column<M, T>, value: V) -> Self
224 where
225 V: ColumnValue<T>,
226 {
227 let value = match value.into_value() {
228 Some(value) => value,
229 None => Value::Null,
230 };
231 self.sets.push((column.as_ref(), value));
232 self
233 }
234
235 pub fn filter(mut self, expr: Expr<bool>) -> Self {
236 self.filters.push(expr);
237 self
238 }
239
240 pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
241 self.returning = Some(columns);
242 self.returning_all = false;
243 self
244 }
245
246 pub fn returning_all(mut self) -> Self {
247 self.returning = None;
248 self.returning_all = true;
249 self
250 }
251
252 pub fn compile(&self) -> CompiledSql {
253 let mut builder = SqlBuilder::new();
254 builder.push_sql("UPDATE ");
255 builder.push_sql(&self.table.qualified_name());
256 builder.push_sql(" SET ");
257 for (idx, (col, value)) in self.sets.iter().enumerate() {
258 if idx > 0 {
259 builder.push_sql(", ");
260 }
261 builder.push_sql(col.name);
262 builder.push_sql(" = ");
263 builder.push_value(value.clone());
264 }
265 if !self.filters.is_empty() {
266 builder.push_sql(" WHERE ");
267 for (idx, expr) in self.filters.iter().enumerate() {
268 if idx > 0 {
269 builder.push_sql(" AND ");
270 }
271 expr.node.to_sql(&mut builder);
272 }
273 }
274 if self.returning_all {
275 builder.push_sql(" RETURNING ");
276 builder.push_sql(self.table.qualifier());
277 builder.push_sql(".*");
278 } else if let Some(columns) = &self.returning {
279 builder.push_sql(" RETURNING ");
280 for (idx, col) in columns.iter().enumerate() {
281 if idx > 0 {
282 builder.push_sql(", ");
283 }
284 builder.push_column(*col);
285 }
286 }
287 builder.finish()
288 }
289}
290
291#[derive(Debug, Clone)]
292pub struct Delete {
293 table: Table,
294 filters: Vec<Expr<bool>>,
295 returning: Option<Vec<ColumnRef>>,
296 returning_all: bool,
297}
298
299impl Delete {
300 pub fn new(table: Table) -> Self {
301 Self {
302 table,
303 filters: Vec::new(),
304 returning: None,
305 returning_all: false,
306 }
307 }
308
309 pub fn filter(mut self, expr: Expr<bool>) -> Self {
310 self.filters.push(expr);
311 self
312 }
313
314 pub fn returning(mut self, columns: Vec<ColumnRef>) -> Self {
315 self.returning = Some(columns);
316 self.returning_all = false;
317 self
318 }
319
320 pub fn returning_all(mut self) -> Self {
321 self.returning = None;
322 self.returning_all = true;
323 self
324 }
325
326 pub fn compile(&self) -> CompiledSql {
327 let mut builder = SqlBuilder::new();
328 builder.push_sql("DELETE FROM ");
329 builder.push_sql(&self.table.qualified_name());
330 if !self.filters.is_empty() {
331 builder.push_sql(" WHERE ");
332 for (idx, expr) in self.filters.iter().enumerate() {
333 if idx > 0 {
334 builder.push_sql(" AND ");
335 }
336 expr.node.to_sql(&mut builder);
337 }
338 }
339 if self.returning_all {
340 builder.push_sql(" RETURNING ");
341 builder.push_sql(self.table.qualifier());
342 builder.push_sql(".*");
343 } else if let Some(columns) = &self.returning {
344 builder.push_sql(" RETURNING ");
345 for (idx, col) in columns.iter().enumerate() {
346 if idx > 0 {
347 builder.push_sql(", ");
348 }
349 builder.push_column(*col);
350 }
351 }
352 builder.finish()
353 }
354}