1use std::marker::PhantomData;
4
5use crate::error::QueryResult;
6use crate::filter::{Filter, FilterValue};
7use crate::traits::{Model, QueryEngine};
8use crate::types::Select;
9
10pub struct UpsertOperation<E: QueryEngine, M: Model> {
25 engine: E,
26 filter: Filter,
27 create_columns: Vec<String>,
28 create_values: Vec<FilterValue>,
29 update_columns: Vec<String>,
30 update_values: Vec<FilterValue>,
31 conflict_columns: Vec<String>,
32 select: Select,
33 _model: PhantomData<M>,
34}
35
36impl<E: QueryEngine, M: Model> UpsertOperation<E, M> {
37 pub fn new(engine: E) -> Self {
39 Self {
40 engine,
41 filter: Filter::None,
42 create_columns: Vec::new(),
43 create_values: Vec::new(),
44 update_columns: Vec::new(),
45 update_values: Vec::new(),
46 conflict_columns: Vec::new(),
47 select: Select::All,
48 _model: PhantomData,
49 }
50 }
51
52 pub fn r#where(mut self, filter: impl Into<Filter>) -> Self {
54 self.filter = filter.into();
55 self
56 }
57
58 pub fn on_conflict(mut self, columns: impl IntoIterator<Item = impl Into<String>>) -> Self {
60 self.conflict_columns = columns.into_iter().map(Into::into).collect();
61 self
62 }
63
64 pub fn create(
66 mut self,
67 values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
68 ) -> Self {
69 for (col, val) in values {
70 self.create_columns.push(col.into());
71 self.create_values.push(val.into());
72 }
73 self
74 }
75
76 pub fn create_set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
78 self.create_columns.push(column.into());
79 self.create_values.push(value.into());
80 self
81 }
82
83 pub fn update(
85 mut self,
86 values: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
87 ) -> Self {
88 for (col, val) in values {
89 self.update_columns.push(col.into());
90 self.update_values.push(val.into());
91 }
92 self
93 }
94
95 pub fn update_set(mut self, column: impl Into<String>, value: impl Into<FilterValue>) -> Self {
97 self.update_columns.push(column.into());
98 self.update_values.push(value.into());
99 self
100 }
101
102 pub fn select(mut self, select: impl Into<Select>) -> Self {
104 self.select = select.into();
105 self
106 }
107
108 pub fn build_sql(&self) -> (String, Vec<FilterValue>) {
110 let mut sql = String::new();
111 let mut params = Vec::new();
112 let mut param_idx = 1;
113
114 sql.push_str("INSERT INTO ");
116 sql.push_str(M::TABLE_NAME);
117
118 sql.push_str(" (");
120 sql.push_str(&self.create_columns.join(", "));
121 sql.push(')');
122
123 sql.push_str(" VALUES (");
125 let placeholders: Vec<_> = self
126 .create_values
127 .iter()
128 .map(|v| {
129 params.push(v.clone());
130 let p = format!("${}", param_idx);
131 param_idx += 1;
132 p
133 })
134 .collect();
135 sql.push_str(&placeholders.join(", "));
136 sql.push(')');
137
138 sql.push_str(" ON CONFLICT ");
140 if !self.conflict_columns.is_empty() {
141 sql.push('(');
142 sql.push_str(&self.conflict_columns.join(", "));
143 sql.push_str(") ");
144 }
145
146 if self.update_columns.is_empty() {
148 sql.push_str("DO NOTHING");
149 } else {
150 sql.push_str("DO UPDATE SET ");
151 let update_parts: Vec<_> = self
152 .update_columns
153 .iter()
154 .zip(self.update_values.iter())
155 .map(|(col, val)| {
156 params.push(val.clone());
157 let part = format!("{} = ${}", col, param_idx);
158 param_idx += 1;
159 part
160 })
161 .collect();
162 sql.push_str(&update_parts.join(", "));
163 }
164
165 sql.push_str(" RETURNING ");
167 sql.push_str(&self.select.to_sql());
168
169 (sql, params)
170 }
171
172 pub async fn exec(self) -> QueryResult<M>
174 where
175 M: Send + 'static,
176 {
177 let (sql, params) = self.build_sql();
178 self.engine.execute_insert::<M>(&sql, params).await
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use crate::error::QueryError;
186
187 struct TestModel;
188
189 impl Model for TestModel {
190 const MODEL_NAME: &'static str = "TestModel";
191 const TABLE_NAME: &'static str = "test_models";
192 const PRIMARY_KEY: &'static [&'static str] = &["id"];
193 const COLUMNS: &'static [&'static str] = &["id", "name", "email"];
194 }
195
196 #[derive(Clone)]
197 struct MockEngine;
198
199 impl QueryEngine for MockEngine {
200 fn query_many<T: Model + Send + 'static>(
201 &self,
202 _sql: &str,
203 _params: Vec<FilterValue>,
204 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
205 Box::pin(async { Ok(Vec::new()) })
206 }
207
208 fn query_one<T: Model + Send + 'static>(
209 &self,
210 _sql: &str,
211 _params: Vec<FilterValue>,
212 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
213 Box::pin(async { Err(QueryError::not_found("test")) })
214 }
215
216 fn query_optional<T: Model + Send + 'static>(
217 &self,
218 _sql: &str,
219 _params: Vec<FilterValue>,
220 ) -> crate::traits::BoxFuture<'_, QueryResult<Option<T>>> {
221 Box::pin(async { Ok(None) })
222 }
223
224 fn execute_insert<T: Model + Send + 'static>(
225 &self,
226 _sql: &str,
227 _params: Vec<FilterValue>,
228 ) -> crate::traits::BoxFuture<'_, QueryResult<T>> {
229 Box::pin(async { Err(QueryError::not_found("test")) })
230 }
231
232 fn execute_update<T: Model + Send + 'static>(
233 &self,
234 _sql: &str,
235 _params: Vec<FilterValue>,
236 ) -> crate::traits::BoxFuture<'_, QueryResult<Vec<T>>> {
237 Box::pin(async { Ok(Vec::new()) })
238 }
239
240 fn execute_delete(
241 &self,
242 _sql: &str,
243 _params: Vec<FilterValue>,
244 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
245 Box::pin(async { Ok(0) })
246 }
247
248 fn execute_raw(
249 &self,
250 _sql: &str,
251 _params: Vec<FilterValue>,
252 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
253 Box::pin(async { Ok(0) })
254 }
255
256 fn count(
257 &self,
258 _sql: &str,
259 _params: Vec<FilterValue>,
260 ) -> crate::traits::BoxFuture<'_, QueryResult<u64>> {
261 Box::pin(async { Ok(0) })
262 }
263 }
264
265 #[test]
268 fn test_upsert_new() {
269 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine);
270 let (sql, params) = op.build_sql();
271
272 assert!(sql.contains("INSERT INTO test_models"));
273 assert!(sql.contains("ON CONFLICT"));
274 assert!(sql.contains("RETURNING *"));
275 assert!(params.is_empty());
276 }
277
278 #[test]
279 fn test_upsert_basic() {
280 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
281 .on_conflict(["email"])
282 .create_set("email", "test@example.com")
283 .create_set("name", "Test")
284 .update_set("name", "Updated");
285
286 let (sql, params) = op.build_sql();
287
288 assert!(sql.contains("INSERT INTO test_models"));
289 assert!(sql.contains("ON CONFLICT (email)"));
290 assert!(sql.contains("DO UPDATE SET"));
291 assert!(sql.contains("RETURNING *"));
292 assert_eq!(params.len(), 3); }
294
295 #[test]
298 fn test_upsert_single_conflict_column() {
299 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
300 .on_conflict(["id"])
301 .create_set("id", FilterValue::Int(1));
302
303 let (sql, _) = op.build_sql();
304
305 assert!(sql.contains("ON CONFLICT (id)"));
306 }
307
308 #[test]
309 fn test_upsert_multiple_conflict_columns() {
310 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
311 .on_conflict(["tenant_id", "email"])
312 .create_set("email", "test@example.com")
313 .create_set("tenant_id", FilterValue::Int(1));
314
315 let (sql, _) = op.build_sql();
316
317 assert!(sql.contains("ON CONFLICT (tenant_id, email)"));
318 }
319
320 #[test]
321 fn test_upsert_without_conflict_columns() {
322 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
323 .create_set("email", "test@example.com");
324
325 let (sql, _) = op.build_sql();
326
327 assert!(sql.contains("ON CONFLICT"));
328 assert!(!sql.contains("ON CONFLICT ("));
329 }
330
331 #[test]
334 fn test_upsert_create_with_set() {
335 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
336 .on_conflict(["email"])
337 .create_set("email", "test@example.com")
338 .create_set("name", "Test User");
339
340 let (sql, params) = op.build_sql();
341
342 assert!(sql.contains("(email, name)"));
343 assert!(sql.contains("VALUES ($1, $2)"));
344 assert_eq!(params.len(), 2);
345 }
346
347 #[test]
348 fn test_upsert_create_with_iterator() {
349 let create_data = vec![
350 ("email", FilterValue::String("test@example.com".to_string())),
351 ("name", FilterValue::String("Test User".to_string())),
352 ("age", FilterValue::Int(25)),
353 ];
354 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
355 .on_conflict(["email"])
356 .create(create_data);
357
358 let (sql, params) = op.build_sql();
359
360 assert!(sql.contains("(email, name, age)"));
361 assert!(sql.contains("VALUES ($1, $2, $3)"));
362 assert_eq!(params.len(), 3);
363 }
364
365 #[test]
368 fn test_upsert_update_with_set() {
369 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
370 .on_conflict(["email"])
371 .create_set("email", "test@example.com")
372 .update_set("name", "Updated Name")
373 .update_set("updated_at", "2024-01-01");
374
375 let (sql, params) = op.build_sql();
376
377 assert!(sql.contains("DO UPDATE SET"));
378 assert!(sql.contains("name = $"));
379 assert!(sql.contains("updated_at = $"));
380 assert_eq!(params.len(), 3); }
382
383 #[test]
384 fn test_upsert_update_with_iterator() {
385 let update_data = vec![
386 ("name", FilterValue::String("Updated".to_string())),
387 ("status", FilterValue::String("active".to_string())),
388 ];
389 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
390 .on_conflict(["id"])
391 .create_set("id", FilterValue::Int(1))
392 .update(update_data);
393
394 let (sql, params) = op.build_sql();
395
396 assert!(sql.contains("DO UPDATE SET"));
397 assert_eq!(params.len(), 3); }
399
400 #[test]
403 fn test_upsert_do_nothing() {
404 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
405 .on_conflict(["email"])
406 .create_set("email", "test@example.com");
407
408 let (sql, _) = op.build_sql();
409
410 assert!(sql.contains("DO NOTHING"));
411 assert!(!sql.contains("DO UPDATE"));
412 }
413
414 #[test]
415 fn test_upsert_do_nothing_multiple_create() {
416 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
417 .on_conflict(["email"])
418 .create_set("email", "test@example.com")
419 .create_set("name", "Test");
420
421 let (sql, params) = op.build_sql();
422
423 assert!(sql.contains("DO NOTHING"));
424 assert_eq!(params.len(), 2);
425 }
426
427 #[test]
430 fn test_upsert_with_select() {
431 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
432 .on_conflict(["email"])
433 .create_set("email", "test@example.com")
434 .update_set("name", "Updated")
435 .select(Select::fields(["id", "email"]));
436
437 let (sql, _) = op.build_sql();
438
439 assert!(sql.contains("RETURNING id, email"));
440 assert!(!sql.contains("RETURNING *"));
441 }
442
443 #[test]
444 fn test_upsert_select_all() {
445 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
446 .on_conflict(["email"])
447 .create_set("email", "test@example.com")
448 .select(Select::All);
449
450 let (sql, _) = op.build_sql();
451
452 assert!(sql.contains("RETURNING *"));
453 }
454
455 #[test]
458 fn test_upsert_with_where() {
459 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
460 .r#where(Filter::Equals(
461 "email".into(),
462 FilterValue::String("test@example.com".to_string()),
463 ))
464 .on_conflict(["email"])
465 .create_set("email", "test@example.com");
466
467 let (_, _) = op.build_sql();
468 }
470
471 #[test]
474 fn test_upsert_sql_structure() {
475 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
476 .on_conflict(["email"])
477 .create_set("email", "test@example.com")
478 .update_set("name", "Updated")
479 .select(Select::fields(["id"]));
480
481 let (sql, _) = op.build_sql();
482
483 let insert_pos = sql.find("INSERT INTO").unwrap();
484 let values_pos = sql.find("VALUES").unwrap();
485 let conflict_pos = sql.find("ON CONFLICT").unwrap();
486 let update_pos = sql.find("DO UPDATE SET").unwrap();
487 let returning_pos = sql.find("RETURNING").unwrap();
488
489 assert!(insert_pos < values_pos);
490 assert!(values_pos < conflict_pos);
491 assert!(conflict_pos < update_pos);
492 assert!(update_pos < returning_pos);
493 }
494
495 #[test]
496 fn test_upsert_table_name() {
497 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine);
498 let (sql, _) = op.build_sql();
499
500 assert!(sql.contains("test_models"));
501 }
502
503 #[test]
506 fn test_upsert_param_ordering() {
507 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
508 .on_conflict(["email"])
509 .create_set("email", "create@test.com")
510 .create_set("name", "Create Name")
511 .update_set("name", "Update Name");
512
513 let (sql, params) = op.build_sql();
514
515 assert!(sql.contains("VALUES ($1, $2)"));
517 assert!(sql.contains("name = $3"));
518 assert_eq!(params.len(), 3);
519 assert_eq!(
520 params[0],
521 FilterValue::String("create@test.com".to_string())
522 );
523 assert_eq!(params[1], FilterValue::String("Create Name".to_string()));
524 assert_eq!(params[2], FilterValue::String("Update Name".to_string()));
525 }
526
527 #[tokio::test]
530 async fn test_upsert_exec() {
531 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
532 .on_conflict(["email"])
533 .create_set("email", "test@example.com");
534
535 let result = op.exec().await;
536
537 assert!(result.is_err());
539 }
540
541 #[test]
544 fn test_upsert_full_chain() {
545 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
546 .r#where(Filter::Equals(
547 "email".into(),
548 FilterValue::String("test@example.com".to_string()),
549 ))
550 .on_conflict(["email"])
551 .create_set("email", "test@example.com")
552 .create_set("name", "Test User")
553 .update_set("name", "Updated User")
554 .select(Select::fields(["id", "name", "email"]));
555
556 let (sql, params) = op.build_sql();
557
558 assert!(sql.contains("INSERT INTO test_models"));
559 assert!(sql.contains("ON CONFLICT (email)"));
560 assert!(sql.contains("DO UPDATE SET"));
561 assert!(sql.contains("RETURNING id, name, email"));
562 assert_eq!(params.len(), 3);
563 }
564
565 #[test]
568 fn test_upsert_with_null_value() {
569 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
570 .on_conflict(["id"])
571 .create_set("id", FilterValue::Int(1))
572 .create_set("nickname", FilterValue::Null);
573
574 let (_, params) = op.build_sql();
575
576 assert_eq!(params[1], FilterValue::Null);
577 }
578
579 #[test]
580 fn test_upsert_with_boolean_value() {
581 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
582 .on_conflict(["id"])
583 .create_set("id", FilterValue::Int(1))
584 .create_set("active", FilterValue::Bool(true))
585 .update_set("active", FilterValue::Bool(false));
586
587 let (_, params) = op.build_sql();
588
589 assert_eq!(params[1], FilterValue::Bool(true));
590 assert_eq!(params[2], FilterValue::Bool(false));
591 }
592
593 #[test]
594 fn test_upsert_with_numeric_values() {
595 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
596 .on_conflict(["id"])
597 .create_set("id", FilterValue::Int(1))
598 .create_set("score", FilterValue::Float(99.5));
599
600 let (_, params) = op.build_sql();
601
602 assert_eq!(params[0], FilterValue::Int(1));
603 assert_eq!(params[1], FilterValue::Float(99.5));
604 }
605
606 #[test]
607 fn test_upsert_with_json_value() {
608 let json = serde_json::json!({"key": "value"});
609 let op = UpsertOperation::<MockEngine, TestModel>::new(MockEngine)
610 .on_conflict(["id"])
611 .create_set("id", FilterValue::Int(1))
612 .create_set("metadata", FilterValue::Json(json.clone()));
613
614 let (_, params) = op.build_sql();
615
616 assert_eq!(params[1], FilterValue::Json(json));
617 }
618}