1use std::fmt::Debug;
54use std::marker::PhantomData;
55
56use crate::filter::{Filter, FilterValue};
57use crate::sql::quote_identifier;
58use crate::traits::Model;
59
60#[derive(Debug, Clone)]
62pub enum NestedWrite<T: Model> {
63 Create(Vec<NestedCreateData<T>>),
65 CreateOrConnect(Vec<NestedCreateOrConnectData<T>>),
67 Connect(Vec<Filter>),
69 Disconnect(Vec<Filter>),
71 Set(Vec<Filter>),
73 Delete(Vec<Filter>),
75 Update(Vec<NestedUpdateData<T>>),
77 Upsert(Vec<NestedUpsertData<T>>),
79 UpdateMany(NestedUpdateManyData<T>),
81 DeleteMany(Filter),
83}
84
85impl<T: Model> NestedWrite<T> {
86 pub fn create(data: NestedCreateData<T>) -> Self {
88 Self::Create(vec![data])
89 }
90
91 pub fn create_many(data: Vec<NestedCreateData<T>>) -> Self {
93 Self::Create(data)
94 }
95
96 pub fn connect_one(filter: impl Into<Filter>) -> Self {
98 Self::Connect(vec![filter.into()])
99 }
100
101 pub fn connect(filters: Vec<impl Into<Filter>>) -> Self {
103 Self::Connect(filters.into_iter().map(Into::into).collect())
104 }
105
106 pub fn disconnect_one(filter: impl Into<Filter>) -> Self {
108 Self::Disconnect(vec![filter.into()])
109 }
110
111 pub fn disconnect(filters: Vec<impl Into<Filter>>) -> Self {
113 Self::Disconnect(filters.into_iter().map(Into::into).collect())
114 }
115
116 pub fn set(filters: Vec<impl Into<Filter>>) -> Self {
118 Self::Set(filters.into_iter().map(Into::into).collect())
119 }
120
121 pub fn delete(filters: Vec<impl Into<Filter>>) -> Self {
123 Self::Delete(filters.into_iter().map(Into::into).collect())
124 }
125
126 pub fn delete_many(filter: impl Into<Filter>) -> Self {
128 Self::DeleteMany(filter.into())
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct NestedCreateData<T: Model> {
135 pub data: Vec<(String, FilterValue)>,
137 _model: PhantomData<T>,
139}
140
141impl<T: Model> NestedCreateData<T> {
142 pub fn new(data: Vec<(String, FilterValue)>) -> Self {
144 Self {
145 data,
146 _model: PhantomData,
147 }
148 }
149
150 pub fn from_pairs(
152 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
153 ) -> Self {
154 Self::new(
155 pairs
156 .into_iter()
157 .map(|(k, v)| (k.into(), v.into()))
158 .collect(),
159 )
160 }
161}
162
163impl<T: Model> Default for NestedCreateData<T> {
164 fn default() -> Self {
165 Self::new(Vec::new())
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct NestedCreateOrConnectData<T: Model> {
172 pub filter: Filter,
174 pub create: NestedCreateData<T>,
176}
177
178impl<T: Model> NestedCreateOrConnectData<T> {
179 pub fn new(filter: impl Into<Filter>, create: NestedCreateData<T>) -> Self {
181 Self {
182 filter: filter.into(),
183 create,
184 }
185 }
186}
187
188#[derive(Debug, Clone)]
190pub struct NestedUpdateData<T: Model> {
191 pub filter: Filter,
193 pub data: Vec<(String, FilterValue)>,
195 _model: PhantomData<T>,
197}
198
199impl<T: Model> NestedUpdateData<T> {
200 pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
202 Self {
203 filter: filter.into(),
204 data,
205 _model: PhantomData,
206 }
207 }
208
209 pub fn from_pairs(
211 filter: impl Into<Filter>,
212 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
213 ) -> Self {
214 Self::new(
215 filter,
216 pairs
217 .into_iter()
218 .map(|(k, v)| (k.into(), v.into()))
219 .collect(),
220 )
221 }
222}
223
224#[derive(Debug, Clone)]
226pub struct NestedUpsertData<T: Model> {
227 pub filter: Filter,
229 pub create: NestedCreateData<T>,
231 pub update: Vec<(String, FilterValue)>,
233 _model: PhantomData<T>,
235}
236
237impl<T: Model> NestedUpsertData<T> {
238 pub fn new(
240 filter: impl Into<Filter>,
241 create: NestedCreateData<T>,
242 update: Vec<(String, FilterValue)>,
243 ) -> Self {
244 Self {
245 filter: filter.into(),
246 create,
247 update,
248 _model: PhantomData,
249 }
250 }
251}
252
253#[derive(Debug, Clone)]
255pub struct NestedUpdateManyData<T: Model> {
256 pub filter: Filter,
258 pub data: Vec<(String, FilterValue)>,
260 _model: PhantomData<T>,
262}
263
264impl<T: Model> NestedUpdateManyData<T> {
265 pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
267 Self {
268 filter: filter.into(),
269 data,
270 _model: PhantomData,
271 }
272 }
273}
274
275#[derive(Debug)]
277pub struct NestedWriteBuilder {
278 parent_table: String,
280 parent_pk: Vec<String>,
282 related_table: String,
284 foreign_key: String,
286 is_one_to_many: bool,
288 join_table: Option<JoinTableInfo>,
290}
291
292#[derive(Debug, Clone)]
294pub struct JoinTableInfo {
295 pub table_name: String,
297 pub parent_column: String,
299 pub related_column: String,
301}
302
303impl NestedWriteBuilder {
304 pub fn one_to_many(
306 parent_table: impl Into<String>,
307 parent_pk: Vec<String>,
308 related_table: impl Into<String>,
309 foreign_key: impl Into<String>,
310 ) -> Self {
311 Self {
312 parent_table: parent_table.into(),
313 parent_pk,
314 related_table: related_table.into(),
315 foreign_key: foreign_key.into(),
316 is_one_to_many: true,
317 join_table: None,
318 }
319 }
320
321 pub fn many_to_many(
323 parent_table: impl Into<String>,
324 parent_pk: Vec<String>,
325 related_table: impl Into<String>,
326 join_table: JoinTableInfo,
327 ) -> Self {
328 Self {
329 parent_table: parent_table.into(),
330 parent_pk,
331 related_table: related_table.into(),
332 foreign_key: String::new(), is_one_to_many: false,
334 join_table: Some(join_table),
335 }
336 }
337
338 pub fn build_connect_sql<T: Model>(
340 &self,
341 parent_id: &FilterValue,
342 filters: &[Filter],
343 ) -> Vec<(String, Vec<FilterValue>)> {
344 let mut statements = Vec::new();
345
346 if self.is_one_to_many {
347 for filter in filters {
349 let (where_sql, mut params) = filter.to_sql(1);
350 let sql = format!(
351 "UPDATE {} SET {} = ${} WHERE {}",
352 quote_identifier(&self.related_table),
353 quote_identifier(&self.foreign_key),
354 params.len() + 1,
355 where_sql
356 );
357 params.push(parent_id.clone());
358 statements.push((sql, params));
359 }
360 } else if let Some(join) = &self.join_table {
361 for filter in filters {
364 let (where_sql, mut params) = filter.to_sql(1);
365
366 let select_sql = format!(
368 "SELECT {} FROM {} WHERE {}",
369 quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
370 quote_identifier(&self.related_table),
371 where_sql
372 );
373
374 let insert_sql = format!(
376 "INSERT INTO {} ({}, {}) SELECT ${}, {} FROM {} WHERE {} ON CONFLICT DO NOTHING",
377 quote_identifier(&join.table_name),
378 quote_identifier(&join.parent_column),
379 quote_identifier(&join.related_column),
380 params.len() + 1,
381 quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
382 quote_identifier(&self.related_table),
383 where_sql
384 );
385 params.push(parent_id.clone());
386 statements.push((insert_sql, params));
387 let _ = select_sql;
389 }
390 }
391
392 statements
393 }
394
395 pub fn build_disconnect_sql(
397 &self,
398 parent_id: &FilterValue,
399 filters: &[Filter],
400 ) -> Vec<(String, Vec<FilterValue>)> {
401 let mut statements = Vec::new();
402
403 if self.is_one_to_many {
404 for filter in filters {
406 let (where_sql, mut params) = filter.to_sql(1);
407 let sql = format!(
408 "UPDATE {} SET {} = NULL WHERE {} AND {} = ${}",
409 quote_identifier(&self.related_table),
410 quote_identifier(&self.foreign_key),
411 where_sql,
412 quote_identifier(&self.foreign_key),
413 params.len() + 1
414 );
415 params.push(parent_id.clone());
416 statements.push((sql, params));
417 }
418 } else if let Some(join) = &self.join_table {
419 for filter in filters {
421 let (where_sql, mut params) = filter.to_sql(2);
422 let sql = format!(
423 "DELETE FROM {} WHERE {} = $1 AND {} IN (SELECT id FROM {} WHERE {})",
424 quote_identifier(&join.table_name),
425 quote_identifier(&join.parent_column),
426 quote_identifier(&join.related_column),
427 quote_identifier(&self.related_table),
428 where_sql
429 );
430 let mut final_params = vec![parent_id.clone()];
431 final_params.extend(params);
432 params = final_params;
433 statements.push((sql, params));
434 }
435 }
436
437 statements
438 }
439
440 pub fn build_set_sql<T: Model>(
442 &self,
443 parent_id: &FilterValue,
444 filters: &[Filter],
445 ) -> Vec<(String, Vec<FilterValue>)> {
446 let mut statements = Vec::new();
447
448 if self.is_one_to_many {
450 let sql = format!(
451 "UPDATE {} SET {} = NULL WHERE {} = $1",
452 quote_identifier(&self.related_table),
453 quote_identifier(&self.foreign_key),
454 quote_identifier(&self.foreign_key)
455 );
456 statements.push((sql, vec![parent_id.clone()]));
457 } else if let Some(join) = &self.join_table {
458 let sql = format!(
459 "DELETE FROM {} WHERE {} = $1",
460 quote_identifier(&join.table_name),
461 quote_identifier(&join.parent_column)
462 );
463 statements.push((sql, vec![parent_id.clone()]));
464 }
465
466 statements.extend(self.build_connect_sql::<T>(parent_id, filters));
468
469 statements
470 }
471
472 pub fn build_create_sql<T: Model>(
474 &self,
475 parent_id: &FilterValue,
476 creates: &[NestedCreateData<T>],
477 ) -> Vec<(String, Vec<FilterValue>)> {
478 let mut statements = Vec::new();
479
480 for create in creates {
481 let mut columns: Vec<String> = create.data.iter().map(|(k, _)| k.clone()).collect();
482 let mut values: Vec<FilterValue> = create.data.iter().map(|(_, v)| v.clone()).collect();
483
484 columns.push(self.foreign_key.clone());
486 values.push(parent_id.clone());
487
488 let placeholders: Vec<String> = (1..=values.len()).map(|i| format!("${}", i)).collect();
489
490 let sql = format!(
491 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
492 quote_identifier(&self.related_table),
493 columns
494 .iter()
495 .map(|c| quote_identifier(c))
496 .collect::<Vec<_>>()
497 .join(", "),
498 placeholders.join(", ")
499 );
500
501 statements.push((sql, values));
502 }
503
504 statements
505 }
506
507 pub fn build_delete_sql(
509 &self,
510 parent_id: &FilterValue,
511 filters: &[Filter],
512 ) -> Vec<(String, Vec<FilterValue>)> {
513 let mut statements = Vec::new();
514
515 for filter in filters {
516 let (where_sql, mut params) = filter.to_sql(1);
517 let sql = format!(
518 "DELETE FROM {} WHERE {} AND {} = ${}",
519 quote_identifier(&self.related_table),
520 where_sql,
521 quote_identifier(&self.foreign_key),
522 params.len() + 1
523 );
524 params.push(parent_id.clone());
525 statements.push((sql, params));
526 }
527
528 statements
529 }
530}
531
532#[derive(Debug, Default)]
534pub struct NestedWriteOperations {
535 pub pre_statements: Vec<(String, Vec<FilterValue>)>,
537 pub post_statements: Vec<(String, Vec<FilterValue>)>,
539}
540
541impl NestedWriteOperations {
542 pub fn new() -> Self {
544 Self::default()
545 }
546
547 pub fn add_pre(&mut self, sql: String, params: Vec<FilterValue>) {
549 self.pre_statements.push((sql, params));
550 }
551
552 pub fn add_post(&mut self, sql: String, params: Vec<FilterValue>) {
554 self.post_statements.push((sql, params));
555 }
556
557 pub fn extend(&mut self, other: Self) {
559 self.pre_statements.extend(other.pre_statements);
560 self.post_statements.extend(other.post_statements);
561 }
562
563 pub fn is_empty(&self) -> bool {
565 self.pre_statements.is_empty() && self.post_statements.is_empty()
566 }
567
568 pub fn len(&self) -> usize {
570 self.pre_statements.len() + self.post_statements.len()
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577
578 struct TestModel;
579
580 impl Model for TestModel {
581 const MODEL_NAME: &'static str = "Post";
582 const TABLE_NAME: &'static str = "posts";
583 const PRIMARY_KEY: &'static [&'static str] = &["id"];
584 const COLUMNS: &'static [&'static str] = &["id", "title", "user_id"];
585 }
586
587 struct TagModel;
588
589 impl Model for TagModel {
590 const MODEL_NAME: &'static str = "Tag";
591 const TABLE_NAME: &'static str = "tags";
592 const PRIMARY_KEY: &'static [&'static str] = &["id"];
593 const COLUMNS: &'static [&'static str] = &["id", "name"];
594 }
595
596 #[test]
597 fn test_nested_create_data() {
598 let data: NestedCreateData<TestModel> =
599 NestedCreateData::from_pairs([("title", FilterValue::String("Test Post".to_string()))]);
600
601 assert_eq!(data.data.len(), 1);
602 assert_eq!(data.data[0].0, "title");
603 }
604
605 #[test]
606 fn test_nested_write_create() {
607 let data: NestedCreateData<TestModel> =
608 NestedCreateData::from_pairs([("title", FilterValue::String("Test Post".to_string()))]);
609
610 let write: NestedWrite<TestModel> = NestedWrite::create(data);
611
612 match write {
613 NestedWrite::Create(creates) => assert_eq!(creates.len(), 1),
614 _ => panic!("Expected Create variant"),
615 }
616 }
617
618 #[test]
619 fn test_nested_write_connect() {
620 let write: NestedWrite<TestModel> = NestedWrite::connect(vec![
621 Filter::Equals("id".into(), FilterValue::Int(1)),
622 Filter::Equals("id".into(), FilterValue::Int(2)),
623 ]);
624
625 match write {
626 NestedWrite::Connect(filters) => assert_eq!(filters.len(), 2),
627 _ => panic!("Expected Connect variant"),
628 }
629 }
630
631 #[test]
632 fn test_nested_write_disconnect() {
633 let write: NestedWrite<TestModel> =
634 NestedWrite::disconnect_one(Filter::Equals("id".into(), FilterValue::Int(1)));
635
636 match write {
637 NestedWrite::Disconnect(filters) => assert_eq!(filters.len(), 1),
638 _ => panic!("Expected Disconnect variant"),
639 }
640 }
641
642 #[test]
643 fn test_nested_write_set() {
644 let write: NestedWrite<TestModel> =
645 NestedWrite::set(vec![Filter::Equals("id".into(), FilterValue::Int(1))]);
646
647 match write {
648 NestedWrite::Set(filters) => assert_eq!(filters.len(), 1),
649 _ => panic!("Expected Set variant"),
650 }
651 }
652
653 #[test]
654 fn test_builder_one_to_many_connect() {
655 let builder =
656 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
657
658 let parent_id = FilterValue::Int(1);
659 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
660
661 let statements = builder.build_connect_sql::<TestModel>(&parent_id, &filters);
662
663 assert_eq!(statements.len(), 1);
664 let (sql, params) = &statements[0];
665 assert!(sql.contains("UPDATE"));
666 assert!(sql.contains("posts"));
667 assert!(sql.contains("user_id"));
668 assert_eq!(params.len(), 2);
669 }
670
671 #[test]
672 fn test_builder_one_to_many_disconnect() {
673 let builder =
674 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
675
676 let parent_id = FilterValue::Int(1);
677 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
678
679 let statements = builder.build_disconnect_sql(&parent_id, &filters);
680
681 assert_eq!(statements.len(), 1);
682 let (sql, params) = &statements[0];
683 assert!(sql.contains("UPDATE"));
684 assert!(sql.contains("SET"));
685 assert!(sql.contains("NULL"));
686 assert_eq!(params.len(), 2);
687 }
688
689 #[test]
690 fn test_builder_many_to_many_connect() {
691 let builder = NestedWriteBuilder::many_to_many(
692 "posts",
693 vec!["id".to_string()],
694 "tags",
695 JoinTableInfo {
696 table_name: "post_tags".to_string(),
697 parent_column: "post_id".to_string(),
698 related_column: "tag_id".to_string(),
699 },
700 );
701
702 let parent_id = FilterValue::Int(1);
703 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
704
705 let statements = builder.build_connect_sql::<TagModel>(&parent_id, &filters);
706
707 assert_eq!(statements.len(), 1);
708 let (sql, _params) = &statements[0];
709 assert!(sql.contains("INSERT INTO"));
710 assert!(sql.contains("post_tags"));
711 assert!(sql.contains("ON CONFLICT DO NOTHING"));
712 }
713
714 #[test]
715 fn test_builder_create() {
716 let builder =
717 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
718
719 let parent_id = FilterValue::Int(1);
720 let creates = vec![NestedCreateData::<TestModel>::from_pairs([(
721 "title",
722 FilterValue::String("New Post".to_string()),
723 )])];
724
725 let statements = builder.build_create_sql::<TestModel>(&parent_id, &creates);
726
727 assert_eq!(statements.len(), 1);
728 let (sql, params) = &statements[0];
729 assert!(sql.contains("INSERT INTO"));
730 assert!(sql.contains("posts"));
731 assert!(sql.contains("RETURNING"));
732 assert_eq!(params.len(), 2); }
734
735 #[test]
736 fn test_builder_set() {
737 let builder =
738 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
739
740 let parent_id = FilterValue::Int(1);
741 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
742
743 let statements = builder.build_set_sql::<TestModel>(&parent_id, &filters);
744
745 assert!(statements.len() >= 2);
747
748 let (first_sql, _) = &statements[0];
750 assert!(first_sql.contains("UPDATE"));
751 assert!(first_sql.contains("NULL"));
752 }
753
754 #[test]
755 fn test_nested_write_operations() {
756 let mut ops = NestedWriteOperations::new();
757 assert!(ops.is_empty());
758 assert_eq!(ops.len(), 0);
759
760 ops.add_pre("SELECT 1".to_string(), vec![]);
761 ops.add_post("SELECT 2".to_string(), vec![]);
762
763 assert!(!ops.is_empty());
764 assert_eq!(ops.len(), 2);
765 }
766
767 #[test]
768 fn test_nested_create_or_connect() {
769 let create_data: NestedCreateData<TestModel> =
770 NestedCreateData::from_pairs([("title", FilterValue::String("New Post".to_string()))]);
771
772 let create_or_connect = NestedCreateOrConnectData::new(
773 Filter::Equals("title".into(), FilterValue::String("Existing".to_string())),
774 create_data,
775 );
776
777 assert!(matches!(create_or_connect.filter, Filter::Equals(..)));
778 assert_eq!(create_or_connect.create.data.len(), 1);
779 }
780
781 #[test]
782 fn test_nested_update_data() {
783 let update: NestedUpdateData<TestModel> = NestedUpdateData::from_pairs(
784 Filter::Equals("id".into(), FilterValue::Int(1)),
785 [("title", FilterValue::String("Updated".to_string()))],
786 );
787
788 assert!(matches!(update.filter, Filter::Equals(..)));
789 assert_eq!(update.data.len(), 1);
790 assert_eq!(update.data[0].0, "title");
791 }
792
793 #[test]
794 fn test_nested_upsert_data() {
795 let create: NestedCreateData<TestModel> =
796 NestedCreateData::from_pairs([("title", FilterValue::String("New".to_string()))]);
797
798 let upsert: NestedUpsertData<TestModel> = NestedUpsertData::new(
799 Filter::Equals("id".into(), FilterValue::Int(1)),
800 create,
801 vec![(
802 "title".to_string(),
803 FilterValue::String("Updated".to_string()),
804 )],
805 );
806
807 assert!(matches!(upsert.filter, Filter::Equals(..)));
808 assert_eq!(upsert.create.data.len(), 1);
809 assert_eq!(upsert.update.len(), 1);
810 }
811}