1use super::cluster::{ClusterConfig, NodeConfig, Transport};
11use super::placement::PlacementDecision;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CheckpointMetadata {
19 pub adapter_idx: usize,
20 pub epoch: usize,
21 pub avg_loss: f32,
22 #[serde(default)]
23 pub val_loss: Option<f32>,
24 #[serde(default)]
25 pub node_name: Option<String>,
26 #[serde(default)]
27 pub timestamp: Option<String>,
28}
29
30#[derive(Debug, Clone)]
32pub struct AdapterStatus {
33 pub adapter_idx: usize,
35 pub node_name: String,
37 pub checkpoint_dir: PathBuf,
39 pub latest: Option<CheckpointMetadata>,
41}
42
43#[derive(Debug, Clone)]
45pub struct LeaderboardEntry {
46 pub rank: usize,
47 pub adapter_idx: usize,
48 pub node_name: String,
49 pub epoch: usize,
50 pub loss: f32,
51}
52
53pub struct CheckpointCoordinator {
55 pub adapters: HashMap<usize, AdapterStatus>,
57 pub poll_interval_secs: u64,
59 cluster: ClusterConfig,
61}
62
63impl CheckpointCoordinator {
64 pub fn new(
66 cluster: ClusterConfig,
67 placements: &[PlacementDecision],
68 checkpoint_dirs: &HashMap<usize, PathBuf>,
69 poll_interval_secs: u64,
70 ) -> Self {
71 let mut adapters = HashMap::new();
72 for p in placements {
73 let checkpoint_dir = checkpoint_dirs
74 .get(&p.adapter_idx)
75 .cloned()
76 .unwrap_or_else(|| PathBuf::from(format!("checkpoints/adapter-{}", p.adapter_idx)));
77 adapters.insert(
78 p.adapter_idx,
79 AdapterStatus {
80 adapter_idx: p.adapter_idx,
81 node_name: p.node_name.clone(),
82 checkpoint_dir,
83 latest: None,
84 },
85 );
86 }
87 Self { adapters, poll_interval_secs, cluster }
88 }
89
90 pub fn poll_all(&mut self) -> Vec<PollResult> {
95 let mut results = Vec::new();
96 let adapter_list: Vec<(usize, String, PathBuf)> = self
97 .adapters
98 .values()
99 .map(|a| (a.adapter_idx, a.node_name.clone(), a.checkpoint_dir.clone()))
100 .collect();
101
102 for (idx, node_name, checkpoint_dir) in adapter_list {
103 let result = self.poll_adapter(idx, &node_name, &checkpoint_dir);
104 results.push(result);
105 }
106 results
107 }
108
109 fn poll_adapter(
110 &mut self,
111 adapter_idx: usize,
112 node_name: &str,
113 checkpoint_dir: &Path,
114 ) -> PollResult {
115 let node = self.cluster.find_node(node_name);
116 let transport = node.map_or(Transport::Local, |n| n.transport);
117
118 let metadata = match transport {
119 Transport::Local => read_local_metadata(checkpoint_dir),
120 Transport::Ssh => {
121 let host = node.map_or("unknown", |n| &n.host);
122 let user = node.and_then(|n| n.user.as_deref());
123 read_ssh_metadata(host, user, checkpoint_dir)
124 }
125 };
126
127 match metadata {
128 Ok(meta) => {
129 if let Some(status) = self.adapters.get_mut(&adapter_idx) {
130 status.latest = Some(meta.clone());
131 }
132 PollResult::Ok { adapter_idx, metadata: meta }
133 }
134 Err(e) => PollResult::Error { adapter_idx, node_name: node_name.to_string(), error: e },
135 }
136 }
137
138 pub fn leaderboard(&self) -> Vec<LeaderboardEntry> {
140 let mut entries: Vec<_> = self
141 .adapters
142 .values()
143 .filter_map(|a| {
144 a.latest.as_ref().map(|meta| LeaderboardEntry {
145 rank: 0,
146 adapter_idx: a.adapter_idx,
147 node_name: a.node_name.clone(),
148 epoch: meta.epoch,
149 loss: meta.val_loss.unwrap_or(meta.avg_loss),
150 })
151 })
152 .collect();
153
154 entries.sort_by(|a, b| a.loss.partial_cmp(&b.loss).unwrap_or(std::cmp::Ordering::Equal));
155 for (i, entry) in entries.iter_mut().enumerate() {
156 entry.rank = i + 1;
157 }
158 entries
159 }
160
161 pub fn best_adapter(&self) -> Option<&AdapterStatus> {
163 let board = self.leaderboard();
164 board.first().and_then(|entry| self.adapters.get(&entry.adapter_idx))
165 }
166
167 pub fn format_leaderboard(&self) -> String {
169 let board = self.leaderboard();
170 if board.is_empty() {
171 return "No checkpoints available yet.".to_string();
172 }
173 let mut out = String::from("Adapter Leaderboard:\n");
174 out.push_str(" Rank | Adapter | Node | Epoch | Loss\n");
175 out.push_str(" -----+---------+------------+-------+--------\n");
176 for entry in &board {
177 out.push_str(&format!(
178 " {:>4} | {:>7} | {:<10} | {:>5} | {:.4}\n",
179 entry.rank, entry.adapter_idx, entry.node_name, entry.epoch, entry.loss
180 ));
181 }
182 out
183 }
184}
185
186#[derive(Debug)]
188pub enum PollResult {
189 Ok { adapter_idx: usize, metadata: CheckpointMetadata },
190 Error { adapter_idx: usize, node_name: String, error: String },
191}
192
193fn read_local_metadata(checkpoint_dir: &Path) -> Result<CheckpointMetadata, String> {
195 let best_meta = checkpoint_dir.join("best").join("metadata.json");
196 let contents = std::fs::read_to_string(&best_meta)
197 .map_err(|e| format!("failed to read {}: {e}", best_meta.display()))?;
198 serde_json::from_str(&contents)
199 .map_err(|e| format!("failed to parse {}: {e}", best_meta.display()))
200}
201
202fn read_ssh_metadata(
207 host: &str,
208 user: Option<&str>,
209 checkpoint_dir: &Path,
210) -> Result<CheckpointMetadata, String> {
211 let remote_path = checkpoint_dir.join("best").join("metadata.json");
212 let cat_cmd = format!("cat {}", remote_path.display());
213 let output = exec_ssh_command(host, user, &cat_cmd)?;
214 serde_json::from_str(&output).map_err(|e| format!("failed to parse metadata from {host}: {e}"))
215}
216
217fn exec_ssh_command(host: &str, user: Option<&str>, script: &str) -> Result<String, String> {
223 let mut cmd = std::process::Command::new("ssh");
224 cmd.args(["-o", "ConnectTimeout=5"]);
225 cmd.args(["-o", "BatchMode=yes"]);
226 cmd.args(["-o", "StrictHostKeyChecking=accept-new"]);
227
228 if let Some(u) = user {
229 cmd.args(["-l", u]);
230 }
231
232 cmd.arg(host);
233 cmd.arg("bash");
234
235 cmd.stdin(std::process::Stdio::piped());
236 cmd.stdout(std::process::Stdio::piped());
237 cmd.stderr(std::process::Stdio::piped());
238
239 let mut child = cmd.spawn().map_err(|e| format!("failed to spawn ssh to {host}: {e}"))?;
240
241 if let Some(stdin) = child.stdin.take() {
243 use std::io::Write;
244 let mut stdin = stdin;
245 let _ = stdin.write_all(script.as_bytes());
246 }
248
249 let output = child.wait_with_output().map_err(|e| format!("ssh to {host} failed: {e}"))?;
250
251 if !output.status.success() {
252 let stderr = String::from_utf8_lossy(&output.stderr);
253 return Err(format!(
254 "ssh to {host} exited {}: {stderr}",
255 output.status.code().unwrap_or(-1)
256 ));
257 }
258
259 String::from_utf8(output.stdout).map_err(|e| format!("invalid UTF-8 from ssh to {host}: {e}"))
260}
261
262pub fn exec_launch(
269 node: &NodeConfig,
270 model_path: &Path,
271 data_path: &Path,
272 checkpoint_dir: &Path,
273 rank: u32,
274 epochs: u32,
275) -> Result<std::process::Child, String> {
276 let script = format!(
277 "apr finetune {} --task instruct --method qlora --quantize-nf4 \
278 --data {} --output {} --rank {rank} --epochs {epochs}",
279 model_path.display(),
280 data_path.display(),
281 checkpoint_dir.display(),
282 );
283
284 match node.transport {
285 Transport::Local => std::process::Command::new("bash")
286 .arg("-c")
287 .arg(&script)
288 .stdin(std::process::Stdio::null())
289 .stdout(std::process::Stdio::piped())
290 .stderr(std::process::Stdio::piped())
291 .spawn()
292 .map_err(|e| format!("failed to launch local training: {e}")),
293 Transport::Ssh => {
294 let mut cmd = std::process::Command::new("ssh");
295 cmd.args(["-o", "ConnectTimeout=5"]);
296 cmd.args(["-o", "BatchMode=yes"]);
297 cmd.args(["-o", "StrictHostKeyChecking=accept-new"]);
298 if let Some(ref u) = node.user {
299 cmd.args(["-l", u]);
300 }
301 cmd.arg(&node.host);
302 cmd.arg("bash");
303 cmd.stdin(std::process::Stdio::piped());
304 cmd.stdout(std::process::Stdio::piped());
305 cmd.stderr(std::process::Stdio::piped());
306
307 let mut child =
308 cmd.spawn().map_err(|e| format!("failed to ssh to {}: {e}", node.host))?;
309
310 if let Some(stdin) = child.stdin.take() {
311 use std::io::Write;
312 let mut stdin = stdin;
313 let _ = stdin.write_all(script.as_bytes());
314 }
315
316 Ok(child)
317 }
318 }
319}
320
321pub fn build_launch_command(
326 node: &NodeConfig,
327 model_path: &Path,
328 data_path: &Path,
329 checkpoint_dir: &Path,
330 rank: u32,
331 epochs: u32,
332) -> String {
333 let base = format!(
334 "apr finetune {} --task instruct --method qlora --quantize-nf4 \
335 --data {} --output {} --rank {rank} --epochs {epochs}",
336 model_path.display(),
337 data_path.display(),
338 checkpoint_dir.display(),
339 );
340
341 match node.transport {
342 Transport::Local => base,
343 Transport::Ssh => {
344 let user_prefix = node.user.as_ref().map_or_else(String::new, |u| format!("{u}@"));
345 format!("ssh {user_prefix}{} '{base}'", node.host)
346 }
347 }
348}
349
350#[derive(Debug, Clone)]
352pub struct NodeHealth {
353 pub node_name: String,
355 pub reachable: bool,
357 pub apr_version: Option<String>,
359 pub error: Option<String>,
361}
362
363pub fn check_cluster_health(cluster: &ClusterConfig) -> Vec<NodeHealth> {
368 cluster.nodes.iter().map(check_node_health).collect()
369}
370
371fn check_node_health(node: &NodeConfig) -> NodeHealth {
372 let script = "apr --version 2>/dev/null || echo 'apr: not found'";
373 let result = match node.transport {
374 Transport::Local => std::process::Command::new("bash")
375 .arg("-c")
376 .arg(script)
377 .output()
378 .map_err(|e| format!("failed to check local health: {e}"))
379 .and_then(|out| {
380 String::from_utf8(out.stdout).map_err(|e| format!("invalid UTF-8: {e}"))
381 }),
382 Transport::Ssh => exec_ssh_command(&node.host, node.user.as_deref(), script),
383 };
384
385 match result {
386 Ok(output) => {
387 let trimmed = output.trim().to_string();
388 let has_apr = !trimmed.contains("not found") && !trimmed.is_empty();
389 NodeHealth {
390 node_name: node.name.clone(),
391 reachable: true,
392 apr_version: if has_apr { Some(trimmed) } else { None },
393 error: if has_apr { None } else { Some("apr CLI not found on node".to_string()) },
394 }
395 }
396 Err(e) => NodeHealth {
397 node_name: node.name.clone(),
398 reachable: false,
399 apr_version: None,
400 error: Some(e),
401 },
402 }
403}
404
405impl CheckpointCoordinator {
406 pub fn pull_best_checkpoint(&self, dest: &Path) -> Result<PathBuf, String> {
413 let best =
414 self.best_adapter().ok_or_else(|| "no adapters with checkpoint data".to_string())?;
415
416 let node = self
417 .cluster
418 .find_node(&best.node_name)
419 .ok_or_else(|| format!("node '{}' not found in cluster", best.node_name))?;
420
421 let source_dir = best.checkpoint_dir.join("best");
422 let dest_dir = dest.join(format!("adapter-{}-best", best.adapter_idx));
423
424 match node.transport {
425 Transport::Local => {
426 copy_dir_recursive(&source_dir, &dest_dir)?;
427 Ok(dest_dir)
428 }
429 Transport::Ssh => {
430 std::fs::create_dir_all(&dest_dir)
431 .map_err(|e| format!("failed to create {}: {e}", dest_dir.display()))?;
432 let user_prefix = node.user.as_ref().map_or_else(String::new, |u| format!("{u}@"));
433 let remote = format!("{user_prefix}{}:{}/", node.host, source_dir.display());
434 let output = std::process::Command::new("scp")
435 .args(["-r", "-o", "ConnectTimeout=10", "-o", "BatchMode=yes"])
436 .arg(&remote)
437 .arg(dest_dir.to_str().unwrap_or("."))
438 .output()
439 .map_err(|e| format!("failed to run scp: {e}"))?;
440
441 if !output.status.success() {
442 let stderr = String::from_utf8_lossy(&output.stderr);
443 return Err(format!("scp failed: {stderr}"));
444 }
445 Ok(dest_dir)
446 }
447 }
448 }
449}
450
451fn copy_dir_recursive(src: &Path, dst: &Path) -> Result<(), String> {
453 std::fs::create_dir_all(dst).map_err(|e| format!("failed to create {}: {e}", dst.display()))?;
454 let entries =
455 std::fs::read_dir(src).map_err(|e| format!("failed to read {}: {e}", src.display()))?;
456 for entry in entries {
457 let entry = entry.map_err(|e| format!("failed to read entry: {e}"))?;
458 let dest_path = dst.join(entry.file_name());
459 if entry.path().is_dir() {
460 copy_dir_recursive(&entry.path(), &dest_path)?;
461 } else {
462 std::fs::copy(entry.path(), &dest_path)
463 .map_err(|e| format!("failed to copy {}: {e}", entry.path().display()))?;
464 }
465 }
466 Ok(())
467}
468
469#[cfg(test)]
470mod tests {
471 #![allow(clippy::unwrap_used)]
472 use super::*;
473 use std::fs;
474
475 fn test_cluster() -> ClusterConfig {
476 ClusterConfig::from_yaml(
477 r"
478nodes:
479 - name: desktop
480 host: localhost
481 gpus:
482 - uuid: GPU-abcd-1234
483 type: rtx-4090
484 vram_mb: 24564
485 max_adapters: 3
486 - name: jetson
487 host: jetson.local
488 transport: ssh
489 gpus:
490 - uuid: GPU-efgh-5678
491 type: jetson-orin
492 vram_mb: 8192
493 memory_type: unified
494 max_adapters: 1
495",
496 )
497 .expect("valid")
498 }
499
500 fn test_placements() -> Vec<PlacementDecision> {
501 vec![
502 PlacementDecision { adapter_idx: 0, node_name: "desktop".to_string(), score: 2.5 },
503 PlacementDecision { adapter_idx: 1, node_name: "desktop".to_string(), score: 1.2 },
504 PlacementDecision { adapter_idx: 2, node_name: "jetson".to_string(), score: 0.3 },
505 ]
506 }
507
508 #[test]
509 fn test_coordinator_creation() {
510 let cluster = test_cluster();
511 let placements = test_placements();
512 let dirs = HashMap::new();
513 let coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 300);
514 assert_eq!(coord.adapters.len(), 3);
515 assert_eq!(coord.poll_interval_secs, 300);
516 }
517
518 #[test]
519 fn test_empty_leaderboard() {
520 let cluster = test_cluster();
521 let coord = CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
522 let board = coord.leaderboard();
523 assert!(board.is_empty());
524 assert!(coord.best_adapter().is_none());
525 }
526
527 #[test]
528 fn test_leaderboard_with_data() {
529 let cluster = test_cluster();
530 let mut coord =
531 CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
532
533 coord.adapters.get_mut(&0).expect("valid").latest = Some(CheckpointMetadata {
535 adapter_idx: 0,
536 epoch: 3,
537 avg_loss: 0.5,
538 val_loss: Some(0.45),
539 node_name: Some("desktop".to_string()),
540 timestamp: None,
541 });
542 coord.adapters.get_mut(&1).expect("valid").latest = Some(CheckpointMetadata {
543 adapter_idx: 1,
544 epoch: 3,
545 avg_loss: 0.8,
546 val_loss: Some(0.75),
547 node_name: Some("desktop".to_string()),
548 timestamp: None,
549 });
550 coord.adapters.get_mut(&2).expect("valid").latest = Some(CheckpointMetadata {
551 adapter_idx: 2,
552 epoch: 2,
553 avg_loss: 0.3,
554 val_loss: Some(0.28),
555 node_name: Some("jetson".to_string()),
556 timestamp: None,
557 });
558
559 let board = coord.leaderboard();
560 assert_eq!(board.len(), 3);
561 assert_eq!(board[0].adapter_idx, 2); assert_eq!(board[0].rank, 1);
563 assert_eq!(board[1].adapter_idx, 0); assert_eq!(board[2].adapter_idx, 1); let best = coord.best_adapter().expect("valid");
567 assert_eq!(best.adapter_idx, 2);
568 }
569
570 #[test]
571 fn test_poll_local_checkpoint() {
572 let dir = tempfile::tempdir().expect("valid");
573 let best_dir = dir.path().join("best");
574 fs::create_dir_all(&best_dir).expect("valid");
575 let meta = CheckpointMetadata {
576 adapter_idx: 0,
577 epoch: 5,
578 avg_loss: 0.42,
579 val_loss: Some(0.39),
580 node_name: Some("desktop".to_string()),
581 timestamp: None,
582 };
583 fs::write(best_dir.join("metadata.json"), serde_json::to_string(&meta).expect("valid"))
584 .expect("valid");
585
586 let cluster = test_cluster();
587 let placements = vec![PlacementDecision {
588 adapter_idx: 0,
589 node_name: "desktop".to_string(),
590 score: 2.5,
591 }];
592 let mut dirs = HashMap::new();
593 dirs.insert(0, dir.path().to_path_buf());
594
595 let mut coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 300);
596 let results = coord.poll_all();
597
598 assert_eq!(results.len(), 1);
599 match &results[0] {
600 PollResult::Ok { adapter_idx, metadata } => {
601 assert_eq!(*adapter_idx, 0);
602 assert_eq!(metadata.epoch, 5);
603 assert!((metadata.avg_loss - 0.42).abs() < f32::EPSILON);
604 }
605 PollResult::Error { error, .. } => panic!("unexpected error: {error}"),
606 }
607 }
608
609 #[test]
610 fn test_poll_ssh_attempts_real_ssh() {
611 let cluster = test_cluster();
613 let placements =
614 vec![PlacementDecision { adapter_idx: 2, node_name: "jetson".to_string(), score: 0.3 }];
615 let mut coord = CheckpointCoordinator::new(cluster, &placements, &HashMap::new(), 300);
616 let results = coord.poll_all();
617
618 assert_eq!(results.len(), 1);
619 match &results[0] {
620 PollResult::Error { error, .. } => {
621 assert!(
624 !error.contains("not yet available"),
625 "SSH transport must not be stubbed: {error}"
626 );
627 }
628 PollResult::Ok { .. } => {
629 }
631 }
632 }
633
634 #[test]
635 fn test_exec_ssh_command_unreachable_host() {
636 let result = exec_ssh_command("192.0.2.1", Some("nobody"), "echo test");
638 assert!(result.is_err());
639 let err = result.unwrap_err();
640 assert!(
642 err.contains("ssh")
643 || err.contains("Connection")
644 || err.contains("timed out")
645 || err.contains("refused")
646 || err.contains("resolve")
647 || err.contains("No route")
648 || err.contains("exited"),
649 "expected real SSH error, got: {err}"
650 );
651 }
652
653 #[test]
654 fn test_exec_ssh_command_builds_correct_args() {
655 let result = exec_ssh_command("192.0.2.1", Some("testuser"), "echo hello");
658 assert!(result.is_err()); }
660
661 #[test]
662 fn test_format_leaderboard() {
663 let cluster = test_cluster();
664 let mut coord =
665 CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
666 coord.adapters.get_mut(&0).expect("valid").latest = Some(CheckpointMetadata {
667 adapter_idx: 0,
668 epoch: 2,
669 avg_loss: 0.5,
670 val_loss: None,
671 node_name: None,
672 timestamp: None,
673 });
674
675 let display = coord.format_leaderboard();
676 assert!(display.contains("Adapter Leaderboard"));
677 assert!(display.contains("0.5000"));
678 }
679
680 #[test]
681 fn test_build_launch_command_local() {
682 let node = NodeConfig {
683 name: "desktop".to_string(),
684 host: "localhost".to_string(),
685 transport: Transport::Local,
686 user: None,
687 gpus: vec![],
688 max_adapters: 1,
689 cpu_cores: None,
690 ram_mb: None,
691 };
692 let cmd = build_launch_command(
693 &node,
694 Path::new("model.apr"),
695 Path::new("data.jsonl"),
696 Path::new("/tmp/ckpt"),
697 16,
698 3,
699 );
700 assert!(cmd.starts_with("apr finetune model.apr"));
701 assert!(cmd.contains("--rank 16"));
702 assert!(cmd.contains("--epochs 3"));
703 assert!(!cmd.contains("ssh"));
704 }
705
706 #[test]
707 fn test_exec_launch_local() {
708 let node = NodeConfig {
710 name: "test".to_string(),
711 host: "localhost".to_string(),
712 transport: Transport::Local,
713 user: None,
714 gpus: vec![],
715 max_adapters: 1,
716 cpu_cores: None,
717 ram_mb: None,
718 };
719 let result = exec_launch(
721 &node,
722 Path::new("/nonexistent/model.apr"),
723 Path::new("/nonexistent/data.jsonl"),
724 Path::new("/tmp/test-ckpt"),
725 16,
726 1,
727 );
728 assert!(result.is_ok(), "local exec_launch should spawn: {:?}", result.err());
730 let mut child = result.expect("valid");
731 let _ = child.kill(); let _ = child.wait(); }
734
735 #[test]
736 fn test_build_launch_command_ssh() {
737 let node = NodeConfig {
738 name: "jetson".to_string(),
739 host: "jetson.local".to_string(),
740 transport: Transport::Ssh,
741 user: Some("noah".to_string()),
742 gpus: vec![],
743 max_adapters: 1,
744 cpu_cores: None,
745 ram_mb: None,
746 };
747 let cmd = build_launch_command(
748 &node,
749 Path::new("model.apr"),
750 Path::new("data.jsonl"),
751 Path::new("/tmp/ckpt"),
752 16,
753 3,
754 );
755 assert!(cmd.starts_with("ssh noah@jetson.local"));
756 assert!(cmd.contains("apr finetune model.apr"));
757 }
758
759 #[test]
760 fn test_check_node_health_local() {
761 let node = NodeConfig {
762 name: "local".to_string(),
763 host: "localhost".to_string(),
764 transport: Transport::Local,
765 user: None,
766 gpus: vec![],
767 max_adapters: 1,
768 cpu_cores: None,
769 ram_mb: None,
770 };
771 let health = check_node_health(&node);
772 assert_eq!(health.node_name, "local");
773 assert!(health.reachable);
774 }
776
777 #[test]
778 fn test_check_cluster_health() {
779 let cluster = test_cluster();
780 let results = check_cluster_health(&cluster);
781 assert_eq!(results.len(), 2); assert_eq!(results[0].node_name, "desktop");
783 assert!(results[0].reachable); }
785
786 #[test]
787 fn test_pull_best_checkpoint_local() {
788 let dir = tempfile::tempdir().expect("valid");
789 let ckpt_dir = dir.path().join("adapter-0");
790 let best_dir = ckpt_dir.join("best");
791 fs::create_dir_all(&best_dir).expect("valid");
792
793 let meta = CheckpointMetadata {
795 adapter_idx: 0,
796 epoch: 3,
797 avg_loss: 0.35,
798 val_loss: Some(0.30),
799 node_name: Some("desktop".to_string()),
800 timestamp: None,
801 };
802 fs::write(best_dir.join("metadata.json"), serde_json::to_string(&meta).expect("valid"))
803 .expect("valid");
804 fs::write(best_dir.join("adapter.safetensors"), b"fake-weights").expect("valid");
805
806 let cluster = test_cluster();
807 let placements = vec![PlacementDecision {
808 adapter_idx: 0,
809 node_name: "desktop".to_string(),
810 score: 2.5,
811 }];
812 let mut dirs = HashMap::new();
813 dirs.insert(0, ckpt_dir);
814
815 let mut coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 300);
816 let _ = coord.poll_all();
818
819 let dest = tempfile::tempdir().expect("valid");
820 let result = coord.pull_best_checkpoint(dest.path());
821 assert!(result.is_ok(), "pull should succeed: {:?}", result.err());
822
823 let pulled = result.expect("valid");
824 assert!(pulled.join("metadata.json").exists());
825 assert!(pulled.join("adapter.safetensors").exists());
826 }
827
828 #[test]
829 fn test_pull_no_checkpoints_fails() {
830 let cluster = test_cluster();
831 let coord = CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
832 let dest = tempfile::tempdir().expect("valid");
833 let result = coord.pull_best_checkpoint(dest.path());
834 assert!(result.is_err());
835 assert!(result.expect_err("should fail").contains("no adapters"));
836 }
837
838 #[test]
839 fn test_copy_dir_recursive() {
840 let src = tempfile::tempdir().expect("valid");
841 let sub = src.path().join("subdir");
842 fs::create_dir_all(&sub).expect("valid");
843 fs::write(src.path().join("a.txt"), "hello").expect("valid");
844 fs::write(sub.join("b.txt"), "world").expect("valid");
845
846 let dst = tempfile::tempdir().expect("valid");
847 let dst_path = dst.path().join("copy");
848 copy_dir_recursive(src.path(), &dst_path).expect("valid");
849
850 assert!(dst_path.join("a.txt").exists());
851 assert!(dst_path.join("subdir").join("b.txt").exists());
852 assert_eq!(fs::read_to_string(dst_path.join("a.txt")).expect("valid"), "hello");
853 }
854
855 #[test]
858 fn test_checkpoint_metadata_serde_roundtrip() {
859 let meta = CheckpointMetadata {
860 adapter_idx: 3,
861 epoch: 10,
862 avg_loss: 0.123,
863 val_loss: Some(0.099),
864 node_name: Some("gpu-node-1".to_string()),
865 timestamp: Some("2026-03-08T12:00:00Z".to_string()),
866 };
867 let json = serde_json::to_string(&meta).unwrap();
868 let restored: CheckpointMetadata = serde_json::from_str(&json).unwrap();
869 assert_eq!(restored.adapter_idx, 3);
870 assert_eq!(restored.epoch, 10);
871 assert!((restored.avg_loss - 0.123).abs() < f32::EPSILON);
872 assert!((restored.val_loss.unwrap() - 0.099).abs() < f32::EPSILON);
873 assert_eq!(restored.node_name.unwrap(), "gpu-node-1");
874 assert_eq!(restored.timestamp.unwrap(), "2026-03-08T12:00:00Z");
875 }
876
877 #[test]
878 fn test_checkpoint_metadata_serde_defaults() {
879 let json = r#"{"adapter_idx":0,"epoch":1,"avg_loss":0.5}"#;
881 let meta: CheckpointMetadata = serde_json::from_str(json).unwrap();
882 assert_eq!(meta.adapter_idx, 0);
883 assert!(meta.val_loss.is_none());
884 assert!(meta.node_name.is_none());
885 assert!(meta.timestamp.is_none());
886 }
887
888 #[test]
889 fn test_coordinator_custom_checkpoint_dirs() {
890 let cluster = test_cluster();
891 let placements = test_placements();
892 let mut dirs = HashMap::new();
893 dirs.insert(0, PathBuf::from("/custom/path/adapter-0"));
894 dirs.insert(2, PathBuf::from("/custom/path/adapter-2"));
895
896 let coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 600);
897 assert_eq!(coord.adapters[&0].checkpoint_dir, PathBuf::from("/custom/path/adapter-0"));
898 assert_eq!(coord.adapters[&2].checkpoint_dir, PathBuf::from("/custom/path/adapter-2"));
899 assert_eq!(coord.adapters[&1].checkpoint_dir, PathBuf::from("checkpoints/adapter-1"));
901 }
902
903 #[test]
904 fn test_coordinator_default_checkpoint_dirs() {
905 let cluster = test_cluster();
906 let placements = test_placements();
907 let dirs = HashMap::new();
908
909 let coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 300);
910 for p in &placements {
911 let expected = PathBuf::from(format!("checkpoints/adapter-{}", p.adapter_idx));
912 assert_eq!(coord.adapters[&p.adapter_idx].checkpoint_dir, expected);
913 }
914 }
915
916 #[test]
917 fn test_leaderboard_uses_val_loss_when_available() {
918 let cluster = test_cluster();
919 let mut coord =
920 CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
921
922 coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
923 adapter_idx: 0,
924 epoch: 5,
925 avg_loss: 1.0,
926 val_loss: Some(0.5), node_name: None,
928 timestamp: None,
929 });
930
931 let board = coord.leaderboard();
932 assert_eq!(board.len(), 1);
933 assert!((board[0].loss - 0.5).abs() < f32::EPSILON);
934 }
935
936 #[test]
937 fn test_leaderboard_falls_back_to_avg_loss() {
938 let cluster = test_cluster();
939 let mut coord =
940 CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
941
942 coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
943 adapter_idx: 0,
944 epoch: 5,
945 avg_loss: 1.0,
946 val_loss: None, node_name: None,
948 timestamp: None,
949 });
950
951 let board = coord.leaderboard();
952 assert_eq!(board.len(), 1);
953 assert!((board[0].loss - 1.0).abs() < f32::EPSILON);
954 }
955
956 #[test]
957 fn test_leaderboard_ranking_three_adapters() {
958 let cluster = test_cluster();
959 let mut coord =
960 CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
961
962 coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
964 adapter_idx: 0,
965 epoch: 2,
966 avg_loss: 0.7,
967 val_loss: None,
968 node_name: None,
969 timestamp: None,
970 });
971 coord.adapters.get_mut(&1).unwrap().latest = Some(CheckpointMetadata {
972 adapter_idx: 1,
973 epoch: 3,
974 avg_loss: 0.3,
975 val_loss: None,
976 node_name: None,
977 timestamp: None,
978 });
979 coord.adapters.get_mut(&2).unwrap().latest = Some(CheckpointMetadata {
980 adapter_idx: 2,
981 epoch: 1,
982 avg_loss: 0.5,
983 val_loss: None,
984 node_name: None,
985 timestamp: None,
986 });
987
988 let board = coord.leaderboard();
989 assert_eq!(board.len(), 3);
990 assert_eq!(board[0].adapter_idx, 1); assert_eq!(board[0].rank, 1);
993 assert_eq!(board[1].adapter_idx, 2); assert_eq!(board[1].rank, 2);
995 assert_eq!(board[2].adapter_idx, 0); assert_eq!(board[2].rank, 3);
997
998 let best = coord.best_adapter().unwrap();
1000 assert_eq!(best.adapter_idx, 1);
1001 }
1002
1003 #[test]
1004 fn test_format_leaderboard_empty() {
1005 let cluster = test_cluster();
1006 let coord = CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
1007 let display = coord.format_leaderboard();
1008 assert_eq!(display, "No checkpoints available yet.");
1009 }
1010
1011 #[test]
1012 fn test_format_leaderboard_multiple_entries() {
1013 let cluster = test_cluster();
1014 let mut coord =
1015 CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
1016
1017 coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
1018 adapter_idx: 0,
1019 epoch: 3,
1020 avg_loss: 0.5,
1021 val_loss: Some(0.45),
1022 node_name: None,
1023 timestamp: None,
1024 });
1025 coord.adapters.get_mut(&1).unwrap().latest = Some(CheckpointMetadata {
1026 adapter_idx: 1,
1027 epoch: 2,
1028 avg_loss: 0.8,
1029 val_loss: Some(0.75),
1030 node_name: None,
1031 timestamp: None,
1032 });
1033
1034 let display = coord.format_leaderboard();
1035 assert!(display.contains("Adapter Leaderboard"));
1036 assert!(display.contains("Rank"));
1037 assert!(display.contains("Adapter"));
1038 assert!(display.contains("Node"));
1039 assert!(display.contains("Epoch"));
1040 assert!(display.contains("Loss"));
1041 assert!(display.contains("0.4500"));
1042 assert!(display.contains("0.7500"));
1043 }
1044
1045 #[test]
1046 fn test_adapter_status_fields() {
1047 let cluster = test_cluster();
1048 let placements = test_placements();
1049 let coord = CheckpointCoordinator::new(cluster, &placements, &HashMap::new(), 300);
1050
1051 let status = &coord.adapters[&0];
1052 assert_eq!(status.adapter_idx, 0);
1053 assert_eq!(status.node_name, "desktop");
1054 assert!(status.latest.is_none());
1055 }
1056
1057 #[test]
1058 fn test_build_launch_command_ssh_no_user() {
1059 let node = NodeConfig {
1060 name: "remote".to_string(),
1061 host: "gpu-server.example.com".to_string(),
1062 transport: Transport::Ssh,
1063 user: None,
1064 gpus: vec![],
1065 max_adapters: 1,
1066 cpu_cores: None,
1067 ram_mb: None,
1068 };
1069 let cmd = build_launch_command(
1070 &node,
1071 Path::new("/models/qwen.safetensors"),
1072 Path::new("/data/train.jsonl"),
1073 Path::new("/checkpoints"),
1074 32,
1075 5,
1076 );
1077 assert!(cmd.starts_with("ssh gpu-server.example.com"));
1079 assert!(cmd.contains("apr finetune"));
1080 assert!(cmd.contains("--rank 32"));
1081 assert!(cmd.contains("--epochs 5"));
1082 }
1083
1084 #[test]
1085 fn test_poll_result_debug_ok() {
1086 let result = PollResult::Ok {
1087 adapter_idx: 0,
1088 metadata: CheckpointMetadata {
1089 adapter_idx: 0,
1090 epoch: 1,
1091 avg_loss: 0.5,
1092 val_loss: None,
1093 node_name: None,
1094 timestamp: None,
1095 },
1096 };
1097 let debug = format!("{result:?}");
1098 assert!(debug.contains("Ok"));
1099 }
1100
1101 #[test]
1102 fn test_poll_result_debug_error() {
1103 let result = PollResult::Error {
1104 adapter_idx: 1,
1105 node_name: "node1".to_string(),
1106 error: "connection refused".to_string(),
1107 };
1108 let debug = format!("{result:?}");
1109 assert!(debug.contains("Error"));
1110 assert!(debug.contains("connection refused"));
1111 }
1112
1113 #[test]
1114 fn test_node_health_debug() {
1115 let health = NodeHealth {
1116 node_name: "test-node".to_string(),
1117 reachable: true,
1118 apr_version: Some("1.0.0".to_string()),
1119 error: None,
1120 };
1121 let debug = format!("{health:?}");
1122 assert!(debug.contains("test-node"));
1123 assert!(debug.contains("1.0.0"));
1124 }
1125
1126 #[test]
1127 fn test_node_health_unreachable() {
1128 let health = NodeHealth {
1129 node_name: "offline-node".to_string(),
1130 reachable: false,
1131 apr_version: None,
1132 error: Some("connection timeout".to_string()),
1133 };
1134 assert!(!health.reachable);
1135 assert!(health.error.is_some());
1136 assert!(health.apr_version.is_none());
1137 }
1138
1139 #[test]
1140 fn test_copy_dir_recursive_nested() {
1141 let src = tempfile::tempdir().unwrap();
1142 let deep = src.path().join("a").join("b").join("c");
1143 fs::create_dir_all(&deep).unwrap();
1144 fs::write(deep.join("deep.txt"), "deep content").unwrap();
1145 fs::write(src.path().join("root.txt"), "root").unwrap();
1146
1147 let dst = tempfile::tempdir().unwrap();
1148 let dst_path = dst.path().join("output");
1149 copy_dir_recursive(src.path(), &dst_path).unwrap();
1150
1151 assert!(dst_path.join("root.txt").exists());
1152 assert!(dst_path.join("a").join("b").join("c").join("deep.txt").exists());
1153 assert_eq!(
1154 fs::read_to_string(dst_path.join("a").join("b").join("c").join("deep.txt")).unwrap(),
1155 "deep content"
1156 );
1157 }
1158
1159 #[test]
1160 fn test_copy_dir_recursive_nonexistent_src() {
1161 let dst = tempfile::tempdir().unwrap();
1162 let result = copy_dir_recursive(Path::new("/nonexistent/path"), &dst.path().join("out"));
1163 assert!(result.is_err());
1164 }
1165
1166 #[test]
1167 fn test_read_local_metadata_missing_file() {
1168 let dir = tempfile::tempdir().unwrap();
1169 let result = read_local_metadata(dir.path());
1170 assert!(result.is_err());
1171 let err = result.unwrap_err();
1172 assert!(err.contains("failed to read"));
1173 }
1174
1175 #[test]
1176 fn test_read_local_metadata_invalid_json() {
1177 let dir = tempfile::tempdir().unwrap();
1178 let best_dir = dir.path().join("best");
1179 fs::create_dir_all(&best_dir).unwrap();
1180 fs::write(best_dir.join("metadata.json"), "not valid json").unwrap();
1181 let result = read_local_metadata(dir.path());
1182 assert!(result.is_err());
1183 let err = result.unwrap_err();
1184 assert!(err.contains("failed to parse"));
1185 }
1186
1187 #[test]
1188 fn test_read_local_metadata_valid() {
1189 let dir = tempfile::tempdir().unwrap();
1190 let best_dir = dir.path().join("best");
1191 fs::create_dir_all(&best_dir).unwrap();
1192 let meta = CheckpointMetadata {
1193 adapter_idx: 42,
1194 epoch: 7,
1195 avg_loss: 0.123,
1196 val_loss: Some(0.111),
1197 node_name: Some("gpu-0".to_string()),
1198 timestamp: Some("2026-01-01".to_string()),
1199 };
1200 fs::write(best_dir.join("metadata.json"), serde_json::to_string(&meta).unwrap()).unwrap();
1201 let result = read_local_metadata(dir.path()).unwrap();
1202 assert_eq!(result.adapter_idx, 42);
1203 assert_eq!(result.epoch, 7);
1204 }
1205
1206 #[test]
1207 fn test_poll_all_updates_adapter_status() {
1208 let dir = tempfile::tempdir().unwrap();
1209 let best_dir = dir.path().join("best");
1210 fs::create_dir_all(&best_dir).unwrap();
1211 let meta = CheckpointMetadata {
1212 adapter_idx: 0,
1213 epoch: 10,
1214 avg_loss: 0.22,
1215 val_loss: Some(0.19),
1216 node_name: None,
1217 timestamp: None,
1218 };
1219 fs::write(best_dir.join("metadata.json"), serde_json::to_string(&meta).unwrap()).unwrap();
1220
1221 let cluster = test_cluster();
1222 let placements = vec![PlacementDecision {
1223 adapter_idx: 0,
1224 node_name: "desktop".to_string(),
1225 score: 1.0,
1226 }];
1227 let mut dirs = HashMap::new();
1228 dirs.insert(0, dir.path().to_path_buf());
1229 let mut coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 60);
1230
1231 assert!(coord.adapters[&0].latest.is_none());
1233
1234 let results = coord.poll_all();
1235 assert_eq!(results.len(), 1);
1236
1237 let latest = coord.adapters[&0].latest.as_ref().unwrap();
1239 assert_eq!(latest.epoch, 10);
1240 assert!((latest.avg_loss - 0.22).abs() < f32::EPSILON);
1241 }
1242
1243 #[test]
1244 fn test_leaderboard_entry_fields() {
1245 let entry = LeaderboardEntry {
1246 rank: 1,
1247 adapter_idx: 5,
1248 node_name: "gpu-node".to_string(),
1249 epoch: 15,
1250 loss: 0.123,
1251 };
1252 assert_eq!(entry.rank, 1);
1253 assert_eq!(entry.adapter_idx, 5);
1254 assert_eq!(entry.node_name, "gpu-node");
1255 assert_eq!(entry.epoch, 15);
1256 assert!((entry.loss - 0.123).abs() < f32::EPSILON);
1257 }
1258
1259 #[test]
1260 fn test_leaderboard_nan_loss_handling() {
1261 let cluster = test_cluster();
1262 let mut coord =
1263 CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
1264
1265 coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
1267 adapter_idx: 0,
1268 epoch: 1,
1269 avg_loss: f32::NAN,
1270 val_loss: Some(f32::NAN),
1271 node_name: None,
1272 timestamp: None,
1273 });
1274
1275 let board = coord.leaderboard();
1276 assert_eq!(board.len(), 1);
1277 }
1278}