1use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10pub struct UpdateOperation<E: QueryEngine, M: Model> {
24 engine: E,
25 filter: Filter,
26 updates: Vec<(String, FilterValue)>,
27 select: Select,
28 _model: PhantomData<M>,
29}
30
31impl<E: QueryEngine, M: Model> UpdateOperation<E, M> {
32 pub fn new(engine: E) -> Self {
34 Self {
35 engine,
36 filter: Filter::None,
37 updates: Vec::new(),
38 select: Select::All,
39 _model: PhantomData,
40 }
41 }
42
43 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
45 let new_filter = filter.into();
46 self.filter = self.filter.and_then(new_filter);
47 self
48 }
49
50 pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
52 self.updates.push((column.into(), value.into()));
53 self
54 }
55
56 pub fn set_many(
58 mut self,
59 values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
60 ) -> Self {
61 for (col, val) in values {
62 self.updates.push((col.into(), val.into()));
63 }
64 self
65 }
66
67 pub fn increment(self, column: impl Into<String>, amount: i64) -> Self {
69 self.set(column, FilterValue::Int(amount))
72 }
73
74 pub fn select(mut self, select: impl Into<Select>) -> Self {
76 self.select = select.into();
77 self
78 }
79
80 pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
82 let mut sql = String::new();
83 let mut params = Vec::new();
84 let mut param_idx = 1;
85
86 sql.push_str("UPDATE ");
88 sql.push_str(M::TABLE_NAME);
89
90 sql.push_str(" SET ");
92 let set_parts: Vec<_> = self
93 .updates
94 .iter()
95 .map(|(col, val)| {
96 params.push(val.clone());
97 let part = format!("{} = ${}", col, param_idx);
98 param_idx += 1;
99 part
100 })
101 .collect();
102 sql.push_str(&set_parts.join(", "));
103
104 if !self.filter.is_none() {
106 let (where_sql, where_params) = self.filter.to_sql(param_idx - 1);
107 sql.push_str(" WHERE ");
108 sql.push_str(&where_sql);
109 params.extend(where_params);
110 }
111
112 sql.push_str(" RETURNING ");
114 sql.push_str(&self.select.to_sql());
115
116 (sql, params)
117 }
118
119 pub async fn exec(self) -> QueryResult<Vec<M>>
121 where
122 M: Send + 'static,
123 {
124 let (sql, params) = self.build_sql();
125 self.engine.execute_update::<M>(&sql, params).await
126 }
127
128 pub async fn exec_one(self) -> QueryResult<M>
130 where
131 M: Send + 'static,
132 {
133 let (sql, params) = self.build_sql();
134 self.engine.query_one::<M>(&sql, params).await
135 }
136}
137
138pub struct UpdateManyOperation<E: QueryEngine, M: Model> {
140 engine: E,
141 filter: Filter,
142 updates: Vec<(String, FilterValue)>,
143 _model: PhantomData<M>,
144}
145
146impl<E: QueryEngine, M: Model> UpdateManyOperation<E, M> {
147 pub fn new(engine: E) -> Self {
149 Self {
150 engine,
151 filter: Filter::None,
152 updates: Vec::new(),
153 _model: PhantomData,
154 }
155 }
156
157 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
159 let new_filter = filter.into();
160 self.filter = self.filter.and_then(new_filter);
161 self
162 }
163
164 pub fn set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
166 self.updates.push((column.into(), value.into()));
167 self
168 }
169
170 pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
172 let mut sql = String::new();
173 let mut params = Vec::new();
174 let mut param_idx = 1;
175
176 sql.push_str("UPDATE ");
178 sql.push_str(M::TABLE_NAME);
179
180 sql.push_str(" SET ");
182 let set_parts: Vec<_> = self
183 .updates
184 .iter()
185 .map(|(col, val)| {
186 params.push(val.clone());
187 let part = format!("{} = ${}", col, param_idx);
188 param_idx += 1;
189 part
190 })
191 .collect();
192 sql.push_str(&set_parts.join(", "));
193
194 if !self.filter.is_none() {
196 let (where_sql, where_params) = self.filter.to_sql(param_idx - 1);
197 sql.push_str(" WHERE ");
198 sql.push_str(&where_sql);
199 params.extend(where_params);
200 }
201
202 (sql, params)
203 }
204
205 pub async fn exec(self) -> QueryResult<u64> {
207 let (sql, params) = self.build_sql();
208 self.engine.execute_raw(&sql, params).await
209 }
210}
211
212#[cfg(test)]
213mod tests {
214 use super::*;
215 use crate::error::QueryError;
216 use crate::types::Select;
217
218 struct TestModel;
219
220 impl Model for TestModel {
221 const MODEL_NAME: &'static str = "TestModel";
222 const TABLE_NAME: &'static str = "test_models";
223 const PRIMARY_KEY: &'static [&'static str] = &["id"];
224 const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
225 }
226
227 #[derive(Clone)]
228 struct MockEngine {
229 return_count: u64,
230 }
231
232 impl MockEngine {
233 fn new() -> Self {
234 Self { return_count: 0 }
235 }
236
237 fn with_count(count: u64) -> Self {
238 Self { return_count: count }
239 }
240 }
241
242 impl QueryEngine for MockEngine {
243 fn query_many<T: Model + Send + 'static>(
244 &self,
245 _sql: &str,
246 _params: Vec<FilterValue>,
247 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
248 Box::pin(async { Ok(Vec::new()) })
249 }
250
251 fn query_one<T: Model + Send + 'static>(
252 &self,
253 _sql: &str,
254 _params: Vec<FilterValue>,
255 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
256 Box::pin(async { Err(QueryError::not_found("test")) })
257 }
258
259 fn query_optional<T: Model + Send + 'static>(
260 &self,
261 _sql: &str,
262 _params: Vec<FilterValue>,
263 ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
264 Box::pin(async { Ok(None) })
265 }
266
267 fn execute_insert<T: Model + Send + 'static>(
268 &self,
269 _sql: &str,
270 _params: Vec<FilterValue>,
271 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
272 Box::pin(async { Err(QueryError::not_found("test")) })
273 }
274
275 fn execute_update<T: Model + Send + 'static>(
276 &self,
277 _sql: &str,
278 _params: Vec<FilterValue>,
279 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
280 Box::pin(async { Ok(Vec::new()) })
281 }
282
283 fn execute_delete(
284 &self,
285 _sql: &str,
286 _params: Vec<FilterValue>,
287 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
288 Box::pin(async { Ok(0) })
289 }
290
291 fn execute_raw(
292 &self,
293 _sql: &str,
294 _params: Vec<FilterValue>,
295 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
296 let count = self.return_count;
297 Box::pin(async move { Ok(count) })
298 }
299
300 fn count(
301 &self,
302 _sql: &str,
303 _params: Vec<FilterValue>,
304 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
305 Box::pin(async { Ok(0) })
306 }
307 }
308
309 #[test]
312 fn test_update_new() {
313 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new());
314 let (sql, params) = op.build_sql();
315
316 assert!(sql.contains("UPDATE test_models SET"));
317 assert!(sql.contains("RETURNING *"));
318 assert!(params.is_empty());
319 }
320
321 #[test]
322 fn test_update_basic() {
323 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
324 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
325 .set("name", "Updated");
326
327 let (sql, params) = op.build_sql();
328
329 assert!(sql.contains("UPDATE test_models SET"));
330 assert!(sql.contains("name = $1"));
331 assert!(sql.contains("WHERE"));
332 assert!(sql.contains("RETURNING *"));
333 assert_eq!(params.len(), 2);
334 }
335
336 #[test]
337 fn test_update_many_fields() {
338 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
339 .set("name", "Updated")
340 .set("email", "updated@example.com");
341
342 let (sql, params) = op.build_sql();
343
344 assert!(sql.contains("name = $1"));
345 assert!(sql.contains("email = $2"));
346 assert_eq!(params.len(), 2);
347 }
348
349 #[test]
350 fn test_update_with_set_many() {
351 let updates = vec![
352 ("name", FilterValue::String("Alice".to_string())),
353 ("email", FilterValue::String("alice@test.com".to_string())),
354 ("age", FilterValue::Int(30)),
355 ];
356 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
357 .set_many(updates);
358
359 let (sql, params) = op.build_sql();
360
361 assert!(sql.contains("name = $1"));
362 assert!(sql.contains("email = $2"));
363 assert!(sql.contains("age = $3"));
364 assert_eq!(params.len(), 3);
365 }
366
367 #[test]
368 fn test_update_increment() {
369 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
370 .increment("counter", 5);
371
372 let (sql, params) = op.build_sql();
373
374 assert!(sql.contains("counter = $1"));
375 assert_eq!(params.len(), 1);
376 assert_eq!(params[0], FilterValue::Int(5));
377 }
378
379 #[test]
380 fn test_update_with_select() {
381 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
382 .set("name", "Updated")
383 .select(Select::fields(["id", "name"]));
384
385 let (sql, _) = op.build_sql();
386
387 assert!(sql.contains("RETURNING id, name"));
388 }
389
390 #[test]
391 fn test_update_with_complex_filter() {
392 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
393 .r#where(Filter::Equals("status".into(), FilterValue::String("active".to_string())))
394 .r#where(Filter::Gt("age".into(), FilterValue::Int(18)))
395 .set("verified", FilterValue::Bool(true));
396
397 let (sql, params) = op.build_sql();
398
399 assert!(sql.contains("WHERE"));
400 assert!(sql.contains("AND"));
401 assert_eq!(params.len(), 3); }
403
404 #[test]
405 fn test_update_without_filter() {
406 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
407 .set("status", "updated");
408
409 let (sql, _) = op.build_sql();
410
411 assert!(!sql.contains("WHERE"));
413 assert!(sql.contains("UPDATE test_models SET"));
414 }
415
416 #[test]
417 fn test_update_with_null_value() {
418 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
419 .set("deleted_at", FilterValue::Null);
420
421 let (sql, params) = op.build_sql();
422
423 assert!(sql.contains("deleted_at = $1"));
424 assert_eq!(params.len(), 1);
425 assert_eq!(params[0], FilterValue::Null);
426 }
427
428 #[test]
429 fn test_update_with_boolean() {
430 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
431 .set("active", FilterValue::Bool(true))
432 .set("verified", FilterValue::Bool(false));
433
434 let (sql, params) = op.build_sql();
435
436 assert_eq!(params.len(), 2);
437 assert_eq!(params[0], FilterValue::Bool(true));
438 assert_eq!(params[1], FilterValue::Bool(false));
439 }
440
441 #[tokio::test]
442 async fn test_update_exec() {
443 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
444 .set("name", "Updated");
445
446 let result = op.exec().await;
447 assert!(result.is_ok());
448 assert!(result.unwrap().is_empty());
449 }
450
451 #[tokio::test]
452 async fn test_update_exec_one() {
453 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
454 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)))
455 .set("name", "Updated");
456
457 let result = op.exec_one().await;
458 assert!(result.is_err()); }
460
461 #[test]
464 fn test_update_many_new() {
465 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new());
466 let (sql, params) = op.build_sql();
467
468 assert!(sql.contains("UPDATE test_models SET"));
469 assert!(!sql.contains("RETURNING")); assert!(params.is_empty());
471 }
472
473 #[test]
474 fn test_update_many_basic() {
475 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
476 .r#where(Filter::In(
477 "id".into(),
478 vec![FilterValue::Int(1), FilterValue::Int(2), FilterValue::Int(3)],
479 ))
480 .set("status", "processed");
481
482 let (sql, params) = op.build_sql();
483
484 assert!(sql.contains("UPDATE test_models SET"));
485 assert!(sql.contains("status = $1"));
486 assert!(sql.contains("WHERE"));
487 assert!(sql.contains("IN"));
488 assert_eq!(params.len(), 4); }
490
491 #[test]
492 fn test_update_many_with_multiple_conditions() {
493 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
494 .r#where(Filter::Equals("department".into(), FilterValue::String("engineering".to_string())))
495 .r#where(Filter::Equals("active".into(), FilterValue::Bool(true)))
496 .set("reviewed", FilterValue::Bool(true));
497
498 let (sql, params) = op.build_sql();
499
500 assert!(sql.contains("AND"));
501 assert_eq!(params.len(), 3);
502 }
503
504 #[test]
505 fn test_update_many_without_where() {
506 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
507 .set("reset_password", FilterValue::Bool(true));
508
509 let (sql, _) = op.build_sql();
510
511 assert!(!sql.contains("WHERE"));
512 }
513
514 #[tokio::test]
515 async fn test_update_many_exec() {
516 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::with_count(5))
517 .set("status", "updated");
518
519 let result = op.exec().await;
520 assert!(result.is_ok());
521 assert_eq!(result.unwrap(), 5);
522 }
523
524 #[test]
527 fn test_update_param_ordering() {
528 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
529 .set("field1", "value1")
530 .set("field2", "value2")
531 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
532
533 let (sql, params) = op.build_sql();
534
535 assert!(sql.contains("field1 = $1"));
537 assert!(sql.contains("field2 = $2"));
538 assert!(sql.contains("id = $3"));
539 assert_eq!(params.len(), 3);
540 }
541
542 #[test]
543 fn test_update_many_param_ordering() {
544 let op = UpdateManyOperation::<MockEngine, TestModel>::new(MockEngine::new())
545 .set("field1", "value1")
546 .r#where(Filter::Equals("id".into(), FilterValue::Int(1)));
547
548 let (sql, params) = op.build_sql();
549
550 assert!(sql.contains("field1 = $1"));
551 assert!(sql.contains("id = $2"));
552 assert_eq!(params.len(), 2);
553 }
554
555 #[test]
556 fn test_update_with_float_value() {
557 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
558 .set("price", FilterValue::Float(99.99));
559
560 let (sql, params) = op.build_sql();
561
562 assert!(sql.contains("price = $1"));
563 assert_eq!(params.len(), 1);
564 }
565
566 #[test]
567 fn test_update_with_json_value() {
568 let json_value = serde_json::json!({"key": "value"});
569 let op = UpdateOperation::<MockEngine, TestModel>::new(MockEngine::new())
570 .set("metadata", FilterValue::Json(json_value.clone()));
571
572 let (sql, params) = op.build_sql();
573
574 assert!(sql.contains("metadata = $1"));
575 assert_eq!(params[0], FilterValue::Json(json_value));
576 }
577}
578