Skip to main content

reinhardt_query/query/
update.rs

1//! UPDATE statement builder
2//!
3//! This module provides the `UpdateStatement` type for building SQL UPDATE queries.
4
5use crate::{
6	expr::{Condition, ConditionHolder, IntoCondition, SimpleExpr},
7	types::{DynIden, IntoIden, IntoTableRef, TableRef},
8	value::{IntoValue, Values},
9};
10
11use super::{
12	returning::ReturningClause,
13	traits::{QueryBuilderTrait, QueryStatementBuilder, QueryStatementWriter},
14};
15
16/// UPDATE statement builder
17///
18/// This struct provides a fluent API for constructing UPDATE queries.
19///
20/// # Examples
21///
22/// ```rust,ignore
23/// use reinhardt_query::prelude::*;
24///
25/// let query = Query::update()
26///     .table("users")
27///     .value("active", false)
28///     .and_where(Expr::col("last_login").lt("2020-01-01"));
29/// ```
30#[derive(Debug, Clone)]
31pub struct UpdateStatement {
32	pub(crate) table: Option<TableRef>,
33	pub(crate) values: Vec<(DynIden, SimpleExpr)>,
34	pub(crate) r#where: ConditionHolder,
35	pub(crate) returning: Option<ReturningClause>,
36}
37
38impl UpdateStatement {
39	/// Create a new UPDATE statement
40	pub fn new() -> Self {
41		Self {
42			table: None,
43			values: Vec::new(),
44			r#where: ConditionHolder::new(),
45			returning: None,
46		}
47	}
48
49	/// Take the ownership of data in the current [`UpdateStatement`]
50	pub fn take(&mut self) -> Self {
51		Self {
52			table: self.table.take(),
53			values: std::mem::take(&mut self.values),
54			r#where: std::mem::replace(&mut self.r#where, ConditionHolder::new()),
55			returning: self.returning.take(),
56		}
57	}
58
59	/// Set the table to update
60	///
61	/// # Examples
62	///
63	/// ```rust,ignore
64	/// use reinhardt_query::prelude::*;
65	///
66	/// let query = Query::update()
67	///     .table("users");
68	/// ```
69	pub fn table<T>(&mut self, tbl: T) -> &mut Self
70	where
71		T: IntoTableRef,
72	{
73		self.table = Some(tbl.into_table_ref());
74		self
75	}
76
77	/// Set a column value
78	///
79	/// # Examples
80	///
81	/// ```rust,ignore
82	/// use reinhardt_query::prelude::*;
83	///
84	/// let query = Query::update()
85	///     .table("users")
86	///     .value("active", false)
87	///     .value("name", "Alice");
88	/// ```
89	pub fn value<C, V>(&mut self, col: C, val: V) -> &mut Self
90	where
91		C: IntoIden,
92		V: IntoValue,
93	{
94		self.values
95			.push((col.into_iden(), SimpleExpr::Value(val.into_value())));
96		self
97	}
98
99	/// Set a column to an expression value
100	///
101	/// Use this method when the value is an expression (e.g., `Expr::current_timestamp()`,
102	/// `Expr::col("other_column")`, `Expr::cust("NULL")`) rather than a plain value.
103	///
104	/// # Examples
105	///
106	/// ```rust,ignore
107	/// use reinhardt_query::prelude::*;
108	///
109	/// let query = Query::update()
110	///     .table("users")
111	///     .value_expr("updated_at", Expr::current_timestamp())
112	///     .value_expr("status", Expr::cust("NULL"));
113	/// ```
114	pub fn value_expr<C, E>(&mut self, col: C, expr: E) -> &mut Self
115	where
116		C: IntoIden,
117		E: Into<SimpleExpr>,
118	{
119		self.values.push((col.into_iden(), expr.into()));
120		self
121	}
122
123	/// Set multiple column values
124	///
125	/// # Examples
126	///
127	/// ```rust,ignore
128	/// use reinhardt_query::prelude::*;
129	///
130	/// let query = Query::update()
131	///     .table("users")
132	///     .values([
133	///         ("name", "Alice".into()),
134	///         ("email", "alice@example.com".into()),
135	///     ]);
136	/// ```
137	pub fn values<I, C, V>(&mut self, values: I) -> &mut Self
138	where
139		I: IntoIterator<Item = (C, V)>,
140		C: IntoIden,
141		V: IntoValue,
142	{
143		for (col, val) in values {
144			self.value(col, val);
145		}
146		self
147	}
148
149	/// Add a condition to the WHERE clause
150	///
151	/// # Examples
152	///
153	/// ```rust,ignore
154	/// use reinhardt_query::prelude::*;
155	///
156	/// let query = Query::update()
157	///     .table("users")
158	///     .value("active", false)
159	///     .and_where(Expr::col("last_login").lt("2020-01-01"));
160	/// ```
161	pub fn and_where<C>(&mut self, condition: C) -> &mut Self
162	where
163		C: IntoCondition,
164	{
165		self.r#where.add_and(condition);
166		self
167	}
168
169	/// Add a conditional WHERE clause.
170	///
171	/// This is an alias for [`and_where`](Self::and_where) that accepts a [`Condition`].
172	pub fn cond_where(&mut self, condition: Condition) -> &mut Self {
173		self.r#where.add_and(condition);
174		self
175	}
176
177	/// Add a RETURNING clause
178	///
179	/// # Examples
180	///
181	/// ```rust,ignore
182	/// use reinhardt_query::prelude::*;
183	///
184	/// let query = Query::update()
185	///     .table("users")
186	///     .value("active", false)
187	///     .and_where(Expr::col("id").eq(1))
188	///     .returning(["id", "updated_at"]);
189	/// ```
190	pub fn returning<I, C>(&mut self, cols: I) -> &mut Self
191	where
192		I: IntoIterator<Item = C>,
193		C: crate::types::IntoColumnRef,
194	{
195		self.returning = Some(ReturningClause::columns(cols));
196		self
197	}
198
199	/// Add a RETURNING * clause
200	///
201	/// # Examples
202	///
203	/// ```rust,ignore
204	/// use reinhardt_query::prelude::*;
205	///
206	/// let query = Query::update()
207	///     .table("users")
208	///     .value("active", false)
209	///     .and_where(Expr::col("id").eq(1))
210	///     .returning_all();
211	/// ```
212	pub fn returning_all(&mut self) -> &mut Self {
213		self.returning = Some(ReturningClause::all());
214		self
215	}
216}
217
218impl Default for UpdateStatement {
219	fn default() -> Self {
220		Self::new()
221	}
222}
223
224impl QueryStatementBuilder for UpdateStatement {
225	fn build_any(&self, query_builder: &dyn QueryBuilderTrait) -> (String, Values) {
226		use crate::backend::{
227			MySqlQueryBuilder, PostgresQueryBuilder, QueryBuilder, SqliteQueryBuilder,
228		};
229		use std::any::Any;
230
231		let any_builder = query_builder as &dyn Any;
232
233		if let Some(pg) = any_builder.downcast_ref::<PostgresQueryBuilder>() {
234			return pg.build_update(self);
235		}
236
237		if let Some(mysql) = any_builder.downcast_ref::<MySqlQueryBuilder>() {
238			return mysql.build_update(self);
239		}
240
241		if let Some(sqlite) = any_builder.downcast_ref::<SqliteQueryBuilder>() {
242			return sqlite.build_update(self);
243		}
244
245		panic!(
246			"Unsupported query builder type. Use PostgresQueryBuilder, MySqlQueryBuilder, or SqliteQueryBuilder."
247		);
248	}
249}
250
251impl QueryStatementWriter for UpdateStatement {}
252
253#[cfg(test)]
254mod tests {
255	use super::*;
256	use crate::expr::{Expr, ExprTrait};
257
258	#[test]
259	fn test_update_basic() {
260		let mut query = UpdateStatement::new();
261		query
262			.table("users")
263			.value("name", "Alice")
264			.value("email", "alice@example.com");
265
266		assert!(query.table.is_some());
267		assert_eq!(query.values.len(), 2);
268	}
269
270	#[test]
271	fn test_update_with_where() {
272		let mut query = UpdateStatement::new();
273		query
274			.table("users")
275			.value("active", false)
276			.and_where(Expr::col("id").eq(1));
277
278		assert!(query.table.is_some());
279		assert_eq!(query.values.len(), 1);
280		assert!(!query.r#where.is_empty());
281	}
282
283	#[test]
284	fn test_update_multiple_values() {
285		let mut query = UpdateStatement::new();
286		query
287			.table("users")
288			.values([("name", "Alice"), ("email", "alice@example.com")]);
289
290		assert_eq!(query.values.len(), 2);
291	}
292
293	#[test]
294	fn test_update_returning() {
295		let mut query = UpdateStatement::new();
296		query
297			.table("users")
298			.value("active", false)
299			.returning(["id", "updated_at"]);
300
301		assert!(query.returning.is_some());
302		let returning = query.returning.unwrap();
303		assert!(!returning.is_all());
304	}
305
306	#[test]
307	fn test_update_returning_all() {
308		let mut query = UpdateStatement::new();
309		query.table("users").value("active", false).returning_all();
310
311		assert!(query.returning.is_some());
312		let returning = query.returning.unwrap();
313		assert!(returning.is_all());
314	}
315
316	#[test]
317	fn test_update_take() {
318		let mut query = UpdateStatement::new();
319		query.table("users").value("active", false);
320
321		let taken = query.take();
322		assert!(taken.table.is_some());
323		assert!(query.table.is_none());
324	}
325}