1use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10pub struct UpdateOperation<E: QueryEngine, M: Model> {
24 engine: E,
25 filter: Filter,
26 updates: Vec<(String, FilterValue)>,
27 select: Select,
28 _model: PhantomData<M>,
29}
30
31impl<E: QueryEngine, M: Model + crate::row::FromRow> UpdateOperation<E, M> {
32 pub fn new(engine: E) -> Self {
34 Self {
35 engine,
36 filter: Filter::None,
37 updates: Vec::new(),
38 select: Select::All,
39 _model: PhantomData,
40 }
41 }
42
43 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
45 let new_filter = filter.into();
46 self.filter = self.filter.and_then(new_filter);
47 self
48 }
49
50 pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
52 self.updates.push((column.into(), value.into()));
53 self
54 }
55
56 pub fn set_many(
58 mut self,
59 values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
60 ) -> Self {
61 for (col, val) in values {
62 self.updates.push((col.into(), val.into()));
63 }
64 self
65 }
66
67 pub fn increment(self, column: impl Into<String>, amount: i64) -> Self {
69 self.set(column, FilterValue::Int(amount))
72 }
73
74 pub fn select(mut self, select: impl Into<Select>) -> Self {
76 self.select = select.into();
77 self
78 }
79
80 pub fn build_sql(
82 &self,
83 dialect: &dyn crate::dialect::SqlDialect,
84 ) -> (String, Vec<FilterValue>) {
85 let mut sql = String::new();
86 let mut params = Vec::new();
87 let mut param_idx = 1;
88
89 sql.push_str("UPDATE ");
91 sql.push_str(M::TABLE_NAME);
92
93 sql.push_str(" SET ");
95 let set_parts: Vec<_> = self
96 .updates
97 .iter()
98 .map(|(col, val)| {
99 params.push(val.clone());
100 let part = format!("{} = {}", col, dialect.placeholder(param_idx));
101 param_idx += 1;
102 part
103 })
104 .collect();
105 sql.push_str(&set_parts.join(", "));
106
107 if !self.filter.is_none() {
109 let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
110 sql.push_str(" WHERE ");
111 sql.push_str(&where_sql);
112 params.extend(where_params);
113 }
114
115 sql.push_str(&dialect.returning_clause(&self.select.to_sql()));
117
118 (sql, params)
119 }
120
121 pub async fn exec(self) -> QueryResult<Vec<M>>
123 where
124 M: Send + 'static,
125 {
126 let dialect = self.engine.dialect();
127 let (sql, params) = self.build_sql(dialect);
128 self.engine.execute_update::<M>(&sql, params).await
129 }
130
131 pub async fn exec_one(self) -> QueryResult<M>
133 where
134 M: Send + 'static,
135 {
136 let dialect = self.engine.dialect();
137 let (sql, params) = self.build_sql(dialect);
138 self.engine.query_one::<M>(&sql, params).await
139 }
140}
141
142pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
144 engine: E,
145 filter: Filter,
146 updates: Vec<(String, FilterValue)>,
147 _model: PhantomData<M>,
148}
149
150impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
151 pub fn new(engine: E) -> Self {
153 Self {
154 engine,
155 filter: Filter::None,
156 updates: Vec::new(),
157 _model: PhantomData,
158 }
159 }
160
161 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
163 let new_filter = filter.into();
164 self.filter = self.filter.and_then(new_filter);
165 self
166 }
167
168 pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
170 self.updates.push((column.into(), value.into()));
171 self
172 }
173
174 pub fn build_sql(
176 &self,
177 dialect: &dyn crate::dialect::SqlDialect,
178 ) -> (String, Vec<FilterValue>) {
179 let mut sql = String::new();
180 let mut params = Vec::new();
181 let mut param_idx = 1;
182
183 sql.push_str("UPDATE ");
185 sql.push_str(M::TABLE_NAME);
186
187 sql.push_str(" SET ");
189 let set_parts: Vec<_> = self
190 .updates
191 .iter()
192 .map(|(col, val)| {
193 params.push(val.clone());
194 let part = format!("{} = {}", col, dialect.placeholder(param_idx));
195 param_idx += 1;
196 part
197 })
198 .collect();
199 sql.push_str(&set_parts.join(", "));
200
201 if !self.filter.is_none() {
203 let (where_sql, where_params) = self.filter.to_sql(param_idx - 1, dialect);
204 sql.push_str(" WHERE ");
205 sql.push_str(&where_sql);
206 params.extend(where_params);
207 }
208
209 (sql, params)
210 }
211
212 pub async fn exec(self) -> QueryResult<u64> {
214 let dialect = self.engine.dialect();
215 let (sql, params) = self.build_sql(dialect);
216 self.engine.execute_raw(&sql, params).await
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223 use crate::error::QueryError;
224 use crate::types::Select;
225
226 struct TestModel;
227
228 impl Model for TestModel {
229 const MODEL_NAME: &'static str = "TestModel";
230 const TABLE_NAME: &'static str = "test_models";
231 const PRIMARY_KEY: &'static [&'static str] = &["id"];
232 const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
233 }
234
235 impl crate::row::FromRow for TestModel {
236 fn from_row(_row: &impl crate::row::RowRef) -> Result<Self, crate::row::RowError> {
237 Ok(TestModel)
238 }
239 }
240
241 #[derive(Clone)]
242 struct MockEngine {
243 return_count: u64,
244 }
245
246 impl MockEngine {
247 fn new() -> Self {
248 Self { return_count: 0 }
249 }
250
251 fn with_count(count: u64) -> Self {
252 Self {
253 return_count: count,
254 }
255 }
256 }
257
258 impl QueryEngine for MockEngine {
259 fn dialect(&self) -> &dyn crate::dialect::SqlDialect {
260 &crate::dialect::Postgres
261 }
262
263 fn query_many<T: Model + crate::row::FromRow + Send + 'static>(
264 &self,
265 _sql: &str,
266 _params: Vec<FilterValue>,
267 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
268 Box::pin(async { Ok(Vec::new()) })
269 }
270
271 fn query_one<T: Model + crate::row::FromRow + Send + 'static>(
272 &self,
273 _sql: &str,
274 _params: Vec<FilterValue>,
275 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
276 Box::pin(async { Err(QueryError::not_found("test")) })
277 }
278
279 fn query_optional<T: Model + crate::row::FromRow + Send + 'static>(
280 &self,
281 _sql: &str,
282 _params: Vec<FilterValue>,
283 ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
284 Box::pin(async { Ok(None) })
285 }
286
287 fn execute_insert<T: Model + crate::row::FromRow + Send + 'static>(
288 &self,
289 _sql: &str,
290 _params: Vec<FilterValue>,
291 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
292 Box::pin(async { Err(QueryError::not_found("test")) })
293 }
294
295 fn execute_update<T: Model + crate::row::FromRow + Send + 'static>(
296 &self,
297 _sql: &str,
298 _params: Vec<FilterValue>,
299 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
300 Box::pin(async { Ok(Vec::new()) })
301 }
302
303 fn execute_delete(
304 &self,
305 _sql: &str,
306 _params: Vec<FilterValue>,
307 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
308 Box::pin(async { Ok(0) })
309 }
310
311 fn execute_raw(
312 &self,
313 _sql: &str,
314 _params: Vec<FilterValue>,
315 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
316 let count = self.return_count;
317 Box::pin(async move { Ok(count) })
318 }
319
320 fn count(
321 &self,
322 _sql: &str,
323 _params: Vec<FilterValue>,
324 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
325 Box::pin(async { Ok(0) })
326 }
327 }
328
329 #[test]
332 fn test_update_new() {
333 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
334 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
335
336 assert!(sql.contains("UPDATE test_models SET"));
337 assert!(sql.contains("RETURNING *"));
338 assert!(params.is_empty());
339 }
340
341 #[test]
342 fn test_update_basic() {
343 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
344 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
345 .set("name", "Updated");
346
347 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
348
349 assert!(sql.contains("UPDATE test_models SET"));
350 assert!(sql.contains("name = $1"));
351 assert!(sql.contains("WHERE"));
352 assert!(sql.contains("RETURNING *"));
353 assert_eq!(params.len(), 2);
354 }
355
356 #[test]
357 fn test_update_many_fields() {
358 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
359 .set("name", "Updated")
360 .set("email", "updated@example.com");
361
362 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
363
364 assert!(sql.contains("name = $1"));
365 assert!(sql.contains("email = $2"));
366 assert_eq!(params.len(), 2);
367 }
368
369 #[test]
370 fn test_update_with_set_many() {
371 let updates = vec![
372 ("name", FilterValue::String("Alice".to_string())),
373 ("email", FilterValue::String("alice@test.com".to_string())),
374 ("age", FilterValue::Int(30)),
375 ];
376 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(updates);
377
378 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
379
380 assert!(sql.contains("name = $1"));
381 assert!(sql.contains("email = $2"));
382 assert!(sql.contains("age = $3"));
383 assert_eq!(params.len(), 3);
384 }
385
386 #[test]
387 fn test_update_increment() {
388 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
389 .increment("counter", 5);
390
391 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
392
393 assert!(sql.contains("counter = $1"));
394 assert_eq!(params.len(), 1);
395 assert_eq!(params[0], FilterValue::Int(5));
396 }
397
398 #[test]
399 fn test_update_with_select() {
400 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
401 .set("name", "Updated")
402 .select(Select::fields(["id", "name"]));
403
404 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
405
406 assert!(sql.contains("RETURNING id, name"));
407 }
408
409 #[test]
410 fn test_update_with_complex_filter() {
411 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
412 .r#where(Filter::Equals(
413 "status".into(),
414 FilterValue::String("active".to_string()),
415 ))
416 .r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
417 .set("verified", FilterValue::Bool(true));
418
419 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
420
421 assert!(sql.contains("WHERE"));
422 assert!(sql.contains("AND"));
423 assert_eq!(params.len(), 3); }
425
426 #[test]
427 fn test_update_without_filter() {
428 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
429 .set("status", "updated");
430
431 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
432
433 assert!(!sql.contains("WHERE"));
435 assert!(sql.contains("UPDATE test_models SET"));
436 }
437
438 #[test]
439 fn test_update_with_null_value() {
440 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
441 .set("deleted_at", FilterValue::Null);
442
443 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
444
445 assert!(sql.contains("deleted_at = $1"));
446 assert_eq!(params.len(), 1);
447 assert_eq!(params[0], FilterValue::Null);
448 }
449
450 #[test]
451 fn test_update_with_boolean() {
452 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
453 .set("active", FilterValue::Bool(true))
454 .set("verified", FilterValue::Bool(false));
455
456 let (_sql, params) = op.build_sql(&crate::dialect::Postgres);
457
458 assert_eq!(params.len(), 2);
459 assert_eq!(params[0], FilterValue::Bool(true));
460 assert_eq!(params[1], FilterValue::Bool(false));
461 }
462
463 #[tokio::test]
464 async fn test_update_exec() {
465 let op =
466 UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Updated");
467
468 let result = op.exec().await;
469 assert!(result.is_ok());
470 assert!(result.unwrap().is_empty());
471 }
472
473 #[tokio::test]
474 async fn test_update_exec_one() {
475 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
476 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
477 .set("name", "Updated");
478
479 let result = op.exec_one().await;
480 assert!(result.is_err()); }
482
483 #[test]
486 fn test_update_many_new() {
487 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
488 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
489
490 assert!(sql.contains("UPDATE test_models SET"));
491 assert!(!sql.contains("RETURNING")); assert!(params.is_empty());
493 }
494
495 #[test]
496 fn test_update_many_basic() {
497 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
498 .r#where(Filter::In(
499 "id".into(),
500 vec![
501 FilterValue::Int(1),
502 FilterValue::Int(2),
503 FilterValue::Int(3),
504 ],
505 ))
506 .set("status", "processed");
507
508 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
509
510 assert!(sql.contains("UPDATE test_models SET"));
511 assert!(sql.contains("status = $1"));
512 assert!(sql.contains("WHERE"));
513 assert!(sql.contains("IN"));
514 assert_eq!(params.len(), 4); }
516
517 #[test]
518 fn test_update_many_with_multiple_conditions() {
519 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
520 .r#where(Filter::Equals(
521 "department".into(),
522 FilterValue::String("engineering".to_string()),
523 ))
524 .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
525 .set("reviewed", FilterValue::Bool(true));
526
527 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
528
529 assert!(sql.contains("AND"));
530 assert_eq!(params.len(), 3);
531 }
532
533 #[test]
534 fn test_update_many_without_where() {
535 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
536 .set("reset_password", FilterValue::Bool(true));
537
538 let (sql, _) = op.build_sql(&crate::dialect::Postgres);
539
540 assert!(!sql.contains("WHERE"));
541 }
542
543 #[tokio::test]
544 async fn test_update_many_exec() {
545 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
546 .set("status", "updated");
547
548 let result = op.exec().await;
549 assert!(result.is_ok());
550 assert_eq!(result.unwrap(), 5);
551 }
552
553 #[test]
556 fn test_update_param_ordering() {
557 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
558 .set("field1", "value1")
559 .set("field2", "value2")
560 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
561
562 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
563
564 assert!(sql.contains("field1 = $1"));
566 assert!(sql.contains("field2 = $2"));
567 assert!(sql.contains(r#""id" = $3"#));
568 assert_eq!(params.len(), 3);
569 }
570
571 #[test]
572 fn test_update_many_param_ordering() {
573 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
574 .set("field1", "value1")
575 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
576
577 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
578
579 assert!(sql.contains("field1 = $1"));
580 assert!(sql.contains(r#""id" = $2"#));
581 assert_eq!(params.len(), 2);
582 }
583
584 #[test]
585 fn test_update_with_float_value() {
586 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
587 .set("price", FilterValue::Float(99.99));
588
589 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
590
591 assert!(sql.contains("price = $1"));
592 assert_eq!(params.len(), 1);
593 }
594
595 #[test]
596 fn test_update_with_json_value() {
597 let json_value = serde_json::json!({"key": "value"});
598 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
599 .set("metadata", FilterValue::Json(json_value.clone()));
600
601 let (sql, params) = op.build_sql(&crate::dialect::Postgres);
602
603 assert!(sql.contains("metadata = $1"));
604 assert_eq!(params[0], FilterValue::Json(json_value));
605 }
606}