Skip to main content

alimentar/cli/
fed.rs

1//! Federated split coordination CLI commands.
2
3use std::path::{Path, PathBuf};
4
5use clap::Subcommand;
6
7use super::basic::load_dataset;
8use crate::{
9    federated::{
10        FederatedSplitCoordinator, FederatedSplitStrategy, NodeSplitInstruction, NodeSplitManifest,
11    },
12    split::DatasetSplit,
13    Dataset,
14};
15
16/// Federated split coordination commands.
17#[derive(Subcommand)]
18pub enum FedCommands {
19    /// Generate a manifest from local dataset (runs on each node)
20    Manifest {
21        /// Input dataset file
22        input: PathBuf,
23        /// Output manifest file
24        #[arg(short, long)]
25        output: PathBuf,
26        /// Unique node identifier
27        #[arg(short, long)]
28        node_id: String,
29        /// Training set ratio
30        #[arg(short = 'r', long, default_value = "0.8")]
31        train_ratio: f64,
32        /// Random seed for reproducibility
33        #[arg(short, long, default_value = "42")]
34        seed: u64,
35        /// Output format (json, binary)
36        #[arg(short, long, default_value = "json")]
37        format: String,
38    },
39    /// Create a split plan from manifests (runs on coordinator)
40    Plan {
41        /// Manifest files from all nodes
42        #[arg(required = true)]
43        manifests: Vec<PathBuf>,
44        /// Output plan file
45        #[arg(short, long)]
46        output: PathBuf,
47        /// Split strategy (local, proportional, stratified)
48        #[arg(short, long, default_value = "local")]
49        strategy: String,
50        /// Training set ratio
51        #[arg(short = 'r', long, default_value = "0.8")]
52        train_ratio: f64,
53        /// Random seed
54        #[arg(long, default_value = "42")]
55        seed: u64,
56        /// Column for stratification (required for stratified strategy)
57        #[arg(long)]
58        stratify_column: Option<String>,
59        /// Output format (json, binary)
60        #[arg(short, long, default_value = "json")]
61        format: String,
62    },
63    /// Execute local split based on plan (runs on each node)
64    Split {
65        /// Input dataset file
66        input: PathBuf,
67        /// Split plan file
68        #[arg(short, long)]
69        plan: PathBuf,
70        /// This node's ID
71        #[arg(short, long)]
72        node_id: String,
73        /// Output training set file
74        #[arg(long)]
75        train_output: PathBuf,
76        /// Output test set file
77        #[arg(long)]
78        test_output: PathBuf,
79        /// Output validation set file (optional)
80        #[arg(long)]
81        validation_output: Option<PathBuf>,
82    },
83    /// Verify global split quality from manifests (runs on coordinator)
84    Verify {
85        /// Manifest files from all nodes
86        #[arg(required = true)]
87        manifests: Vec<PathBuf>,
88        /// Output format (text, json)
89        #[arg(short, long, default_value = "text")]
90        format: String,
91    },
92}
93
94/// Parse federated split strategy from string.
95pub(crate) fn parse_fed_strategy(
96    strategy: &str,
97    train_ratio: f64,
98    seed: u64,
99    stratify_column: Option<&str>,
100) -> Option<FederatedSplitStrategy> {
101    match strategy.to_lowercase().as_str() {
102        "local" | "local-seed" => Some(FederatedSplitStrategy::LocalWithSeed { seed, train_ratio }),
103        "proportional" | "iid" => Some(FederatedSplitStrategy::ProportionalIID { train_ratio }),
104        "stratified" => {
105            let column = stratify_column.unwrap_or("label").to_string();
106            Some(FederatedSplitStrategy::GlobalStratified {
107                label_column: column,
108                target_distribution: std::collections::HashMap::new(),
109            })
110        }
111        _ => None,
112    }
113}
114
115/// Generate a manifest from local dataset.
116pub(crate) fn cmd_fed_manifest(
117    input: &Path,
118    output: &Path,
119    node_id: &str,
120    train_ratio: f64,
121    seed: u64,
122    format: &str,
123) -> crate::Result<()> {
124    let dataset = load_dataset(input)?;
125
126    // Create a split to generate the manifest
127    let split =
128        DatasetSplit::from_ratios(&dataset, train_ratio, 1.0 - train_ratio, None, Some(seed))?;
129
130    let manifest = NodeSplitManifest::from_split(node_id, &split);
131
132    match format {
133        "binary" | "bin" => {
134            let bytes =
135                rmp_serde::to_vec(&manifest).map_err(|e| crate::Error::Format(e.to_string()))?;
136            std::fs::write(output, bytes).map_err(|e| crate::Error::io(e, output))?;
137        }
138        _ => {
139            let json = serde_json::to_string_pretty(&manifest)
140                .map_err(|e| crate::Error::Format(e.to_string()))?;
141            std::fs::write(output, json).map_err(|e| crate::Error::io(e, output))?;
142        }
143    }
144
145    println!(
146        "Created manifest for node '{}' ({} rows) -> {}",
147        node_id,
148        dataset.len(),
149        output.display()
150    );
151
152    Ok(())
153}
154
155/// Load a manifest from a file.
156pub(crate) fn load_manifest(path: &Path) -> crate::Result<NodeSplitManifest> {
157    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
158
159    match ext {
160        "bin" | "binary" => {
161            let bytes = std::fs::read(path).map_err(|e| crate::Error::io(e, path))?;
162            rmp_serde::from_slice(&bytes)
163                .map_err(|e| crate::Error::Format(format!("Invalid manifest binary: {}", e)))
164        }
165        _ => {
166            let json = std::fs::read_to_string(path).map_err(|e| crate::Error::io(e, path))?;
167            serde_json::from_str(&json)
168                .map_err(|e| crate::Error::Format(format!("Invalid manifest JSON: {}", e)))
169        }
170    }
171}
172
173/// Create a split plan from manifests.
174#[allow(clippy::too_many_arguments)]
175pub(crate) fn cmd_fed_plan(
176    manifests: &[PathBuf],
177    output: &Path,
178    strategy: &str,
179    train_ratio: f64,
180    seed: u64,
181    stratify_column: Option<&str>,
182    format: &str,
183) -> crate::Result<()> {
184    if manifests.is_empty() {
185        return Err(crate::Error::invalid_config("No manifests provided"));
186    }
187
188    let loaded: Vec<NodeSplitManifest> = manifests
189        .iter()
190        .map(|p| load_manifest(p))
191        .collect::<Result<Vec<_>, _>>()?;
192
193    let strategy =
194        parse_fed_strategy(strategy, train_ratio, seed, stratify_column).ok_or_else(|| {
195            crate::Error::invalid_config(format!(
196                "Unknown strategy: {}. Use 'local', 'proportional', or 'stratified'",
197                strategy
198            ))
199        })?;
200
201    let coordinator = FederatedSplitCoordinator::new(strategy);
202    let instructions = coordinator.compute_split_plan(&loaded)?;
203
204    match format {
205        "binary" | "bin" => {
206            let bytes = rmp_serde::to_vec(&instructions)
207                .map_err(|e| crate::Error::Format(e.to_string()))?;
208            std::fs::write(output, bytes).map_err(|e| crate::Error::io(e, output))?;
209        }
210        _ => {
211            let json = serde_json::to_string_pretty(&instructions)
212                .map_err(|e| crate::Error::Format(e.to_string()))?;
213            std::fs::write(output, json).map_err(|e| crate::Error::io(e, output))?;
214        }
215    }
216
217    println!(
218        "Created split plan for {} nodes -> {}",
219        instructions.len(),
220        output.display()
221    );
222
223    Ok(())
224}
225
226/// Load a split plan from a file.
227pub(crate) fn load_plan(path: &Path) -> crate::Result<Vec<NodeSplitInstruction>> {
228    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
229
230    match ext {
231        "bin" | "binary" => {
232            let bytes = std::fs::read(path).map_err(|e| crate::Error::io(e, path))?;
233            rmp_serde::from_slice(&bytes)
234                .map_err(|e| crate::Error::Format(format!("Invalid plan binary: {}", e)))
235        }
236        _ => {
237            let json = std::fs::read_to_string(path).map_err(|e| crate::Error::io(e, path))?;
238            serde_json::from_str(&json)
239                .map_err(|e| crate::Error::Format(format!("Invalid plan JSON: {}", e)))
240        }
241    }
242}
243
244/// Execute local split based on plan.
245pub(crate) fn cmd_fed_split(
246    input: &Path,
247    plan: &Path,
248    node_id: &str,
249    train_output: &Path,
250    test_output: &Path,
251    validation_output: Option<&PathBuf>,
252) -> crate::Result<()> {
253    let dataset = load_dataset(input)?;
254    let instructions = load_plan(plan)?;
255
256    // Find instruction for this node
257    let instruction = instructions
258        .iter()
259        .find(|i| i.node_id == node_id)
260        .ok_or_else(|| {
261            crate::Error::invalid_config(format!(
262                "No instruction found for node '{}' in plan",
263                node_id
264            ))
265        })?;
266
267    // Execute the split
268    let split = FederatedSplitCoordinator::execute_local_split(&dataset, instruction)?;
269
270    // Save outputs
271    split.train.to_parquet(train_output)?;
272    split.test.to_parquet(test_output)?;
273
274    if let (Some(val_output), Some(val_data)) = (validation_output, &split.validation) {
275        val_data.to_parquet(val_output)?;
276    }
277
278    println!(
279        "Split executed for node '{}': {} train, {} test{}",
280        node_id,
281        split.train.len(),
282        split.test.len(),
283        split
284            .validation
285            .as_ref()
286            .map_or(String::new(), |v| format!(", {} validation", v.len()))
287    );
288
289    Ok(())
290}
291
292/// Verify global split quality from manifests.
293pub(crate) fn cmd_fed_verify(manifests: &[PathBuf], format: &str) -> crate::Result<()> {
294    if manifests.is_empty() {
295        return Err(crate::Error::invalid_config("No manifests provided"));
296    }
297
298    let loaded: Vec<NodeSplitManifest> = manifests
299        .iter()
300        .map(|p| load_manifest(p))
301        .collect::<Result<Vec<_>, _>>()?;
302
303    let report = FederatedSplitCoordinator::verify_global_split(&loaded)?;
304
305    if format == "json" {
306        let json = serde_json::json!({
307            "total_rows": report.total_rows,
308            "total_train_rows": report.total_train_rows,
309            "total_test_rows": report.total_test_rows,
310            "total_validation_rows": report.total_validation_rows,
311            "effective_train_ratio": report.effective_train_ratio,
312            "effective_test_ratio": report.effective_test_ratio,
313            "effective_validation_ratio": report.effective_validation_ratio,
314            "quality_passed": report.quality_passed,
315            "issues": report.issues.iter().map(|i| format!("{:?}", i)).collect::<Vec<_>>(),
316            "node_summaries": report.node_summaries.iter().map(|n| {
317                serde_json::json!({
318                    "node_id": n.node_id,
319                    "contribution_ratio": n.contribution_ratio,
320                    "train_ratio": n.train_ratio,
321                    "test_ratio": n.test_ratio,
322                })
323            }).collect::<Vec<_>>()
324        });
325
326        let json_str =
327            serde_json::to_string_pretty(&json).map_err(|e| crate::Error::Format(e.to_string()))?;
328        println!("{}", json_str);
329    } else {
330        // Text format
331        println!("Federated Split Verification");
332        println!("============================");
333        println!();
334        println!("Global Statistics:");
335        println!("  Total rows:        {}", report.total_rows);
336        println!(
337            "  Train rows:        {} ({:.1}%)",
338            report.total_train_rows,
339            report.effective_train_ratio * 100.0
340        );
341        println!(
342            "  Test rows:         {} ({:.1}%)",
343            report.total_test_rows,
344            report.effective_test_ratio * 100.0
345        );
346        if let Some(val) = report.total_validation_rows {
347            println!(
348                "  Validation rows:   {} ({:.1}%)",
349                val,
350                report.effective_validation_ratio.unwrap_or(0.0) * 100.0
351            );
352        }
353        println!();
354
355        println!("Node Summaries:");
356        println!(
357            "{:<15} {:>12} {:>10} {:>10}",
358            "NODE", "CONTRIBUTION", "TRAIN", "TEST"
359        );
360        println!("{}", "-".repeat(50));
361
362        for summary in &report.node_summaries {
363            println!(
364                "{:<15} {:>11.1}% {:>9.1}% {:>9.1}%",
365                summary.node_id,
366                summary.contribution_ratio * 100.0,
367                summary.train_ratio * 100.0,
368                summary.test_ratio * 100.0
369            );
370        }
371
372        println!();
373        if report.quality_passed {
374            println!("\u{2713} Quality check passed");
375        } else {
376            println!("\u{26A0} Quality issues detected:");
377            for issue in &report.issues {
378                println!("  - {:?}", issue);
379            }
380        }
381    }
382
383    Ok(())
384}
385
386#[cfg(test)]
387#[allow(
388    clippy::cast_possible_truncation,
389    clippy::cast_possible_wrap,
390    clippy::cast_precision_loss,
391    clippy::uninlined_format_args,
392    clippy::unwrap_used,
393    clippy::expect_used,
394    clippy::redundant_clone,
395    clippy::cast_lossless,
396    clippy::redundant_closure_for_method_calls,
397    clippy::too_many_lines,
398    clippy::float_cmp,
399    clippy::similar_names,
400    clippy::needless_late_init,
401    clippy::redundant_pattern_matching
402)]
403mod tests {
404    use std::sync::Arc;
405
406    use arrow::{
407        array::{Int32Array, StringArray},
408        datatypes::{DataType, Field, Schema},
409    };
410
411    use super::*;
412    use crate::ArrowDataset;
413
414    fn create_test_parquet(path: &Path, rows: usize) {
415        let schema = Arc::new(Schema::new(vec![
416            Field::new("id", DataType::Int32, false),
417            Field::new("name", DataType::Utf8, false),
418        ]));
419
420        let ids: Vec<i32> = (0..rows as i32).collect();
421        let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
422
423        let batch = arrow::array::RecordBatch::try_new(
424            schema,
425            vec![
426                Arc::new(Int32Array::from(ids)),
427                Arc::new(StringArray::from(names)),
428            ],
429        )
430        .ok()
431        .unwrap_or_else(|| panic!("Should create batch"));
432
433        let dataset = ArrowDataset::from_batch(batch)
434            .ok()
435            .unwrap_or_else(|| panic!("Should create dataset"));
436
437        dataset
438            .to_parquet(path)
439            .ok()
440            .unwrap_or_else(|| panic!("Should write parquet"));
441    }
442
443    #[test]
444    fn test_parse_fed_strategy() {
445        assert!(matches!(
446            parse_fed_strategy("local", 0.8, 42, None),
447            Some(_)
448        ));
449        assert!(matches!(
450            parse_fed_strategy("proportional", 0.8, 42, None),
451            Some(_)
452        ));
453        assert!(matches!(
454            parse_fed_strategy("stratified", 0.8, 42, Some("label")),
455            Some(_)
456        ));
457        assert!(parse_fed_strategy("invalid", 0.8, 42, None).is_none());
458    }
459
460    #[test]
461    fn test_cmd_fed_manifest_basic() {
462        let temp_dir = tempfile::tempdir()
463            .ok()
464            .unwrap_or_else(|| panic!("Should create temp dir"));
465        let data_path = temp_dir.path().join("data.parquet");
466        let manifest_path = temp_dir.path().join("manifest.json");
467        create_test_parquet(&data_path, 100);
468
469        let result = cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json");
470        assert!(result.is_ok());
471        assert!(manifest_path.exists());
472
473        // Verify manifest contents
474        let content = std::fs::read_to_string(&manifest_path)
475            .ok()
476            .unwrap_or_else(|| panic!("Should read file"));
477        let parsed: serde_json::Value = serde_json::from_str(&content)
478            .ok()
479            .unwrap_or_else(|| panic!("Should parse JSON"));
480        assert_eq!(
481            parsed.get("node_id").and_then(|v| v.as_str()),
482            Some("node-1")
483        );
484        assert_eq!(parsed.get("total_rows").and_then(|v| v.as_u64()), Some(100));
485    }
486
487    #[test]
488    fn test_cmd_fed_manifest_binary() {
489        let temp_dir = tempfile::tempdir()
490            .ok()
491            .unwrap_or_else(|| panic!("Should create temp dir"));
492        let data_path = temp_dir.path().join("data.parquet");
493        let manifest_path = temp_dir.path().join("manifest.bin");
494        create_test_parquet(&data_path, 50);
495
496        let result = cmd_fed_manifest(&data_path, &manifest_path, "node-2", 0.8, 42, "binary");
497        assert!(result.is_ok());
498        assert!(manifest_path.exists());
499    }
500
501    #[test]
502    fn test_cmd_fed_plan_local_strategy() {
503        let temp_dir = tempfile::tempdir()
504            .ok()
505            .unwrap_or_else(|| panic!("Should create temp dir"));
506
507        // Create manifests for two nodes
508        let data1 = temp_dir.path().join("data1.parquet");
509        let data2 = temp_dir.path().join("data2.parquet");
510        let manifest1 = temp_dir.path().join("manifest1.json");
511        let manifest2 = temp_dir.path().join("manifest2.json");
512        let plan_path = temp_dir.path().join("plan.json");
513
514        create_test_parquet(&data1, 100);
515        create_test_parquet(&data2, 150);
516
517        cmd_fed_manifest(&data1, &manifest1, "node-1", 0.8, 42, "json")
518            .ok()
519            .unwrap_or_else(|| panic!("Should create manifest1"));
520        cmd_fed_manifest(&data2, &manifest2, "node-2", 0.8, 42, "json")
521            .ok()
522            .unwrap_or_else(|| panic!("Should create manifest2"));
523
524        let manifests = vec![manifest1.clone(), manifest2.clone()];
525        let result = cmd_fed_plan(&manifests, &plan_path, "local", 0.8, 42, None, "json");
526        assert!(result.is_ok());
527        assert!(plan_path.exists());
528
529        // Verify plan contents
530        let content = std::fs::read_to_string(&plan_path)
531            .ok()
532            .unwrap_or_else(|| panic!("Should read file"));
533        let parsed: serde_json::Value = serde_json::from_str(&content)
534            .ok()
535            .unwrap_or_else(|| panic!("Should parse JSON"));
536        let instructions = parsed.as_array();
537        assert!(instructions.is_some());
538        assert_eq!(instructions.map(|a| a.len()), Some(2));
539    }
540
541    #[test]
542    fn test_cmd_fed_plan_empty_manifests_fails() {
543        let temp_dir = tempfile::tempdir()
544            .ok()
545            .unwrap_or_else(|| panic!("Should create temp dir"));
546        let plan_path = temp_dir.path().join("plan.json");
547
548        let manifests: Vec<PathBuf> = vec![];
549        let result = cmd_fed_plan(&manifests, &plan_path, "local", 0.8, 42, None, "json");
550        assert!(result.is_err());
551    }
552
553    #[test]
554    fn test_cmd_fed_split_basic() {
555        let temp_dir = tempfile::tempdir()
556            .ok()
557            .unwrap_or_else(|| panic!("Should create temp dir"));
558
559        let data_path = temp_dir.path().join("data.parquet");
560        let manifest_path = temp_dir.path().join("manifest.json");
561        let plan_path = temp_dir.path().join("plan.json");
562        let train_path = temp_dir.path().join("train.parquet");
563        let test_path = temp_dir.path().join("test.parquet");
564
565        create_test_parquet(&data_path, 100);
566
567        // Create manifest
568        cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
569            .ok()
570            .unwrap_or_else(|| panic!("Should create manifest"));
571
572        // Create plan
573        let manifests = vec![manifest_path.clone()];
574        cmd_fed_plan(&manifests, &plan_path, "local", 0.8, 42, None, "json")
575            .ok()
576            .unwrap_or_else(|| panic!("Should create plan"));
577
578        // Execute split
579        let result = cmd_fed_split(
580            &data_path,
581            &plan_path,
582            "node-1",
583            &train_path,
584            &test_path,
585            None,
586        );
587        assert!(result.is_ok());
588        assert!(train_path.exists());
589        assert!(test_path.exists());
590
591        // Verify split sizes
592        let train_ds = ArrowDataset::from_parquet(&train_path)
593            .ok()
594            .unwrap_or_else(|| panic!("Should load train"));
595        let test_ds = ArrowDataset::from_parquet(&test_path)
596            .ok()
597            .unwrap_or_else(|| panic!("Should load test"));
598
599        assert!(train_ds.len() > 0);
600        assert!(test_ds.len() > 0);
601        assert_eq!(train_ds.len() + test_ds.len(), 100);
602    }
603
604    #[test]
605    fn test_cmd_fed_split_node_not_found() {
606        let temp_dir = tempfile::tempdir()
607            .ok()
608            .unwrap_or_else(|| panic!("Should create temp dir"));
609
610        let data_path = temp_dir.path().join("data.parquet");
611        let manifest_path = temp_dir.path().join("manifest.json");
612        let plan_path = temp_dir.path().join("plan.json");
613        let train_path = temp_dir.path().join("train.parquet");
614        let test_path = temp_dir.path().join("test.parquet");
615
616        create_test_parquet(&data_path, 100);
617
618        cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
619            .ok()
620            .unwrap_or_else(|| panic!("Should create manifest"));
621
622        let manifests = vec![manifest_path.clone()];
623        cmd_fed_plan(&manifests, &plan_path, "local", 0.8, 42, None, "json")
624            .ok()
625            .unwrap_or_else(|| panic!("Should create plan"));
626
627        // Try to split with wrong node ID
628        let result = cmd_fed_split(
629            &data_path,
630            &plan_path,
631            "wrong-node",
632            &train_path,
633            &test_path,
634            None,
635        );
636        assert!(result.is_err());
637    }
638
639    #[test]
640    fn test_cmd_fed_verify_basic() {
641        let temp_dir = tempfile::tempdir()
642            .ok()
643            .unwrap_or_else(|| panic!("Should create temp dir"));
644
645        let data1 = temp_dir.path().join("data1.parquet");
646        let data2 = temp_dir.path().join("data2.parquet");
647        let manifest1 = temp_dir.path().join("manifest1.json");
648        let manifest2 = temp_dir.path().join("manifest2.json");
649
650        create_test_parquet(&data1, 100);
651        create_test_parquet(&data2, 150);
652
653        cmd_fed_manifest(&data1, &manifest1, "node-1", 0.8, 42, "json")
654            .ok()
655            .unwrap_or_else(|| panic!("Should create manifest1"));
656        cmd_fed_manifest(&data2, &manifest2, "node-2", 0.8, 42, "json")
657            .ok()
658            .unwrap_or_else(|| panic!("Should create manifest2"));
659
660        let manifests = vec![manifest1.clone(), manifest2.clone()];
661        let result = cmd_fed_verify(&manifests, "text");
662        assert!(result.is_ok());
663    }
664
665    #[test]
666    fn test_cmd_fed_verify_json_format() {
667        let temp_dir = tempfile::tempdir()
668            .ok()
669            .unwrap_or_else(|| panic!("Should create temp dir"));
670
671        let data_path = temp_dir.path().join("data.parquet");
672        let manifest_path = temp_dir.path().join("manifest.json");
673
674        create_test_parquet(&data_path, 100);
675
676        cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
677            .ok()
678            .unwrap_or_else(|| panic!("Should create manifest"));
679
680        let manifests = vec![manifest_path.clone()];
681        let result = cmd_fed_verify(&manifests, "json");
682        assert!(result.is_ok());
683    }
684
685    #[test]
686    fn test_cmd_fed_verify_empty_manifests_fails() {
687        let manifests: Vec<PathBuf> = vec![];
688        let result = cmd_fed_verify(&manifests, "text");
689        assert!(result.is_err());
690    }
691
692    #[test]
693    fn test_cmd_fed_plan_proportional_strategy() {
694        let temp_dir = tempfile::tempdir()
695            .ok()
696            .unwrap_or_else(|| panic!("Should create temp dir"));
697
698        let data_path = temp_dir.path().join("data.parquet");
699        let manifest_path = temp_dir.path().join("manifest.json");
700        let plan_path = temp_dir.path().join("plan.json");
701
702        create_test_parquet(&data_path, 100);
703
704        cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
705            .ok()
706            .unwrap_or_else(|| panic!("Should create manifest"));
707
708        let manifests = vec![manifest_path.clone()];
709        let result = cmd_fed_plan(
710            &manifests,
711            &plan_path,
712            "proportional",
713            0.7,
714            42,
715            None,
716            "json",
717        );
718        assert!(result.is_ok());
719
720        // Verify train ratio is 0.7
721        let content = std::fs::read_to_string(&plan_path)
722            .ok()
723            .unwrap_or_else(|| panic!("Should read file"));
724        let parsed: serde_json::Value = serde_json::from_str(&content)
725            .ok()
726            .unwrap_or_else(|| panic!("Should parse JSON"));
727        let instructions = parsed
728            .as_array()
729            .unwrap_or_else(|| panic!("Should be array"));
730        let train_ratio = instructions[0].get("train_ratio").and_then(|v| v.as_f64());
731        assert!((train_ratio.unwrap_or(0.0) - 0.7).abs() < 0.01);
732    }
733
734    #[test]
735    fn test_load_manifest_invalid_json() {
736        let temp_dir = tempfile::tempdir()
737            .ok()
738            .unwrap_or_else(|| panic!("Should create temp dir"));
739        let manifest_path = temp_dir.path().join("invalid.json");
740
741        std::fs::write(&manifest_path, "not valid json")
742            .ok()
743            .unwrap_or_else(|| panic!("Should write file"));
744
745        let result = load_manifest(&manifest_path);
746        assert!(result.is_err());
747    }
748
749    #[test]
750    fn test_load_plan_invalid_json() {
751        let temp_dir = tempfile::tempdir()
752            .ok()
753            .unwrap_or_else(|| panic!("Should create temp dir"));
754        let plan_path = temp_dir.path().join("invalid.json");
755
756        std::fs::write(&plan_path, "{ broken }")
757            .ok()
758            .unwrap_or_else(|| panic!("Should write file"));
759
760        let result = load_plan(&plan_path);
761        assert!(result.is_err());
762    }
763
764    #[test]
765    fn test_cmd_fed_plan_invalid_strategy() {
766        let temp_dir = tempfile::tempdir()
767            .ok()
768            .unwrap_or_else(|| panic!("Should create temp dir"));
769
770        let data_path = temp_dir.path().join("data.parquet");
771        let manifest_path = temp_dir.path().join("manifest.json");
772        let plan_path = temp_dir.path().join("plan.json");
773
774        create_test_parquet(&data_path, 100);
775
776        cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
777            .ok()
778            .unwrap_or_else(|| panic!("Should create manifest"));
779
780        let manifests = vec![manifest_path.clone()];
781        let result = cmd_fed_plan(
782            &manifests,
783            &plan_path,
784            "invalid_strategy",
785            0.8,
786            42,
787            None,
788            "json",
789        );
790        assert!(result.is_err());
791    }
792
793    #[test]
794    fn test_cmd_fed_plan_stratified_strategy() {
795        let temp_dir = tempfile::tempdir()
796            .ok()
797            .unwrap_or_else(|| panic!("Should create temp dir"));
798
799        let data_path = temp_dir.path().join("data.parquet");
800        let manifest_path = temp_dir.path().join("manifest.json");
801        let plan_path = temp_dir.path().join("plan.json");
802
803        create_test_parquet(&data_path, 100);
804
805        cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
806            .ok()
807            .unwrap_or_else(|| panic!("Should create manifest"));
808
809        let manifests = vec![manifest_path.clone()];
810        let result = cmd_fed_plan(
811            &manifests,
812            &plan_path,
813            "stratified",
814            0.8,
815            42,
816            Some("name"),
817            "json",
818        );
819        assert!(result.is_ok());
820    }
821
822    #[test]
823    fn test_cmd_fed_verify_with_quality_issues() {
824        let temp_dir = tempfile::tempdir()
825            .ok()
826            .unwrap_or_else(|| panic!("Should create temp dir"));
827
828        let data_path = temp_dir.path().join("small.parquet");
829        let manifest_path = temp_dir.path().join("manifest.json");
830
831        create_test_parquet(&data_path, 15);
832
833        cmd_fed_manifest(&data_path, &manifest_path, "small-node", 0.8, 42, "json")
834            .ok()
835            .unwrap_or_else(|| panic!("Should create manifest"));
836
837        let manifests = vec![manifest_path.clone()];
838        let result = cmd_fed_verify(&manifests, "text");
839        assert!(result.is_ok());
840    }
841
842    #[test]
843    fn test_parse_fed_strategy_iid() {
844        let result = parse_fed_strategy("iid", 0.8, 42, None);
845        assert!(result.is_some());
846    }
847
848    #[test]
849    fn test_parse_fed_strategy_proportional() {
850        let result = parse_fed_strategy("proportional", 0.8, 42, None);
851        assert!(result.is_some());
852    }
853
854    #[test]
855    fn test_parse_fed_strategy_local() {
856        let result = parse_fed_strategy("local", 0.8, 42, None);
857        assert!(result.is_some());
858    }
859
860    #[test]
861    fn test_parse_fed_strategy_stratified() {
862        let result = parse_fed_strategy("stratified", 0.8, 42, Some("label"));
863        assert!(result.is_some());
864    }
865
866    #[test]
867    fn test_parse_fed_strategy_unknown() {
868        assert!(parse_fed_strategy("invalid", 0.8, 42, None).is_none());
869    }
870
871    #[test]
872    fn test_cmd_fed_plan() {
873        let temp_dir = tempfile::tempdir()
874            .ok()
875            .unwrap_or_else(|| panic!("Should create temp dir"));
876        let manifest1 = temp_dir.path().join("node1.json");
877        let manifest2 = temp_dir.path().join("node2.json");
878        let data1 = temp_dir.path().join("data1.parquet");
879        let data2 = temp_dir.path().join("data2.parquet");
880        let output = temp_dir.path().join("plan.json");
881
882        create_test_parquet(&data1, 50);
883        create_test_parquet(&data2, 50);
884
885        cmd_fed_manifest(&data1, &manifest1, "node1", 0.8, 42, "json").unwrap();
886        cmd_fed_manifest(&data2, &manifest2, "node2", 0.8, 42, "json").unwrap();
887
888        let manifests = vec![manifest1, manifest2];
889        let result = cmd_fed_plan(&manifests, &output, "iid", 0.8, 42, None, "json");
890        assert!(result.is_ok());
891    }
892}