1use std::path::{Path, PathBuf};
4
5use clap::Subcommand;
6
7use super::basic::load_dataset;
8use crate::{
9 federated::{
10 FederatedSplitCoordinator, FederatedSplitStrategy, NodeSplitInstruction, NodeSplitManifest,
11 },
12 split::DatasetSplit,
13 Dataset,
14};
15
16#[derive(Subcommand)]
18pub enum FedCommands {
19 Manifest {
21 input: PathBuf,
23 #[arg(short, long)]
25 output: PathBuf,
26 #[arg(short, long)]
28 node_id: String,
29 #[arg(short = 'r', long, default_value = "0.8")]
31 train_ratio: f64,
32 #[arg(short, long, default_value = "42")]
34 seed: u64,
35 #[arg(short, long, default_value = "json")]
37 format: String,
38 },
39 Plan {
41 #[arg(required = true)]
43 manifests: Vec<PathBuf>,
44 #[arg(short, long)]
46 output: PathBuf,
47 #[arg(short, long, default_value = "local")]
49 strategy: String,
50 #[arg(short = 'r', long, default_value = "0.8")]
52 train_ratio: f64,
53 #[arg(long, default_value = "42")]
55 seed: u64,
56 #[arg(long)]
58 stratify_column: Option<String>,
59 #[arg(short, long, default_value = "json")]
61 format: String,
62 },
63 Split {
65 input: PathBuf,
67 #[arg(short, long)]
69 plan: PathBuf,
70 #[arg(short, long)]
72 node_id: String,
73 #[arg(long)]
75 train_output: PathBuf,
76 #[arg(long)]
78 test_output: PathBuf,
79 #[arg(long)]
81 validation_output: Option<PathBuf>,
82 },
83 Verify {
85 #[arg(required = true)]
87 manifests: Vec<PathBuf>,
88 #[arg(short, long, default_value = "text")]
90 format: String,
91 },
92}
93
94pub(crate) fn parse_fed_strategy(
96 strategy: &str,
97 train_ratio: f64,
98 seed: u64,
99 stratify_column: Option<&str>,
100) -> Option<FederatedSplitStrategy> {
101 match strategy.to_lowercase().as_str() {
102 "local" | "local-seed" => Some(FederatedSplitStrategy::LocalWithSeed { seed, train_ratio }),
103 "proportional" | "iid" => Some(FederatedSplitStrategy::ProportionalIID { train_ratio }),
104 "stratified" => {
105 let column = stratify_column.unwrap_or("label").to_string();
106 Some(FederatedSplitStrategy::GlobalStratified {
107 label_column: column,
108 target_distribution: std::collections::HashMap::new(),
109 })
110 }
111 _ => None,
112 }
113}
114
115pub(crate) fn cmd_fed_manifest(
117 input: &Path,
118 output: &Path,
119 node_id: &str,
120 train_ratio: f64,
121 seed: u64,
122 format: &str,
123) -> crate::Result<()> {
124 let dataset = load_dataset(input)?;
125
126 let split =
128 DatasetSplit::from_ratios(&dataset, train_ratio, 1.0 - train_ratio, None, Some(seed))?;
129
130 let manifest = NodeSplitManifest::from_split(node_id, &split);
131
132 match format {
133 "binary" | "bin" => {
134 let bytes =
135 rmp_serde::to_vec(&manifest).map_err(|e| crate::Error::Format(e.to_string()))?;
136 std::fs::write(output, bytes).map_err(|e| crate::Error::io(e, output))?;
137 }
138 _ => {
139 let json = serde_json::to_string_pretty(&manifest)
140 .map_err(|e| crate::Error::Format(e.to_string()))?;
141 std::fs::write(output, json).map_err(|e| crate::Error::io(e, output))?;
142 }
143 }
144
145 println!(
146 "Created manifest for node '{}' ({} rows) -> {}",
147 node_id,
148 dataset.len(),
149 output.display()
150 );
151
152 Ok(())
153}
154
155pub(crate) fn load_manifest(path: &Path) -> crate::Result<NodeSplitManifest> {
157 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
158
159 match ext {
160 "bin" | "binary" => {
161 let bytes = std::fs::read(path).map_err(|e| crate::Error::io(e, path))?;
162 rmp_serde::from_slice(&bytes)
163 .map_err(|e| crate::Error::Format(format!("Invalid manifest binary: {}", e)))
164 }
165 _ => {
166 let json = std::fs::read_to_string(path).map_err(|e| crate::Error::io(e, path))?;
167 serde_json::from_str(&json)
168 .map_err(|e| crate::Error::Format(format!("Invalid manifest JSON: {}", e)))
169 }
170 }
171}
172
173#[allow(clippy::too_many_arguments)]
175pub(crate) fn cmd_fed_plan(
176 manifests: &[PathBuf],
177 output: &Path,
178 strategy: &str,
179 train_ratio: f64,
180 seed: u64,
181 stratify_column: Option<&str>,
182 format: &str,
183) -> crate::Result<()> {
184 if manifests.is_empty() {
185 return Err(crate::Error::invalid_config("No manifests provided"));
186 }
187
188 let loaded: Vec<NodeSplitManifest> = manifests
189 .iter()
190 .map(|p| load_manifest(p))
191 .collect::<Result<Vec<_>, _>>()?;
192
193 let strategy =
194 parse_fed_strategy(strategy, train_ratio, seed, stratify_column).ok_or_else(|| {
195 crate::Error::invalid_config(format!(
196 "Unknown strategy: {}. Use 'local', 'proportional', or 'stratified'",
197 strategy
198 ))
199 })?;
200
201 let coordinator = FederatedSplitCoordinator::new(strategy);
202 let instructions = coordinator.compute_split_plan(&loaded)?;
203
204 match format {
205 "binary" | "bin" => {
206 let bytes = rmp_serde::to_vec(&instructions)
207 .map_err(|e| crate::Error::Format(e.to_string()))?;
208 std::fs::write(output, bytes).map_err(|e| crate::Error::io(e, output))?;
209 }
210 _ => {
211 let json = serde_json::to_string_pretty(&instructions)
212 .map_err(|e| crate::Error::Format(e.to_string()))?;
213 std::fs::write(output, json).map_err(|e| crate::Error::io(e, output))?;
214 }
215 }
216
217 println!(
218 "Created split plan for {} nodes -> {}",
219 instructions.len(),
220 output.display()
221 );
222
223 Ok(())
224}
225
226pub(crate) fn load_plan(path: &Path) -> crate::Result<Vec<NodeSplitInstruction>> {
228 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
229
230 match ext {
231 "bin" | "binary" => {
232 let bytes = std::fs::read(path).map_err(|e| crate::Error::io(e, path))?;
233 rmp_serde::from_slice(&bytes)
234 .map_err(|e| crate::Error::Format(format!("Invalid plan binary: {}", e)))
235 }
236 _ => {
237 let json = std::fs::read_to_string(path).map_err(|e| crate::Error::io(e, path))?;
238 serde_json::from_str(&json)
239 .map_err(|e| crate::Error::Format(format!("Invalid plan JSON: {}", e)))
240 }
241 }
242}
243
244pub(crate) fn cmd_fed_split(
246 input: &Path,
247 plan: &Path,
248 node_id: &str,
249 train_output: &Path,
250 test_output: &Path,
251 validation_output: Option<&PathBuf>,
252) -> crate::Result<()> {
253 let dataset = load_dataset(input)?;
254 let instructions = load_plan(plan)?;
255
256 let instruction = instructions
258 .iter()
259 .find(|i| i.node_id == node_id)
260 .ok_or_else(|| {
261 crate::Error::invalid_config(format!(
262 "No instruction found for node '{}' in plan",
263 node_id
264 ))
265 })?;
266
267 let split = FederatedSplitCoordinator::execute_local_split(&dataset, instruction)?;
269
270 split.train.to_parquet(train_output)?;
272 split.test.to_parquet(test_output)?;
273
274 if let (Some(val_output), Some(val_data)) = (validation_output, &split.validation) {
275 val_data.to_parquet(val_output)?;
276 }
277
278 println!(
279 "Split executed for node '{}': {} train, {} test{}",
280 node_id,
281 split.train.len(),
282 split.test.len(),
283 split
284 .validation
285 .as_ref()
286 .map_or(String::new(), |v| format!(", {} validation", v.len()))
287 );
288
289 Ok(())
290}
291
292pub(crate) fn cmd_fed_verify(manifests: &[PathBuf], format: &str) -> crate::Result<()> {
294 if manifests.is_empty() {
295 return Err(crate::Error::invalid_config("No manifests provided"));
296 }
297
298 let loaded: Vec<NodeSplitManifest> = manifests
299 .iter()
300 .map(|p| load_manifest(p))
301 .collect::<Result<Vec<_>, _>>()?;
302
303 let report = FederatedSplitCoordinator::verify_global_split(&loaded)?;
304
305 if format == "json" {
306 let json = serde_json::json!({
307 "total_rows": report.total_rows,
308 "total_train_rows": report.total_train_rows,
309 "total_test_rows": report.total_test_rows,
310 "total_validation_rows": report.total_validation_rows,
311 "effective_train_ratio": report.effective_train_ratio,
312 "effective_test_ratio": report.effective_test_ratio,
313 "effective_validation_ratio": report.effective_validation_ratio,
314 "quality_passed": report.quality_passed,
315 "issues": report.issues.iter().map(|i| format!("{:?}", i)).collect::<Vec<_>>(),
316 "node_summaries": report.node_summaries.iter().map(|n| {
317 serde_json::json!({
318 "node_id": n.node_id,
319 "contribution_ratio": n.contribution_ratio,
320 "train_ratio": n.train_ratio,
321 "test_ratio": n.test_ratio,
322 })
323 }).collect::<Vec<_>>()
324 });
325
326 let json_str =
327 serde_json::to_string_pretty(&json).map_err(|e| crate::Error::Format(e.to_string()))?;
328 println!("{}", json_str);
329 } else {
330 println!("Federated Split Verification");
332 println!("============================");
333 println!();
334 println!("Global Statistics:");
335 println!(" Total rows: {}", report.total_rows);
336 println!(
337 " Train rows: {} ({:.1}%)",
338 report.total_train_rows,
339 report.effective_train_ratio * 100.0
340 );
341 println!(
342 " Test rows: {} ({:.1}%)",
343 report.total_test_rows,
344 report.effective_test_ratio * 100.0
345 );
346 if let Some(val) = report.total_validation_rows {
347 println!(
348 " Validation rows: {} ({:.1}%)",
349 val,
350 report.effective_validation_ratio.unwrap_or(0.0) * 100.0
351 );
352 }
353 println!();
354
355 println!("Node Summaries:");
356 println!(
357 "{:<15} {:>12} {:>10} {:>10}",
358 "NODE", "CONTRIBUTION", "TRAIN", "TEST"
359 );
360 println!("{}", "-".repeat(50));
361
362 for summary in &report.node_summaries {
363 println!(
364 "{:<15} {:>11.1}% {:>9.1}% {:>9.1}%",
365 summary.node_id,
366 summary.contribution_ratio * 100.0,
367 summary.train_ratio * 100.0,
368 summary.test_ratio * 100.0
369 );
370 }
371
372 println!();
373 if report.quality_passed {
374 println!("\u{2713} Quality check passed");
375 } else {
376 println!("\u{26A0} Quality issues detected:");
377 for issue in &report.issues {
378 println!(" - {:?}", issue);
379 }
380 }
381 }
382
383 Ok(())
384}
385
386#[cfg(test)]
387#[allow(
388 clippy::cast_possible_truncation,
389 clippy::cast_possible_wrap,
390 clippy::cast_precision_loss,
391 clippy::uninlined_format_args,
392 clippy::unwrap_used,
393 clippy::expect_used,
394 clippy::redundant_clone,
395 clippy::cast_lossless,
396 clippy::redundant_closure_for_method_calls,
397 clippy::too_many_lines,
398 clippy::float_cmp,
399 clippy::similar_names,
400 clippy::needless_late_init,
401 clippy::redundant_pattern_matching
402)]
403mod tests {
404 use std::sync::Arc;
405
406 use arrow::{
407 array::{Int32Array, StringArray},
408 datatypes::{DataType, Field, Schema},
409 };
410
411 use super::*;
412 use crate::ArrowDataset;
413
414 fn create_test_parquet(path: &Path, rows: usize) {
415 let schema = Arc::new(Schema::new(vec![
416 Field::new("id", DataType::Int32, false),
417 Field::new("name", DataType::Utf8, false),
418 ]));
419
420 let ids: Vec<i32> = (0..rows as i32).collect();
421 let names: Vec<String> = ids.iter().map(|i| format!("item_{}", i)).collect();
422
423 let batch = arrow::array::RecordBatch::try_new(
424 schema,
425 vec![
426 Arc::new(Int32Array::from(ids)),
427 Arc::new(StringArray::from(names)),
428 ],
429 )
430 .ok()
431 .unwrap_or_else(|| panic!("Should create batch"));
432
433 let dataset = ArrowDataset::from_batch(batch)
434 .ok()
435 .unwrap_or_else(|| panic!("Should create dataset"));
436
437 dataset
438 .to_parquet(path)
439 .ok()
440 .unwrap_or_else(|| panic!("Should write parquet"));
441 }
442
443 #[test]
444 fn test_parse_fed_strategy() {
445 assert!(matches!(
446 parse_fed_strategy("local", 0.8, 42, None),
447 Some(_)
448 ));
449 assert!(matches!(
450 parse_fed_strategy("proportional", 0.8, 42, None),
451 Some(_)
452 ));
453 assert!(matches!(
454 parse_fed_strategy("stratified", 0.8, 42, Some("label")),
455 Some(_)
456 ));
457 assert!(parse_fed_strategy("invalid", 0.8, 42, None).is_none());
458 }
459
460 #[test]
461 fn test_cmd_fed_manifest_basic() {
462 let temp_dir = tempfile::tempdir()
463 .ok()
464 .unwrap_or_else(|| panic!("Should create temp dir"));
465 let data_path = temp_dir.path().join("data.parquet");
466 let manifest_path = temp_dir.path().join("manifest.json");
467 create_test_parquet(&data_path, 100);
468
469 let result = cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json");
470 assert!(result.is_ok());
471 assert!(manifest_path.exists());
472
473 let content = std::fs::read_to_string(&manifest_path)
475 .ok()
476 .unwrap_or_else(|| panic!("Should read file"));
477 let parsed: serde_json::Value = serde_json::from_str(&content)
478 .ok()
479 .unwrap_or_else(|| panic!("Should parse JSON"));
480 assert_eq!(
481 parsed.get("node_id").and_then(|v| v.as_str()),
482 Some("node-1")
483 );
484 assert_eq!(parsed.get("total_rows").and_then(|v| v.as_u64()), Some(100));
485 }
486
487 #[test]
488 fn test_cmd_fed_manifest_binary() {
489 let temp_dir = tempfile::tempdir()
490 .ok()
491 .unwrap_or_else(|| panic!("Should create temp dir"));
492 let data_path = temp_dir.path().join("data.parquet");
493 let manifest_path = temp_dir.path().join("manifest.bin");
494 create_test_parquet(&data_path, 50);
495
496 let result = cmd_fed_manifest(&data_path, &manifest_path, "node-2", 0.8, 42, "binary");
497 assert!(result.is_ok());
498 assert!(manifest_path.exists());
499 }
500
501 #[test]
502 fn test_cmd_fed_plan_local_strategy() {
503 let temp_dir = tempfile::tempdir()
504 .ok()
505 .unwrap_or_else(|| panic!("Should create temp dir"));
506
507 let data1 = temp_dir.path().join("data1.parquet");
509 let data2 = temp_dir.path().join("data2.parquet");
510 let manifest1 = temp_dir.path().join("manifest1.json");
511 let manifest2 = temp_dir.path().join("manifest2.json");
512 let plan_path = temp_dir.path().join("plan.json");
513
514 create_test_parquet(&data1, 100);
515 create_test_parquet(&data2, 150);
516
517 cmd_fed_manifest(&data1, &manifest1, "node-1", 0.8, 42, "json")
518 .ok()
519 .unwrap_or_else(|| panic!("Should create manifest1"));
520 cmd_fed_manifest(&data2, &manifest2, "node-2", 0.8, 42, "json")
521 .ok()
522 .unwrap_or_else(|| panic!("Should create manifest2"));
523
524 let manifests = vec![manifest1.clone(), manifest2.clone()];
525 let result = cmd_fed_plan(&manifests, &plan_path, "local", 0.8, 42, None, "json");
526 assert!(result.is_ok());
527 assert!(plan_path.exists());
528
529 let content = std::fs::read_to_string(&plan_path)
531 .ok()
532 .unwrap_or_else(|| panic!("Should read file"));
533 let parsed: serde_json::Value = serde_json::from_str(&content)
534 .ok()
535 .unwrap_or_else(|| panic!("Should parse JSON"));
536 let instructions = parsed.as_array();
537 assert!(instructions.is_some());
538 assert_eq!(instructions.map(|a| a.len()), Some(2));
539 }
540
541 #[test]
542 fn test_cmd_fed_plan_empty_manifests_fails() {
543 let temp_dir = tempfile::tempdir()
544 .ok()
545 .unwrap_or_else(|| panic!("Should create temp dir"));
546 let plan_path = temp_dir.path().join("plan.json");
547
548 let manifests: Vec<PathBuf> = vec![];
549 let result = cmd_fed_plan(&manifests, &plan_path, "local", 0.8, 42, None, "json");
550 assert!(result.is_err());
551 }
552
553 #[test]
554 fn test_cmd_fed_split_basic() {
555 let temp_dir = tempfile::tempdir()
556 .ok()
557 .unwrap_or_else(|| panic!("Should create temp dir"));
558
559 let data_path = temp_dir.path().join("data.parquet");
560 let manifest_path = temp_dir.path().join("manifest.json");
561 let plan_path = temp_dir.path().join("plan.json");
562 let train_path = temp_dir.path().join("train.parquet");
563 let test_path = temp_dir.path().join("test.parquet");
564
565 create_test_parquet(&data_path, 100);
566
567 cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
569 .ok()
570 .unwrap_or_else(|| panic!("Should create manifest"));
571
572 let manifests = vec![manifest_path.clone()];
574 cmd_fed_plan(&manifests, &plan_path, "local", 0.8, 42, None, "json")
575 .ok()
576 .unwrap_or_else(|| panic!("Should create plan"));
577
578 let result = cmd_fed_split(
580 &data_path,
581 &plan_path,
582 "node-1",
583 &train_path,
584 &test_path,
585 None,
586 );
587 assert!(result.is_ok());
588 assert!(train_path.exists());
589 assert!(test_path.exists());
590
591 let train_ds = ArrowDataset::from_parquet(&train_path)
593 .ok()
594 .unwrap_or_else(|| panic!("Should load train"));
595 let test_ds = ArrowDataset::from_parquet(&test_path)
596 .ok()
597 .unwrap_or_else(|| panic!("Should load test"));
598
599 assert!(train_ds.len() > 0);
600 assert!(test_ds.len() > 0);
601 assert_eq!(train_ds.len() + test_ds.len(), 100);
602 }
603
604 #[test]
605 fn test_cmd_fed_split_node_not_found() {
606 let temp_dir = tempfile::tempdir()
607 .ok()
608 .unwrap_or_else(|| panic!("Should create temp dir"));
609
610 let data_path = temp_dir.path().join("data.parquet");
611 let manifest_path = temp_dir.path().join("manifest.json");
612 let plan_path = temp_dir.path().join("plan.json");
613 let train_path = temp_dir.path().join("train.parquet");
614 let test_path = temp_dir.path().join("test.parquet");
615
616 create_test_parquet(&data_path, 100);
617
618 cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
619 .ok()
620 .unwrap_or_else(|| panic!("Should create manifest"));
621
622 let manifests = vec![manifest_path.clone()];
623 cmd_fed_plan(&manifests, &plan_path, "local", 0.8, 42, None, "json")
624 .ok()
625 .unwrap_or_else(|| panic!("Should create plan"));
626
627 let result = cmd_fed_split(
629 &data_path,
630 &plan_path,
631 "wrong-node",
632 &train_path,
633 &test_path,
634 None,
635 );
636 assert!(result.is_err());
637 }
638
639 #[test]
640 fn test_cmd_fed_verify_basic() {
641 let temp_dir = tempfile::tempdir()
642 .ok()
643 .unwrap_or_else(|| panic!("Should create temp dir"));
644
645 let data1 = temp_dir.path().join("data1.parquet");
646 let data2 = temp_dir.path().join("data2.parquet");
647 let manifest1 = temp_dir.path().join("manifest1.json");
648 let manifest2 = temp_dir.path().join("manifest2.json");
649
650 create_test_parquet(&data1, 100);
651 create_test_parquet(&data2, 150);
652
653 cmd_fed_manifest(&data1, &manifest1, "node-1", 0.8, 42, "json")
654 .ok()
655 .unwrap_or_else(|| panic!("Should create manifest1"));
656 cmd_fed_manifest(&data2, &manifest2, "node-2", 0.8, 42, "json")
657 .ok()
658 .unwrap_or_else(|| panic!("Should create manifest2"));
659
660 let manifests = vec![manifest1.clone(), manifest2.clone()];
661 let result = cmd_fed_verify(&manifests, "text");
662 assert!(result.is_ok());
663 }
664
665 #[test]
666 fn test_cmd_fed_verify_json_format() {
667 let temp_dir = tempfile::tempdir()
668 .ok()
669 .unwrap_or_else(|| panic!("Should create temp dir"));
670
671 let data_path = temp_dir.path().join("data.parquet");
672 let manifest_path = temp_dir.path().join("manifest.json");
673
674 create_test_parquet(&data_path, 100);
675
676 cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
677 .ok()
678 .unwrap_or_else(|| panic!("Should create manifest"));
679
680 let manifests = vec![manifest_path.clone()];
681 let result = cmd_fed_verify(&manifests, "json");
682 assert!(result.is_ok());
683 }
684
685 #[test]
686 fn test_cmd_fed_verify_empty_manifests_fails() {
687 let manifests: Vec<PathBuf> = vec![];
688 let result = cmd_fed_verify(&manifests, "text");
689 assert!(result.is_err());
690 }
691
692 #[test]
693 fn test_cmd_fed_plan_proportional_strategy() {
694 let temp_dir = tempfile::tempdir()
695 .ok()
696 .unwrap_or_else(|| panic!("Should create temp dir"));
697
698 let data_path = temp_dir.path().join("data.parquet");
699 let manifest_path = temp_dir.path().join("manifest.json");
700 let plan_path = temp_dir.path().join("plan.json");
701
702 create_test_parquet(&data_path, 100);
703
704 cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
705 .ok()
706 .unwrap_or_else(|| panic!("Should create manifest"));
707
708 let manifests = vec![manifest_path.clone()];
709 let result = cmd_fed_plan(
710 &manifests,
711 &plan_path,
712 "proportional",
713 0.7,
714 42,
715 None,
716 "json",
717 );
718 assert!(result.is_ok());
719
720 let content = std::fs::read_to_string(&plan_path)
722 .ok()
723 .unwrap_or_else(|| panic!("Should read file"));
724 let parsed: serde_json::Value = serde_json::from_str(&content)
725 .ok()
726 .unwrap_or_else(|| panic!("Should parse JSON"));
727 let instructions = parsed
728 .as_array()
729 .unwrap_or_else(|| panic!("Should be array"));
730 let train_ratio = instructions[0].get("train_ratio").and_then(|v| v.as_f64());
731 assert!((train_ratio.unwrap_or(0.0) - 0.7).abs() < 0.01);
732 }
733
734 #[test]
735 fn test_load_manifest_invalid_json() {
736 let temp_dir = tempfile::tempdir()
737 .ok()
738 .unwrap_or_else(|| panic!("Should create temp dir"));
739 let manifest_path = temp_dir.path().join("invalid.json");
740
741 std::fs::write(&manifest_path, "not valid json")
742 .ok()
743 .unwrap_or_else(|| panic!("Should write file"));
744
745 let result = load_manifest(&manifest_path);
746 assert!(result.is_err());
747 }
748
749 #[test]
750 fn test_load_plan_invalid_json() {
751 let temp_dir = tempfile::tempdir()
752 .ok()
753 .unwrap_or_else(|| panic!("Should create temp dir"));
754 let plan_path = temp_dir.path().join("invalid.json");
755
756 std::fs::write(&plan_path, "{ broken }")
757 .ok()
758 .unwrap_or_else(|| panic!("Should write file"));
759
760 let result = load_plan(&plan_path);
761 assert!(result.is_err());
762 }
763
764 #[test]
765 fn test_cmd_fed_plan_invalid_strategy() {
766 let temp_dir = tempfile::tempdir()
767 .ok()
768 .unwrap_or_else(|| panic!("Should create temp dir"));
769
770 let data_path = temp_dir.path().join("data.parquet");
771 let manifest_path = temp_dir.path().join("manifest.json");
772 let plan_path = temp_dir.path().join("plan.json");
773
774 create_test_parquet(&data_path, 100);
775
776 cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
777 .ok()
778 .unwrap_or_else(|| panic!("Should create manifest"));
779
780 let manifests = vec![manifest_path.clone()];
781 let result = cmd_fed_plan(
782 &manifests,
783 &plan_path,
784 "invalid_strategy",
785 0.8,
786 42,
787 None,
788 "json",
789 );
790 assert!(result.is_err());
791 }
792
793 #[test]
794 fn test_cmd_fed_plan_stratified_strategy() {
795 let temp_dir = tempfile::tempdir()
796 .ok()
797 .unwrap_or_else(|| panic!("Should create temp dir"));
798
799 let data_path = temp_dir.path().join("data.parquet");
800 let manifest_path = temp_dir.path().join("manifest.json");
801 let plan_path = temp_dir.path().join("plan.json");
802
803 create_test_parquet(&data_path, 100);
804
805 cmd_fed_manifest(&data_path, &manifest_path, "node-1", 0.8, 42, "json")
806 .ok()
807 .unwrap_or_else(|| panic!("Should create manifest"));
808
809 let manifests = vec![manifest_path.clone()];
810 let result = cmd_fed_plan(
811 &manifests,
812 &plan_path,
813 "stratified",
814 0.8,
815 42,
816 Some("name"),
817 "json",
818 );
819 assert!(result.is_ok());
820 }
821
822 #[test]
823 fn test_cmd_fed_verify_with_quality_issues() {
824 let temp_dir = tempfile::tempdir()
825 .ok()
826 .unwrap_or_else(|| panic!("Should create temp dir"));
827
828 let data_path = temp_dir.path().join("small.parquet");
829 let manifest_path = temp_dir.path().join("manifest.json");
830
831 create_test_parquet(&data_path, 15);
832
833 cmd_fed_manifest(&data_path, &manifest_path, "small-node", 0.8, 42, "json")
834 .ok()
835 .unwrap_or_else(|| panic!("Should create manifest"));
836
837 let manifests = vec![manifest_path.clone()];
838 let result = cmd_fed_verify(&manifests, "text");
839 assert!(result.is_ok());
840 }
841
842 #[test]
843 fn test_parse_fed_strategy_iid() {
844 let result = parse_fed_strategy("iid", 0.8, 42, None);
845 assert!(result.is_some());
846 }
847
848 #[test]
849 fn test_parse_fed_strategy_proportional() {
850 let result = parse_fed_strategy("proportional", 0.8, 42, None);
851 assert!(result.is_some());
852 }
853
854 #[test]
855 fn test_parse_fed_strategy_local() {
856 let result = parse_fed_strategy("local", 0.8, 42, None);
857 assert!(result.is_some());
858 }
859
860 #[test]
861 fn test_parse_fed_strategy_stratified() {
862 let result = parse_fed_strategy("stratified", 0.8, 42, Some("label"));
863 assert!(result.is_some());
864 }
865
866 #[test]
867 fn test_parse_fed_strategy_unknown() {
868 assert!(parse_fed_strategy("invalid", 0.8, 42, None).is_none());
869 }
870
871 #[test]
872 fn test_cmd_fed_plan() {
873 let temp_dir = tempfile::tempdir()
874 .ok()
875 .unwrap_or_else(|| panic!("Should create temp dir"));
876 let manifest1 = temp_dir.path().join("node1.json");
877 let manifest2 = temp_dir.path().join("node2.json");
878 let data1 = temp_dir.path().join("data1.parquet");
879 let data2 = temp_dir.path().join("data2.parquet");
880 let output = temp_dir.path().join("plan.json");
881
882 create_test_parquet(&data1, 50);
883 create_test_parquet(&data2, 50);
884
885 cmd_fed_manifest(&data1, &manifest1, "node1", 0.8, 42, "json").unwrap();
886 cmd_fed_manifest(&data2, &manifest2, "node2", 0.8, 42, "json").unwrap();
887
888 let manifests = vec![manifest1, manifest2];
889 let result = cmd_fed_plan(&manifests, &output, "iid", 0.8, 42, None, "json");
890 assert!(result.is_ok());
891 }
892}