reinhardt_query/query/
insert.rs1use crate::{
6 types::{DynIden, IntoIden, IntoTableRef, TableRef},
7 value::{IntoValue, Value, Values},
8};
9
10use super::{
11 returning::ReturningClause,
12 select::SelectStatement,
13 traits::{QueryBuilderTrait, QueryStatementBuilder, QueryStatementWriter},
14};
15
16#[derive(Debug, Clone)]
21pub enum InsertSource {
22 Values(Vec<Vec<Value>>),
24 Subquery(Box<SelectStatement>),
26}
27
28impl Default for InsertSource {
29 fn default() -> Self {
30 Self::Values(Vec::new())
31 }
32}
33
34#[derive(Debug, Clone)]
50pub struct InsertStatement {
51 pub(crate) table: Option<TableRef>,
52 pub(crate) columns: Vec<DynIden>,
53 pub(crate) source: InsertSource,
54 pub(crate) returning: Option<ReturningClause>,
55 pub(crate) on_conflict: Option<super::on_conflict::OnConflict>,
56}
57
58impl InsertStatement {
59 pub fn new() -> Self {
61 Self {
62 table: None,
63 columns: Vec::new(),
64 source: InsertSource::Values(Vec::new()),
65 returning: None,
66 on_conflict: None,
67 }
68 }
69
70 pub fn take(&mut self) -> Self {
72 Self {
73 table: self.table.take(),
74 columns: std::mem::take(&mut self.columns),
75 source: std::mem::replace(&mut self.source, InsertSource::Values(Vec::new())),
76 returning: self.returning.take(),
77 on_conflict: self.on_conflict.take(),
78 }
79 }
80
81 pub fn into_table<T>(&mut self, tbl: T) -> &mut Self
92 where
93 T: IntoTableRef,
94 {
95 self.table = Some(tbl.into_table_ref());
96 self
97 }
98
99 pub fn column<C>(&mut self, col: C) -> &mut Self
112 where
113 C: IntoIden,
114 {
115 self.columns.push(col.into_iden());
116 self
117 }
118
119 pub fn columns<I, C>(&mut self, cols: I) -> &mut Self
131 where
132 I: IntoIterator<Item = C>,
133 C: IntoIden,
134 {
135 for col in cols {
136 self.column(col);
137 }
138 self
139 }
140
141 pub fn values(&mut self, values: Vec<Value>) -> Result<&mut Self, String> {
156 if !self.columns.is_empty() && values.len() != self.columns.len() {
157 return Err(format!(
158 "Number of values ({}) doesn't match number of columns ({})",
159 values.len(),
160 self.columns.len()
161 ));
162 }
163 match &mut self.source {
164 InsertSource::Values(vals) => vals.push(values),
165 InsertSource::Subquery(_) => {
166 self.source = InsertSource::Values(vec![values]);
167 }
168 }
169 Ok(self)
170 }
171
172 pub fn values_panic<I, V>(&mut self, values: I) -> &mut Self
190 where
191 I: IntoIterator<Item = V>,
192 V: IntoValue,
193 {
194 let values: Vec<Value> = values.into_iter().map(|v| v.into_value()).collect();
195 if !self.columns.is_empty() && values.len() != self.columns.len() {
196 panic!(
197 "Number of values ({}) doesn't match number of columns ({})",
198 values.len(),
199 self.columns.len()
200 );
201 }
202 match &mut self.source {
203 InsertSource::Values(vals) => vals.push(values),
204 InsertSource::Subquery(_) => {
205 self.source = InsertSource::Values(vec![values]);
206 }
207 }
208 self
209 }
210
211 pub fn returning<I, C>(&mut self, cols: I) -> &mut Self
225 where
226 I: IntoIterator<Item = C>,
227 C: crate::types::IntoColumnRef,
228 {
229 self.returning = Some(ReturningClause::columns(cols));
230 self
231 }
232
233 pub fn returning_col<C>(&mut self, col: C) -> &mut Self
247 where
248 C: crate::types::IntoColumnRef,
249 {
250 self.returning = Some(ReturningClause::columns([col]));
251 self
252 }
253
254 pub fn on_conflict(&mut self, on_conflict: super::on_conflict::OnConflict) -> &mut Self {
269 self.on_conflict = Some(on_conflict);
270 self
271 }
272
273 pub fn returning_all(&mut self) -> &mut Self {
287 self.returning = Some(ReturningClause::all());
288 self
289 }
290
291 pub fn from_subquery(&mut self, select: SelectStatement) -> &mut Self {
309 self.source = InsertSource::Subquery(Box::new(select));
310 self
311 }
312
313 pub fn get_values(&self) -> Option<&Vec<Vec<Value>>> {
317 match &self.source {
318 InsertSource::Values(vals) => Some(vals),
319 InsertSource::Subquery(_) => None,
320 }
321 }
322}
323
324impl Default for InsertStatement {
325 fn default() -> Self {
326 Self::new()
327 }
328}
329
330impl QueryStatementBuilder for InsertStatement {
331 fn build_any(&self, query_builder: &dyn QueryBuilderTrait) -> (String, Values) {
332 use crate::backend::{
333 MySqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder,
334 };
335 use std::any::Any;
336
337 let any_builder = query_builder as &dyn Any;
338
339 if let Some(pg) = any_builder.downcast_ref::<PostgresQueryBuilder>() {
340 return pg.build_insert(self);
341 }
342
343 if let Some(mysql) = any_builder.downcast_ref::<MySqlQueryBuilder>() {
344 return mysql.build_insert(self);
345 }
346
347 if let Some(sqlite) = any_builder.downcast_ref::<SqliteQueryBuilder>() {
348 return sqlite.build_insert(self);
349 }
350
351 panic!(
352 "Unsupported query builder type. Use PostgresQueryBuilder, MySqlQueryBuilder, or SqliteQueryBuilder."
353 );
354 }
355}
356
357impl QueryStatementWriter for InsertStatement {}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use crate::Query;
363
364 #[test]
365 fn test_insert_basic() {
366 let mut query = InsertStatement::new();
367 query
368 .into_table("users")
369 .columns(["name", "email"])
370 .values_panic(["Alice", "alice@example.com"]);
371
372 assert!(query.table.is_some());
373 assert_eq!(query.columns.len(), 2);
374 let values = query.get_values().expect("should have values");
375 assert_eq!(values.len(), 1);
376 assert_eq!(values[0].len(), 2);
377 }
378
379 #[test]
380 fn test_insert_multiple_rows() {
381 let mut query = InsertStatement::new();
382 query
383 .into_table("users")
384 .columns(["name", "email"])
385 .values_panic(["Alice", "alice@example.com"])
386 .values_panic(["Bob", "bob@example.com"]);
387
388 let values = query.get_values().expect("should have values");
389 assert_eq!(values.len(), 2);
390 }
391
392 #[test]
393 #[should_panic(expected = "Number of values")]
394 fn test_insert_values_mismatch() {
395 let mut query = InsertStatement::new();
396 query
397 .into_table("users")
398 .columns(["name", "email"])
399 .values_panic(["Alice"]); }
401
402 #[test]
403 fn test_insert_returning() {
404 let mut query = InsertStatement::new();
405 query
406 .into_table("users")
407 .columns(["name"])
408 .values_panic(["Alice"])
409 .returning(["id", "created_at"]);
410
411 assert!(query.returning.is_some());
412 let returning = query.returning.unwrap();
413 assert!(!returning.is_all());
414 }
415
416 #[test]
417 fn test_insert_returning_all() {
418 let mut query = InsertStatement::new();
419 query
420 .into_table("users")
421 .columns(["name"])
422 .values_panic(["Alice"])
423 .returning_all();
424
425 assert!(query.returning.is_some());
426 let returning = query.returning.unwrap();
427 assert!(returning.is_all());
428 }
429
430 #[test]
431 fn test_insert_take() {
432 let mut query = InsertStatement::new();
433 query
434 .into_table("users")
435 .columns(["name"])
436 .values_panic(["Alice"]);
437
438 let taken = query.take();
439 assert!(taken.table.is_some());
440 assert!(query.table.is_none());
441 }
442
443 #[test]
444 fn test_insert_from_subquery() {
445 let mut query = InsertStatement::new();
446 let select = Query::select()
447 .column("name")
448 .column("email")
449 .from("temp_users")
450 .to_owned();
451
452 query
453 .into_table("users")
454 .columns(["name", "email"])
455 .from_subquery(select);
456
457 assert!(query.table.is_some());
458 assert_eq!(query.columns.len(), 2);
459 assert!(
460 query.get_values().is_none(),
461 "should not have values when using subquery"
462 );
463 }
464}