1use std::fmt::Debug;
54use std::marker::PhantomData;
55
56use crate::filter::{Filter, FilterValue};
57use crate::sql::quote_identifier;
58use crate::traits::Model;
59
60#[derive(Debug, Clone)]
62pub enum NestedWrite<T: Model> {
63 Create(Vec<NestedCreateData<T>>),
65 CreateOrConnect(Vec<NestedCreateOrConnectData<T>>),
67 Connect(Vec<Filter>),
69 Disconnect(Vec<Filter>),
71 Set(Vec<Filter>),
73 Delete(Vec<Filter>),
75 Update(Vec<NestedUpdateData<T>>),
77 Upsert(Vec<NestedUpsertData<T>>),
79 UpdateMany(NestedUpdateManyData<T>),
81 DeleteMany(Filter),
83}
84
85impl<T: Model> NestedWrite<T> {
86 pub fn create(data: NestedCreateData<T>) -> Self {
88 Self::Create(vec![data])
89 }
90
91 pub fn create_many(data: Vec<NestedCreateData<T>>) -> Self {
93 Self::Create(data)
94 }
95
96 pub fn connect_one(filter: impl Into<Filter>) -> Self {
98 Self::Connect(vec![filter.into()])
99 }
100
101 pub fn connect(filters: Vec<impl Into<Filter>>) -> Self {
103 Self::Connect(filters.into_iter().map(Into::into).collect())
104 }
105
106 pub fn disconnect_one(filter: impl Into<Filter>) -> Self {
108 Self::Disconnect(vec![filter.into()])
109 }
110
111 pub fn disconnect(filters: Vec<impl Into<Filter>>) -> Self {
113 Self::Disconnect(filters.into_iter().map(Into::into).collect())
114 }
115
116 pub fn set(filters: Vec<impl Into<Filter>>) -> Self {
118 Self::Set(filters.into_iter().map(Into::into).collect())
119 }
120
121 pub fn delete(filters: Vec<impl Into<Filter>>) -> Self {
123 Self::Delete(filters.into_iter().map(Into::into).collect())
124 }
125
126 pub fn delete_many(filter: impl Into<Filter>) -> Self {
128 Self::DeleteMany(filter.into())
129 }
130}
131
132#[derive(Debug, Clone)]
134pub struct NestedCreateData<T: Model> {
135 pub data: Vec<(String, FilterValue)>,
137 _model: PhantomData<T>,
139}
140
141impl<T: Model> NestedCreateData<T> {
142 pub fn new(data: Vec<(String, FilterValue)>) -> Self {
144 Self {
145 data,
146 _model: PhantomData,
147 }
148 }
149
150 pub fn from_pairs(pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>) -> Self {
152 Self::new(
153 pairs
154 .into_iter()
155 .map(|(k, v)| (k.into(), v.into()))
156 .collect(),
157 )
158 }
159}
160
161impl<T: Model> Default for NestedCreateData<T> {
162 fn default() -> Self {
163 Self::new(Vec::new())
164 }
165}
166
167#[derive(Debug, Clone)]
169pub struct NestedCreateOrConnectData<T: Model> {
170 pub filter: Filter,
172 pub create: NestedCreateData<T>,
174}
175
176impl<T: Model> NestedCreateOrConnectData<T> {
177 pub fn new(filter: impl Into<Filter>, create: NestedCreateData<T>) -> Self {
179 Self {
180 filter: filter.into(),
181 create,
182 }
183 }
184}
185
186#[derive(Debug, Clone)]
188pub struct NestedUpdateData<T: Model> {
189 pub filter: Filter,
191 pub data: Vec<(String, FilterValue)>,
193 _model: PhantomData<T>,
195}
196
197impl<T: Model> NestedUpdateData<T> {
198 pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
200 Self {
201 filter: filter.into(),
202 data,
203 _model: PhantomData,
204 }
205 }
206
207 pub fn from_pairs(
209 filter: impl Into<Filter>,
210 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
211 ) -> Self {
212 Self::new(
213 filter,
214 pairs
215 .into_iter()
216 .map(|(k, v)| (k.into(), v.into()))
217 .collect(),
218 )
219 }
220}
221
222#[derive(Debug, Clone)]
224pub struct NestedUpsertData<T: Model> {
225 pub filter: Filter,
227 pub create: NestedCreateData<T>,
229 pub update: Vec<(String, FilterValue)>,
231 _model: PhantomData<T>,
233}
234
235impl<T: Model> NestedUpsertData<T> {
236 pub fn new(
238 filter: impl Into<Filter>,
239 create: NestedCreateData<T>,
240 update: Vec<(String, FilterValue)>,
241 ) -> Self {
242 Self {
243 filter: filter.into(),
244 create,
245 update,
246 _model: PhantomData,
247 }
248 }
249}
250
251#[derive(Debug, Clone)]
253pub struct NestedUpdateManyData<T: Model> {
254 pub filter: Filter,
256 pub data: Vec<(String, FilterValue)>,
258 _model: PhantomData<T>,
260}
261
262impl<T: Model> NestedUpdateManyData<T> {
263 pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
265 Self {
266 filter: filter.into(),
267 data,
268 _model: PhantomData,
269 }
270 }
271}
272
273#[derive(Debug)]
275pub struct NestedWriteBuilder {
276 parent_table: String,
278 parent_pk: Vec<String>,
280 related_table: String,
282 foreign_key: String,
284 is_one_to_many: bool,
286 join_table: Option<JoinTableInfo>,
288}
289
290#[derive(Debug, Clone)]
292pub struct JoinTableInfo {
293 pub table_name: String,
295 pub parent_column: String,
297 pub related_column: String,
299}
300
301impl NestedWriteBuilder {
302 pub fn one_to_many(
304 parent_table: impl Into<String>,
305 parent_pk: Vec<String>,
306 related_table: impl Into<String>,
307 foreign_key: impl Into<String>,
308 ) -> Self {
309 Self {
310 parent_table: parent_table.into(),
311 parent_pk,
312 related_table: related_table.into(),
313 foreign_key: foreign_key.into(),
314 is_one_to_many: true,
315 join_table: None,
316 }
317 }
318
319 pub fn many_to_many(
321 parent_table: impl Into<String>,
322 parent_pk: Vec<String>,
323 related_table: impl Into<String>,
324 join_table: JoinTableInfo,
325 ) -> Self {
326 Self {
327 parent_table: parent_table.into(),
328 parent_pk,
329 related_table: related_table.into(),
330 foreign_key: String::new(), is_one_to_many: false,
332 join_table: Some(join_table),
333 }
334 }
335
336 pub fn build_connect_sql<T: Model>(
338 &self,
339 parent_id: &FilterValue,
340 filters: &[Filter],
341 ) -> Vec<(String, Vec<FilterValue>)> {
342 let mut statements = Vec::new();
343
344 if self.is_one_to_many {
345 for filter in filters {
347 let (where_sql, mut params) = filter.to_sql(1);
348 let sql = format!(
349 "UPDATE {} SET {} = ${} WHERE {}",
350 quote_identifier(&self.related_table),
351 quote_identifier(&self.foreign_key),
352 params.len() + 1,
353 where_sql
354 );
355 params.push(parent_id.clone());
356 statements.push((sql, params));
357 }
358 } else if let Some(join) = &self.join_table {
359 for filter in filters {
362 let (where_sql, mut params) = filter.to_sql(1);
363
364 let select_sql = format!(
366 "SELECT {} FROM {} WHERE {}",
367 quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
368 quote_identifier(&self.related_table),
369 where_sql
370 );
371
372 let insert_sql = format!(
374 "INSERT INTO {} ({}, {}) SELECT ${}, {} FROM {} WHERE {} ON CONFLICT DO NOTHING",
375 quote_identifier(&join.table_name),
376 quote_identifier(&join.parent_column),
377 quote_identifier(&join.related_column),
378 params.len() + 1,
379 quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
380 quote_identifier(&self.related_table),
381 where_sql
382 );
383 params.push(parent_id.clone());
384 statements.push((insert_sql, params));
385 let _ = select_sql;
387 }
388 }
389
390 statements
391 }
392
393 pub fn build_disconnect_sql(
395 &self,
396 parent_id: &FilterValue,
397 filters: &[Filter],
398 ) -> Vec<(String, Vec<FilterValue>)> {
399 let mut statements = Vec::new();
400
401 if self.is_one_to_many {
402 for filter in filters {
404 let (where_sql, mut params) = filter.to_sql(1);
405 let sql = format!(
406 "UPDATE {} SET {} = NULL WHERE {} AND {} = ${}",
407 quote_identifier(&self.related_table),
408 quote_identifier(&self.foreign_key),
409 where_sql,
410 quote_identifier(&self.foreign_key),
411 params.len() + 1
412 );
413 params.push(parent_id.clone());
414 statements.push((sql, params));
415 }
416 } else if let Some(join) = &self.join_table {
417 for filter in filters {
419 let (where_sql, mut params) = filter.to_sql(2);
420 let sql = format!(
421 "DELETE FROM {} WHERE {} = $1 AND {} IN (SELECT id FROM {} WHERE {})",
422 quote_identifier(&join.table_name),
423 quote_identifier(&join.parent_column),
424 quote_identifier(&join.related_column),
425 quote_identifier(&self.related_table),
426 where_sql
427 );
428 let mut final_params = vec![parent_id.clone()];
429 final_params.extend(params);
430 params = final_params;
431 statements.push((sql, params));
432 }
433 }
434
435 statements
436 }
437
438 pub fn build_set_sql<T: Model>(
440 &self,
441 parent_id: &FilterValue,
442 filters: &[Filter],
443 ) -> Vec<(String, Vec<FilterValue>)> {
444 let mut statements = Vec::new();
445
446 if self.is_one_to_many {
448 let sql = format!(
449 "UPDATE {} SET {} = NULL WHERE {} = $1",
450 quote_identifier(&self.related_table),
451 quote_identifier(&self.foreign_key),
452 quote_identifier(&self.foreign_key)
453 );
454 statements.push((sql, vec![parent_id.clone()]));
455 } else if let Some(join) = &self.join_table {
456 let sql = format!(
457 "DELETE FROM {} WHERE {} = $1",
458 quote_identifier(&join.table_name),
459 quote_identifier(&join.parent_column)
460 );
461 statements.push((sql, vec![parent_id.clone()]));
462 }
463
464 statements.extend(self.build_connect_sql::<T>(parent_id, filters));
466
467 statements
468 }
469
470 pub fn build_create_sql<T: Model>(
472 &self,
473 parent_id: &FilterValue,
474 creates: &[NestedCreateData<T>],
475 ) -> Vec<(String, Vec<FilterValue>)> {
476 let mut statements = Vec::new();
477
478 for create in creates {
479 let mut columns: Vec<String> = create.data.iter().map(|(k, _)| k.clone()).collect();
480 let mut values: Vec<FilterValue> = create.data.iter().map(|(_, v)| v.clone()).collect();
481
482 columns.push(self.foreign_key.clone());
484 values.push(parent_id.clone());
485
486 let placeholders: Vec<String> = (1..=values.len())
487 .map(|i| format!("${}", i))
488 .collect();
489
490 let sql = format!(
491 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
492 quote_identifier(&self.related_table),
493 columns.iter().map(|c| quote_identifier(c)).collect::<Vec<_>>().join(", "),
494 placeholders.join(", ")
495 );
496
497 statements.push((sql, values));
498 }
499
500 statements
501 }
502
503 pub fn build_delete_sql(
505 &self,
506 parent_id: &FilterValue,
507 filters: &[Filter],
508 ) -> Vec<(String, Vec<FilterValue>)> {
509 let mut statements = Vec::new();
510
511 for filter in filters {
512 let (where_sql, mut params) = filter.to_sql(1);
513 let sql = format!(
514 "DELETE FROM {} WHERE {} AND {} = ${}",
515 quote_identifier(&self.related_table),
516 where_sql,
517 quote_identifier(&self.foreign_key),
518 params.len() + 1
519 );
520 params.push(parent_id.clone());
521 statements.push((sql, params));
522 }
523
524 statements
525 }
526}
527
528#[derive(Debug, Default)]
530pub struct NestedWriteOperations {
531 pub pre_statements: Vec<(String, Vec<FilterValue>)>,
533 pub post_statements: Vec<(String, Vec<FilterValue>)>,
535}
536
537impl NestedWriteOperations {
538 pub fn new() -> Self {
540 Self::default()
541 }
542
543 pub fn add_pre(&mut self, sql: String, params: Vec<FilterValue>) {
545 self.pre_statements.push((sql, params));
546 }
547
548 pub fn add_post(&mut self, sql: String, params: Vec<FilterValue>) {
550 self.post_statements.push((sql, params));
551 }
552
553 pub fn extend(&mut self, other: Self) {
555 self.pre_statements.extend(other.pre_statements);
556 self.post_statements.extend(other.post_statements);
557 }
558
559 pub fn is_empty(&self) -> bool {
561 self.pre_statements.is_empty() && self.post_statements.is_empty()
562 }
563
564 pub fn len(&self) -> usize {
566 self.pre_statements.len() + self.post_statements.len()
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573
574 struct TestModel;
575
576 impl Model for TestModel {
577 const MODEL_NAME: &'static str = "Post";
578 const TABLE_NAME: &'static str = "posts";
579 const PRIMARY_KEY: &'static [&'static str] = &["id"];
580 const COLUMNS: &'static [&'static str] = &["id", "title", "user_id"];
581 }
582
583 struct TagModel;
584
585 impl Model for TagModel {
586 const MODEL_NAME: &'static str = "Tag";
587 const TABLE_NAME: &'static str = "tags";
588 const PRIMARY_KEY: &'static [&'static str] = &["id"];
589 const COLUMNS: &'static [&'static str] = &["id", "name"];
590 }
591
592 #[test]
593 fn test_nested_create_data() {
594 let data: NestedCreateData<TestModel> = NestedCreateData::from_pairs([
595 ("title", FilterValue::String("Test Post".to_string())),
596 ]);
597
598 assert_eq!(data.data.len(), 1);
599 assert_eq!(data.data[0].0, "title");
600 }
601
602 #[test]
603 fn test_nested_write_create() {
604 let data: NestedCreateData<TestModel> = NestedCreateData::from_pairs([
605 ("title", FilterValue::String("Test Post".to_string())),
606 ]);
607
608 let write: NestedWrite<TestModel> = NestedWrite::create(data);
609
610 match write {
611 NestedWrite::Create(creates) => assert_eq!(creates.len(), 1),
612 _ => panic!("Expected Create variant"),
613 }
614 }
615
616 #[test]
617 fn test_nested_write_connect() {
618 let write: NestedWrite<TestModel> = NestedWrite::connect(vec![
619 Filter::Equals("id".into(), FilterValue::Int(1)),
620 Filter::Equals("id".into(), FilterValue::Int(2)),
621 ]);
622
623 match write {
624 NestedWrite::Connect(filters) => assert_eq!(filters.len(), 2),
625 _ => panic!("Expected Connect variant"),
626 }
627 }
628
629 #[test]
630 fn test_nested_write_disconnect() {
631 let write: NestedWrite<TestModel> = NestedWrite::disconnect_one(
632 Filter::Equals("id".into(), FilterValue::Int(1)),
633 );
634
635 match write {
636 NestedWrite::Disconnect(filters) => assert_eq!(filters.len(), 1),
637 _ => panic!("Expected Disconnect variant"),
638 }
639 }
640
641 #[test]
642 fn test_nested_write_set() {
643 let write: NestedWrite<TestModel> = NestedWrite::set(vec![
644 Filter::Equals("id".into(), FilterValue::Int(1)),
645 ]);
646
647 match write {
648 NestedWrite::Set(filters) => assert_eq!(filters.len(), 1),
649 _ => panic!("Expected Set variant"),
650 }
651 }
652
653 #[test]
654 fn test_builder_one_to_many_connect() {
655 let builder = NestedWriteBuilder::one_to_many(
656 "users",
657 vec!["id".to_string()],
658 "posts",
659 "user_id",
660 );
661
662 let parent_id = FilterValue::Int(1);
663 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
664
665 let statements = builder.build_connect_sql::<TestModel>(&parent_id, &filters);
666
667 assert_eq!(statements.len(), 1);
668 let (sql, params) = &statements[0];
669 assert!(sql.contains("UPDATE"));
670 assert!(sql.contains("posts"));
671 assert!(sql.contains("user_id"));
672 assert_eq!(params.len(), 2);
673 }
674
675 #[test]
676 fn test_builder_one_to_many_disconnect() {
677 let builder = NestedWriteBuilder::one_to_many(
678 "users",
679 vec!["id".to_string()],
680 "posts",
681 "user_id",
682 );
683
684 let parent_id = FilterValue::Int(1);
685 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
686
687 let statements = builder.build_disconnect_sql(&parent_id, &filters);
688
689 assert_eq!(statements.len(), 1);
690 let (sql, params) = &statements[0];
691 assert!(sql.contains("UPDATE"));
692 assert!(sql.contains("SET"));
693 assert!(sql.contains("NULL"));
694 assert_eq!(params.len(), 2);
695 }
696
697 #[test]
698 fn test_builder_many_to_many_connect() {
699 let builder = NestedWriteBuilder::many_to_many(
700 "posts",
701 vec!["id".to_string()],
702 "tags",
703 JoinTableInfo {
704 table_name: "post_tags".to_string(),
705 parent_column: "post_id".to_string(),
706 related_column: "tag_id".to_string(),
707 },
708 );
709
710 let parent_id = FilterValue::Int(1);
711 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
712
713 let statements = builder.build_connect_sql::<TagModel>(&parent_id, &filters);
714
715 assert_eq!(statements.len(), 1);
716 let (sql, _params) = &statements[0];
717 assert!(sql.contains("INSERT INTO"));
718 assert!(sql.contains("post_tags"));
719 assert!(sql.contains("ON CONFLICT DO NOTHING"));
720 }
721
722 #[test]
723 fn test_builder_create() {
724 let builder = NestedWriteBuilder::one_to_many(
725 "users",
726 vec!["id".to_string()],
727 "posts",
728 "user_id",
729 );
730
731 let parent_id = FilterValue::Int(1);
732 let creates = vec![NestedCreateData::<TestModel>::from_pairs([
733 ("title", FilterValue::String("New Post".to_string())),
734 ])];
735
736 let statements = builder.build_create_sql::<TestModel>(&parent_id, &creates);
737
738 assert_eq!(statements.len(), 1);
739 let (sql, params) = &statements[0];
740 assert!(sql.contains("INSERT INTO"));
741 assert!(sql.contains("posts"));
742 assert!(sql.contains("RETURNING"));
743 assert_eq!(params.len(), 2); }
745
746 #[test]
747 fn test_builder_set() {
748 let builder = NestedWriteBuilder::one_to_many(
749 "users",
750 vec!["id".to_string()],
751 "posts",
752 "user_id",
753 );
754
755 let parent_id = FilterValue::Int(1);
756 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
757
758 let statements = builder.build_set_sql::<TestModel>(&parent_id, &filters);
759
760 assert!(statements.len() >= 2);
762
763 let (first_sql, _) = &statements[0];
765 assert!(first_sql.contains("UPDATE"));
766 assert!(first_sql.contains("NULL"));
767 }
768
769 #[test]
770 fn test_nested_write_operations() {
771 let mut ops = NestedWriteOperations::new();
772 assert!(ops.is_empty());
773 assert_eq!(ops.len(), 0);
774
775 ops.add_pre("SELECT 1".to_string(), vec![]);
776 ops.add_post("SELECT 2".to_string(), vec![]);
777
778 assert!(!ops.is_empty());
779 assert_eq!(ops.len(), 2);
780 }
781
782 #[test]
783 fn test_nested_create_or_connect() {
784 let create_data: NestedCreateData<TestModel> = NestedCreateData::from_pairs([
785 ("title", FilterValue::String("New Post".to_string())),
786 ]);
787
788 let create_or_connect = NestedCreateOrConnectData::new(
789 Filter::Equals("title".into(), FilterValue::String("Existing".to_string())),
790 create_data,
791 );
792
793 assert!(matches!(create_or_connect.filter, Filter::Equals(..)));
794 assert_eq!(create_or_connect.create.data.len(), 1);
795 }
796
797 #[test]
798 fn test_nested_update_data() {
799 let update: NestedUpdateData<TestModel> = NestedUpdateData::from_pairs(
800 Filter::Equals("id".into(), FilterValue::Int(1)),
801 [("title", FilterValue::String("Updated".to_string()))],
802 );
803
804 assert!(matches!(update.filter, Filter::Equals(..)));
805 assert_eq!(update.data.len(), 1);
806 assert_eq!(update.data[0].0, "title");
807 }
808
809 #[test]
810 fn test_nested_upsert_data() {
811 let create: NestedCreateData<TestModel> = NestedCreateData::from_pairs([
812 ("title", FilterValue::String("New".to_string())),
813 ]);
814
815 let upsert: NestedUpsertData<TestModel> = NestedUpsertData::new(
816 Filter::Equals("id".into(), FilterValue::Int(1)),
817 create,
818 vec![("title".to_string(), FilterValue::String("Updated".to_string()))],
819 );
820
821 assert!(matches!(upsert.filter, Filter::Equals(..)));
822 assert_eq!(upsert.create.data.len(), 1);
823 assert_eq!(upsert.update.len(), 1);
824 }
825}
826