1use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10pub struct UpsertOperation<E: QueryEngine, M: Model> {
25 engine: E,
26 filter: Filter,
27 create_columns: Vec<String>,
28 create_values: Vec<FilterValue>,
29 update_columns: Vec<String>,
30 update_values: Vec<FilterValue>,
31 conflict_columns: Vec<String>,
32 select: Select,
33 _model: PhantomData<M>,
34}
35
36impl<E: QueryEngine, M: Model> UpsertOperation<E, M> {
37 pub fn new(engine: E) -> Self {
39 Self {
40 engine,
41 filter: Filter::None,
42 create_columns: Vec::new(),
43 create_values: Vec::new(),
44 update_columns: Vec::new(),
45 update_values: Vec::new(),
46 conflict_columns: Vec::new(),
47 select: Select::All,
48 _model: PhantomData,
49 }
50 }
51
52 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
54 self.filter = filter.into();
55 self
56 }
57
58 pub fn on_conflict(mut self, columns: impl IntoIterator<Item = impl Into<String>>) -> Self {
60 self.conflict_columns = columns.into_iter().map(Into::into).collect();
61 self
62 }
63
64 pub fn create(
66 mut self,
67 values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
68 ) -> Self {
69 for (col, val) in values {
70 self.create_columns.push(col.into());
71 self.create_values.push(val.into());
72 }
73 self
74 }
75
76 pub fn create_set(
78 mut self,
79 column: impl Into<String>,
80 value: impl Into<FilterValue>,
81 ) -> Self {
82 self.create_columns.push(column.into());
83 self.create_values.push(value.into());
84 self
85 }
86
87 pub fn update(
89 mut self,
90 values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
91 ) -> Self {
92 for (col, val) in values {
93 self.update_columns.push(col.into());
94 self.update_values.push(val.into());
95 }
96 self
97 }
98
99 pub fn update_set(
101 mut self,
102 column: impl Into<String>,
103 value: impl Into<FilterValue>,
104 ) -> Self {
105 self.update_columns.push(column.into());
106 self.update_values.push(value.into());
107 self
108 }
109
110 pub fn select(mut self, select: impl Into<Select>) -> Self {
112 self.select = select.into();
113 self
114 }
115
116 pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
118 let mut sql = String::new();
119 let mut params = Vec::new();
120 let mut param_idx = 1;
121
122 sql.push_str("INSERT INTO ");
124 sql.push_str(M::TABLE_NAME);
125
126 sql.push_str(" (");
128 sql.push_str(&self.create_columns.join(", "));
129 sql.push(')');
130
131 sql.push_str(" VALUES (");
133 let placeholders: Vec<_> = self
134 .create_values
135 .iter()
136 .map(|v| {
137 params.push(v.clone());
138 let p = format!("${}", param_idx);
139 param_idx += 1;
140 p
141 })
142 .collect();
143 sql.push_str(&placeholders.join(", "));
144 sql.push(')');
145
146 sql.push_str(" ON CONFLICT ");
148 if !self.conflict_columns.is_empty() {
149 sql.push('(');
150 sql.push_str(&self.conflict_columns.join(", "));
151 sql.push_str(") ");
152 }
153
154 if self.update_columns.is_empty() {
156 sql.push_str("DO NOTHING");
157 } else {
158 sql.push_str("DO UPDATE SET ");
159 let update_parts: Vec<_> = self
160 .update_columns
161 .iter()
162 .zip(self.update_values.iter())
163 .map(|(col, val)| {
164 params.push(val.clone());
165 let part = format!("{} = ${}", col, param_idx);
166 param_idx += 1;
167 part
168 })
169 .collect();
170 sql.push_str(&update_parts.join(", "));
171 }
172
173 sql.push_str(" RETURNING ");
175 sql.push_str(&self.select.to_sql());
176
177 (sql, params)
178 }
179
180 pub async fn exec(self) -> QueryResult<M>
182 where
183 M: Send + 'static,
184 {
185 let (sql, params) = self.build_sql();
186 self.engine.execute_insert::<M>(&sql, params).await
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use crate::error::QueryError;
194
195 struct TestModel;
196
197 impl Model for TestModel {
198 const MODEL_NAME: &'static str = "TestModel";
199 const TABLE_NAME: &'static str = "test_models";
200 const PRIMARY_KEY: &'static [&'static str] = &["id"];
201 const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
202 }
203
204 #[derive(Clone)]
205 struct MockEngine;
206
207 impl QueryEngine for MockEngine {
208 fn query_many<T: Model + Send + 'static>(
209 &self,
210 _sql: &str,
211 _params: Vec<FilterValue>,
212 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
213 Box::pin(async { Ok(Vec::new()) })
214 }
215
216 fn query_one<T: Model + Send + 'static>(
217 &self,
218 _sql: &str,
219 _params: Vec<FilterValue>,
220 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
221 Box::pin(async { Err(QueryError::not_found("test")) })
222 }
223
224 fn query_optional<T: Model + Send + 'static>(
225 &self,
226 _sql: &str,
227 _params: Vec<FilterValue>,
228 ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
229 Box::pin(async { Ok(None) })
230 }
231
232 fn execute_insert<T: Model + Send + 'static>(
233 &self,
234 _sql: &str,
235 _params: Vec<FilterValue>,
236 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
237 Box::pin(async { Err(QueryError::not_found("test")) })
238 }
239
240 fn execute_update<T: Model + Send + 'static>(
241 &self,
242 _sql: &str,
243 _params: Vec<FilterValue>,
244 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
245 Box::pin(async { Ok(Vec::new()) })
246 }
247
248 fn execute_delete(
249 &self,
250 _sql: &str,
251 _params: Vec<FilterValue>,
252 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
253 Box::pin(async { Ok(0) })
254 }
255
256 fn execute_raw(
257 &self,
258 _sql: &str,
259 _params: Vec<FilterValue>,
260 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
261 Box::pin(async { Ok(0) })
262 }
263
264 fn count(
265 &self,
266 _sql: &str,
267 _params: Vec<FilterValue>,
268 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
269 Box::pin(async { Ok(0) })
270 }
271 }
272
273 #[test]
276 fn test_upsert_new() {
277 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine);
278 let (sql, params) = op.build_sql();
279
280 assert!(sql.contains("INSERT INTO test_models"));
281 assert!(sql.contains("ON CONFLICT"));
282 assert!(sql.contains("RETURNING *"));
283 assert!(params.is_empty());
284 }
285
286 #[test]
287 fn test_upsert_basic() {
288 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
289 .on_conflict(["email"])
290 .create_set("email", "test@example.com")
291 .create_set("name", "Test")
292 .update_set("name", "Updated");
293
294 let (sql, params) = op.build_sql();
295
296 assert!(sql.contains("INSERT INTO test_models"));
297 assert!(sql.contains("ON CONFLICT (email)"));
298 assert!(sql.contains("DO UPDATE SET"));
299 assert!(sql.contains("RETURNING *"));
300 assert_eq!(params.len(), 3); }
302
303 #[test]
306 fn test_upsert_single_conflict_column() {
307 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
308 .on_conflict(["id"])
309 .create_set("id", FilterValue::Int(1));
310
311 let (sql, _) = op.build_sql();
312
313 assert!(sql.contains("ON CONFLICT (id)"));
314 }
315
316 #[test]
317 fn test_upsert_multiple_conflict_columns() {
318 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
319 .on_conflict(["tenant_id", "email"])
320 .create_set("email", "test@example.com")
321 .create_set("tenant_id", FilterValue::Int(1));
322
323 let (sql, _) = op.build_sql();
324
325 assert!(sql.contains("ON CONFLICT (tenant_id, email)"));
326 }
327
328 #[test]
329 fn test_upsert_without_conflict_columns() {
330 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
331 .create_set("email", "test@example.com");
332
333 let (sql, _) = op.build_sql();
334
335 assert!(sql.contains("ON CONFLICT"));
336 assert!(!sql.contains("ON CONFLICT ("));
337 }
338
339 #[test]
342 fn test_upsert_create_with_set() {
343 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
344 .on_conflict(["email"])
345 .create_set("email", "test@example.com")
346 .create_set("name", "Test User");
347
348 let (sql, params) = op.build_sql();
349
350 assert!(sql.contains("(email, name)"));
351 assert!(sql.contains("VALUES ($1, $2)"));
352 assert_eq!(params.len(), 2);
353 }
354
355 #[test]
356 fn test_upsert_create_with_iterator() {
357 let create_data = vec![
358 ("email", FilterValue::String("test@example.com".to_string())),
359 ("name", FilterValue::String("Test User".to_string())),
360 ("age", FilterValue::Int(25)),
361 ];
362 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
363 .on_conflict(["email"])
364 .create(create_data);
365
366 let (sql, params) = op.build_sql();
367
368 assert!(sql.contains("(email, name, age)"));
369 assert!(sql.contains("VALUES ($1, $2, $3)"));
370 assert_eq!(params.len(), 3);
371 }
372
373 #[test]
376 fn test_upsert_update_with_set() {
377 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
378 .on_conflict(["email"])
379 .create_set("email", "test@example.com")
380 .update_set("name", "Updated Name")
381 .update_set("updated_at", "2024-01-01");
382
383 let (sql, params) = op.build_sql();
384
385 assert!(sql.contains("DO UPDATE SET"));
386 assert!(sql.contains("name = $"));
387 assert!(sql.contains("updated_at = $"));
388 assert_eq!(params.len(), 3); }
390
391 #[test]
392 fn test_upsert_update_with_iterator() {
393 let update_data = vec![
394 ("name", FilterValue::String("Updated".to_string())),
395 ("status", FilterValue::String("active".to_string())),
396 ];
397 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
398 .on_conflict(["id"])
399 .create_set("id", FilterValue::Int(1))
400 .update(update_data);
401
402 let (sql, params) = op.build_sql();
403
404 assert!(sql.contains("DO UPDATE SET"));
405 assert_eq!(params.len(), 3); }
407
408 #[test]
411 fn test_upsert_do_nothing() {
412 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
413 .on_conflict(["email"])
414 .create_set("email", "test@example.com");
415
416 let (sql, _) = op.build_sql();
417
418 assert!(sql.contains("DO NOTHING"));
419 assert!(!sql.contains("DO UPDATE"));
420 }
421
422 #[test]
423 fn test_upsert_do_nothing_multiple_create() {
424 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
425 .on_conflict(["email"])
426 .create_set("email", "test@example.com")
427 .create_set("name", "Test");
428
429 let (sql, params) = op.build_sql();
430
431 assert!(sql.contains("DO NOTHING"));
432 assert_eq!(params.len(), 2);
433 }
434
435 #[test]
438 fn test_upsert_with_select() {
439 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
440 .on_conflict(["email"])
441 .create_set("email", "test@example.com")
442 .update_set("name", "Updated")
443 .select(Select::fields(["id", "email"]));
444
445 let (sql, _) = op.build_sql();
446
447 assert!(sql.contains("RETURNING id, email"));
448 assert!(!sql.contains("RETURNING *"));
449 }
450
451 #[test]
452 fn test_upsert_select_all() {
453 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
454 .on_conflict(["email"])
455 .create_set("email", "test@example.com")
456 .select(Select::All);
457
458 let (sql, _) = op.build_sql();
459
460 assert!(sql.contains("RETURNING *"));
461 }
462
463 #[test]
466 fn test_upsert_with_where() {
467 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
468 .r#where(Filter::Equals("email".into(), FilterValue::String("test@example.com".to_string())))
469 .on_conflict(["email"])
470 .create_set("email", "test@example.com");
471
472 let (_, _) = op.build_sql();
473 }
475
476 #[test]
479 fn test_upsert_sql_structure() {
480 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
481 .on_conflict(["email"])
482 .create_set("email", "test@example.com")
483 .update_set("name", "Updated")
484 .select(Select::fields(["id"]));
485
486 let (sql, _) = op.build_sql();
487
488 let insert_pos = sql.find("INSERT INTO").unwrap();
489 let values_pos = sql.find("VALUES").unwrap();
490 let conflict_pos = sql.find("ON CONFLICT").unwrap();
491 let update_pos = sql.find("DO UPDATE SET").unwrap();
492 let returning_pos = sql.find("RETURNING").unwrap();
493
494 assert!(insert_pos < values_pos);
495 assert!(values_pos < conflict_pos);
496 assert!(conflict_pos < update_pos);
497 assert!(update_pos < returning_pos);
498 }
499
500 #[test]
501 fn test_upsert_table_name() {
502 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine);
503 let (sql, _) = op.build_sql();
504
505 assert!(sql.contains("test_models"));
506 }
507
508 #[test]
511 fn test_upsert_param_ordering() {
512 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
513 .on_conflict(["email"])
514 .create_set("email", "create@test.com")
515 .create_set("name", "Create Name")
516 .update_set("name", "Update Name");
517
518 let (sql, params) = op.build_sql();
519
520 assert!(sql.contains("VALUES ($1, $2)"));
522 assert!(sql.contains("name = $3"));
523 assert_eq!(params.len(), 3);
524 assert_eq!(params[0], FilterValue::String("create@test.com".to_string()));
525 assert_eq!(params[1], FilterValue::String("Create Name".to_string()));
526 assert_eq!(params[2], FilterValue::String("Update Name".to_string()));
527 }
528
529 #[tokio::test]
532 async fn test_upsert_exec() {
533 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
534 .on_conflict(["email"])
535 .create_set("email", "test@example.com");
536
537 let result = op.exec().await;
538
539 assert!(result.is_err());
541 }
542
543 #[test]
546 fn test_upsert_full_chain() {
547 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
548 .r#where(Filter::Equals("email".into(), FilterValue::String("test@example.com".to_string())))
549 .on_conflict(["email"])
550 .create_set("email", "test@example.com")
551 .create_set("name", "Test User")
552 .update_set("name", "Updated User")
553 .select(Select::fields(["id", "name", "email"]));
554
555 let (sql, params) = op.build_sql();
556
557 assert!(sql.contains("INSERT INTO test_models"));
558 assert!(sql.contains("ON CONFLICT (email)"));
559 assert!(sql.contains("DO UPDATE SET"));
560 assert!(sql.contains("RETURNING id, name, email"));
561 assert_eq!(params.len(), 3);
562 }
563
564 #[test]
567 fn test_upsert_with_null_value() {
568 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
569 .on_conflict(["id"])
570 .create_set("id", FilterValue::Int(1))
571 .create_set("nickname", FilterValue::Null);
572
573 let (_, params) = op.build_sql();
574
575 assert_eq!(params[1], FilterValue::Null);
576 }
577
578 #[test]
579 fn test_upsert_with_boolean_value() {
580 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
581 .on_conflict(["id"])
582 .create_set("id", FilterValue::Int(1))
583 .create_set("active", FilterValue::Bool(true))
584 .update_set("active", FilterValue::Bool(false));
585
586 let (_, params) = op.build_sql();
587
588 assert_eq!(params[1], FilterValue::Bool(true));
589 assert_eq!(params[2], FilterValue::Bool(false));
590 }
591
592 #[test]
593 fn test_upsert_with_numeric_values() {
594 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
595 .on_conflict(["id"])
596 .create_set("id", FilterValue::Int(1))
597 .create_set("score", FilterValue::Float(99.5));
598
599 let (_, params) = op.build_sql();
600
601 assert_eq!(params[0], FilterValue::Int(1));
602 assert_eq!(params[1], FilterValue::Float(99.5));
603 }
604
605 #[test]
606 fn test_upsert_with_json_value() {
607 let json = serde_json::json!({"key": "value"});
608 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
609 .on_conflict(["id"])
610 .create_set("id", FilterValue::Int(1))
611 .create_set("metadata", FilterValue::Json(json.clone()));
612
613 let (_, params) = op.build_sql();
614
615 assert_eq!(params[1], FilterValue::Json(json));
616 }
617}
618