1#![allow(dead_code)]
2
3use std::fmt::Debug;
56use std::marker::PhantomData;
57
58use crate::filter::{Filter, FilterValue};
59use crate::sql::quote_identifier;
60use crate::traits::Model;
61
62#[derive(Debug, Clone)]
64pub enum NestedWrite<T: Model> {
65 Create(Vec<NestedCreateData<T>>),
67 CreateOrConnect(Vec<NestedCreateOrConnectData<T>>),
69 Connect(Vec<Filter>),
71 Disconnect(Vec<Filter>),
73 Set(Vec<Filter>),
75 Delete(Vec<Filter>),
77 Update(Vec<NestedUpdateData<T>>),
79 Upsert(Vec<NestedUpsertData<T>>),
81 UpdateMany(NestedUpdateManyData<T>),
83 DeleteMany(Filter),
85}
86
87impl<T: Model> NestedWrite<T> {
88 pub fn create(data: NestedCreateData<T>) -> Self {
90 Self::Create(vec![data])
91 }
92
93 pub fn create_many(data: Vec<NestedCreateData<T>>) -> Self {
95 Self::Create(data)
96 }
97
98 pub fn connect_one(filter: impl Into<Filter>) -> Self {
100 Self::Connect(vec![filter.into()])
101 }
102
103 pub fn connect(filters: Vec<impl Into<Filter>>) -> Self {
105 Self::Connect(filters.into_iter().map(Into::into).collect())
106 }
107
108 pub fn disconnect_one(filter: impl Into<Filter>) -> Self {
110 Self::Disconnect(vec![filter.into()])
111 }
112
113 pub fn disconnect(filters: Vec<impl Into<Filter>>) -> Self {
115 Self::Disconnect(filters.into_iter().map(Into::into).collect())
116 }
117
118 pub fn set(filters: Vec<impl Into<Filter>>) -> Self {
120 Self::Set(filters.into_iter().map(Into::into).collect())
121 }
122
123 pub fn delete(filters: Vec<impl Into<Filter>>) -> Self {
125 Self::Delete(filters.into_iter().map(Into::into).collect())
126 }
127
128 pub fn delete_many(filter: impl Into<Filter>) -> Self {
130 Self::DeleteMany(filter.into())
131 }
132}
133
134#[derive(Debug, Clone)]
136pub struct NestedCreateData<T: Model> {
137 pub data: Vec<(String, FilterValue)>,
139 _model: PhantomData<T>,
141}
142
143impl<T: Model> NestedCreateData<T> {
144 pub fn new(data: Vec<(String, FilterValue)>) -> Self {
146 Self {
147 data,
148 _model: PhantomData,
149 }
150 }
151
152 pub fn from_pairs(
154 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
155 ) -> Self {
156 Self::new(
157 pairs
158 .into_iter()
159 .map(|(k, v)| (k.into(), v.into()))
160 .collect(),
161 )
162 }
163}
164
165impl<T: Model> Default for NestedCreateData<T> {
166 fn default() -> Self {
167 Self::new(Vec::new())
168 }
169}
170
171#[derive(Debug, Clone)]
173pub struct NestedCreateOrConnectData<T: Model> {
174 pub filter: Filter,
176 pub create: NestedCreateData<T>,
178}
179
180impl<T: Model> NestedCreateOrConnectData<T> {
181 pub fn new(filter: impl Into<Filter>, create: NestedCreateData<T>) -> Self {
183 Self {
184 filter: filter.into(),
185 create,
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
192pub struct NestedUpdateData<T: Model> {
193 pub filter: Filter,
195 pub data: Vec<(String, FilterValue)>,
197 _model: PhantomData<T>,
199}
200
201impl<T: Model> NestedUpdateData<T> {
202 pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
204 Self {
205 filter: filter.into(),
206 data,
207 _model: PhantomData,
208 }
209 }
210
211 pub fn from_pairs(
213 filter: impl Into<Filter>,
214 pairs: impl IntoIterator<Item = (impl Into<String>, impl Into<FilterValue>)>,
215 ) -> Self {
216 Self::new(
217 filter,
218 pairs
219 .into_iter()
220 .map(|(k, v)| (k.into(), v.into()))
221 .collect(),
222 )
223 }
224}
225
226#[derive(Debug, Clone)]
228pub struct NestedUpsertData<T: Model> {
229 pub filter: Filter,
231 pub create: NestedCreateData<T>,
233 pub update: Vec<(String, FilterValue)>,
235 _model: PhantomData<T>,
237}
238
239impl<T: Model> NestedUpsertData<T> {
240 pub fn new(
242 filter: impl Into<Filter>,
243 create: NestedCreateData<T>,
244 update: Vec<(String, FilterValue)>,
245 ) -> Self {
246 Self {
247 filter: filter.into(),
248 create,
249 update,
250 _model: PhantomData,
251 }
252 }
253}
254
255#[derive(Debug, Clone)]
257pub struct NestedUpdateManyData<T: Model> {
258 pub filter: Filter,
260 pub data: Vec<(String, FilterValue)>,
262 _model: PhantomData<T>,
264}
265
266impl<T: Model> NestedUpdateManyData<T> {
267 pub fn new(filter: impl Into<Filter>, data: Vec<(String, FilterValue)>) -> Self {
269 Self {
270 filter: filter.into(),
271 data,
272 _model: PhantomData,
273 }
274 }
275}
276
277#[derive(Debug)]
279pub struct NestedWriteBuilder {
280 parent_table: String,
282 parent_pk: Vec<String>,
284 related_table: String,
286 foreign_key: String,
288 is_one_to_many: bool,
290 join_table: Option<JoinTableInfo>,
292}
293
294#[derive(Debug, Clone)]
296pub struct JoinTableInfo {
297 pub table_name: String,
299 pub parent_column: String,
301 pub related_column: String,
303}
304
305impl NestedWriteBuilder {
306 pub fn one_to_many(
308 parent_table: impl Into<String>,
309 parent_pk: Vec<String>,
310 related_table: impl Into<String>,
311 foreign_key: impl Into<String>,
312 ) -> Self {
313 Self {
314 parent_table: parent_table.into(),
315 parent_pk,
316 related_table: related_table.into(),
317 foreign_key: foreign_key.into(),
318 is_one_to_many: true,
319 join_table: None,
320 }
321 }
322
323 pub fn many_to_many(
325 parent_table: impl Into<String>,
326 parent_pk: Vec<String>,
327 related_table: impl Into<String>,
328 join_table: JoinTableInfo,
329 ) -> Self {
330 Self {
331 parent_table: parent_table.into(),
332 parent_pk,
333 related_table: related_table.into(),
334 foreign_key: String::new(), is_one_to_many: false,
336 join_table: Some(join_table),
337 }
338 }
339
340 pub fn build_connect_sql<T: Model>(
342 &self,
343 parent_id: &FilterValue,
344 filters: &[Filter],
345 ) -> Vec<(String, Vec<FilterValue>)> {
346 let mut statements = Vec::new();
347
348 if self.is_one_to_many {
349 for filter in filters {
351 let (where_sql, mut params) = filter.to_sql(1);
352 let sql = format!(
353 "UPDATE {} SET {} = ${} WHERE {}",
354 quote_identifier(&self.related_table),
355 quote_identifier(&self.foreign_key),
356 params.len() + 1,
357 where_sql
358 );
359 params.push(parent_id.clone());
360 statements.push((sql, params));
361 }
362 } else if let Some(join) = &self.join_table {
363 for filter in filters {
366 let (where_sql, mut params) = filter.to_sql(1);
367
368 let select_sql = format!(
370 "SELECT {} FROM {} WHERE {}",
371 quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
372 quote_identifier(&self.related_table),
373 where_sql
374 );
375
376 let insert_sql = format!(
378 "INSERT INTO {} ({}, {}) SELECT ${}, {} FROM {} WHERE {} ON CONFLICT DO NOTHING",
379 quote_identifier(&join.table_name),
380 quote_identifier(&join.parent_column),
381 quote_identifier(&join.related_column),
382 params.len() + 1,
383 quote_identifier(T::PRIMARY_KEY.first().unwrap_or(&"id")),
384 quote_identifier(&self.related_table),
385 where_sql
386 );
387 params.push(parent_id.clone());
388 statements.push((insert_sql, params));
389 let _ = select_sql;
391 }
392 }
393
394 statements
395 }
396
397 pub fn build_disconnect_sql(
399 &self,
400 parent_id: &FilterValue,
401 filters: &[Filter],
402 ) -> Vec<(String, Vec<FilterValue>)> {
403 let mut statements = Vec::new();
404
405 if self.is_one_to_many {
406 for filter in filters {
408 let (where_sql, mut params) = filter.to_sql(1);
409 let sql = format!(
410 "UPDATE {} SET {} = NULL WHERE {} AND {} = ${}",
411 quote_identifier(&self.related_table),
412 quote_identifier(&self.foreign_key),
413 where_sql,
414 quote_identifier(&self.foreign_key),
415 params.len() + 1
416 );
417 params.push(parent_id.clone());
418 statements.push((sql, params));
419 }
420 } else if let Some(join) = &self.join_table {
421 for filter in filters {
423 let (where_sql, mut params) = filter.to_sql(2);
424 let sql = format!(
425 "DELETE FROM {} WHERE {} = $1 AND {} IN (SELECT id FROM {} WHERE {})",
426 quote_identifier(&join.table_name),
427 quote_identifier(&join.parent_column),
428 quote_identifier(&join.related_column),
429 quote_identifier(&self.related_table),
430 where_sql
431 );
432 let mut final_params = vec![parent_id.clone()];
433 final_params.extend(params);
434 params = final_params;
435 statements.push((sql, params));
436 }
437 }
438
439 statements
440 }
441
442 pub fn build_set_sql<T: Model>(
444 &self,
445 parent_id: &FilterValue,
446 filters: &[Filter],
447 ) -> Vec<(String, Vec<FilterValue>)> {
448 let mut statements = Vec::new();
449
450 if self.is_one_to_many {
452 let sql = format!(
453 "UPDATE {} SET {} = NULL WHERE {} = $1",
454 quote_identifier(&self.related_table),
455 quote_identifier(&self.foreign_key),
456 quote_identifier(&self.foreign_key)
457 );
458 statements.push((sql, vec![parent_id.clone()]));
459 } else if let Some(join) = &self.join_table {
460 let sql = format!(
461 "DELETE FROM {} WHERE {} = $1",
462 quote_identifier(&join.table_name),
463 quote_identifier(&join.parent_column)
464 );
465 statements.push((sql, vec![parent_id.clone()]));
466 }
467
468 statements.extend(self.build_connect_sql::<T>(parent_id, filters));
470
471 statements
472 }
473
474 pub fn build_create_sql<T: Model>(
476 &self,
477 parent_id: &FilterValue,
478 creates: &[NestedCreateData<T>],
479 ) -> Vec<(String, Vec<FilterValue>)> {
480 let mut statements = Vec::new();
481
482 for create in creates {
483 let mut columns: Vec<String> = create.data.iter().map(|(k, _)| k.clone()).collect();
484 let mut values: Vec<FilterValue> = create.data.iter().map(|(_, v)| v.clone()).collect();
485
486 columns.push(self.foreign_key.clone());
488 values.push(parent_id.clone());
489
490 let placeholders: Vec<String> = (1..=values.len()).map(|i| format!("${}", i)).collect();
491
492 let sql = format!(
493 "INSERT INTO {} ({}) VALUES ({}) RETURNING *",
494 quote_identifier(&self.related_table),
495 columns
496 .iter()
497 .map(|c| quote_identifier(c))
498 .collect::<Vec<_>>()
499 .join(", "),
500 placeholders.join(", ")
501 );
502
503 statements.push((sql, values));
504 }
505
506 statements
507 }
508
509 pub fn build_delete_sql(
511 &self,
512 parent_id: &FilterValue,
513 filters: &[Filter],
514 ) -> Vec<(String, Vec<FilterValue>)> {
515 let mut statements = Vec::new();
516
517 for filter in filters {
518 let (where_sql, mut params) = filter.to_sql(1);
519 let sql = format!(
520 "DELETE FROM {} WHERE {} AND {} = ${}",
521 quote_identifier(&self.related_table),
522 where_sql,
523 quote_identifier(&self.foreign_key),
524 params.len() + 1
525 );
526 params.push(parent_id.clone());
527 statements.push((sql, params));
528 }
529
530 statements
531 }
532}
533
534#[derive(Debug, Default)]
536pub struct NestedWriteOperations {
537 pub pre_statements: Vec<(String, Vec<FilterValue>)>,
539 pub post_statements: Vec<(String, Vec<FilterValue>)>,
541}
542
543impl NestedWriteOperations {
544 pub fn new() -> Self {
546 Self::default()
547 }
548
549 pub fn add_pre(&mut self, sql: String, params: Vec<FilterValue>) {
551 self.pre_statements.push((sql, params));
552 }
553
554 pub fn add_post(&mut self, sql: String, params: Vec<FilterValue>) {
556 self.post_statements.push((sql, params));
557 }
558
559 pub fn extend(&mut self, other: Self) {
561 self.pre_statements.extend(other.pre_statements);
562 self.post_statements.extend(other.post_statements);
563 }
564
565 pub fn is_empty(&self) -> bool {
567 self.pre_statements.is_empty() && self.post_statements.is_empty()
568 }
569
570 pub fn len(&self) -> usize {
572 self.pre_statements.len() + self.post_statements.len()
573 }
574}
575
576#[cfg(test)]
577mod tests {
578 use super::*;
579
580 struct TestModel;
581
582 impl Model for TestModel {
583 const MODEL_NAME: &'static str = "Post";
584 const TABLE_NAME: &'static str = "posts";
585 const PRIMARY_KEY: &'static [&'static str] = &["id"];
586 const COLUMNS: &'static [&'static str] = &["id", "title", "user_id"];
587 }
588
589 struct TagModel;
590
591 impl Model for TagModel {
592 const MODEL_NAME: &'static str = "Tag";
593 const TABLE_NAME: &'static str = "tags";
594 const PRIMARY_KEY: &'static [&'static str] = &["id"];
595 const COLUMNS: &'static [&'static str] = &["id", "name"];
596 }
597
598 #[test]
599 fn test_nested_create_data() {
600 let data: NestedCreateData<TestModel> =
601 NestedCreateData::from_pairs([("title", FilterValue::String("Test Post".to_string()))]);
602
603 assert_eq!(data.data.len(), 1);
604 assert_eq!(data.data[0].0, "title");
605 }
606
607 #[test]
608 fn test_nested_write_create() {
609 let data: NestedCreateData<TestModel> =
610 NestedCreateData::from_pairs([("title", FilterValue::String("Test Post".to_string()))]);
611
612 let write: NestedWrite<TestModel> = NestedWrite::create(data);
613
614 match write {
615 NestedWrite::Create(creates) => assert_eq!(creates.len(), 1),
616 _ => panic!("Expected Create variant"),
617 }
618 }
619
620 #[test]
621 fn test_nested_write_connect() {
622 let write: NestedWrite<TestModel> = NestedWrite::connect(vec![
623 Filter::Equals("id".into(), FilterValue::Int(1)),
624 Filter::Equals("id".into(), FilterValue::Int(2)),
625 ]);
626
627 match write {
628 NestedWrite::Connect(filters) => assert_eq!(filters.len(), 2),
629 _ => panic!("Expected Connect variant"),
630 }
631 }
632
633 #[test]
634 fn test_nested_write_disconnect() {
635 let write: NestedWrite<TestModel> =
636 NestedWrite::disconnect_one(Filter::Equals("id".into(), FilterValue::Int(1)));
637
638 match write {
639 NestedWrite::Disconnect(filters) => assert_eq!(filters.len(), 1),
640 _ => panic!("Expected Disconnect variant"),
641 }
642 }
643
644 #[test]
645 fn test_nested_write_set() {
646 let write: NestedWrite<TestModel> =
647 NestedWrite::set(vec![Filter::Equals("id".into(), FilterValue::Int(1))]);
648
649 match write {
650 NestedWrite::Set(filters) => assert_eq!(filters.len(), 1),
651 _ => panic!("Expected Set variant"),
652 }
653 }
654
655 #[test]
656 fn test_builder_one_to_many_connect() {
657 let builder =
658 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
659
660 let parent_id = FilterValue::Int(1);
661 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
662
663 let statements = builder.build_connect_sql::<TestModel>(&parent_id, &filters);
664
665 assert_eq!(statements.len(), 1);
666 let (sql, params) = &statements[0];
667 assert!(sql.contains("UPDATE"));
668 assert!(sql.contains("posts"));
669 assert!(sql.contains("user_id"));
670 assert_eq!(params.len(), 2);
671 }
672
673 #[test]
674 fn test_builder_one_to_many_disconnect() {
675 let builder =
676 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
677
678 let parent_id = FilterValue::Int(1);
679 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
680
681 let statements = builder.build_disconnect_sql(&parent_id, &filters);
682
683 assert_eq!(statements.len(), 1);
684 let (sql, params) = &statements[0];
685 assert!(sql.contains("UPDATE"));
686 assert!(sql.contains("SET"));
687 assert!(sql.contains("NULL"));
688 assert_eq!(params.len(), 2);
689 }
690
691 #[test]
692 fn test_builder_many_to_many_connect() {
693 let builder = NestedWriteBuilder::many_to_many(
694 "posts",
695 vec!["id".to_string()],
696 "tags",
697 JoinTableInfo {
698 table_name: "post_tags".to_string(),
699 parent_column: "post_id".to_string(),
700 related_column: "tag_id".to_string(),
701 },
702 );
703
704 let parent_id = FilterValue::Int(1);
705 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
706
707 let statements = builder.build_connect_sql::<TagModel>(&parent_id, &filters);
708
709 assert_eq!(statements.len(), 1);
710 let (sql, _params) = &statements[0];
711 assert!(sql.contains("INSERT INTO"));
712 assert!(sql.contains("post_tags"));
713 assert!(sql.contains("ON CONFLICT DO NOTHING"));
714 }
715
716 #[test]
717 fn test_builder_create() {
718 let builder =
719 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
720
721 let parent_id = FilterValue::Int(1);
722 let creates = vec![NestedCreateData::<TestModel>::from_pairs([(
723 "title",
724 FilterValue::String("New Post".to_string()),
725 )])];
726
727 let statements = builder.build_create_sql::<TestModel>(&parent_id, &creates);
728
729 assert_eq!(statements.len(), 1);
730 let (sql, params) = &statements[0];
731 assert!(sql.contains("INSERT INTO"));
732 assert!(sql.contains("posts"));
733 assert!(sql.contains("RETURNING"));
734 assert_eq!(params.len(), 2); }
736
737 #[test]
738 fn test_builder_set() {
739 let builder =
740 NestedWriteBuilder::one_to_many("users", vec!["id".to_string()], "posts", "user_id");
741
742 let parent_id = FilterValue::Int(1);
743 let filters = vec![Filter::Equals("id".into(), FilterValue::Int(10))];
744
745 let statements = builder.build_set_sql::<TestModel>(&parent_id, &filters);
746
747 assert!(statements.len() >= 2);
749
750 let (first_sql, _) = &statements[0];
752 assert!(first_sql.contains("UPDATE"));
753 assert!(first_sql.contains("NULL"));
754 }
755
756 #[test]
757 fn test_nested_write_operations() {
758 let mut ops = NestedWriteOperations::new();
759 assert!(ops.is_empty());
760 assert_eq!(ops.len(), 0);
761
762 ops.add_pre("SELECT 1".to_string(), vec![]);
763 ops.add_post("SELECT 2".to_string(), vec![]);
764
765 assert!(!ops.is_empty());
766 assert_eq!(ops.len(), 2);
767 }
768
769 #[test]
770 fn test_nested_create_or_connect() {
771 let create_data: NestedCreateData<TestModel> =
772 NestedCreateData::from_pairs([("title", FilterValue::String("New Post".to_string()))]);
773
774 let create_or_connect = NestedCreateOrConnectData::new(
775 Filter::Equals("title".into(), FilterValue::String("Existing".to_string())),
776 create_data,
777 );
778
779 assert!(matches!(create_or_connect.filter, Filter::Equals(..)));
780 assert_eq!(create_or_connect.create.data.len(), 1);
781 }
782
783 #[test]
784 fn test_nested_update_data() {
785 let update: NestedUpdateData<TestModel> = NestedUpdateData::from_pairs(
786 Filter::Equals("id".into(), FilterValue::Int(1)),
787 [("title", FilterValue::String("Updated".to_string()))],
788 );
789
790 assert!(matches!(update.filter, Filter::Equals(..)));
791 assert_eq!(update.data.len(), 1);
792 assert_eq!(update.data[0].0, "title");
793 }
794
795 #[test]
796 fn test_nested_upsert_data() {
797 let create: NestedCreateData<TestModel> =
798 NestedCreateData::from_pairs([("title", FilterValue::String("New".to_string()))]);
799
800 let upsert: NestedUpsertData<TestModel> = NestedUpsertData::new(
801 Filter::Equals("id".into(), FilterValue::Int(1)),
802 create,
803 vec![(
804 "title".to_string(),
805 FilterValue::String("Updated".to_string()),
806 )],
807 );
808
809 assert!(matches!(upsert.filter, Filter::Equals(..)));
810 assert_eq!(upsert.create.data.len(), 1);
811 assert_eq!(upsert.update.len(), 1);
812 }
813}