1use crate::builder::WhereExpr;
29use crate::client::{GenericClient, StreamingClient};
30use crate::error::{OrmError, OrmResult};
31use crate::ident::{Ident, IntoIdent};
32use crate::row::FromRow;
33use crate::sql::{FromRowStream, Sql};
34use std::sync::Arc;
35use tokio_postgres::types::ToSql;
36
37pub enum SetExpr {
55 Value {
57 column: Ident,
58 value: Arc<dyn ToSql + Send + Sync>,
59 },
60 Increment { column: Ident, amount: i64 },
62 Raw(String),
64}
65
66impl SetExpr {
67 pub fn set<T: ToSql + Send + Sync + 'static>(
69 column: impl IntoIdent,
70 value: T,
71 ) -> OrmResult<Self> {
72 Ok(SetExpr::Value {
73 column: column.into_ident()?,
74 value: Arc::new(value),
75 })
76 }
77
78 pub fn increment(column: impl IntoIdent, amount: i64) -> OrmResult<Self> {
82 Ok(SetExpr::Increment {
83 column: column.into_ident()?,
84 amount,
85 })
86 }
87
88 pub fn raw(expr: impl Into<String>) -> Self {
94 SetExpr::Raw(expr.into())
95 }
96
97 fn append_to_sql(&self, sql: &mut Sql) {
98 match self {
99 SetExpr::Value { column, value } => {
100 sql.push_ident_ref(column);
101 sql.push(" = ");
102 sql.push_bind_value(value.clone());
103 }
104 SetExpr::Increment { column, amount } => {
105 sql.push_ident_ref(column);
106 sql.push(" = ");
107 sql.push_ident_ref(column);
108 if *amount >= 0 {
109 let s = format!(" + {amount}");
110 sql.push(&s);
111 } else {
112 let s = format!(" - {}", amount.abs());
113 sql.push(&s);
114 }
115 }
116 SetExpr::Raw(expr) => {
117 sql.push(expr);
118 }
119 }
120 }
121}
122
123#[must_use]
140pub struct UpdateManyBuilder {
141 pub(crate) table: Ident,
142 pub(crate) sets: Vec<SetExpr>,
143 pub(crate) where_clause: Option<WhereExpr>,
144 pub(crate) all_rows: bool,
145}
146
147impl UpdateManyBuilder {
148 pub fn filter(mut self, condition: impl Into<WhereExpr>) -> Self {
150 let new_where = condition.into();
151 self.where_clause = Some(match self.where_clause.take() {
152 Some(existing) => existing.and_with(new_where),
153 None => new_where,
154 });
155 self
156 }
157
158 pub fn all_rows(mut self) -> Self {
162 self.all_rows = true;
163 self
164 }
165
166 pub fn build_sql(&self) -> OrmResult<Sql> {
170 if self.where_clause.is_none() && !self.all_rows {
171 return Err(OrmError::Validation(
172 "update_many requires a .filter() condition or .all_rows() to proceed. \
173 This prevents accidental full-table updates."
174 .to_string(),
175 ));
176 }
177
178 let mut sql = Sql::new("UPDATE ");
179 sql.push_ident_ref(&self.table);
180 sql.push(" SET ");
181
182 for (i, set) in self.sets.iter().enumerate() {
183 if i > 0 {
184 sql.push(", ");
185 }
186 set.append_to_sql(&mut sql);
187 }
188
189 if let Some(ref where_clause) = self.where_clause {
190 sql.push(" WHERE ");
191 where_clause.append_to_sql(&mut sql);
192 }
193
194 Ok(sql)
195 }
196
197 pub async fn execute(self, conn: &impl GenericClient) -> OrmResult<u64> {
199 let sql = self.build_sql()?;
200 sql.execute(conn).await
201 }
202
203 pub async fn returning<T: FromRow>(self, conn: &impl GenericClient) -> OrmResult<Vec<T>> {
207 let mut sql = self.build_sql()?;
208 sql.push(" RETURNING *");
209 sql.fetch_all_as(conn).await
210 }
211
212 pub async fn returning_stream<T: FromRow>(
216 self,
217 conn: &impl StreamingClient,
218 ) -> OrmResult<FromRowStream<T>> {
219 let mut sql = self.build_sql()?;
220 sql.push(" RETURNING *");
221 sql.stream_as(conn).await
222 }
223}
224
225#[must_use]
240pub struct DeleteManyBuilder {
241 pub(crate) table: Ident,
242 pub(crate) where_clause: Option<WhereExpr>,
243 pub(crate) all_rows: bool,
244}
245
246impl DeleteManyBuilder {
247 pub fn filter(mut self, condition: impl Into<WhereExpr>) -> Self {
249 let new_where = condition.into();
250 self.where_clause = Some(match self.where_clause.take() {
251 Some(existing) => existing.and_with(new_where),
252 None => new_where,
253 });
254 self
255 }
256
257 pub fn all_rows(mut self) -> Self {
261 self.all_rows = true;
262 self
263 }
264
265 pub fn build_sql(&self) -> OrmResult<Sql> {
269 if self.where_clause.is_none() && !self.all_rows {
270 return Err(OrmError::Validation(
271 "delete_many requires a .filter() condition or .all_rows() to proceed. \
272 This prevents accidental full-table deletes."
273 .to_string(),
274 ));
275 }
276
277 let mut sql = Sql::new("DELETE FROM ");
278 sql.push_ident_ref(&self.table);
279
280 if let Some(ref where_clause) = self.where_clause {
281 sql.push(" WHERE ");
282 where_clause.append_to_sql(&mut sql);
283 }
284
285 Ok(sql)
286 }
287
288 pub async fn execute(self, conn: &impl GenericClient) -> OrmResult<u64> {
290 let sql = self.build_sql()?;
291 sql.execute(conn).await
292 }
293
294 pub async fn returning<T: FromRow>(self, conn: &impl GenericClient) -> OrmResult<Vec<T>> {
298 let mut sql = self.build_sql()?;
299 sql.push(" RETURNING *");
300 sql.fetch_all_as(conn).await
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::condition::Condition;
308
309 #[test]
310 fn update_many_basic_sql() {
311 let builder = UpdateManyBuilder {
312 table: Ident::parse("users").unwrap(),
313 sets: vec![SetExpr::set("status", "inactive").unwrap()],
314 where_clause: Some(WhereExpr::Atom(Condition::eq("active", true).unwrap())),
315 all_rows: false,
316 };
317 let sql = builder.build_sql().unwrap();
318 assert_eq!(
319 sql.to_sql(),
320 "UPDATE users SET status = $1 WHERE active = $2"
321 );
322 assert_eq!(sql.params_ref().len(), 2);
323 }
324
325 #[test]
326 fn update_many_multiple_sets() {
327 let builder = UpdateManyBuilder {
328 table: Ident::parse("orders").unwrap(),
329 sets: vec![
330 SetExpr::set("status", "shipped").unwrap(),
331 SetExpr::raw("shipped_at = NOW()"),
332 ],
333 where_clause: Some(WhereExpr::Atom(Condition::eq("id", 1_i64).unwrap())),
334 all_rows: false,
335 };
336 let sql = builder.build_sql().unwrap();
337 assert_eq!(
338 sql.to_sql(),
339 "UPDATE orders SET status = $1, shipped_at = NOW() WHERE id = $2"
340 );
341 assert_eq!(sql.params_ref().len(), 2);
342 }
343
344 #[test]
345 fn update_many_increment() {
346 let builder = UpdateManyBuilder {
347 table: Ident::parse("products").unwrap(),
348 sets: vec![SetExpr::increment("view_count", 1).unwrap()],
349 where_clause: Some(WhereExpr::Atom(Condition::eq("id", 42_i64).unwrap())),
350 all_rows: false,
351 };
352 let sql = builder.build_sql().unwrap();
353 assert_eq!(
354 sql.to_sql(),
355 "UPDATE products SET view_count = view_count + 1 WHERE id = $1"
356 );
357 assert_eq!(sql.params_ref().len(), 1);
358 }
359
360 #[test]
361 fn update_many_decrement() {
362 let builder = UpdateManyBuilder {
363 table: Ident::parse("products").unwrap(),
364 sets: vec![SetExpr::increment("stock", -5).unwrap()],
365 where_clause: Some(WhereExpr::Atom(Condition::eq("id", 1_i64).unwrap())),
366 all_rows: false,
367 };
368 let sql = builder.build_sql().unwrap();
369 assert_eq!(
370 sql.to_sql(),
371 "UPDATE products SET stock = stock - 5 WHERE id = $1"
372 );
373 }
374
375 #[test]
376 fn update_many_all_rows() {
377 let builder = UpdateManyBuilder {
378 table: Ident::parse("temp_data").unwrap(),
379 sets: vec![SetExpr::set("status", "archived").unwrap()],
380 where_clause: None,
381 all_rows: true,
382 };
383 let sql = builder.build_sql().unwrap();
384 assert_eq!(sql.to_sql(), "UPDATE temp_data SET status = $1");
385 }
386
387 #[test]
388 fn update_many_rejects_no_where() {
389 let builder = UpdateManyBuilder {
390 table: Ident::parse("users").unwrap(),
391 sets: vec![SetExpr::set("status", "x").unwrap()],
392 where_clause: None,
393 all_rows: false,
394 };
395 assert!(builder.build_sql().is_err());
396 }
397
398 #[test]
399 fn delete_many_basic_sql() {
400 let builder = DeleteManyBuilder {
401 table: Ident::parse("sessions").unwrap(),
402 where_clause: Some(WhereExpr::raw("expires_at < NOW()")),
403 all_rows: false,
404 };
405 let sql = builder.build_sql().unwrap();
406 assert_eq!(
407 sql.to_sql(),
408 "DELETE FROM sessions WHERE expires_at < NOW()"
409 );
410 }
411
412 #[test]
413 fn delete_many_with_condition() {
414 let builder = DeleteManyBuilder {
415 table: Ident::parse("audit_logs").unwrap(),
416 where_clause: Some(WhereExpr::And(vec![
417 WhereExpr::Atom(Condition::eq("level", "debug").unwrap()),
418 WhereExpr::Atom(Condition::eq("archived", true).unwrap()),
419 ])),
420 all_rows: false,
421 };
422 let sql = builder.build_sql().unwrap();
423 assert_eq!(
424 sql.to_sql(),
425 "DELETE FROM audit_logs WHERE (level = $1 AND archived = $2)"
426 );
427 assert_eq!(sql.params_ref().len(), 2);
428 }
429
430 #[test]
431 fn delete_many_all_rows() {
432 let builder = DeleteManyBuilder {
433 table: Ident::parse("temp_data").unwrap(),
434 where_clause: None,
435 all_rows: true,
436 };
437 let sql = builder.build_sql().unwrap();
438 assert_eq!(sql.to_sql(), "DELETE FROM temp_data");
439 }
440
441 #[test]
442 fn delete_many_rejects_no_where() {
443 let builder = DeleteManyBuilder {
444 table: Ident::parse("users").unwrap(),
445 where_clause: None,
446 all_rows: false,
447 };
448 assert!(builder.build_sql().is_err());
449 }
450
451 #[test]
452 fn update_many_via_sql_builder() {
453 let builder = crate::sql("users")
454 .update_many([SetExpr::set("status", "inactive").unwrap()])
455 .unwrap()
456 .filter(Condition::eq("active", true).unwrap());
457 let sql = builder.build_sql().unwrap();
458 assert_eq!(
459 sql.to_sql(),
460 "UPDATE users SET status = $1 WHERE active = $2"
461 );
462 }
463
464 #[test]
465 fn delete_many_via_sql_builder() {
466 let builder = crate::sql("sessions")
467 .delete_many()
468 .unwrap()
469 .filter(WhereExpr::raw("expires_at < NOW()"));
470 let sql = builder.build_sql().unwrap();
471 assert_eq!(
472 sql.to_sql(),
473 "DELETE FROM sessions WHERE expires_at < NOW()"
474 );
475 }
476
477 #[test]
478 fn update_many_filter_combines_with_and() {
479 let builder = crate::sql("orders")
480 .update_many([SetExpr::set("status", "archived").unwrap()])
481 .unwrap()
482 .filter(Condition::eq("status", "cancelled").unwrap())
483 .filter(Condition::eq("archived", false).unwrap());
484 let sql = builder.build_sql().unwrap();
485 assert_eq!(
486 sql.to_sql(),
487 "UPDATE orders SET status = $1 WHERE (status = $2 AND archived = $3)"
488 );
489 }
490
491 #[test]
492 fn delete_many_filter_combines_with_and() {
493 let builder = crate::sql("logs")
494 .delete_many()
495 .unwrap()
496 .filter(Condition::eq("level", "debug").unwrap())
497 .filter(Condition::eq("archived", true).unwrap());
498 let sql = builder.build_sql().unwrap();
499 assert_eq!(
500 sql.to_sql(),
501 "DELETE FROM logs WHERE (level = $1 AND archived = $2)"
502 );
503 }
504
505 #[test]
506 fn set_expr_validates_column_name() {
507 assert!(SetExpr::set("valid_column", "value").is_ok());
508 assert!(SetExpr::set("1invalid", "value").is_err());
509 assert!(SetExpr::set("has space", "value").is_err());
510 assert!(SetExpr::increment("valid_col", 1).is_ok());
511 assert!(SetExpr::increment("bad;col", 1).is_err());
512 }
513}