Skip to main content

reinhardt_query/query/
insert.rs

1//! INSERT statement builder
2//!
3//! This module provides the `InsertStatement` type for building SQL INSERT queries.
4
5use crate::{
6	types::{DynIden, IntoIden, IntoTableRef, TableRef},
7	value::{IntoValue, Value, Values},
8};
9
10use super::{
11	returning::ReturningClause,
12	select::SelectStatement,
13	traits::{QueryBuilderTrait, QueryStatementBuilder, QueryStatementWriter},
14};
15
16/// Source of data for INSERT statement
17///
18/// This enum represents the data source for an INSERT statement.
19/// It can be either explicit values (VALUES clause) or a subquery (SELECT statement).
20#[derive(Debug, Clone)]
21pub enum InsertSource {
22	/// Explicit values for INSERT (VALUES clause)
23	Values(Vec<Vec<Value>>),
24	/// Subquery for INSERT FROM SELECT
25	Subquery(Box<SelectStatement>),
26}
27
28impl Default for InsertSource {
29	fn default() -> Self {
30		Self::Values(Vec::new())
31	}
32}
33
34/// INSERT statement builder
35///
36/// This struct provides a fluent API for constructing INSERT queries.
37///
38/// # Examples
39///
40/// ```rust,ignore
41/// use reinhardt_query::prelude::*;
42///
43/// let query = Query::insert()
44///     .into_table("users")
45///     .columns(["name", "email"])
46///     .values_panic(["Alice", "alice@example.com"])
47///     .values_panic(["Bob", "bob@example.com"]);
48/// ```
49#[derive(Debug, Clone)]
50pub struct InsertStatement {
51	pub(crate) table: Option<TableRef>,
52	pub(crate) columns: Vec<DynIden>,
53	pub(crate) source: InsertSource,
54	pub(crate) returning: Option<ReturningClause>,
55	pub(crate) on_conflict: Option<super::on_conflict::OnConflict>,
56}
57
58impl InsertStatement {
59	/// Create a new INSERT statement
60	pub fn new() -> Self {
61		Self {
62			table: None,
63			columns: Vec::new(),
64			source: InsertSource::Values(Vec::new()),
65			returning: None,
66			on_conflict: None,
67		}
68	}
69
70	/// Take the ownership of data in the current [`InsertStatement`]
71	pub fn take(&mut self) -> Self {
72		Self {
73			table: self.table.take(),
74			columns: std::mem::take(&mut self.columns),
75			source: std::mem::replace(&mut self.source, InsertSource::Values(Vec::new())),
76			returning: self.returning.take(),
77			on_conflict: self.on_conflict.take(),
78		}
79	}
80
81	/// Set the table to insert into
82	///
83	/// # Examples
84	///
85	/// ```rust,ignore
86	/// use reinhardt_query::prelude::*;
87	///
88	/// let query = Query::insert()
89	///     .into_table("users");
90	/// ```
91	pub fn into_table<T>(&mut self, tbl: T) -> &mut Self
92	where
93		T: IntoTableRef,
94	{
95		self.table = Some(tbl.into_table_ref());
96		self
97	}
98
99	/// Add a column to insert into
100	///
101	/// # Examples
102	///
103	/// ```rust,ignore
104	/// use reinhardt_query::prelude::*;
105	///
106	/// let query = Query::insert()
107	///     .into_table("users")
108	///     .column("name")
109	///     .column("email");
110	/// ```
111	pub fn column<C>(&mut self, col: C) -> &mut Self
112	where
113		C: IntoIden,
114	{
115		self.columns.push(col.into_iden());
116		self
117	}
118
119	/// Add multiple columns to insert into
120	///
121	/// # Examples
122	///
123	/// ```rust,ignore
124	/// use reinhardt_query::prelude::*;
125	///
126	/// let query = Query::insert()
127	///     .into_table("users")
128	///     .columns(["name", "email", "created_at"]);
129	/// ```
130	pub fn columns<I, C>(&mut self, cols: I) -> &mut Self
131	where
132		I: IntoIterator<Item = C>,
133		C: IntoIden,
134	{
135		for col in cols {
136			self.column(col);
137		}
138		self
139	}
140
141	/// Add values for the columns
142	///
143	/// Returns `Err` if the number of values doesn't match the number of columns.
144	///
145	/// # Examples
146	///
147	/// ```rust,ignore
148	/// use reinhardt_query::prelude::*;
149	///
150	/// let result = Query::insert()
151	///     .into_table("users")
152	///     .columns(["name", "email"])
153	///     .values(vec!["Alice".into(), "alice@example.com".into()]);
154	/// ```
155	pub fn values(&mut self, values: Vec<Value>) -> Result<&mut Self, String> {
156		if !self.columns.is_empty() && values.len() != self.columns.len() {
157			return Err(format!(
158				"Number of values ({}) doesn't match number of columns ({})",
159				values.len(),
160				self.columns.len()
161			));
162		}
163		match &mut self.source {
164			InsertSource::Values(vals) => vals.push(values),
165			InsertSource::Subquery(_) => {
166				self.source = InsertSource::Values(vec![values]);
167			}
168		}
169		Ok(self)
170	}
171
172	/// Add values for the columns (panics on mismatch)
173	///
174	/// # Panics
175	///
176	/// Panics if the number of values doesn't match the number of columns.
177	///
178	/// # Examples
179	///
180	/// ```rust,ignore
181	/// use reinhardt_query::prelude::*;
182	///
183	/// let query = Query::insert()
184	///     .into_table("users")
185	///     .columns(["name", "email"])
186	///     .values_panic(["Alice", "alice@example.com"])
187	///     .values_panic(["Bob", "bob@example.com"]);
188	/// ```
189	pub fn values_panic<I, V>(&mut self, values: I) -> &mut Self
190	where
191		I: IntoIterator<Item = V>,
192		V: IntoValue,
193	{
194		let values: Vec<Value> = values.into_iter().map(|v| v.into_value()).collect();
195		if !self.columns.is_empty() && values.len() != self.columns.len() {
196			panic!(
197				"Number of values ({}) doesn't match number of columns ({})",
198				values.len(),
199				self.columns.len()
200			);
201		}
202		match &mut self.source {
203			InsertSource::Values(vals) => vals.push(values),
204			InsertSource::Subquery(_) => {
205				self.source = InsertSource::Values(vec![values]);
206			}
207		}
208		self
209	}
210
211	/// Add a RETURNING clause with multiple columns
212	///
213	/// # Examples
214	///
215	/// ```rust,ignore
216	/// use reinhardt_query::prelude::*;
217	///
218	/// let query = Query::insert()
219	///     .into_table("users")
220	///     .columns(["name", "email"])
221	///     .values_panic(["Alice", "alice@example.com"])
222	///     .returning(["id", "created_at"]);
223	/// ```
224	pub fn returning<I, C>(&mut self, cols: I) -> &mut Self
225	where
226		I: IntoIterator<Item = C>,
227		C: crate::types::IntoColumnRef,
228	{
229		self.returning = Some(ReturningClause::columns(cols));
230		self
231	}
232
233	/// Add a RETURNING clause for a single column
234	///
235	/// # Examples
236	///
237	/// ```rust,ignore
238	/// use reinhardt_query::prelude::*;
239	///
240	/// let query = Query::insert()
241	///     .into_table("users")
242	///     .columns(["name"])
243	///     .values_panic(["Alice"])
244	///     .returning_col(Alias::new("id"));
245	/// ```
246	pub fn returning_col<C>(&mut self, col: C) -> &mut Self
247	where
248		C: crate::types::IntoColumnRef,
249	{
250		self.returning = Some(ReturningClause::columns([col]));
251		self
252	}
253
254	/// Set ON CONFLICT clause for upsert behavior.
255	///
256	/// # Examples
257	///
258	/// ```rust,ignore
259	/// use reinhardt_query::prelude::*;
260	/// use reinhardt_query::query::OnConflict;
261	///
262	/// let query = Query::insert()
263	///     .into_table("users")
264	///     .columns(["id", "name"])
265	///     .values_panic([1, "Alice"])
266	///     .on_conflict(OnConflict::column("id").update_columns(["name"]));
267	/// ```
268	pub fn on_conflict(&mut self, on_conflict: super::on_conflict::OnConflict) -> &mut Self {
269		self.on_conflict = Some(on_conflict);
270		self
271	}
272
273	/// Add a RETURNING * clause
274	///
275	/// # Examples
276	///
277	/// ```rust,ignore
278	/// use reinhardt_query::prelude::*;
279	///
280	/// let query = Query::insert()
281	///     .into_table("users")
282	///     .columns(["name", "email"])
283	///     .values_panic(["Alice", "alice@example.com"])
284	///     .returning_all();
285	/// ```
286	pub fn returning_all(&mut self) -> &mut Self {
287		self.returning = Some(ReturningClause::all());
288		self
289	}
290
291	/// Use a subquery as the data source for INSERT
292	///
293	/// # Examples
294	///
295	/// ```rust,ignore
296	/// use reinhardt_query::prelude::*;
297	///
298	/// let select = Query::select()
299	///     .column("name")
300	///     .column("email")
301	///     .from("temp_users");
302	///
303	/// let query = Query::insert()
304	///     .into_table("users")
305	///     .columns(["name", "email"])
306	///     .from_subquery(select);
307	/// ```
308	pub fn from_subquery(&mut self, select: SelectStatement) -> &mut Self {
309		self.source = InsertSource::Subquery(Box::new(select));
310		self
311	}
312
313	/// Get the values if this is a VALUES source
314	///
315	/// Returns `None` if the source is a subquery.
316	pub fn get_values(&self) -> Option<&Vec<Vec<Value>>> {
317		match &self.source {
318			InsertSource::Values(vals) => Some(vals),
319			InsertSource::Subquery(_) => None,
320		}
321	}
322}
323
324impl Default for InsertStatement {
325	fn default() -> Self {
326		Self::new()
327	}
328}
329
330impl QueryStatementBuilder for InsertStatement {
331	fn build_any(&self, query_builder: &dyn QueryBuilderTrait) -> (String, Values) {
332		use crate::backend::{
333			MySqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder,
334		};
335		use std::any::Any;
336
337		let any_builder = query_builder as &dyn Any;
338
339		if let Some(pg) = any_builder.downcast_ref::<PostgresQueryBuilder>() {
340			return pg.build_insert(self);
341		}
342
343		if let Some(mysql) = any_builder.downcast_ref::<MySqlQueryBuilder>() {
344			return mysql.build_insert(self);
345		}
346
347		if let Some(sqlite) = any_builder.downcast_ref::<SqliteQueryBuilder>() {
348			return sqlite.build_insert(self);
349		}
350
351		panic!(
352			"Unsupported query builder type. Use PostgresQueryBuilder, MySqlQueryBuilder, or SqliteQueryBuilder."
353		);
354	}
355}
356
357impl QueryStatementWriter for InsertStatement {}
358
359#[cfg(test)]
360mod tests {
361	use super::*;
362	use crate::Query;
363
364	#[test]
365	fn test_insert_basic() {
366		let mut query = InsertStatement::new();
367		query
368			.into_table("users")
369			.columns(["name", "email"])
370			.values_panic(["Alice", "alice@example.com"]);
371
372		assert!(query.table.is_some());
373		assert_eq!(query.columns.len(), 2);
374		let values = query.get_values().expect("should have values");
375		assert_eq!(values.len(), 1);
376		assert_eq!(values[0].len(), 2);
377	}
378
379	#[test]
380	fn test_insert_multiple_rows() {
381		let mut query = InsertStatement::new();
382		query
383			.into_table("users")
384			.columns(["name", "email"])
385			.values_panic(["Alice", "alice@example.com"])
386			.values_panic(["Bob", "bob@example.com"]);
387
388		let values = query.get_values().expect("should have values");
389		assert_eq!(values.len(), 2);
390	}
391
392	#[test]
393	#[should_panic(expected = "Number of values")]
394	fn test_insert_values_mismatch() {
395		let mut query = InsertStatement::new();
396		query
397			.into_table("users")
398			.columns(["name", "email"])
399			.values_panic(["Alice"]); // Should panic: 1 value, 2 columns
400	}
401
402	#[test]
403	fn test_insert_returning() {
404		let mut query = InsertStatement::new();
405		query
406			.into_table("users")
407			.columns(["name"])
408			.values_panic(["Alice"])
409			.returning(["id", "created_at"]);
410
411		assert!(query.returning.is_some());
412		let returning = query.returning.unwrap();
413		assert!(!returning.is_all());
414	}
415
416	#[test]
417	fn test_insert_returning_all() {
418		let mut query = InsertStatement::new();
419		query
420			.into_table("users")
421			.columns(["name"])
422			.values_panic(["Alice"])
423			.returning_all();
424
425		assert!(query.returning.is_some());
426		let returning = query.returning.unwrap();
427		assert!(returning.is_all());
428	}
429
430	#[test]
431	fn test_insert_take() {
432		let mut query = InsertStatement::new();
433		query
434			.into_table("users")
435			.columns(["name"])
436			.values_panic(["Alice"]);
437
438		let taken = query.take();
439		assert!(taken.table.is_some());
440		assert!(query.table.is_none());
441	}
442
443	#[test]
444	fn test_insert_from_subquery() {
445		let mut query = InsertStatement::new();
446		let select = Query::select()
447			.column("name")
448			.column("email")
449			.from("temp_users")
450			.to_owned();
451
452		query
453			.into_table("users")
454			.columns(["name", "email"])
455			.from_subquery(select);
456
457		assert!(query.table.is_some());
458		assert_eq!(query.columns.len(), 2);
459		assert!(
460			query.get_values().is_none(),
461			"should not have values when using subquery"
462		);
463	}
464}