Skip to main content

alimentar/
federated.rs

1//! Federated Split Coordination for Privacy-Preserving ML
2//!
3//! This module enables distributed/federated ML workflows where data stays
4//! local on each node (sovereignty) and only metadata/sketches cross
5//! boundaries.
6//!
7//! # Architecture
8//!
9//! ```text
10//! Node A (EU):     data_eu.ald      → local train_eu.ald, test_eu.ald
11//! Node B (US):     data_us.ald      → local train_us.ald, test_us.ald
12//! Node C (APAC):   data_apac.ald    → local train_apac.ald, test_apac.ald
13//!                        ↓
14//!              Coordinator (sees only manifests)
15//!                        ↓
16//!              Global split verification
17//! ```
18
19// Federated coordination uses ratio calculations where precision loss is acceptable
20#![allow(clippy::cast_precision_loss)]
21
22use std::collections::HashMap;
23
24use serde::{Deserialize, Serialize};
25
26use crate::{
27    dataset::{ArrowDataset, Dataset},
28    error::{Error, Result},
29    split::DatasetSplit,
30};
31
32/// Federated split coordination (no raw data leaves nodes)
33#[derive(Debug, Clone)]
34pub struct FederatedSplitCoordinator {
35    /// Strategy for distributed splitting
36    strategy: FederatedSplitStrategy,
37}
38
39/// Strategy for federated/distributed splitting
40#[derive(Debug, Clone, PartialEq)]
41pub enum FederatedSplitStrategy {
42    /// Each node splits locally with same seed (simple, no coordination)
43    LocalWithSeed {
44        /// Random seed for reproducibility
45        seed: u64,
46        /// Training set ratio (0.0 to 1.0)
47        train_ratio: f64,
48    },
49
50    /// Stratified across nodes - coordinator sees only label distributions
51    GlobalStratified {
52        /// Column containing class labels
53        label_column: String,
54        /// Target distribution (coordinator computes from sketches)
55        target_distribution: HashMap<String, f64>,
56    },
57
58    /// IID sampling - each node contributes proportionally
59    ProportionalIID {
60        /// Training set ratio
61        train_ratio: f64,
62    },
63}
64
65/// Per-node split manifest (shared with coordinator, no raw data)
66#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
67pub struct NodeSplitManifest {
68    /// Unique node identifier
69    pub node_id: String,
70    /// Total rows in dataset
71    pub total_rows: u64,
72    /// Rows in training split
73    pub train_rows: u64,
74    /// Rows in test split
75    pub test_rows: u64,
76    /// Rows in validation split (optional)
77    pub validation_rows: Option<u64>,
78    /// Label distribution (for stratification verification)
79    pub label_distribution: Option<HashMap<String, u64>>,
80    /// Hash of split indices (for reproducibility verification)
81    pub split_hash: [u8; 32],
82}
83
84/// Instructions for a node to execute its split
85#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
86pub struct NodeSplitInstruction {
87    /// Node this instruction is for
88    pub node_id: String,
89    /// Random seed to use
90    pub seed: u64,
91    /// Training ratio
92    pub train_ratio: f64,
93    /// Test ratio
94    pub test_ratio: f64,
95    /// Validation ratio (optional)
96    pub validation_ratio: Option<f64>,
97    /// Column to stratify by (optional)
98    pub stratify_column: Option<String>,
99}
100
101/// Report on global split quality across all nodes
102#[derive(Debug, Clone)]
103pub struct GlobalSplitReport {
104    /// Total rows across all nodes
105    pub total_rows: u64,
106    /// Total training rows
107    pub total_train_rows: u64,
108    /// Total test rows
109    pub total_test_rows: u64,
110    /// Total validation rows
111    pub total_validation_rows: Option<u64>,
112    /// Effective train/test/val ratios
113    pub effective_train_ratio: f64,
114    /// Effective test ratio
115    pub effective_test_ratio: f64,
116    /// Effective validation ratio
117    pub effective_validation_ratio: Option<f64>,
118    /// Per-node summaries
119    pub node_summaries: Vec<NodeSummary>,
120    /// Global label distribution (if stratified)
121    pub global_label_distribution: Option<HashMap<String, u64>>,
122    /// Whether split quality meets requirements
123    pub quality_passed: bool,
124    /// Quality issues found
125    pub issues: Vec<SplitQualityIssue>,
126}
127
128/// Summary for a single node
129#[derive(Debug, Clone)]
130pub struct NodeSummary {
131    /// Node ID
132    pub node_id: String,
133    /// Contribution ratio (this node's rows / total rows)
134    pub contribution_ratio: f64,
135    /// Train ratio for this node
136    pub train_ratio: f64,
137    /// Test ratio for this node
138    pub test_ratio: f64,
139}
140
141/// Quality issues that can be detected in federated splits
142#[derive(Debug, Clone, PartialEq)]
143pub enum SplitQualityIssue {
144    /// Train/test ratio differs too much from target
145    RatioDeviation {
146        /// Node with the issue
147        node_id: String,
148        /// Expected ratio
149        expected: f64,
150        /// Actual ratio
151        actual: f64,
152    },
153    /// Label distribution imbalance across nodes
154    DistributionImbalance {
155        /// Label with imbalance
156        label: String,
157        /// Nodes with significantly different distributions
158        nodes: Vec<String>,
159    },
160    /// Node has too few samples
161    InsufficientSamples {
162        /// Node ID
163        node_id: String,
164        /// Number of samples
165        samples: u64,
166        /// Minimum required
167        minimum: u64,
168    },
169    /// Split hashes don't match expected (reproducibility issue)
170    HashMismatch {
171        /// Node ID
172        node_id: String,
173    },
174}
175
176impl FederatedSplitCoordinator {
177    /// Create a new coordinator with the given strategy
178    #[must_use]
179    pub fn new(strategy: FederatedSplitStrategy) -> Self {
180        Self { strategy }
181    }
182
183    /// Get the current strategy
184    #[must_use]
185    pub fn strategy(&self) -> &FederatedSplitStrategy {
186        &self.strategy
187    }
188
189    /// Compute split instructions for each node (runs on coordinator)
190    ///
191    /// The coordinator only sees manifests (metadata), never raw data.
192    pub fn compute_split_plan(
193        &self,
194        manifests: &[NodeSplitManifest],
195    ) -> Result<Vec<NodeSplitInstruction>> {
196        if manifests.is_empty() {
197            return Err(Error::invalid_config(
198                "Cannot compute plan for empty manifest list",
199            ));
200        }
201
202        match &self.strategy {
203            FederatedSplitStrategy::LocalWithSeed { seed, train_ratio } => Ok(
204                Self::compute_local_seed_plan(manifests, *seed, *train_ratio),
205            ),
206            FederatedSplitStrategy::GlobalStratified {
207                label_column,
208                target_distribution,
209            } => Ok(Self::compute_stratified_plan(
210                manifests,
211                label_column,
212                target_distribution,
213            )),
214            FederatedSplitStrategy::ProportionalIID { train_ratio } => {
215                Ok(Self::compute_proportional_plan(manifests, *train_ratio))
216            }
217        }
218    }
219
220    /// Compute plan for LocalWithSeed strategy
221    fn compute_local_seed_plan(
222        manifests: &[NodeSplitManifest],
223        seed: u64,
224        train_ratio: f64,
225    ) -> Vec<NodeSplitInstruction> {
226        let test_ratio = 1.0 - train_ratio;
227
228        manifests
229            .iter()
230            .map(|m| NodeSplitInstruction {
231                node_id: m.node_id.clone(),
232                seed,
233                train_ratio,
234                test_ratio,
235                validation_ratio: None,
236                stratify_column: None,
237            })
238            .collect()
239    }
240
241    /// Compute plan for GlobalStratified strategy
242    fn compute_stratified_plan(
243        manifests: &[NodeSplitManifest],
244        label_column: &str,
245        _target_distribution: &HashMap<String, f64>,
246    ) -> Vec<NodeSplitInstruction> {
247        // For stratified splits, each node uses a deterministic seed based on node_id
248        // and stratifies by the label column
249        let base_seed = 42u64; // Fixed base for reproducibility
250
251        // Default to 80/20 split if not specified in target distribution
252        // Future: use target_distribution to adjust per-node ratios
253        let train_ratio = 0.8;
254        let test_ratio = 0.2;
255
256        manifests
257            .iter()
258            .enumerate()
259            .map(|(i, m)| {
260                // Derive node-specific seed from base + index
261                let node_seed = base_seed.wrapping_add(i as u64);
262
263                NodeSplitInstruction {
264                    node_id: m.node_id.clone(),
265                    seed: node_seed,
266                    train_ratio,
267                    test_ratio,
268                    validation_ratio: None,
269                    stratify_column: Some(label_column.to_string()),
270                }
271            })
272            .collect()
273    }
274
275    /// Compute plan for ProportionalIID strategy
276    fn compute_proportional_plan(
277        manifests: &[NodeSplitManifest],
278        train_ratio: f64,
279    ) -> Vec<NodeSplitInstruction> {
280        let test_ratio = 1.0 - train_ratio;
281
282        // Each node gets a unique seed based on position
283        manifests
284            .iter()
285            .enumerate()
286            .map(|(i, m)| NodeSplitInstruction {
287                node_id: m.node_id.clone(),
288                seed: i as u64,
289                train_ratio,
290                test_ratio,
291                validation_ratio: None,
292                stratify_column: None,
293            })
294            .collect()
295    }
296
297    /// Execute split locally (runs on each node)
298    ///
299    /// This function runs on the data-owning node - raw data never leaves.
300    pub fn execute_local_split(
301        dataset: &ArrowDataset,
302        instruction: &NodeSplitInstruction,
303    ) -> Result<DatasetSplit> {
304        let val_ratio = instruction.validation_ratio;
305
306        if let Some(ref column) = instruction.stratify_column {
307            DatasetSplit::stratified(
308                dataset,
309                column,
310                instruction.train_ratio,
311                instruction.test_ratio,
312                val_ratio,
313                Some(instruction.seed),
314            )
315        } else {
316            DatasetSplit::from_ratios(
317                dataset,
318                instruction.train_ratio,
319                instruction.test_ratio,
320                val_ratio,
321                Some(instruction.seed),
322            )
323        }
324    }
325
326    /// Verify global split quality (runs on coordinator)
327    ///
328    /// Only examines manifests - no access to raw data.
329    pub fn verify_global_split(manifests: &[NodeSplitManifest]) -> Result<GlobalSplitReport> {
330        if manifests.is_empty() {
331            return Err(Error::invalid_config("Cannot verify empty manifest list"));
332        }
333
334        let total_rows: u64 = manifests.iter().map(|m| m.total_rows).sum();
335        let total_train_rows: u64 = manifests.iter().map(|m| m.train_rows).sum();
336        let total_test_rows: u64 = manifests.iter().map(|m| m.test_rows).sum();
337        let total_validation_rows: Option<u64> =
338            if manifests.iter().any(|m| m.validation_rows.is_some()) {
339                Some(manifests.iter().filter_map(|m| m.validation_rows).sum())
340            } else {
341                None
342            };
343
344        let effective_train_ratio = if total_rows > 0 {
345            total_train_rows as f64 / total_rows as f64
346        } else {
347            0.0
348        };
349
350        let effective_test_ratio = if total_rows > 0 {
351            total_test_rows as f64 / total_rows as f64
352        } else {
353            0.0
354        };
355
356        let effective_validation_ratio = total_validation_rows.map(|v| {
357            if total_rows > 0 {
358                v as f64 / total_rows as f64
359            } else {
360                0.0
361            }
362        });
363
364        // Build node summaries
365        let node_summaries: Vec<NodeSummary> = manifests
366            .iter()
367            .map(|m| {
368                let contribution_ratio = if total_rows > 0 {
369                    m.total_rows as f64 / total_rows as f64
370                } else {
371                    0.0
372                };
373
374                let train_ratio = if m.total_rows > 0 {
375                    m.train_rows as f64 / m.total_rows as f64
376                } else {
377                    0.0
378                };
379
380                let test_ratio = if m.total_rows > 0 {
381                    m.test_rows as f64 / m.total_rows as f64
382                } else {
383                    0.0
384                };
385
386                NodeSummary {
387                    node_id: m.node_id.clone(),
388                    contribution_ratio,
389                    train_ratio,
390                    test_ratio,
391                }
392            })
393            .collect();
394
395        // Aggregate global label distribution
396        let global_label_distribution = Self::aggregate_label_distributions(manifests);
397
398        // Check for quality issues
399        let issues = Self::detect_quality_issues(manifests, &node_summaries);
400
401        let quality_passed = issues.is_empty();
402
403        Ok(GlobalSplitReport {
404            total_rows,
405            total_train_rows,
406            total_test_rows,
407            total_validation_rows,
408            effective_train_ratio,
409            effective_test_ratio,
410            effective_validation_ratio,
411            node_summaries,
412            global_label_distribution,
413            quality_passed,
414            issues,
415        })
416    }
417
418    /// Aggregate label distributions from all nodes
419    fn aggregate_label_distributions(
420        manifests: &[NodeSplitManifest],
421    ) -> Option<HashMap<String, u64>> {
422        let mut global_dist: HashMap<String, u64> = HashMap::new();
423        let mut any_has_distribution = false;
424
425        for manifest in manifests {
426            if let Some(ref dist) = manifest.label_distribution {
427                any_has_distribution = true;
428                for (label, count) in dist {
429                    *global_dist.entry(label.clone()).or_insert(0) += count;
430                }
431            }
432        }
433
434        if any_has_distribution {
435            Some(global_dist)
436        } else {
437            None
438        }
439    }
440
441    /// Detect quality issues in the split
442    fn detect_quality_issues(
443        manifests: &[NodeSplitManifest],
444        summaries: &[NodeSummary],
445    ) -> Vec<SplitQualityIssue> {
446        // Minimum samples threshold for valid split
447        const MIN_SAMPLES: u64 = 10;
448
449        let mut issues = Vec::new();
450
451        // Check for insufficient samples
452        for manifest in manifests {
453            if manifest.train_rows < MIN_SAMPLES || manifest.test_rows < MIN_SAMPLES {
454                issues.push(SplitQualityIssue::InsufficientSamples {
455                    node_id: manifest.node_id.clone(),
456                    samples: manifest.train_rows.min(manifest.test_rows),
457                    minimum: MIN_SAMPLES,
458                });
459            }
460        }
461
462        // Check for ratio deviation (more than 10% from mean)
463        if !summaries.is_empty() {
464            let mean_train_ratio: f64 =
465                summaries.iter().map(|s| s.train_ratio).sum::<f64>() / summaries.len() as f64;
466
467            for summary in summaries {
468                let deviation = (summary.train_ratio - mean_train_ratio).abs();
469                if deviation > 0.1 {
470                    issues.push(SplitQualityIssue::RatioDeviation {
471                        node_id: summary.node_id.clone(),
472                        expected: mean_train_ratio,
473                        actual: summary.train_ratio,
474                    });
475                }
476            }
477        }
478
479        issues
480    }
481}
482
483impl NodeSplitManifest {
484    /// Create a new manifest from split results
485    #[must_use]
486    pub fn new(
487        node_id: impl Into<String>,
488        total_rows: u64,
489        train_rows: u64,
490        test_rows: u64,
491    ) -> Self {
492        Self {
493            node_id: node_id.into(),
494            total_rows,
495            train_rows,
496            test_rows,
497            validation_rows: None,
498            label_distribution: None,
499            split_hash: [0u8; 32],
500        }
501    }
502
503    /// Set validation rows
504    #[must_use]
505    pub fn with_validation(mut self, rows: u64) -> Self {
506        self.validation_rows = Some(rows);
507        self
508    }
509
510    /// Set label distribution
511    #[must_use]
512    pub fn with_label_distribution(mut self, distribution: HashMap<String, u64>) -> Self {
513        self.label_distribution = Some(distribution);
514        self
515    }
516
517    /// Set split hash
518    #[must_use]
519    pub fn with_split_hash(mut self, hash: [u8; 32]) -> Self {
520        self.split_hash = hash;
521        self
522    }
523
524    /// Create manifest from a dataset split
525    #[must_use]
526    pub fn from_split(node_id: impl Into<String>, split: &DatasetSplit) -> Self {
527        let train_rows = split.train.len() as u64;
528        let test_rows = split.test.len() as u64;
529        let validation_rows = split.validation.as_ref().map(|v| v.len() as u64);
530
531        let mut manifest = Self::new(
532            node_id,
533            train_rows + test_rows + validation_rows.unwrap_or(0),
534            train_rows,
535            test_rows,
536        );
537
538        if let Some(v) = validation_rows {
539            manifest = manifest.with_validation(v);
540        }
541
542        manifest
543    }
544
545    /// Serialize to JSON bytes
546    pub fn to_json(&self) -> Result<Vec<u8>> {
547        serde_json::to_vec(self).map_err(|e| Error::Format(e.to_string()))
548    }
549
550    /// Deserialize from JSON bytes
551    pub fn from_json(data: &[u8]) -> Result<Self> {
552        serde_json::from_slice(data).map_err(|e| Error::Format(e.to_string()))
553    }
554}
555
556impl NodeSplitInstruction {
557    /// Serialize to JSON bytes
558    pub fn to_json(&self) -> Result<Vec<u8>> {
559        serde_json::to_vec(self).map_err(|e| Error::Format(e.to_string()))
560    }
561
562    /// Deserialize from JSON bytes
563    pub fn from_json(data: &[u8]) -> Result<Self> {
564        serde_json::from_slice(data).map_err(|e| Error::Format(e.to_string()))
565    }
566}
567
568#[cfg(test)]
569mod tests {
570    use super::*;
571
572    // ============================================================
573    // FederatedSplitStrategy tests
574    // ============================================================
575
576    #[test]
577    fn test_strategy_local_with_seed() {
578        let strategy = FederatedSplitStrategy::LocalWithSeed {
579            seed: 42,
580            train_ratio: 0.8,
581        };
582
583        match strategy {
584            FederatedSplitStrategy::LocalWithSeed { seed, train_ratio } => {
585                assert_eq!(seed, 42);
586                assert!((train_ratio - 0.8).abs() < f64::EPSILON);
587            }
588            _ => panic!("Wrong variant"),
589        }
590    }
591
592    #[test]
593    fn test_strategy_global_stratified() {
594        let mut target = HashMap::new();
595        target.insert("class_a".to_string(), 0.5);
596        target.insert("class_b".to_string(), 0.5);
597
598        let strategy = FederatedSplitStrategy::GlobalStratified {
599            label_column: "label".to_string(),
600            target_distribution: target.clone(),
601        };
602
603        match strategy {
604            FederatedSplitStrategy::GlobalStratified {
605                label_column,
606                target_distribution,
607            } => {
608                assert_eq!(label_column, "label");
609                assert_eq!(target_distribution, target);
610            }
611            _ => panic!("Wrong variant"),
612        }
613    }
614
615    #[test]
616    fn test_strategy_proportional_iid() {
617        let strategy = FederatedSplitStrategy::ProportionalIID { train_ratio: 0.7 };
618
619        match strategy {
620            FederatedSplitStrategy::ProportionalIID { train_ratio } => {
621                assert!((train_ratio - 0.7).abs() < f64::EPSILON);
622            }
623            _ => panic!("Wrong variant"),
624        }
625    }
626
627    #[test]
628    fn test_strategy_clone_and_debug() {
629        let strategy = FederatedSplitStrategy::LocalWithSeed {
630            seed: 42,
631            train_ratio: 0.8,
632        };
633
634        let cloned = strategy.clone();
635        assert_eq!(strategy, cloned);
636
637        let debug = format!("{:?}", strategy);
638        assert!(debug.contains("LocalWithSeed"));
639        assert!(debug.contains("42"));
640    }
641
642    // ============================================================
643    // NodeSplitManifest tests
644    // ============================================================
645
646    #[test]
647    fn test_manifest_new() {
648        let manifest = NodeSplitManifest::new("node_a", 1000, 800, 200);
649
650        assert_eq!(manifest.node_id, "node_a");
651        assert_eq!(manifest.total_rows, 1000);
652        assert_eq!(manifest.train_rows, 800);
653        assert_eq!(manifest.test_rows, 200);
654        assert!(manifest.validation_rows.is_none());
655        assert!(manifest.label_distribution.is_none());
656    }
657
658    #[test]
659    fn test_manifest_with_validation() {
660        let manifest = NodeSplitManifest::new("node_a", 1000, 700, 200).with_validation(100);
661
662        assert_eq!(manifest.validation_rows, Some(100));
663    }
664
665    #[test]
666    fn test_manifest_with_label_distribution() {
667        let mut dist = HashMap::new();
668        dist.insert("cat".to_string(), 500);
669        dist.insert("dog".to_string(), 500);
670
671        let manifest =
672            NodeSplitManifest::new("node_a", 1000, 800, 200).with_label_distribution(dist.clone());
673
674        assert_eq!(manifest.label_distribution, Some(dist));
675    }
676
677    #[test]
678    fn test_manifest_with_split_hash() {
679        let hash = [1u8; 32];
680        let manifest = NodeSplitManifest::new("node_a", 1000, 800, 200).with_split_hash(hash);
681
682        assert_eq!(manifest.split_hash, hash);
683    }
684
685    #[test]
686    fn test_manifest_serialization() {
687        let manifest = NodeSplitManifest::new("node_a", 1000, 800, 200);
688
689        let json = manifest.to_json().expect("serialization failed");
690        let parsed = NodeSplitManifest::from_json(&json).expect("deserialization failed");
691
692        assert_eq!(manifest, parsed);
693    }
694
695    #[test]
696    fn test_manifest_full_serialization() {
697        let mut dist = HashMap::new();
698        dist.insert("a".to_string(), 400);
699        dist.insert("b".to_string(), 600);
700
701        let manifest = NodeSplitManifest::new("node_eu", 1000, 700, 200)
702            .with_validation(100)
703            .with_label_distribution(dist)
704            .with_split_hash([42u8; 32]);
705
706        let json = manifest.to_json().expect("serialization failed");
707        let parsed = NodeSplitManifest::from_json(&json).expect("deserialization failed");
708
709        assert_eq!(manifest, parsed);
710    }
711
712    // ============================================================
713    // NodeSplitInstruction tests
714    // ============================================================
715
716    #[test]
717    fn test_instruction_serialization() {
718        let instruction = NodeSplitInstruction {
719            node_id: "node_a".to_string(),
720            seed: 42,
721            train_ratio: 0.8,
722            test_ratio: 0.2,
723            validation_ratio: None,
724            stratify_column: None,
725        };
726
727        let json = instruction.to_json().expect("serialization failed");
728        let parsed = NodeSplitInstruction::from_json(&json).expect("deserialization failed");
729
730        assert_eq!(instruction, parsed);
731    }
732
733    #[test]
734    fn test_instruction_with_stratification() {
735        let instruction = NodeSplitInstruction {
736            node_id: "node_b".to_string(),
737            seed: 123,
738            train_ratio: 0.7,
739            test_ratio: 0.15,
740            validation_ratio: Some(0.15),
741            stratify_column: Some("label".to_string()),
742        };
743
744        let json = instruction.to_json().expect("serialization failed");
745        let parsed = NodeSplitInstruction::from_json(&json).expect("deserialization failed");
746
747        assert_eq!(instruction, parsed);
748    }
749
750    // ============================================================
751    // FederatedSplitCoordinator tests
752    // ============================================================
753
754    #[test]
755    fn test_coordinator_new() {
756        let strategy = FederatedSplitStrategy::LocalWithSeed {
757            seed: 42,
758            train_ratio: 0.8,
759        };
760        let coordinator = FederatedSplitCoordinator::new(strategy.clone());
761
762        assert_eq!(coordinator.strategy(), &strategy);
763    }
764
765    #[test]
766    fn test_coordinator_empty_manifests_error() {
767        let coordinator = FederatedSplitCoordinator::new(FederatedSplitStrategy::LocalWithSeed {
768            seed: 42,
769            train_ratio: 0.8,
770        });
771
772        let result = coordinator.compute_split_plan(&[]);
773        assert!(result.is_err());
774    }
775
776    #[test]
777    fn test_coordinator_local_seed_plan() {
778        let coordinator = FederatedSplitCoordinator::new(FederatedSplitStrategy::LocalWithSeed {
779            seed: 42,
780            train_ratio: 0.8,
781        });
782
783        let manifests = vec![
784            NodeSplitManifest::new("node_a", 1000, 800, 200),
785            NodeSplitManifest::new("node_b", 2000, 1600, 400),
786        ];
787
788        let plan = coordinator
789            .compute_split_plan(&manifests)
790            .expect("plan failed");
791
792        assert_eq!(plan.len(), 2);
793
794        // All nodes get same seed
795        assert_eq!(plan[0].seed, 42);
796        assert_eq!(plan[1].seed, 42);
797
798        // All nodes get same ratios
799        assert!((plan[0].train_ratio - 0.8).abs() < f64::EPSILON);
800        assert!((plan[1].train_ratio - 0.8).abs() < f64::EPSILON);
801    }
802
803    #[test]
804    fn test_coordinator_stratified_plan() {
805        let mut target = HashMap::new();
806        target.insert("a".to_string(), 0.5);
807        target.insert("b".to_string(), 0.5);
808
809        let coordinator =
810            FederatedSplitCoordinator::new(FederatedSplitStrategy::GlobalStratified {
811                label_column: "label".to_string(),
812                target_distribution: target,
813            });
814
815        let manifests = vec![
816            NodeSplitManifest::new("node_a", 1000, 800, 200),
817            NodeSplitManifest::new("node_b", 2000, 1600, 400),
818        ];
819
820        let plan = coordinator
821            .compute_split_plan(&manifests)
822            .expect("plan failed");
823
824        assert_eq!(plan.len(), 2);
825
826        // Each node has stratify column set
827        assert_eq!(plan[0].stratify_column, Some("label".to_string()));
828        assert_eq!(plan[1].stratify_column, Some("label".to_string()));
829
830        // Nodes have different seeds (derived from position)
831        assert_ne!(plan[0].seed, plan[1].seed);
832    }
833
834    #[test]
835    fn test_coordinator_proportional_plan() {
836        let coordinator = FederatedSplitCoordinator::new(FederatedSplitStrategy::ProportionalIID {
837            train_ratio: 0.7,
838        });
839
840        let manifests = vec![
841            NodeSplitManifest::new("node_a", 1000, 700, 300),
842            NodeSplitManifest::new("node_b", 2000, 1400, 600),
843            NodeSplitManifest::new("node_c", 500, 350, 150),
844        ];
845
846        let plan = coordinator
847            .compute_split_plan(&manifests)
848            .expect("plan failed");
849
850        assert_eq!(plan.len(), 3);
851
852        // All nodes get same ratio
853        for instruction in &plan {
854            assert!((instruction.train_ratio - 0.7).abs() < f64::EPSILON);
855            assert!((instruction.test_ratio - 0.3).abs() < f64::EPSILON);
856        }
857
858        // Each node has unique seed
859        assert_eq!(plan[0].seed, 0);
860        assert_eq!(plan[1].seed, 1);
861        assert_eq!(plan[2].seed, 2);
862    }
863
864    // ============================================================
865    // GlobalSplitReport tests
866    // ============================================================
867
868    #[test]
869    fn test_verify_global_split_empty_error() {
870        let result = FederatedSplitCoordinator::verify_global_split(&[]);
871        assert!(result.is_err());
872    }
873
874    #[test]
875    fn test_verify_global_split_single_node() {
876        let manifests = vec![NodeSplitManifest::new("node_a", 1000, 800, 200)];
877
878        let report =
879            FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
880
881        assert_eq!(report.total_rows, 1000);
882        assert_eq!(report.total_train_rows, 800);
883        assert_eq!(report.total_test_rows, 200);
884        assert!((report.effective_train_ratio - 0.8).abs() < f64::EPSILON);
885        assert!((report.effective_test_ratio - 0.2).abs() < f64::EPSILON);
886        assert!(report.quality_passed);
887    }
888
889    #[test]
890    fn test_verify_global_split_multiple_nodes() {
891        let manifests = vec![
892            NodeSplitManifest::new("node_a", 1000, 800, 200),
893            NodeSplitManifest::new("node_b", 2000, 1600, 400),
894            NodeSplitManifest::new("node_c", 1000, 800, 200),
895        ];
896
897        let report =
898            FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
899
900        assert_eq!(report.total_rows, 4000);
901        assert_eq!(report.total_train_rows, 3200);
902        assert_eq!(report.total_test_rows, 800);
903        assert!((report.effective_train_ratio - 0.8).abs() < f64::EPSILON);
904
905        assert_eq!(report.node_summaries.len(), 3);
906        assert!(report.quality_passed);
907    }
908
909    #[test]
910    fn test_verify_global_split_with_validation() {
911        let manifests = vec![
912            NodeSplitManifest::new("node_a", 1000, 700, 200).with_validation(100),
913            NodeSplitManifest::new("node_b", 2000, 1400, 400).with_validation(200),
914        ];
915
916        let report =
917            FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
918
919        assert_eq!(report.total_validation_rows, Some(300));
920        assert!((report.effective_validation_ratio.unwrap() - 0.1).abs() < f64::EPSILON);
921    }
922
923    #[test]
924    fn test_verify_global_split_aggregates_labels() {
925        let mut dist_a = HashMap::new();
926        dist_a.insert("cat".to_string(), 600);
927        dist_a.insert("dog".to_string(), 400);
928
929        let mut dist_b = HashMap::new();
930        dist_b.insert("cat".to_string(), 800);
931        dist_b.insert("dog".to_string(), 1200);
932
933        let manifests = vec![
934            NodeSplitManifest::new("node_a", 1000, 800, 200).with_label_distribution(dist_a),
935            NodeSplitManifest::new("node_b", 2000, 1600, 400).with_label_distribution(dist_b),
936        ];
937
938        let report =
939            FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
940
941        let global_dist = report
942            .global_label_distribution
943            .expect("should have distribution");
944        assert_eq!(global_dist.get("cat"), Some(&1400));
945        assert_eq!(global_dist.get("dog"), Some(&1600));
946    }
947
948    #[test]
949    fn test_verify_detects_insufficient_samples() {
950        let manifests = vec![
951            NodeSplitManifest::new("node_a", 1000, 800, 200),
952            NodeSplitManifest::new("node_b", 15, 10, 5), // Too few test samples
953        ];
954
955        let report =
956            FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
957
958        assert!(!report.quality_passed);
959        assert!(!report.issues.is_empty());
960
961        let has_insufficient = report.issues.iter().any(|i| {
962            matches!(i, SplitQualityIssue::InsufficientSamples { node_id, .. } if node_id == "node_b")
963        });
964        assert!(has_insufficient);
965    }
966
967    #[test]
968    fn test_verify_detects_ratio_deviation() {
969        let manifests = vec![
970            NodeSplitManifest::new("node_a", 1000, 800, 200), // 80/20
971            NodeSplitManifest::new("node_b", 1000, 500, 500), // 50/50 - big deviation
972        ];
973
974        let report =
975            FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
976
977        assert!(!report.quality_passed);
978
979        let has_deviation = report
980            .issues
981            .iter()
982            .any(|i| matches!(i, SplitQualityIssue::RatioDeviation { .. }));
983        assert!(has_deviation);
984    }
985
986    #[test]
987    fn test_node_summary_contribution_ratio() {
988        let manifests = vec![
989            NodeSplitManifest::new("node_a", 1000, 800, 200),
990            NodeSplitManifest::new("node_b", 3000, 2400, 600),
991        ];
992
993        let report =
994            FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
995
996        // node_a has 1000/4000 = 0.25
997        assert!((report.node_summaries[0].contribution_ratio - 0.25).abs() < f64::EPSILON);
998
999        // node_b has 3000/4000 = 0.75
1000        assert!((report.node_summaries[1].contribution_ratio - 0.75).abs() < f64::EPSILON);
1001    }
1002
1003    #[test]
1004    fn test_split_quality_issue_variants() {
1005        let ratio_issue = SplitQualityIssue::RatioDeviation {
1006            node_id: "node_a".to_string(),
1007            expected: 0.8,
1008            actual: 0.5,
1009        };
1010        assert!(format!("{:?}", ratio_issue).contains("RatioDeviation"));
1011
1012        let dist_issue = SplitQualityIssue::DistributionImbalance {
1013            label: "cat".to_string(),
1014            nodes: vec!["node_a".to_string(), "node_b".to_string()],
1015        };
1016        assert!(format!("{:?}", dist_issue).contains("DistributionImbalance"));
1017
1018        let sample_issue = SplitQualityIssue::InsufficientSamples {
1019            node_id: "node_a".to_string(),
1020            samples: 5,
1021            minimum: 10,
1022        };
1023        assert!(format!("{:?}", sample_issue).contains("InsufficientSamples"));
1024
1025        let hash_issue = SplitQualityIssue::HashMismatch {
1026            node_id: "node_a".to_string(),
1027        };
1028        assert!(format!("{:?}", hash_issue).contains("HashMismatch"));
1029    }
1030
1031    #[test]
1032    fn test_global_split_report_debug() {
1033        let manifests = vec![NodeSplitManifest::new("node_a", 100, 80, 20)];
1034        let report =
1035            FederatedSplitCoordinator::verify_global_split(&manifests).expect("verify failed");
1036
1037        let debug = format!("{:?}", report);
1038        assert!(debug.contains("GlobalSplitReport"));
1039        assert!(debug.contains("total_rows"));
1040    }
1041}