1use crate::input::mutation::InputValue;
6use crate::resolver::query::TableFilter;
7use bytes::Bytes;
8use postrust_core::plan::{CoercibleField, CoercibleLogicTree, MutatePlan};
9use postrust_core::schema_cache::Table;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12
13#[derive(Debug, Clone, Default)]
15pub struct InsertArgs {
16 pub objects: Vec<HashMap<String, InputValue>>,
18 pub on_conflict: Option<OnConflictArgs>,
20 pub returning: Vec<String>,
22}
23
24impl InsertArgs {
25 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn with_object(mut self, object: HashMap<String, InputValue>) -> Self {
32 self.objects.push(object);
33 self
34 }
35
36 pub fn with_objects(mut self, objects: Vec<HashMap<String, InputValue>>) -> Self {
38 self.objects = objects;
39 self
40 }
41
42 pub fn with_on_conflict(mut self, on_conflict: OnConflictArgs) -> Self {
44 self.on_conflict = Some(on_conflict);
45 self
46 }
47
48 pub fn with_returning(mut self, returning: Vec<String>) -> Self {
50 self.returning = returning;
51 self
52 }
53
54 pub fn has_objects(&self) -> bool {
56 !self.objects.is_empty()
57 }
58
59 pub fn object_count(&self) -> usize {
61 self.objects.len()
62 }
63
64 pub fn to_json_bytes(&self) -> Option<Bytes> {
66 if self.objects.is_empty() {
67 return None;
68 }
69
70 let json_objects: Vec<serde_json::Value> = self
72 .objects
73 .iter()
74 .map(|obj| {
75 let map: serde_json::Map<String, serde_json::Value> = obj
76 .iter()
77 .map(|(k, v)| (k.clone(), input_value_to_json(v)))
78 .collect();
79 serde_json::Value::Object(map)
80 })
81 .collect();
82
83 let json = if json_objects.len() == 1 {
84 serde_json::to_vec(&json_objects[0]).ok()
85 } else {
86 serde_json::to_vec(&json_objects).ok()
87 };
88
89 json.map(Bytes::from)
90 }
91
92 pub fn column_names(&self) -> Vec<String> {
94 self.objects
95 .first()
96 .map(|obj| obj.keys().cloned().collect())
97 .unwrap_or_default()
98 }
99}
100
101#[derive(Debug, Clone, Default, Serialize, Deserialize)]
103pub struct OnConflictArgs {
104 pub constraint: Vec<String>,
106 pub update_columns: Vec<String>,
108 pub where_filter: Option<TableFilter>,
110}
111
112impl OnConflictArgs {
113 pub fn new(constraint: Vec<String>) -> Self {
115 Self {
116 constraint,
117 update_columns: vec![],
118 where_filter: None,
119 }
120 }
121
122 pub fn with_update_columns(mut self, columns: Vec<String>) -> Self {
124 self.update_columns = columns;
125 self
126 }
127
128 pub fn with_where(mut self, filter: TableFilter) -> Self {
130 self.where_filter = Some(filter);
131 self
132 }
133}
134
135#[derive(Debug, Clone, Default)]
137pub struct UpdateArgs {
138 pub filter: Option<TableFilter>,
140 pub set: HashMap<String, InputValue>,
142 pub returning: Vec<String>,
144}
145
146impl UpdateArgs {
147 pub fn new() -> Self {
149 Self::default()
150 }
151
152 pub fn with_filter(mut self, filter: TableFilter) -> Self {
154 self.filter = Some(filter);
155 self
156 }
157
158 pub fn with_set(mut self, set: HashMap<String, InputValue>) -> Self {
160 self.set = set;
161 self
162 }
163
164 pub fn with_returning(mut self, returning: Vec<String>) -> Self {
166 self.returning = returning;
167 self
168 }
169
170 pub fn has_filter(&self) -> bool {
172 self.filter.is_some()
173 }
174
175 pub fn has_set(&self) -> bool {
177 !self.set.is_empty()
178 }
179
180 pub fn to_json_bytes(&self) -> Option<Bytes> {
182 if self.set.is_empty() {
183 return None;
184 }
185
186 let map: serde_json::Map<String, serde_json::Value> = self
187 .set
188 .iter()
189 .map(|(k, v)| (k.clone(), input_value_to_json(v)))
190 .collect();
191
192 serde_json::to_vec(&serde_json::Value::Object(map))
193 .ok()
194 .map(Bytes::from)
195 }
196
197 pub fn column_names(&self) -> Vec<String> {
199 self.set.keys().cloned().collect()
200 }
201}
202
203#[derive(Debug, Clone, Default)]
205pub struct DeleteArgs {
206 pub filter: Option<TableFilter>,
208 pub returning: Vec<String>,
210}
211
212impl DeleteArgs {
213 pub fn new() -> Self {
215 Self::default()
216 }
217
218 pub fn with_filter(mut self, filter: TableFilter) -> Self {
220 self.filter = Some(filter);
221 self
222 }
223
224 pub fn with_returning(mut self, returning: Vec<String>) -> Self {
226 self.returning = returning;
227 self
228 }
229
230 pub fn has_filter(&self) -> bool {
232 self.filter.is_some()
233 }
234}
235
236fn input_value_to_json(value: &InputValue) -> serde_json::Value {
238 match value {
239 InputValue::Null => serde_json::Value::Null,
240 InputValue::Bool(b) => serde_json::Value::Bool(*b),
241 InputValue::Int(i) => serde_json::Value::Number((*i).into()),
242 InputValue::Float(f) => {
243 serde_json::Number::from_f64(*f)
244 .map(serde_json::Value::Number)
245 .unwrap_or(serde_json::Value::Null)
246 }
247 InputValue::String(s) => serde_json::Value::String(s.clone()),
248 InputValue::Object(obj) => {
249 let map: serde_json::Map<String, serde_json::Value> = obj
250 .iter()
251 .map(|(k, v)| (k.clone(), input_value_to_json(v)))
252 .collect();
253 serde_json::Value::Object(map)
254 }
255 InputValue::Array(arr) => {
256 serde_json::Value::Array(arr.iter().map(input_value_to_json).collect())
257 }
258 }
259}
260
261fn build_coercible_fields(columns: &[String], table: &Table) -> Vec<CoercibleField> {
263 columns
264 .iter()
265 .filter_map(|name| {
266 table
267 .columns
268 .get(name)
269 .map(|col| CoercibleField::simple(name, &col.data_type))
270 })
271 .collect()
272}
273
274fn build_where_clauses(filter: &Option<TableFilter>, table: &Table) -> Vec<CoercibleLogicTree> {
276 let Some(filter) = filter else {
277 return vec![];
278 };
279
280 let type_resolver = |name: &str| -> String {
281 table
282 .get_column(name)
283 .map(|c| c.data_type.clone())
284 .unwrap_or_else(|| "text".to_string())
285 };
286
287 filter
288 .to_logic_tree()
289 .map(|tree| vec![CoercibleLogicTree::from_logic_tree(&tree, type_resolver)])
290 .unwrap_or_default()
291}
292
293pub fn build_insert_plan(args: &InsertArgs, table: &Table) -> MutatePlan {
295 let columns = build_coercible_fields(&args.column_names(), table);
296 let body = args.to_json_bytes();
297 let returning = if args.returning.is_empty() {
298 table.pk_cols.clone()
299 } else {
300 args.returning.clone()
301 };
302
303 let on_conflict = args.on_conflict.as_ref().map(|oc| {
304 (
305 postrust_core::api_request::PreferResolution::MergeDuplicates,
306 oc.constraint.clone(),
307 )
308 });
309
310 MutatePlan::Insert {
311 target: table.qualified_identifier(),
312 columns,
313 body,
314 on_conflict,
315 where_clauses: vec![],
316 returning,
317 pk_cols: table.pk_cols.clone(),
318 apply_defaults: true,
319 }
320}
321
322pub fn build_update_plan(args: &UpdateArgs, table: &Table) -> MutatePlan {
324 let columns = build_coercible_fields(&args.column_names(), table);
325 let body = args.to_json_bytes();
326 let where_clauses = build_where_clauses(&args.filter, table);
327 let returning = if args.returning.is_empty() {
328 table.pk_cols.clone()
329 } else {
330 args.returning.clone()
331 };
332
333 MutatePlan::Update {
334 target: table.qualified_identifier(),
335 columns,
336 body,
337 where_clauses,
338 returning,
339 apply_defaults: false,
340 }
341}
342
343pub fn build_delete_plan(args: &DeleteArgs, table: &Table) -> MutatePlan {
345 let where_clauses = build_where_clauses(&args.filter, table);
346 let returning = if args.returning.is_empty() {
347 table.pk_cols.clone()
348 } else {
349 args.returning.clone()
350 };
351
352 MutatePlan::Delete {
353 target: table.qualified_identifier(),
354 where_clauses,
355 returning,
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use crate::input::filter::IntFilterInput;
363 use crate::resolver::query::FieldFilter;
364 use indexmap::IndexMap;
365 use postrust_core::schema_cache::Column;
366 use pretty_assertions::assert_eq;
367
368 fn create_test_table() -> Table {
369 let mut columns = IndexMap::new();
370 columns.insert(
371 "id".into(),
372 Column {
373 name: "id".into(),
374 description: None,
375 nullable: false,
376 data_type: "integer".into(),
377 nominal_type: "int4".into(),
378 max_len: None,
379 default: Some("nextval('users_id_seq')".into()),
380 enum_values: vec![],
381 is_pk: true,
382 position: 1,
383 },
384 );
385 columns.insert(
386 "name".into(),
387 Column {
388 name: "name".into(),
389 description: None,
390 nullable: false,
391 data_type: "text".into(),
392 nominal_type: "text".into(),
393 max_len: None,
394 default: None,
395 enum_values: vec![],
396 is_pk: false,
397 position: 2,
398 },
399 );
400 columns.insert(
401 "email".into(),
402 Column {
403 name: "email".into(),
404 description: None,
405 nullable: true,
406 data_type: "text".into(),
407 nominal_type: "text".into(),
408 max_len: None,
409 default: None,
410 enum_values: vec![],
411 is_pk: false,
412 position: 3,
413 },
414 );
415
416 Table {
417 schema: "public".into(),
418 name: "users".into(),
419 description: None,
420 is_view: false,
421 insertable: true,
422 updatable: true,
423 deletable: true,
424 pk_cols: vec!["id".into()],
425 columns,
426 }
427 }
428
429 #[test]
434 fn test_insert_args_default() {
435 let args = InsertArgs::new();
436 assert!(!args.has_objects());
437 assert_eq!(args.object_count(), 0);
438 }
439
440 #[test]
441 fn test_insert_args_with_object() {
442 let mut object = HashMap::new();
443 object.insert("name".to_string(), InputValue::String("Alice".to_string()));
444 object.insert("email".to_string(), InputValue::String("alice@example.com".to_string()));
445
446 let args = InsertArgs::new().with_object(object);
447 assert!(args.has_objects());
448 assert_eq!(args.object_count(), 1);
449 }
450
451 #[test]
452 fn test_insert_args_with_multiple_objects() {
453 let obj1: HashMap<String, InputValue> = [
454 ("name".to_string(), InputValue::String("Alice".to_string())),
455 ].into_iter().collect();
456 let obj2: HashMap<String, InputValue> = [
457 ("name".to_string(), InputValue::String("Bob".to_string())),
458 ].into_iter().collect();
459
460 let args = InsertArgs::new().with_objects(vec![obj1, obj2]);
461 assert_eq!(args.object_count(), 2);
462 }
463
464 #[test]
465 fn test_insert_args_column_names() {
466 let object: HashMap<String, InputValue> = [
467 ("name".to_string(), InputValue::String("Alice".to_string())),
468 ("email".to_string(), InputValue::String("alice@example.com".to_string())),
469 ].into_iter().collect();
470
471 let args = InsertArgs::new().with_object(object);
472 let columns = args.column_names();
473 assert_eq!(columns.len(), 2);
474 assert!(columns.contains(&"name".to_string()));
475 assert!(columns.contains(&"email".to_string()));
476 }
477
478 #[test]
479 fn test_insert_args_to_json_bytes() {
480 let object: HashMap<String, InputValue> = [
481 ("name".to_string(), InputValue::String("Alice".to_string())),
482 ].into_iter().collect();
483
484 let args = InsertArgs::new().with_object(object);
485 let bytes = args.to_json_bytes().unwrap();
486 let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
487 assert_eq!(json["name"], "Alice");
488 }
489
490 #[test]
491 fn test_insert_args_with_returning() {
492 let args = InsertArgs::new()
493 .with_returning(vec!["id".to_string(), "name".to_string()]);
494 assert_eq!(args.returning.len(), 2);
495 }
496
497 #[test]
498 fn test_insert_args_with_on_conflict() {
499 let on_conflict = OnConflictArgs::new(vec!["email".to_string()])
500 .with_update_columns(vec!["name".to_string()]);
501
502 let args = InsertArgs::new().with_on_conflict(on_conflict);
503 assert!(args.on_conflict.is_some());
504 }
505
506 #[test]
511 fn test_on_conflict_args() {
512 let args = OnConflictArgs::new(vec!["id".to_string()])
513 .with_update_columns(vec!["name".to_string(), "email".to_string()]);
514
515 assert_eq!(args.constraint, vec!["id".to_string()]);
516 assert_eq!(args.update_columns.len(), 2);
517 }
518
519 #[test]
524 fn test_update_args_default() {
525 let args = UpdateArgs::new();
526 assert!(!args.has_filter());
527 assert!(!args.has_set());
528 }
529
530 #[test]
531 fn test_update_args_with_set() {
532 let set: HashMap<String, InputValue> = [
533 ("name".to_string(), InputValue::String("Updated".to_string())),
534 ].into_iter().collect();
535
536 let args = UpdateArgs::new().with_set(set);
537 assert!(args.has_set());
538 assert_eq!(args.column_names().len(), 1);
539 }
540
541 #[test]
542 fn test_update_args_with_filter() {
543 let filter = TableFilter::new().with_field(
544 "id",
545 FieldFilter::int(IntFilterInput {
546 eq: Some(1),
547 ..Default::default()
548 }),
549 );
550
551 let args = UpdateArgs::new().with_filter(filter);
552 assert!(args.has_filter());
553 }
554
555 #[test]
556 fn test_update_args_to_json_bytes() {
557 let set: HashMap<String, InputValue> = [
558 ("name".to_string(), InputValue::String("Updated".to_string())),
559 ("active".to_string(), InputValue::Bool(true)),
560 ].into_iter().collect();
561
562 let args = UpdateArgs::new().with_set(set);
563 let bytes = args.to_json_bytes().unwrap();
564 let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
565 assert_eq!(json["name"], "Updated");
566 assert_eq!(json["active"], true);
567 }
568
569 #[test]
574 fn test_delete_args_default() {
575 let args = DeleteArgs::new();
576 assert!(!args.has_filter());
577 }
578
579 #[test]
580 fn test_delete_args_with_filter() {
581 let filter = TableFilter::new().with_field(
582 "id",
583 FieldFilter::int(IntFilterInput {
584 eq: Some(1),
585 ..Default::default()
586 }),
587 );
588
589 let args = DeleteArgs::new().with_filter(filter);
590 assert!(args.has_filter());
591 }
592
593 #[test]
594 fn test_delete_args_with_returning() {
595 let args = DeleteArgs::new()
596 .with_returning(vec!["id".to_string(), "name".to_string()]);
597 assert_eq!(args.returning.len(), 2);
598 }
599
600 #[test]
605 fn test_input_value_to_json_null() {
606 let json = input_value_to_json(&InputValue::Null);
607 assert!(json.is_null());
608 }
609
610 #[test]
611 fn test_input_value_to_json_bool() {
612 let json = input_value_to_json(&InputValue::Bool(true));
613 assert_eq!(json, serde_json::Value::Bool(true));
614 }
615
616 #[test]
617 fn test_input_value_to_json_int() {
618 let json = input_value_to_json(&InputValue::Int(42));
619 assert_eq!(json, serde_json::json!(42));
620 }
621
622 #[test]
623 fn test_input_value_to_json_float() {
624 let json = input_value_to_json(&InputValue::Float(3.14));
625 assert_eq!(json, serde_json::json!(3.14));
626 }
627
628 #[test]
629 fn test_input_value_to_json_string() {
630 let json = input_value_to_json(&InputValue::String("hello".to_string()));
631 assert_eq!(json, serde_json::json!("hello"));
632 }
633
634 #[test]
635 fn test_input_value_to_json_array() {
636 let arr = vec![
637 InputValue::Int(1),
638 InputValue::Int(2),
639 InputValue::Int(3),
640 ];
641 let json = input_value_to_json(&InputValue::Array(arr));
642 assert_eq!(json, serde_json::json!([1, 2, 3]));
643 }
644
645 #[test]
646 fn test_input_value_to_json_object() {
647 let obj: HashMap<String, InputValue> = [
648 ("name".to_string(), InputValue::String("test".to_string())),
649 ("count".to_string(), InputValue::Int(5)),
650 ].into_iter().collect();
651 let json = input_value_to_json(&InputValue::Object(obj));
652 assert_eq!(json["name"], "test");
653 assert_eq!(json["count"], 5);
654 }
655
656 #[test]
661 fn test_build_insert_plan_basic() {
662 let table = create_test_table();
663 let object: HashMap<String, InputValue> = [
664 ("name".to_string(), InputValue::String("Alice".to_string())),
665 ].into_iter().collect();
666
667 let args = InsertArgs::new().with_object(object);
668 let plan = build_insert_plan(&args, &table);
669
670 match plan {
671 MutatePlan::Insert { target, body, returning, .. } => {
672 assert_eq!(target.name, "users");
673 assert!(body.is_some());
674 assert_eq!(returning, vec!["id".to_string()]);
675 }
676 _ => panic!("Expected Insert plan"),
677 }
678 }
679
680 #[test]
681 fn test_build_insert_plan_with_returning() {
682 let table = create_test_table();
683 let object: HashMap<String, InputValue> = [
684 ("name".to_string(), InputValue::String("Alice".to_string())),
685 ].into_iter().collect();
686
687 let args = InsertArgs::new()
688 .with_object(object)
689 .with_returning(vec!["id".to_string(), "name".to_string()]);
690 let plan = build_insert_plan(&args, &table);
691
692 match plan {
693 MutatePlan::Insert { returning, .. } => {
694 assert_eq!(returning.len(), 2);
695 }
696 _ => panic!("Expected Insert plan"),
697 }
698 }
699
700 #[test]
701 fn test_build_insert_plan_with_on_conflict() {
702 let table = create_test_table();
703 let object: HashMap<String, InputValue> = [
704 ("name".to_string(), InputValue::String("Alice".to_string())),
705 ].into_iter().collect();
706
707 let on_conflict = OnConflictArgs::new(vec!["id".to_string()]);
708 let args = InsertArgs::new()
709 .with_object(object)
710 .with_on_conflict(on_conflict);
711 let plan = build_insert_plan(&args, &table);
712
713 match plan {
714 MutatePlan::Insert { on_conflict, .. } => {
715 assert!(on_conflict.is_some());
716 let (_, cols) = on_conflict.unwrap();
717 assert_eq!(cols, vec!["id".to_string()]);
718 }
719 _ => panic!("Expected Insert plan"),
720 }
721 }
722
723 #[test]
724 fn test_build_update_plan_basic() {
725 let table = create_test_table();
726 let set: HashMap<String, InputValue> = [
727 ("name".to_string(), InputValue::String("Updated".to_string())),
728 ].into_iter().collect();
729
730 let filter = TableFilter::new().with_field(
731 "id",
732 FieldFilter::int(IntFilterInput {
733 eq: Some(1),
734 ..Default::default()
735 }),
736 );
737
738 let args = UpdateArgs::new()
739 .with_set(set)
740 .with_filter(filter);
741 let plan = build_update_plan(&args, &table);
742
743 match plan {
744 MutatePlan::Update { target, body, where_clauses, .. } => {
745 assert_eq!(target.name, "users");
746 assert!(body.is_some());
747 assert!(!where_clauses.is_empty());
748 }
749 _ => panic!("Expected Update plan"),
750 }
751 }
752
753 #[test]
754 fn test_build_update_plan_with_returning() {
755 let table = create_test_table();
756 let set: HashMap<String, InputValue> = [
757 ("name".to_string(), InputValue::String("Updated".to_string())),
758 ].into_iter().collect();
759
760 let args = UpdateArgs::new()
761 .with_set(set)
762 .with_returning(vec!["id".to_string(), "name".to_string()]);
763 let plan = build_update_plan(&args, &table);
764
765 match plan {
766 MutatePlan::Update { returning, .. } => {
767 assert_eq!(returning.len(), 2);
768 }
769 _ => panic!("Expected Update plan"),
770 }
771 }
772
773 #[test]
774 fn test_build_delete_plan_basic() {
775 let table = create_test_table();
776 let filter = TableFilter::new().with_field(
777 "id",
778 FieldFilter::int(IntFilterInput {
779 eq: Some(1),
780 ..Default::default()
781 }),
782 );
783
784 let args = DeleteArgs::new().with_filter(filter);
785 let plan = build_delete_plan(&args, &table);
786
787 match plan {
788 MutatePlan::Delete { target, where_clauses, returning } => {
789 assert_eq!(target.name, "users");
790 assert!(!where_clauses.is_empty());
791 assert_eq!(returning, vec!["id".to_string()]);
792 }
793 _ => panic!("Expected Delete plan"),
794 }
795 }
796
797 #[test]
798 fn test_build_delete_plan_with_returning() {
799 let table = create_test_table();
800 let args = DeleteArgs::new()
801 .with_returning(vec!["id".to_string(), "name".to_string(), "email".to_string()]);
802 let plan = build_delete_plan(&args, &table);
803
804 match plan {
805 MutatePlan::Delete { returning, .. } => {
806 assert_eq!(returning.len(), 3);
807 }
808 _ => panic!("Expected Delete plan"),
809 }
810 }
811
812 #[test]
813 fn test_build_delete_plan_no_filter() {
814 let table = create_test_table();
815 let args = DeleteArgs::new();
816 let plan = build_delete_plan(&args, &table);
817
818 match plan {
819 MutatePlan::Delete { where_clauses, .. } => {
820 assert!(where_clauses.is_empty());
821 }
822 _ => panic!("Expected Delete plan"),
823 }
824 }
825}