1use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10pub struct UpdateOperation<E: QueryEngine, M: Model> {
24 engine: E,
25 filter: Filter,
26 updates: Vec<(String, FilterValue)>,
27 select: Select,
28 _model: PhantomData<M>,
29}
30
31impl<E: QueryEngine, M: Model> UpdateOperation<E, M> {
32 pub fn new(engine: E) -> Self {
34 Self {
35 engine,
36 filter: Filter::None,
37 updates: Vec::new(),
38 select: Select::All,
39 _model: PhantomData,
40 }
41 }
42
43 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
45 let new_filter = filter.into();
46 self.filter = self.filter.and_then(new_filter);
47 self
48 }
49
50 pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
52 self.updates.push((column.into(), value.into()));
53 self
54 }
55
56 pub fn set_many(
58 mut self,
59 values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
60 ) -> Self {
61 for (col, val) in values {
62 self.updates.push((col.into(), val.into()));
63 }
64 self
65 }
66
67 pub fn increment(self, column: impl Into<String>, amount: i64) -> Self {
69 self.set(column, FilterValue::Int(amount))
72 }
73
74 pub fn select(mut self, select: impl Into<Select>) -> Self {
76 self.select = select.into();
77 self
78 }
79
80 pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
82 let mut sql = String::new();
83 let mut params = Vec::new();
84 let mut param_idx = 1;
85
86 sql.push_str("UPDATE ");
88 sql.push_str(M::TABLE_NAME);
89
90 sql.push_str(" SET ");
92 let set_parts: Vec<_> = self
93 .updates
94 .iter()
95 .map(|(col, val)| {
96 params.push(val.clone());
97 let part = format!("{} = ${}", col, param_idx);
98 param_idx += 1;
99 part
100 })
101 .collect();
102 sql.push_str(&set_parts.join(", "));
103
104 if !self.filter.is_none() {
106 let (where_sql, where_params) = self.filter.to_sql(param_idx - 1);
107 sql.push_str(" WHERE ");
108 sql.push_str(&where_sql);
109 params.extend(where_params);
110 }
111
112 sql.push_str(" RETURNING ");
114 sql.push_str(&self.select.to_sql());
115
116 (sql, params)
117 }
118
119 pub async fn exec(self) -> QueryResult<Vec<M>>
121 where
122 M: Send + 'static,
123 {
124 let (sql, params) = self.build_sql();
125 self.engine.execute_update::<M>(&sql, params).await
126 }
127
128 pub async fn exec_one(self) -> QueryResult<M>
130 where
131 M: Send + 'static,
132 {
133 let (sql, params) = self.build_sql();
134 self.engine.query_one::<M>(&sql, params).await
135 }
136}
137
138pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
140 engine: E,
141 filter: Filter,
142 updates: Vec<(String, FilterValue)>,
143 _model: PhantomData<M>,
144}
145
146impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
147 pub fn new(engine: E) -> Self {
149 Self {
150 engine,
151 filter: Filter::None,
152 updates: Vec::new(),
153 _model: PhantomData,
154 }
155 }
156
157 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
159 let new_filter = filter.into();
160 self.filter = self.filter.and_then(new_filter);
161 self
162 }
163
164 pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
166 self.updates.push((column.into(), value.into()));
167 self
168 }
169
170 pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
172 let mut sql = String::new();
173 let mut params = Vec::new();
174 let mut param_idx = 1;
175
176 sql.push_str("UPDATE ");
178 sql.push_str(M::TABLE_NAME);
179
180 sql.push_str(" SET ");
182 let set_parts: Vec<_> = self
183 .updates
184 .iter()
185 .map(|(col, val)| {
186 params.push(val.clone());
187 let part = format!("{} = ${}", col, param_idx);
188 param_idx += 1;
189 part
190 })
191 .collect();
192 sql.push_str(&set_parts.join(", "));
193
194 if !self.filter.is_none() {
196 let (where_sql, where_params) = self.filter.to_sql(param_idx - 1);
197 sql.push_str(" WHERE ");
198 sql.push_str(&where_sql);
199 params.extend(where_params);
200 }
201
202 (sql, params)
203 }
204
205 pub async fn exec(self) -> QueryResult<u64> {
207 let (sql, params) = self.build_sql();
208 self.engine.execute_raw(&sql, params).await
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::error::QueryError;
216 use crate::types::Select;
217
218 struct TestModel;
219
220 impl Model for TestModel {
221 const MODEL_NAME: &'static str = "TestModel";
222 const TABLE_NAME: &'static str = "test_models";
223 const PRIMARY_KEY: &'static [&'static str] = &["id"];
224 const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
225 }
226
227 #[derive(Clone)]
228 struct MockEngine {
229 return_count: u64,
230 }
231
232 impl MockEngine {
233 fn new() -> Self {
234 Self { return_count: 0 }
235 }
236
237 fn with_count(count: u64) -> Self {
238 Self {
239 return_count: count,
240 }
241 }
242 }
243
244 impl QueryEngine for MockEngine {
245 fn query_many<T: Model + Send + 'static>(
246 &self,
247 _sql: &str,
248 _params: Vec<FilterValue>,
249 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
250 Box::pin(async { Ok(Vec::new()) })
251 }
252
253 fn query_one<T: Model + Send + 'static>(
254 &self,
255 _sql: &str,
256 _params: Vec<FilterValue>,
257 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
258 Box::pin(async { Err(QueryError::not_found("test")) })
259 }
260
261 fn query_optional<T: Model + Send + 'static>(
262 &self,
263 _sql: &str,
264 _params: Vec<FilterValue>,
265 ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
266 Box::pin(async { Ok(None) })
267 }
268
269 fn execute_insert<T: Model + Send + 'static>(
270 &self,
271 _sql: &str,
272 _params: Vec<FilterValue>,
273 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
274 Box::pin(async { Err(QueryError::not_found("test")) })
275 }
276
277 fn execute_update<T: Model + Send + 'static>(
278 &self,
279 _sql: &str,
280 _params: Vec<FilterValue>,
281 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
282 Box::pin(async { Ok(Vec::new()) })
283 }
284
285 fn execute_delete(
286 &self,
287 _sql: &str,
288 _params: Vec<FilterValue>,
289 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
290 Box::pin(async { Ok(0) })
291 }
292
293 fn execute_raw(
294 &self,
295 _sql: &str,
296 _params: Vec<FilterValue>,
297 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
298 let count = self.return_count;
299 Box::pin(async move { Ok(count) })
300 }
301
302 fn count(
303 &self,
304 _sql: &str,
305 _params: Vec<FilterValue>,
306 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
307 Box::pin(async { Ok(0) })
308 }
309 }
310
311 #[test]
314 fn test_update_new() {
315 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
316 let (sql, params) = op.build_sql();
317
318 assert!(sql.contains("UPDATE test_models SET"));
319 assert!(sql.contains("RETURNING *"));
320 assert!(params.is_empty());
321 }
322
323 #[test]
324 fn test_update_basic() {
325 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
326 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
327 .set("name", "Updated");
328
329 let (sql, params) = op.build_sql();
330
331 assert!(sql.contains("UPDATE test_models SET"));
332 assert!(sql.contains("name = $1"));
333 assert!(sql.contains("WHERE"));
334 assert!(sql.contains("RETURNING *"));
335 assert_eq!(params.len(), 2);
336 }
337
338 #[test]
339 fn test_update_many_fields() {
340 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
341 .set("name", "Updated")
342 .set("email", "updated@example.com");
343
344 let (sql, params) = op.build_sql();
345
346 assert!(sql.contains("name = $1"));
347 assert!(sql.contains("email = $2"));
348 assert_eq!(params.len(), 2);
349 }
350
351 #[test]
352 fn test_update_with_set_many() {
353 let updates = vec![
354 ("name", FilterValue::String("Alice".to_string())),
355 ("email", FilterValue::String("alice@test.com".to_string())),
356 ("age", FilterValue::Int(30)),
357 ];
358 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set_many(updates);
359
360 let (sql, params) = op.build_sql();
361
362 assert!(sql.contains("name = $1"));
363 assert!(sql.contains("email = $2"));
364 assert!(sql.contains("age = $3"));
365 assert_eq!(params.len(), 3);
366 }
367
368 #[test]
369 fn test_update_increment() {
370 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
371 .increment("counter", 5);
372
373 let (sql, params) = op.build_sql();
374
375 assert!(sql.contains("counter = $1"));
376 assert_eq!(params.len(), 1);
377 assert_eq!(params[0], FilterValue::Int(5));
378 }
379
380 #[test]
381 fn test_update_with_select() {
382 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
383 .set("name", "Updated")
384 .select(Select::fields(["id", "name"]));
385
386 let (sql, _) = op.build_sql();
387
388 assert!(sql.contains("RETURNING id, name"));
389 }
390
391 #[test]
392 fn test_update_with_complex_filter() {
393 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
394 .r#where(Filter::Equals(
395 "status".into(),
396 FilterValue::String("active".to_string()),
397 ))
398 .r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
399 .set("verified", FilterValue::Bool(true));
400
401 let (sql, params) = op.build_sql();
402
403 assert!(sql.contains("WHERE"));
404 assert!(sql.contains("AND"));
405 assert_eq!(params.len(), 3); }
407
408 #[test]
409 fn test_update_without_filter() {
410 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
411 .set("status", "updated");
412
413 let (sql, _) = op.build_sql();
414
415 assert!(!sql.contains("WHERE"));
417 assert!(sql.contains("UPDATE test_models SET"));
418 }
419
420 #[test]
421 fn test_update_with_null_value() {
422 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
423 .set("deleted_at", FilterValue::Null);
424
425 let (sql, params) = op.build_sql();
426
427 assert!(sql.contains("deleted_at = $1"));
428 assert_eq!(params.len(), 1);
429 assert_eq!(params[0], FilterValue::Null);
430 }
431
432 #[test]
433 fn test_update_with_boolean() {
434 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
435 .set("active", FilterValue::Bool(true))
436 .set("verified", FilterValue::Bool(false));
437
438 let (_sql, params) = op.build_sql();
439
440 assert_eq!(params.len(), 2);
441 assert_eq!(params[0], FilterValue::Bool(true));
442 assert_eq!(params[1], FilterValue::Bool(false));
443 }
444
445 #[tokio::test]
446 async fn test_update_exec() {
447 let op =
448 UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new()).set("name", "Updated");
449
450 let result = op.exec().await;
451 assert!(result.is_ok());
452 assert!(result.unwrap().is_empty());
453 }
454
455 #[tokio::test]
456 async fn test_update_exec_one() {
457 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
458 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
459 .set("name", "Updated");
460
461 let result = op.exec_one().await;
462 assert!(result.is_err()); }
464
465 #[test]
468 fn test_update_many_new() {
469 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
470 let (sql, params) = op.build_sql();
471
472 assert!(sql.contains("UPDATE test_models SET"));
473 assert!(!sql.contains("RETURNING")); assert!(params.is_empty());
475 }
476
477 #[test]
478 fn test_update_many_basic() {
479 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
480 .r#where(Filter::In(
481 "id".into(),
482 vec![
483 FilterValue::Int(1),
484 FilterValue::Int(2),
485 FilterValue::Int(3),
486 ],
487 ))
488 .set("status", "processed");
489
490 let (sql, params) = op.build_sql();
491
492 assert!(sql.contains("UPDATE test_models SET"));
493 assert!(sql.contains("status = $1"));
494 assert!(sql.contains("WHERE"));
495 assert!(sql.contains("IN"));
496 assert_eq!(params.len(), 4); }
498
499 #[test]
500 fn test_update_many_with_multiple_conditions() {
501 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
502 .r#where(Filter::Equals(
503 "department".into(),
504 FilterValue::String("engineering".to_string()),
505 ))
506 .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
507 .set("reviewed", FilterValue::Bool(true));
508
509 let (sql, params) = op.build_sql();
510
511 assert!(sql.contains("AND"));
512 assert_eq!(params.len(), 3);
513 }
514
515 #[test]
516 fn test_update_many_without_where() {
517 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
518 .set("reset_password", FilterValue::Bool(true));
519
520 let (sql, _) = op.build_sql();
521
522 assert!(!sql.contains("WHERE"));
523 }
524
525 #[tokio::test]
526 async fn test_update_many_exec() {
527 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
528 .set("status", "updated");
529
530 let result = op.exec().await;
531 assert!(result.is_ok());
532 assert_eq!(result.unwrap(), 5);
533 }
534
535 #[test]
538 fn test_update_param_ordering() {
539 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
540 .set("field1", "value1")
541 .set("field2", "value2")
542 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
543
544 let (sql, params) = op.build_sql();
545
546 assert!(sql.contains("field1 = $1"));
548 assert!(sql.contains("field2 = $2"));
549 assert!(sql.contains("id = $3"));
550 assert_eq!(params.len(), 3);
551 }
552
553 #[test]
554 fn test_update_many_param_ordering() {
555 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
556 .set("field1", "value1")
557 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
558
559 let (sql, params) = op.build_sql();
560
561 assert!(sql.contains("field1 = $1"));
562 assert!(sql.contains("id = $2"));
563 assert_eq!(params.len(), 2);
564 }
565
566 #[test]
567 fn test_update_with_float_value() {
568 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
569 .set("price", FilterValue::Float(99.99));
570
571 let (sql, params) = op.build_sql();
572
573 assert!(sql.contains("price = $1"));
574 assert_eq!(params.len(), 1);
575 }
576
577 #[test]
578 fn test_update_with_json_value() {
579 let json_value = serde_json::json!({"key": "value"});
580 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
581 .set("metadata", FilterValue::Json(json_value.clone()));
582
583 let (sql, params) = op.build_sql();
584
585 assert!(sql.contains("metadata = $1"));
586 assert_eq!(params[0], FilterValue::Json(json_value));
587 }
588}