1#![allow(dead_code)]
2
3use std::fmt::Debug;
62use std::marker::PhantomData;
63
64use crate::error::{QueryError, QueryResult};
65use crate::filter::{Filter, FilterValue};
66use crate::sql::quote_identifier;
67use crate::traits::{Model, QueryEngine};
68
69#[derive(Debug, Clone)]
71pub enum NestedWrite<T: Model> {
72 Create(Vec<NestedCreateData<T>>),
74 CreateOrConnect(Vec<NestedCreateOrConnectData<T>>),
76 Connect(Vec<Filter>),
78 Disconnect(Vec<Filter>),
80 Set(Vec<Filter>),
82 Delete(Vec<Filter>),
84 Update(Vec<NestedUpdateData<T>>),
86 Upsert(Vec<NestedUpsertData<T>>),
88 UpdateMany(NestedUpdateManyData<T>),
90 DeleteMany(Filter),
92}
93
94impl<T: Model> NestedWrite<T> {
95 pub fn create(data: NestedCreateData<T>) -> Self {
97 Self::Create(vec![data])
98 }
99
100 pub fn create_many(data: Vec<NestedCreateData<T>>) -> Self {
102 Self::Create(data)
103 }
104
105 pub fn connect_one(filter: impl Into<Filter>) -> Self {
107 Self::Connect(vec![filter.into()])
108 }
109
110 pub fn connect(filters: Vec<impl Into<Filter>>) -> Self {
112 Self::Connect(filters.into_iter().map(Into::into).collect())
113 }
114
115 pub fn disconnect_one(filter: impl Into<Filter>) -> Self {
117 Self::Disconnect(vec![filter.into()])
118 }
119
120 pub fn disconnect(filters: Vec<impl Into<Filter>>) -> Self {
122 Self::Disconnect(filters.into_iter().map(Into::into).collect())
123 }
124
125 pub fn set(filters: Vec<impl Into<Filter>>) -> Self {
127 Self::Set(filters.into_iter().map(Into::into).collect())
128 }
129
130 pub fn delete(filters: Vec<impl Into<Filter>>) -> Self {
132 Self::Delete(filters.into_iter().map(Into::into).collect())
133 }
134
135 pub fn delete_many(filter: impl Into<Filter>) -> Self {
137 Self::DeleteMany(filter.into())
138 }
139}
140
141#[derive(Debug, Clone)]
143pub struct NestedCreateData<T: Model> {
144 pub data: Vec<(String, FilterValue)>,
146 _model: PhantomData<T>,
148}
149
150impl<T: Model> NestedCreateData<T> {
151 pub fn new(data: Vec<(String, FilterValue)>) -> Self {
153 Self {
154 data,
155 _model: PhantomData,
156 }
157 }
158
159 pub fn from_pairs(
161 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
162 ) -> Self {
163 Self::new(
164 pairs
165 .into_iter()
166 .map(|(k, v)| (k.into(), v.into()))
167 .collect(),
168 )
169 }
170}
171
172impl<T: Model> Default for NestedCreateData<T> {
173 fn default() -> Self {
174 Self::new(Vec::new())
175 }
176}
177
178#[derive(Debug, Clone)]
180pub struct NestedCreateOrConnectData<T: Model> {
181 pub filter: Filter,
183 pub create: NestedCreateData<T>,
185}
186
187impl<T: Model> NestedCreateOrConnectData<T> {
188 pub fn new(filter: impl Into<Filter>, create: NestedCreateData<T>) -> Self {
190 Self {
191 filter: filter.into(),
192 create,
193 }
194 }
195}
196
197#[derive(Debug, Clone)]
199pub struct NestedUpdateData<T: Model> {
200 pub filter: Filter,
202 pub data: Vec<(String, FilterValue)>,
204 _model: PhantomData<T>,
206}
207
208impl<T: Model> NestedUpdateData<T> {
209 pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
211 Self {
212 filter: filter.into(),
213 data,
214 _model: PhantomData,
215 }
216 }
217
218 pub fn from_pairs(
220 filter: impl Into<Filter>,
221 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
222 ) -> Self {
223 Self::new(
224 filter,
225 pairs
226 .into_iter()
227 .map(|(k, v)| (k.into(), v.into()))
228 .collect(),
229 )
230 }
231}
232
233#[derive(Debug, Clone)]
235pub struct NestedUpsertData<T: Model> {
236 pub filter: Filter,
238 pub create: NestedCreateData<T>,
240 pub update: Vec<(String, FilterValue)>,
242 _model: PhantomData<T>,
244}
245
246impl<T: Model> NestedUpsertData<T> {
247 pub fn new(
249 filter: impl Into<Filter>,
250 create: NestedCreateData<T>,
251 update: Vec<(String, FilterValue)>,
252 ) -> Self {
253 Self {
254 filter: filter.into(),
255 create,
256 update,
257 _model: PhantomData,
258 }
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct NestedUpdateManyData<T: Model> {
265 pub filter: Filter,
267 pub data: Vec<(String, FilterValue)>,
269 _model: PhantomData<T>,
271}
272
273impl<T: Model> NestedUpdateManyData<T> {
274 pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
276 Self {
277 filter: filter.into(),
278 data,
279 _model: PhantomData,
280 }
281 }
282}
283
284#[derive(Debug)]
286pub struct NestedWriteBuilder {
287 parent_table: String,
289 parent_pk: Vec<String>,
291 related_table: String,
293 foreign_key: String,
295 is_one_to_many: bool,
297 join_table: Option<JoinTableInfo>,
299}
300
301#[derive(Debug, Clone)]
303pub struct JoinTableInfo {
304 pub table_name: String,
306 pub parent_column: String,
308 pub related_column: String,
310}
311
312impl NestedWriteBuilder {
313 pub fn one_to_many(
315 parent_table: impl Into<String>,
316 parent_pk: Vec<String>,
317 related_table: impl Into<String>,
318 foreign_key: impl Into<String>,
319 ) -> Self {
320 Self {
321 parent_table: parent_table.into(),
322 parent_pk,
323 related_table: related_table.into(),
324 foreign_key: foreign_key.into(),
325 is_one_to_many: true,
326 join_table: None,
327 }
328 }
329
330 pub fn many_to_many(
332 parent_table: impl Into<String>,
333 parent_pk: Vec<String>,
334 related_table: impl Into<String>,
335 join_table: JoinTableInfo,
336 ) -> Self {
337 Self {
338 parent_table: parent_table.into(),
339 parent_pk,
340 related_table: related_table.into(),
341 foreign_key: String::new(), is_one_to_many: false,
343 join_table: Some(join_table),
344 }
345 }
346
347 pub fn build_connect_sql<T: Model>(
349 &self,
350 parent_id: &FilterValue,
351 filters: &[Filter],
352 ) -> Vec<(String, Vec<FilterValue>)> {
353 let mut statements = Vec::new();
354
355 if self.is_one_to_many {
356 for filter in filters {
358 let (where_sql, mut params) = filter.to_sql(1, &crate::dialect::Postgres);
359 let sql = format!(
360 "UPDATE {} SET {} = ${} WHERE {}",
361 quote_identifier(&self.related_table),
362 quote_identifier(&self.foreign_key),
363 params.len() + 1,
364 where_sql
365 );
366 params.push(parent_id.clone());
367 statements.push((sql, params));
368 }
369 } else if let Some(join) = &self.join_table {
370 for filter in filters {
373 let (where_sql, mut params) = filter.to_sql(1, &crate::dialect::Postgres);
374
375 let select_sql = format!(
377 "SELECT {} FROM {} WHERE {}",
378 quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
379 quote_identifier(&self.related_table),
380 where_sql
381 );
382
383 let insert_sql = format!(
385 "INSERT INTO {} ({}, {}) SELECT ${}, {} FROM {} WHERE {} ON CONFLICT DO NOTHING",
386 quote_identifier(&join.table_name),
387 quote_identifier(&join.parent_column),
388 quote_identifier(&join.related_column),
389 params.len() + 1,
390 quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
391 quote_identifier(&self.related_table),
392 where_sql
393 );
394 params.push(parent_id.clone());
395 statements.push((insert_sql, params));
396 let _ = select_sql;
398 }
399 }
400
401 statements
402 }
403
404 pub fn build_disconnect_sql(
406 &self,
407 parent_id: &FilterValue,
408 filters: &[Filter],
409 ) -> Vec<(String, Vec<FilterValue>)> {
410 let mut statements = Vec::new();
411
412 if self.is_one_to_many {
413 for filter in filters {
415 let (where_sql, mut params) = filter.to_sql(1, &crate::dialect::Postgres);
416 let sql = format!(
417 "UPDATE {} SET {} = NULL WHERE {} AND {} = ${}",
418 quote_identifier(&self.related_table),
419 quote_identifier(&self.foreign_key),
420 where_sql,
421 quote_identifier(&self.foreign_key),
422 params.len() + 1
423 );
424 params.push(parent_id.clone());
425 statements.push((sql, params));
426 }
427 } else if let Some(join) = &self.join_table {
428 for filter in filters {
430 let (where_sql, mut params) = filter.to_sql(2, &crate::dialect::Postgres);
431 let sql = format!(
432 "DELETE FROM {} WHERE {} = $1 AND {} IN (SELECT id FROM {} WHERE {})",
433 quote_identifier(&join.table_name),
434 quote_identifier(&join.parent_column),
435 quote_identifier(&join.related_column),
436 quote_identifier(&self.related_table),
437 where_sql
438 );
439 let mut final_params = vec![parent_id.clone()];
440 final_params.extend(params);
441 params = final_params;
442 statements.push((sql, params));
443 }
444 }
445
446 statements
447 }
448
449 pub fn build_set_sql<T: Model>(
451 &self,
452 parent_id: &FilterValue,
453 filters: &[Filter],
454 ) -> Vec<(String, Vec<FilterValue>)> {
455 let mut statements = Vec::new();
456
457 if self.is_one_to_many {
459 let sql = format!(
460 "UPDATE {} SET {} = NULL WHERE {} = $1",
461 quote_identifier(&self.related_table),
462 quote_identifier(&self.foreign_key),
463 quote_identifier(&self.foreign_key)
464 );
465 statements.push((sql, vec![parent_id.clone()]));
466 } else if let Some(join) = &self.join_table {
467 let sql = format!(
468 "DELETE FROM {} WHERE {} = $1",
469 quote_identifier(&join.table_name),
470 quote_identifier(&join.parent_column)
471 );
472 statements.push((sql, vec![parent_id.clone()]));
473 }
474
475 statements.extend(self.build_connect_sql::<T>(parent_id, filters));
477
478 statements
479 }
480
481 pub fn build_create_sql<T: Model>(
483 &self,
484 parent_id: &FilterValue,
485 creates: &[NestedCreateData<T>],
486 ) -> Vec<(String, Vec<FilterValue>)> {
487 let mut statements = Vec::new();
488
489 for create in creates {
490 let mut columns: Vec<String> = create.data.iter().map(|(k, _)| k.clone()).collect();
491 let mut values: Vec<FilterValue> = create.data.iter().map(|(_, v)| v.clone()).collect();
492
493 columns.push(self.foreign_key.clone());
495 values.push(parent_id.clone());
496
497 let placeholders: Vec<String> = (1..=values.len()).map(|i| format!("${}", i)).collect();
498
499 let sql = format!(
500 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
501 quote_identifier(&self.related_table),
502 columns
503 .iter()
504 .map(|c| quote_identifier(c))
505 .collect::<Vec<_>>()
506 .join(", "),
507 placeholders.join(", ")
508 );
509
510 statements.push((sql, values));
511 }
512
513 statements
514 }
515
516 pub fn build_delete_sql(
518 &self,
519 parent_id: &FilterValue,
520 filters: &[Filter],
521 ) -> Vec<(String, Vec<FilterValue>)> {
522 let mut statements = Vec::new();
523
524 for filter in filters {
525 let (where_sql, mut params) = filter.to_sql(1, &crate::dialect::Postgres);
526 let sql = format!(
527 "DELETE FROM {} WHERE {} AND {} = ${}",
528 quote_identifier(&self.related_table),
529 where_sql,
530 quote_identifier(&self.foreign_key),
531 params.len() + 1
532 );
533 params.push(parent_id.clone());
534 statements.push((sql, params));
535 }
536
537 statements
538 }
539}
540
541#[derive(Debug, Default)]
543pub struct NestedWriteOperations {
544 pub pre_statements: Vec<(String, Vec<FilterValue>)>,
546 pub post_statements: Vec<(String, Vec<FilterValue>)>,
548}
549
550impl NestedWriteOperations {
551 pub fn new() -> Self {
553 Self::default()
554 }
555
556 pub fn add_pre(&mut self, sql: String, params: Vec<FilterValue>) {
558 self.pre_statements.push((sql, params));
559 }
560
561 pub fn add_post(&mut self, sql: String, params: Vec<FilterValue>) {
563 self.post_statements.push((sql, params));
564 }
565
566 pub fn extend(&mut self, other: Self) {
568 self.pre_statements.extend(other.pre_statements);
569 self.post_statements.extend(other.post_statements);
570 }
571
572 pub fn is_empty(&self) -> bool {
574 self.pre_statements.is_empty() && self.post_statements.is_empty()
575 }
576
577 pub fn len(&self) -> usize {
579 self.pre_statements.len() + self.post_statements.len()
580 }
581}
582
583#[derive(Debug, Clone)]
601pub enum NestedWriteOp {
602 Create {
607 relation: String,
609 target_table: String,
611 foreign_key: String,
613 payload: Vec<Vec<(String, FilterValue)>>,
616 },
617 Connect {
627 relation: String,
629 pk: FilterValue,
631 },
632}
633
634impl NestedWriteOp {
635 pub async fn execute<E>(self, engine: &E, parent_pk: &FilterValue) -> QueryResult<()>
642 where
643 E: QueryEngine,
644 {
645 match self {
646 NestedWriteOp::Connect { relation, pk: _ } => {
647 let _ = relation;
648 Err(QueryError::internal(
649 "nested Connect is not yet implemented (needs child-PK column metadata)",
650 ))
651 }
652 NestedWriteOp::Create {
653 relation: _,
654 target_table,
655 foreign_key,
656 payload,
657 } => {
658 let dialect = engine.dialect();
659 for child in payload {
660 let mut columns: Vec<String> = child.iter().map(|(c, _)| c.clone()).collect();
664 let mut values: Vec<FilterValue> = child.into_iter().map(|(_, v)| v).collect();
665 columns.push(foreign_key.clone());
666 values.push(parent_pk.clone());
667
668 let placeholders: Vec<String> =
669 (1..=values.len()).map(|i| dialect.placeholder(i)).collect();
670 let quoted_cols: Vec<String> =
671 columns.iter().map(|c| dialect.quote_ident(c)).collect();
672
673 let sql = format!(
674 "INSERT INTO {} ({}) VALUES ({})",
675 dialect.quote_ident(&target_table),
676 quoted_cols.join(", "),
677 placeholders.join(", "),
678 );
679
680 engine.execute_raw(&sql, values).await?;
681 }
682 Ok(())
683 }
684 }
685 }
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691
692 struct TestModel;
693
694 impl Model for TestModel {
695 const MODEL_NAME: &'static str = "Post";
696 const TABLE_NAME: &'static str = "posts";
697 const PRIMARY_KEY: &'static [&'static str] = &["id"];
698 const COLUMNS: &'static [&'static str] = &["id", "title", "user_id"];
699 }
700
701 struct TagModel;
702
703 impl Model for TagModel {
704 const MODEL_NAME: &'static str = "Tag";
705 const TABLE_NAME: &'static str = "tags";
706 const PRIMARY_KEY: &'static [&'static str] = &["id"];
707 const COLUMNS: &'static [&'static str] = &["id", "name"];
708 }
709
710 #[test]
711 fn test_nested_create_data() {
712 let data: NestedCreateData<TestModel> =
713 NestedCreateData::from_pairs([("title", FilterValue::String("Test Post".to_string()))]);
714
715 assert_eq!(data.data.len(), 1);
716 assert_eq!(data.data[0].0, "title");
717 }
718
719 #[test]
720 fn test_nested_write_create() {
721 let data: NestedCreateData<TestModel> =
722 NestedCreateData::from_pairs([("title", FilterValue::String("Test Post".to_string()))]);
723
724 let write: NestedWrite<TestModel> = NestedWrite::create(data);
725
726 match write {
727 NestedWrite::Create(creates) => assert_eq!(creates.len(), 1),
728 _ => panic!("Expected Create variant"),
729 }
730 }
731
732 #[test]
733 fn test_nested_write_connect() {
734 let write: NestedWrite<TestModel> = NestedWrite::connect(vec![
735 Filter::Equals("id".into(), FilterValue::Int(1)),
736 Filter::Equals("id".into(), FilterValue::Int(2)),
737 ]);
738
739 match write {
740 NestedWrite::Connect(filters) => assert_eq!(filters.len(), 2),
741 _ => panic!("Expected Connect variant"),
742 }
743 }
744
745 #[test]
746 fn test_nested_write_disconnect() {
747 let write: NestedWrite<TestModel> =
748 NestedWrite::disconnect_one(Filter::Equals("id".into(), FilterValue::Int(1)));
749
750 match write {
751 NestedWrite::Disconnect(filters) => assert_eq!(filters.len(), 1),
752 _ => panic!("Expected Disconnect variant"),
753 }
754 }
755
756 #[test]
757 fn test_nested_write_set() {
758 let write: NestedWrite<TestModel> =
759 NestedWrite::set(vec![Filter::Equals("id".into(), FilterValue::Int(1))]);
760
761 match write {
762 NestedWrite::Set(filters) => assert_eq!(filters.len(), 1),
763 _ => panic!("Expected Set variant"),
764 }
765 }
766
767 #[test]
768 fn test_builder_one_to_many_connect() {
769 let builder =
770 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
771
772 let parent_id = FilterValue::Int(1);
773 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
774
775 let statements = builder.build_connect_sql::<TestModel>(&parent_id, &filters);
776
777 assert_eq!(statements.len(), 1);
778 let (sql, params) = &statements[0];
779 assert!(sql.contains("UPDATE"));
780 assert!(sql.contains("posts"));
781 assert!(sql.contains("user_id"));
782 assert_eq!(params.len(), 2);
783 }
784
785 #[test]
786 fn test_builder_one_to_many_disconnect() {
787 let builder =
788 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
789
790 let parent_id = FilterValue::Int(1);
791 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
792
793 let statements = builder.build_disconnect_sql(&parent_id, &filters);
794
795 assert_eq!(statements.len(), 1);
796 let (sql, params) = &statements[0];
797 assert!(sql.contains("UPDATE"));
798 assert!(sql.contains("SET"));
799 assert!(sql.contains("NULL"));
800 assert_eq!(params.len(), 2);
801 }
802
803 #[test]
804 fn test_builder_many_to_many_connect() {
805 let builder = NestedWriteBuilder::many_to_many(
806 "posts",
807 vec!["id".to_string()],
808 "tags",
809 JoinTableInfo {
810 table_name: "post_tags".to_string(),
811 parent_column: "post_id".to_string(),
812 related_column: "tag_id".to_string(),
813 },
814 );
815
816 let parent_id = FilterValue::Int(1);
817 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
818
819 let statements = builder.build_connect_sql::<TagModel>(&parent_id, &filters);
820
821 assert_eq!(statements.len(), 1);
822 let (sql, _params) = &statements[0];
823 assert!(sql.contains("INSERT INTO"));
824 assert!(sql.contains("post_tags"));
825 assert!(sql.contains("ON CONFLICT DO NOTHING"));
826 }
827
828 #[test]
829 fn test_builder_create() {
830 let builder =
831 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
832
833 let parent_id = FilterValue::Int(1);
834 let creates = vec![NestedCreateData::<TestModel>::from_pairs([(
835 "title",
836 FilterValue::String("New Post".to_string()),
837 )])];
838
839 let statements = builder.build_create_sql::<TestModel>(&parent_id, &creates);
840
841 assert_eq!(statements.len(), 1);
842 let (sql, params) = &statements[0];
843 assert!(sql.contains("INSERT INTO"));
844 assert!(sql.contains("posts"));
845 assert!(sql.contains("RETURNING"));
846 assert_eq!(params.len(), 2); }
848
849 #[test]
850 fn test_builder_set() {
851 let builder =
852 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
853
854 let parent_id = FilterValue::Int(1);
855 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
856
857 let statements = builder.build_set_sql::<TestModel>(&parent_id, &filters);
858
859 assert!(statements.len() >= 2);
861
862 let (first_sql, _) = &statements[0];
864 assert!(first_sql.contains("UPDATE"));
865 assert!(first_sql.contains("NULL"));
866 }
867
868 #[test]
869 fn test_nested_write_operations() {
870 let mut ops = NestedWriteOperations::new();
871 assert!(ops.is_empty());
872 assert_eq!(ops.len(), 0);
873
874 ops.add_pre("SELECT 1".to_string(), vec![]);
875 ops.add_post("SELECT 2".to_string(), vec![]);
876
877 assert!(!ops.is_empty());
878 assert_eq!(ops.len(), 2);
879 }
880
881 #[test]
882 fn test_nested_create_or_connect() {
883 let create_data: NestedCreateData<TestModel> =
884 NestedCreateData::from_pairs([("title", FilterValue::String("New Post".to_string()))]);
885
886 let create_or_connect = NestedCreateOrConnectData::new(
887 Filter::Equals("title".into(), FilterValue::String("Existing".to_string())),
888 create_data,
889 );
890
891 assert!(matches!(create_or_connect.filter, Filter::Equals(..)));
892 assert_eq!(create_or_connect.create.data.len(), 1);
893 }
894
895 #[test]
896 fn test_nested_update_data() {
897 let update: NestedUpdateData<TestModel> = NestedUpdateData::from_pairs(
898 Filter::Equals("id".into(), FilterValue::Int(1)),
899 [("title", FilterValue::String("Updated".to_string()))],
900 );
901
902 assert!(matches!(update.filter, Filter::Equals(..)));
903 assert_eq!(update.data.len(), 1);
904 assert_eq!(update.data[0].0, "title");
905 }
906
907 #[test]
908 fn test_nested_upsert_data() {
909 let create: NestedCreateData<TestModel> =
910 NestedCreateData::from_pairs([("title", FilterValue::String("New".to_string()))]);
911
912 let upsert: NestedUpsertData<TestModel> = NestedUpsertData::new(
913 Filter::Equals("id".into(), FilterValue::Int(1)),
914 create,
915 vec![(
916 "title".to_string(),
917 FilterValue::String("Updated".to_string()),
918 )],
919 );
920
921 assert!(matches!(upsert.filter, Filter::Equals(..)));
922 assert_eq!(upsert.create.data.len(), 1);
923 assert_eq!(upsert.update.len(), 1);
924 }
925}