1#![allow(clippy::cast_precision_loss)]
21
22use std::collections::HashMap;
23
24use serde::{Deserialize, Serialize};
25
26use crate::{
27 dataset::{ArrowDataset, Dataset},
28 error::{Error, Result},
29 split::DatasetSplit,
30};
31
32#[derive(Debug, Clone)]
34pub struct FederatedSplitCoordinator {
35 strategy: FederatedSplitStrategy,
37}
38
39#[derive(Debug, Clone, PartialEq)]
41pub enum FederatedSplitStrategy {
42 LocalWithSeed {
44 seed: u64,
46 train_ratio: f64,
48 },
49
50 GlobalStratified {
52 label_column: String,
54 target_distribution: HashMap<String, f64>,
56 },
57
58 ProportionalIID {
60 train_ratio: f64,
62 },
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
67pub struct NodeSplitManifest {
68 pub node_id: String,
70 pub total_rows: u64,
72 pub train_rows: u64,
74 pub test_rows: u64,
76 pub validation_rows: Option<u64>,
78 pub label_distribution: Option<HashMap<String, u64>>,
80 pub split_hash: [u8; 32],
82}
83
84#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct NodeSplitInstruction {
87 pub node_id: String,
89 pub seed: u64,
91 pub train_ratio: f64,
93 pub test_ratio: f64,
95 pub validation_ratio: Option<f64>,
97 pub stratify_column: Option<String>,
99}
100
101#[derive(Debug, Clone)]
103pub struct GlobalSplitReport {
104 pub total_rows: u64,
106 pub total_train_rows: u64,
108 pub total_test_rows: u64,
110 pub total_validation_rows: Option<u64>,
112 pub effective_train_ratio: f64,
114 pub effective_test_ratio: f64,
116 pub effective_validation_ratio: Option<f64>,
118 pub node_summaries: Vec<NodeSummary>,
120 pub global_label_distribution: Option<HashMap<String, u64>>,
122 pub quality_passed: bool,
124 pub issues: Vec<SplitQualityIssue>,
126}
127
128#[derive(Debug, Clone)]
130pub struct NodeSummary {
131 pub node_id: String,
133 pub contribution_ratio: f64,
135 pub train_ratio: f64,
137 pub test_ratio: f64,
139}
140
141#[derive(Debug, Clone, PartialEq)]
143pub enum SplitQualityIssue {
144 RatioDeviation {
146 node_id: String,
148 expected: f64,
150 actual: f64,
152 },
153 DistributionImbalance {
155 label: String,
157 nodes: Vec<String>,
159 },
160 InsufficientSamples {
162 node_id: String,
164 samples: u64,
166 minimum: u64,
168 },
169 HashMismatch {
171 node_id: String,
173 },
174}
175
176impl FederatedSplitCoordinator {
177 #[must_use]
179 pub fn new(strategy: FederatedSplitStrategy) -> Self {
180 Self { strategy }
181 }
182
183 #[must_use]
185 pub fn strategy(&self) -> &FederatedSplitStrategy {
186 &self.strategy
187 }
188
189 pub fn compute_split_plan(
193 &self,
194 manifests: &[NodeSplitManifest],
195 ) -> Result<Vec<NodeSplitInstruction>> {
196 if manifests.is_empty() {
197 return Err(Error::invalid_config(
198 "Cannot compute plan for empty manifest list",
199 ));
200 }
201
202 match &self.strategy {
203 FederatedSplitStrategy::LocalWithSeed { seed, train_ratio } => Ok(
204 Self::compute_local_seed_plan(manifests, *seed, *train_ratio),
205 ),
206 FederatedSplitStrategy::GlobalStratified {
207 label_column,
208 target_distribution,
209 } => Ok(Self::compute_stratified_plan(
210 manifests,
211 label_column,
212 target_distribution,
213 )),
214 FederatedSplitStrategy::ProportionalIID { train_ratio } => {
215 Ok(Self::compute_proportional_plan(manifests, *train_ratio))
216 }
217 }
218 }
219
220 fn compute_local_seed_plan(
222 manifests: &[NodeSplitManifest],
223 seed: u64,
224 train_ratio: f64,
225 ) -> Vec<NodeSplitInstruction> {
226 let test_ratio = 1.0 - train_ratio;
227
228 manifests
229 .iter()
230 .map(|m| NodeSplitInstruction {
231 node_id: m.node_id.clone(),
232 seed,
233 train_ratio,
234 test_ratio,
235 validation_ratio: None,
236 stratify_column: None,
237 })
238 .collect()
239 }
240
241 fn compute_stratified_plan(
243 manifests: &[NodeSplitManifest],
244 label_column: &str,
245 _target_distribution: &HashMap<String, f64>,
246 ) -> Vec<NodeSplitInstruction> {
247 let base_seed = 42u64; let train_ratio = 0.8;
254 let test_ratio = 0.2;
255
256 manifests
257 .iter()
258 .enumerate()
259 .map(|(i, m)| {
260 let node_seed = base_seed.wrapping_add(i as u64);
262
263 NodeSplitInstruction {
264 node_id: m.node_id.clone(),
265 seed: node_seed,
266 train_ratio,
267 test_ratio,
268 validation_ratio: None,
269 stratify_column: Some(label_column.to_string()),
270 }
271 })
272 .collect()
273 }
274
275 fn compute_proportional_plan(
277 manifests: &[NodeSplitManifest],
278 train_ratio: f64,
279 ) -> Vec<NodeSplitInstruction> {
280 let test_ratio = 1.0 - train_ratio;
281
282 manifests
284 .iter()
285 .enumerate()
286 .map(|(i, m)| NodeSplitInstruction {
287 node_id: m.node_id.clone(),
288 seed: i as u64,
289 train_ratio,
290 test_ratio,
291 validation_ratio: None,
292 stratify_column: None,
293 })
294 .collect()
295 }
296
297 pub fn execute_local_split(
301 dataset: &ArrowDataset,
302 instruction: &NodeSplitInstruction,
303 ) -> Result<DatasetSplit> {
304 let val_ratio = instruction.validation_ratio;
305
306 if let Some(ref column) = instruction.stratify_column {
307 DatasetSplit::stratified(
308 dataset,
309 column,
310 instruction.train_ratio,
311 instruction.test_ratio,
312 val_ratio,
313 Some(instruction.seed),
314 )
315 } else {
316 DatasetSplit::from_ratios(
317 dataset,
318 instruction.train_ratio,
319 instruction.test_ratio,
320 val_ratio,
321 Some(instruction.seed),
322 )
323 }
324 }
325
326 pub fn verify_global_split(manifests: &[NodeSplitManifest]) -> Result<GlobalSplitReport> {
330 if manifests.is_empty() {
331 return Err(Error::invalid_config("Cannot verify empty manifest list"));
332 }
333
334 let total_rows: u64 = manifests.iter().map(|m| m.total_rows).sum();
335 let total_train_rows: u64 = manifests.iter().map(|m| m.train_rows).sum();
336 let total_test_rows: u64 = manifests.iter().map(|m| m.test_rows).sum();
337 let total_validation_rows: Option<u64> =
338 if manifests.iter().any(|m| m.validation_rows.is_some()) {
339 Some(manifests.iter().filter_map(|m| m.validation_rows).sum())
340 } else {
341 None
342 };
343
344 let effective_train_ratio = if total_rows > 0 {
345 total_train_rows as f64 / total_rows as f64
346 } else {
347 0.0
348 };
349
350 let effective_test_ratio = if total_rows > 0 {
351 total_test_rows as f64 / total_rows as f64
352 } else {
353 0.0
354 };
355
356 let effective_validation_ratio = total_validation_rows.map(|v| {
357 if total_rows > 0 {
358 v as f64 / total_rows as f64
359 } else {
360 0.0
361 }
362 });
363
364 let node_summaries: Vec<NodeSummary> = manifests
366 .iter()
367 .map(|m| {
368 let contribution_ratio = if total_rows > 0 {
369 m.total_rows as f64 / total_rows as f64
370 } else {
371 0.0
372 };
373
374 let train_ratio = if m.total_rows > 0 {
375 m.train_rows as f64 / m.total_rows as f64
376 } else {
377 0.0
378 };
379
380 let test_ratio = if m.total_rows > 0 {
381 m.test_rows as f64 / m.total_rows as f64
382 } else {
383 0.0
384 };
385
386 NodeSummary {
387 node_id: m.node_id.clone(),
388 contribution_ratio,
389 train_ratio,
390 test_ratio,
391 }
392 })
393 .collect();
394
395 let global_label_distribution = Self::aggregate_label_distributions(manifests);
397
398 let issues = Self::detect_quality_issues(manifests, &node_summaries);
400
401 let quality_passed = issues.is_empty();
402
403 Ok(GlobalSplitReport {
404 total_rows,
405 total_train_rows,
406 total_test_rows,
407 total_validation_rows,
408 effective_train_ratio,
409 effective_test_ratio,
410 effective_validation_ratio,
411 node_summaries,
412 global_label_distribution,
413 quality_passed,
414 issues,
415 })
416 }
417
418 fn aggregate_label_distributions(
420 manifests: &[NodeSplitManifest],
421 ) -> Option<HashMap<String, u64>> {
422 let mut global_dist: HashMap<String, u64> = HashMap::new();
423 let mut any_has_distribution = false;
424
425 for manifest in manifests {
426 if let Some(ref dist) = manifest.label_distribution {
427 any_has_distribution = true;
428 for (label, count) in dist {
429 *global_dist.entry(label.clone()).or_insert(0) += count;
430 }
431 }
432 }
433
434 if any_has_distribution {
435 Some(global_dist)
436 } else {
437 None
438 }
439 }
440
441 fn detect_quality_issues(
443 manifests: &[NodeSplitManifest],
444 summaries: &[NodeSummary],
445 ) -> Vec<SplitQualityIssue> {
446 const MIN_SAMPLES: u64 = 10;
448
449 let mut issues = Vec::new();
450
451 for manifest in manifests {
453 if manifest.train_rows < MIN_SAMPLES || manifest.test_rows < MIN_SAMPLES {
454 issues.push(SplitQualityIssue::InsufficientSamples {
455 node_id: manifest.node_id.clone(),
456 samples: manifest.train_rows.min(manifest.test_rows),
457 minimum: MIN_SAMPLES,
458 });
459 }
460 }
461
462 if !summaries.is_empty() {
464 let mean_train_ratio: f64 =
465 summaries.iter().map(|s| s.train_ratio).sum::<f64>() / summaries.len() as f64;
466
467 for summary in summaries {
468 let deviation = (summary.train_ratio - mean_train_ratio).abs();
469 if deviation > 0.1 {
470 issues.push(SplitQualityIssue::RatioDeviation {
471 node_id: summary.node_id.clone(),
472 expected: mean_train_ratio,
473 actual: summary.train_ratio,
474 });
475 }
476 }
477 }
478
479 issues
480 }
481}
482
483impl NodeSplitManifest {
484 #[must_use]
486 pub fn new(
487 node_id: impl Into<String>,
488 total_rows: u64,
489 train_rows: u64,
490 test_rows: u64,
491 ) -> Self {
492 Self {
493 node_id: node_id.into(),
494 total_rows,
495 train_rows,
496 test_rows,
497 validation_rows: None,
498 label_distribution: None,
499 split_hash: [0u8; 32],
500 }
501 }
502
503 #[must_use]
505 pub fn with_validation(mut self, rows: u64) -> Self {
506 self.validation_rows = Some(rows);
507 self
508 }
509
510 #[must_use]
512 pub fn with_label_distribution(mut self, distribution: HashMap<String, u64>) -> Self {
513 self.label_distribution = Some(distribution);
514 self
515 }
516
517 #[must_use]
519 pub fn with_split_hash(mut self, hash: [u8; 32]) -> Self {
520 self.split_hash = hash;
521 self
522 }
523
524 #[must_use]
526 pub fn from_split(node_id: impl Into<String>, split: &DatasetSplit) -> Self {
527 let train_rows = split.train.len() as u64;
528 let test_rows = split.test.len() as u64;
529 let validation_rows = split.validation.as_ref().map(|v| v.len() as u64);
530
531 let mut manifest = Self::new(
532 node_id,
533 train_rows + test_rows + validation_rows.unwrap_or(0),
534 train_rows,
535 test_rows,
536 );
537
538 if let Some(v) = validation_rows {
539 manifest = manifest.with_validation(v);
540 }
541
542 manifest
543 }
544
545 pub fn to_json(&self) -> Result<Vec<u8>> {
547 serde_json::to_vec(self).map_err(|e| Error::Format(e.to_string()))
548 }
549
550 pub fn from_json(data: &[u8]) -> Result<Self> {
552 serde_json::from_slice(data).map_err(|e| Error::Format(e.to_string()))
553 }
554}
555
556impl NodeSplitInstruction {
557 pub fn to_json(&self) -> Result<Vec<u8>> {
559 serde_json::to_vec(self).map_err(|e| Error::Format(e.to_string()))
560 }
561
562 pub fn from_json(data: &[u8]) -> Result<Self> {
564 serde_json::from_slice(data).map_err(|e| Error::Format(e.to_string()))
565 }
566}
567
568#[cfg(test)]
569mod tests {
570 use super::*;
571
572 #[test]
577 fn test_strategy_local_with_seed() {
578 let strategy = FederatedSplitStrategy::LocalWithSeed {
579 seed: 42,
580 train_ratio: 0.8,
581 };
582
583 match strategy {
584 FederatedSplitStrategy::LocalWithSeed { seed, train_ratio } => {
585 assert_eq!(seed, 42);
586 assert!((train_ratio - 0.8).abs() < f64::EPSILON);
587 }
588 _ => panic!("Wrong variant"),
589 }
590 }
591
592 #[test]
593 fn test_strategy_global_stratified() {
594 let mut target = HashMap::new();
595 target.insert("class_a".to_string(), 0.5);
596 target.insert("class_b".to_string(), 0.5);
597
598 let strategy = FederatedSplitStrategy::GlobalStratified {
599 label_column: "label".to_string(),
600 target_distribution: target.clone(),
601 };
602
603 match strategy {
604 FederatedSplitStrategy::GlobalStratified {
605 label_column,
606 target_distribution,
607 } => {
608 assert_eq!(label_column, "label");
609 assert_eq!(target_distribution, target);
610 }
611 _ => panic!("Wrong variant"),
612 }
613 }
614
615 #[test]
616 fn test_strategy_proportional_iid() {
617 let strategy = FederatedSplitStrategy::ProportionalIID { train_ratio: 0.7 };
618
619 match strategy {
620 FederatedSplitStrategy::ProportionalIID { train_ratio } => {
621 assert!((train_ratio - 0.7).abs() < f64::EPSILON);
622 }
623 _ => panic!("Wrong variant"),
624 }
625 }
626
627 #[test]
628 fn test_strategy_clone_and_debug() {
629 let strategy = FederatedSplitStrategy::LocalWithSeed {
630 seed: 42,
631 train_ratio: 0.8,
632 };
633
634 let cloned = strategy.clone();
635 assert_eq!(strategy, cloned);
636
637 let debug = format!("{:?}", strategy);
638 assert!(debug.contains("LocalWithSeed"));
639 assert!(debug.contains("42"));
640 }
641
642 #[test]
647 fn test_manifest_new() {
648 let manifest = NodeSplitManifest::new("node_a", 1000, 800, 200);
649
650 assert_eq!(manifest.node_id, "node_a");
651 assert_eq!(manifest.total_rows, 1000);
652 assert_eq!(manifest.train_rows, 800);
653 assert_eq!(manifest.test_rows, 200);
654 assert!(manifest.validation_rows.is_none());
655 assert!(manifest.label_distribution.is_none());
656 }
657
658 #[test]
659 fn test_manifest_with_validation() {
660 let manifest = NodeSplitManifest::new("node_a", 1000, 700, 200).with_validation(100);
661
662 assert_eq!(manifest.validation_rows, Some(100));
663 }
664
665 #[test]
666 fn test_manifest_with_label_distribution() {
667 let mut dist = HashMap::new();
668 dist.insert("cat".to_string(), 500);
669 dist.insert("dog".to_string(), 500);
670
671 let manifest =
672 NodeSplitManifest::new("node_a", 1000, 800, 200).with_label_distribution(dist.clone());
673
674 assert_eq!(manifest.label_distribution, Some(dist));
675 }
676
677 #[test]
678 fn test_manifest_with_split_hash() {
679 let hash = [1u8; 32];
680 let manifest = NodeSplitManifest::new("node_a", 1000, 800, 200).with_split_hash(hash);
681
682 assert_eq!(manifest.split_hash, hash);
683 }
684
685 #[test]
686 fn test_manifest_serialization() {
687 let manifest = NodeSplitManifest::new("node_a", 1000, 800, 200);
688
689 let json = manifest.to_json().expect("serialization failed");
690 let parsed = NodeSplitManifest::from_json(&json).expect("deserialization failed");
691
692 assert_eq!(manifest, parsed);
693 }
694
695 #[test]
696 fn test_manifest_full_serialization() {
697 let mut dist = HashMap::new();
698 dist.insert("a".to_string(), 400);
699 dist.insert("b".to_string(), 600);
700
701 let manifest = NodeSplitManifest::new("node_eu", 1000, 700, 200)
702 .with_validation(100)
703 .with_label_distribution(dist)
704 .with_split_hash([42u8; 32]);
705
706 let json = manifest.to_json().expect("serialization failed");
707 let parsed = NodeSplitManifest::from_json(&json).expect("deserialization failed");
708
709 assert_eq!(manifest, parsed);
710 }
711
712 #[test]
717 fn test_instruction_serialization() {
718 let instruction = NodeSplitInstruction {
719 node_id: "node_a".to_string(),
720 seed: 42,
721 train_ratio: 0.8,
722 test_ratio: 0.2,
723 validation_ratio: None,
724 stratify_column: None,
725 };
726
727 let json = instruction.to_json().expect("serialization failed");
728 let parsed = NodeSplitInstruction::from_json(&json).expect("deserialization failed");
729
730 assert_eq!(instruction, parsed);
731 }
732
733 #[test]
734 fn test_instruction_with_stratification() {
735 let instruction = NodeSplitInstruction {
736 node_id: "node_b".to_string(),
737 seed: 123,
738 train_ratio: 0.7,
739 test_ratio: 0.15,
740 validation_ratio: Some(0.15),
741 stratify_column: Some("label".to_string()),
742 };
743
744 let json = instruction.to_json().expect("serialization failed");
745 let parsed = NodeSplitInstruction::from_json(&json).expect("deserialization failed");
746
747 assert_eq!(instruction, parsed);
748 }
749
750 #[test]
755 fn test_coordinator_new() {
756 let strategy = FederatedSplitStrategy::LocalWithSeed {
757 seed: 42,
758 train_ratio: 0.8,
759 };
760 let coordinator = FederatedSplitCoordinator::new(strategy.clone());
761
762 assert_eq!(coordinator.strategy(), &strategy);
763 }
764
765 #[test]
766 fn test_coordinator_empty_manifests_error() {
767 let coordinator = FederatedSplitCoordinator::new(FederatedSplitStrategy::LocalWithSeed {
768 seed: 42,
769 train_ratio: 0.8,
770 });
771
772 let result = coordinator.compute_split_plan(&[]);
773 assert!(result.is_err());
774 }
775
776 #[test]
777 fn test_coordinator_local_seed_plan() {
778 let coordinator = FederatedSplitCoordinator::new(FederatedSplitStrategy::LocalWithSeed {
779 seed: 42,
780 train_ratio: 0.8,
781 });
782
783 let manifests = vec![
784 NodeSplitManifest::new("node_a", 1000, 800, 200),
785 NodeSplitManifest::new("node_b", 2000, 1600, 400),
786 ];
787
788 let plan = coordinator
789 .compute_split_plan(&manifests)
790 .expect("plan failed");
791
792 assert_eq!(plan.len(), 2);
793
794 assert_eq!(plan[0].seed, 42);
796 assert_eq!(plan[1].seed, 42);
797
798 assert!((plan[0].train_ratio - 0.8).abs() < f64::EPSILON);
800 assert!((plan[1].train_ratio - 0.8).abs() < f64::EPSILON);
801 }
802
803 #[test]
804 fn test_coordinator_stratified_plan() {
805 let mut target = HashMap::new();
806 target.insert("a".to_string(), 0.5);
807 target.insert("b".to_string(), 0.5);
808
809 let coordinator =
810 FederatedSplitCoordinator::new(FederatedSplitStrategy::GlobalStratified {
811 label_column: "label".to_string(),
812 target_distribution: target,
813 });
814
815 let manifests = vec![
816 NodeSplitManifest::new("node_a", 1000, 800, 200),
817 NodeSplitManifest::new("node_b", 2000, 1600, 400),
818 ];
819
820 let plan = coordinator
821 .compute_split_plan(&manifests)
822 .expect("plan failed");
823
824 assert_eq!(plan.len(), 2);
825
826 assert_eq!(plan[0].stratify_column, Some("label".to_string()));
828 assert_eq!(plan[1].stratify_column, Some("label".to_string()));
829
830 assert_ne!(plan[0].seed, plan[1].seed);
832 }
833
834 #[test]
835 fn test_coordinator_proportional_plan() {
836 let coordinator = FederatedSplitCoordinator::new(FederatedSplitStrategy::ProportionalIID {
837 train_ratio: 0.7,
838 });
839
840 let manifests = vec![
841 NodeSplitManifest::new("node_a", 1000, 700, 300),
842 NodeSplitManifest::new("node_b", 2000, 1400, 600),
843 NodeSplitManifest::new("node_c", 500, 350, 150),
844 ];
845
846 let plan = coordinator
847 .compute_split_plan(&manifests)
848 .expect("plan failed");
849
850 assert_eq!(plan.len(), 3);
851
852 for instruction in &plan {
854 assert!((instruction.train_ratio - 0.7).abs() < f64::EPSILON);
855 assert!((instruction.test_ratio - 0.3).abs() < f64::EPSILON);
856 }
857
858 assert_eq!(plan[0].seed, 0);
860 assert_eq!(plan[1].seed, 1);
861 assert_eq!(plan[2].seed, 2);
862 }
863
864 #[test]
869 fn test_verify_global_split_empty_error() {
870 let result = FederatedSplitCoordinator::verify_global_split(&[]);
871 assert!(result.is_err());
872 }
873
874 #[test]
875 fn test_verify_global_split_single_node() {
876 let manifests = vec![NodeSplitManifest::new("node_a", 1000, 800, 200)];
877
878 let report =
879 FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
880
881 assert_eq!(report.total_rows, 1000);
882 assert_eq!(report.total_train_rows, 800);
883 assert_eq!(report.total_test_rows, 200);
884 assert!((report.effective_train_ratio - 0.8).abs() < f64::EPSILON);
885 assert!((report.effective_test_ratio - 0.2).abs() < f64::EPSILON);
886 assert!(report.quality_passed);
887 }
888
889 #[test]
890 fn test_verify_global_split_multiple_nodes() {
891 let manifests = vec![
892 NodeSplitManifest::new("node_a", 1000, 800, 200),
893 NodeSplitManifest::new("node_b", 2000, 1600, 400),
894 NodeSplitManifest::new("node_c", 1000, 800, 200),
895 ];
896
897 let report =
898 FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
899
900 assert_eq!(report.total_rows, 4000);
901 assert_eq!(report.total_train_rows, 3200);
902 assert_eq!(report.total_test_rows, 800);
903 assert!((report.effective_train_ratio - 0.8).abs() < f64::EPSILON);
904
905 assert_eq!(report.node_summaries.len(), 3);
906 assert!(report.quality_passed);
907 }
908
909 #[test]
910 fn test_verify_global_split_with_validation() {
911 let manifests = vec![
912 NodeSplitManifest::new("node_a", 1000, 700, 200).with_validation(100),
913 NodeSplitManifest::new("node_b", 2000, 1400, 400).with_validation(200),
914 ];
915
916 let report =
917 FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
918
919 assert_eq!(report.total_validation_rows, Some(300));
920 assert!((report.effective_validation_ratio.unwrap() - 0.1).abs() < f64::EPSILON);
921 }
922
923 #[test]
924 fn test_verify_global_split_aggregates_labels() {
925 let mut dist_a = HashMap::new();
926 dist_a.insert("cat".to_string(), 600);
927 dist_a.insert("dog".to_string(), 400);
928
929 let mut dist_b = HashMap::new();
930 dist_b.insert("cat".to_string(), 800);
931 dist_b.insert("dog".to_string(), 1200);
932
933 let manifests = vec![
934 NodeSplitManifest::new("node_a", 1000, 800, 200).with_label_distribution(dist_a),
935 NodeSplitManifest::new("node_b", 2000, 1600, 400).with_label_distribution(dist_b),
936 ];
937
938 let report =
939 FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
940
941 let global_dist = report
942 .global_label_distribution
943 .expect("should have distribution");
944 assert_eq!(global_dist.get("cat"), Some(&1400));
945 assert_eq!(global_dist.get("dog"), Some(&1600));
946 }
947
948 #[test]
949 fn test_verify_detects_insufficient_samples() {
950 let manifests = vec![
951 NodeSplitManifest::new("node_a", 1000, 800, 200),
952 NodeSplitManifest::new("node_b", 15, 10, 5), ];
954
955 let report =
956 FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
957
958 assert!(!report.quality_passed);
959 assert!(!report.issues.is_empty());
960
961 let has_insufficient = report.issues.iter().any(|i| {
962 matches!(i, SplitQualityIssue::InsufficientSamples { node_id, .. } if node_id == "node_b")
963 });
964 assert!(has_insufficient);
965 }
966
967 #[test]
968 fn test_verify_detects_ratio_deviation() {
969 let manifests = vec![
970 NodeSplitManifest::new("node_a", 1000, 800, 200), NodeSplitManifest::new("node_b", 1000, 500, 500), ];
973
974 let report =
975 FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
976
977 assert!(!report.quality_passed);
978
979 let has_deviation = report
980 .issues
981 .iter()
982 .any(|i| matches!(i, SplitQualityIssue::RatioDeviation { .. }));
983 assert!(has_deviation);
984 }
985
986 #[test]
987 fn test_node_summary_contribution_ratio() {
988 let manifests = vec![
989 NodeSplitManifest::new("node_a", 1000, 800, 200),
990 NodeSplitManifest::new("node_b", 3000, 2400, 600),
991 ];
992
993 let report =
994 FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
995
996 assert!((report.node_summaries[0].contribution_ratio - 0.25).abs() < f64::EPSILON);
998
999 assert!((report.node_summaries[1].contribution_ratio - 0.75).abs() < f64::EPSILON);
1001 }
1002
1003 #[test]
1004 fn test_split_quality_issue_variants() {
1005 let ratio_issue = SplitQualityIssue::RatioDeviation {
1006 node_id: "node_a".to_string(),
1007 expected: 0.8,
1008 actual: 0.5,
1009 };
1010 assert!(format!("{:?}", ratio_issue).contains("RatioDeviation"));
1011
1012 let dist_issue = SplitQualityIssue::DistributionImbalance {
1013 label: "cat".to_string(),
1014 nodes: vec!["node_a".to_string(), "node_b".to_string()],
1015 };
1016 assert!(format!("{:?}", dist_issue).contains("DistributionImbalance"));
1017
1018 let sample_issue = SplitQualityIssue::InsufficientSamples {
1019 node_id: "node_a".to_string(),
1020 samples: 5,
1021 minimum: 10,
1022 };
1023 assert!(format!("{:?}", sample_issue).contains("InsufficientSamples"));
1024
1025 let hash_issue = SplitQualityIssue::HashMismatch {
1026 node_id: "node_a".to_string(),
1027 };
1028 assert!(format!("{:?}", hash_issue).contains("HashMismatch"));
1029 }
1030
1031 #[test]
1032 fn test_global_split_report_debug() {
1033 let manifests = vec![NodeSplitManifest::new("node_a", 100, 80, 20)];
1034 let report =
1035 FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
1036
1037 let debug = format!("{:?}", report);
1038 assert!(debug.contains("GlobalSplitReport"));
1039 assert!(debug.contains("total_rows"));
1040 }
1041}