Skip to main content

entrenar/gpu/
coordinator.rs

1//! Checkpoint coordination for multi-node adapter training (GPU-SHARE Phase 3, §3.4).
2//!
3//! The coordinator polls each node's checkpoint directory for metadata,
4//! compares val_loss across adapters, maintains a leaderboard, and identifies
5//! the best adapter at end of training.
6//!
7//! Remote nodes are polled via `cat checkpoint_dir/best/metadata.json` over SSH.
8//! Local nodes read the file directly.
9
10use super::cluster::{ClusterConfig, NodeConfig, Transport};
11use super::placement::PlacementDecision;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15
16/// Metadata written by `save_adapter_checkpoint()` in multi_adapter_pipeline.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct CheckpointMetadata {
19    pub adapter_idx: usize,
20    pub epoch: usize,
21    pub avg_loss: f32,
22    #[serde(default)]
23    pub val_loss: Option<f32>,
24    #[serde(default)]
25    pub node_name: Option<String>,
26    #[serde(default)]
27    pub timestamp: Option<String>,
28}
29
30/// Status of a single adapter across the cluster.
31#[derive(Debug, Clone)]
32pub struct AdapterStatus {
33    /// Adapter index.
34    pub adapter_idx: usize,
35    /// Node name where this adapter is running.
36    pub node_name: String,
37    /// Checkpoint directory on the remote node.
38    pub checkpoint_dir: PathBuf,
39    /// Latest checkpoint metadata (if available).
40    pub latest: Option<CheckpointMetadata>,
41}
42
43/// Leaderboard entry ranking adapters by loss.
44#[derive(Debug, Clone)]
45pub struct LeaderboardEntry {
46    pub rank: usize,
47    pub adapter_idx: usize,
48    pub node_name: String,
49    pub epoch: usize,
50    pub loss: f32,
51}
52
53/// Coordinator for multi-node training checkpoint polling.
54pub struct CheckpointCoordinator {
55    /// Adapter statuses indexed by adapter_idx.
56    pub adapters: HashMap<usize, AdapterStatus>,
57    /// Poll interval in seconds.
58    pub poll_interval_secs: u64,
59    /// Cluster configuration reference.
60    cluster: ClusterConfig,
61}
62
63impl CheckpointCoordinator {
64    /// Create a new coordinator from placement decisions.
65    pub fn new(
66        cluster: ClusterConfig,
67        placements: &[PlacementDecision],
68        checkpoint_dirs: &HashMap<usize, PathBuf>,
69        poll_interval_secs: u64,
70    ) -> Self {
71        let mut adapters = HashMap::new();
72        for p in placements {
73            let checkpoint_dir = checkpoint_dirs
74                .get(&p.adapter_idx)
75                .cloned()
76                .unwrap_or_else(|| PathBuf::from(format!("checkpoints/adapter-{}", p.adapter_idx)));
77            adapters.insert(
78                p.adapter_idx,
79                AdapterStatus {
80                    adapter_idx: p.adapter_idx,
81                    node_name: p.node_name.clone(),
82                    checkpoint_dir,
83                    latest: None,
84                },
85            );
86        }
87        Self { adapters, poll_interval_secs, cluster }
88    }
89
90    /// Poll all adapters for their latest checkpoint metadata.
91    ///
92    /// For local nodes, reads the file directly.
93    /// For SSH nodes, executes `ssh host cat <path>` and parses JSON.
94    pub fn poll_all(&mut self) -> Vec<PollResult> {
95        let mut results = Vec::new();
96        let adapter_list: Vec<(usize, String, PathBuf)> = self
97            .adapters
98            .values()
99            .map(|a| (a.adapter_idx, a.node_name.clone(), a.checkpoint_dir.clone()))
100            .collect();
101
102        for (idx, node_name, checkpoint_dir) in adapter_list {
103            let result = self.poll_adapter(idx, &node_name, &checkpoint_dir);
104            results.push(result);
105        }
106        results
107    }
108
109    fn poll_adapter(
110        &mut self,
111        adapter_idx: usize,
112        node_name: &str,
113        checkpoint_dir: &Path,
114    ) -> PollResult {
115        let node = self.cluster.find_node(node_name);
116        let transport = node.map_or(Transport::Local, |n| n.transport);
117
118        let metadata = match transport {
119            Transport::Local => read_local_metadata(checkpoint_dir),
120            Transport::Ssh => {
121                let host = node.map_or("unknown", |n| &n.host);
122                let user = node.and_then(|n| n.user.as_deref());
123                read_ssh_metadata(host, user, checkpoint_dir)
124            }
125        };
126
127        match metadata {
128            Ok(meta) => {
129                if let Some(status) = self.adapters.get_mut(&adapter_idx) {
130                    status.latest = Some(meta.clone());
131                }
132                PollResult::Ok { adapter_idx, metadata: meta }
133            }
134            Err(e) => PollResult::Error { adapter_idx, node_name: node_name.to_string(), error: e },
135        }
136    }
137
138    /// Generate leaderboard sorted by loss (ascending).
139    pub fn leaderboard(&self) -> Vec<LeaderboardEntry> {
140        let mut entries: Vec<_> = self
141            .adapters
142            .values()
143            .filter_map(|a| {
144                a.latest.as_ref().map(|meta| LeaderboardEntry {
145                    rank: 0,
146                    adapter_idx: a.adapter_idx,
147                    node_name: a.node_name.clone(),
148                    epoch: meta.epoch,
149                    loss: meta.val_loss.unwrap_or(meta.avg_loss),
150                })
151            })
152            .collect();
153
154        entries.sort_by(|a, b| a.loss.partial_cmp(&b.loss).unwrap_or(std::cmp::Ordering::Equal));
155        for (i, entry) in entries.iter_mut().enumerate() {
156            entry.rank = i + 1;
157        }
158        entries
159    }
160
161    /// Find the best adapter (lowest loss).
162    pub fn best_adapter(&self) -> Option<&AdapterStatus> {
163        let board = self.leaderboard();
164        board.first().and_then(|entry| self.adapters.get(&entry.adapter_idx))
165    }
166
167    /// Format leaderboard as a human-readable string.
168    pub fn format_leaderboard(&self) -> String {
169        let board = self.leaderboard();
170        if board.is_empty() {
171            return "No checkpoints available yet.".to_string();
172        }
173        let mut out = String::from("Adapter Leaderboard:\n");
174        out.push_str("  Rank | Adapter | Node       | Epoch | Loss\n");
175        out.push_str("  -----+---------+------------+-------+--------\n");
176        for entry in &board {
177            out.push_str(&format!(
178                "  {:>4} | {:>7} | {:<10} | {:>5} | {:.4}\n",
179                entry.rank, entry.adapter_idx, entry.node_name, entry.epoch, entry.loss
180            ));
181        }
182        out
183    }
184}
185
186/// Result of polling a single adapter's checkpoint.
187#[derive(Debug)]
188pub enum PollResult {
189    Ok { adapter_idx: usize, metadata: CheckpointMetadata },
190    Error { adapter_idx: usize, node_name: String, error: String },
191}
192
193/// Read checkpoint metadata from a local path.
194fn read_local_metadata(checkpoint_dir: &Path) -> Result<CheckpointMetadata, String> {
195    let best_meta = checkpoint_dir.join("best").join("metadata.json");
196    let contents = std::fs::read_to_string(&best_meta)
197        .map_err(|e| format!("failed to read {}: {e}", best_meta.display()))?;
198    serde_json::from_str(&contents)
199        .map_err(|e| format!("failed to parse {}: {e}", best_meta.display()))
200}
201
202/// Read checkpoint metadata from a remote node via SSH.
203///
204/// Executes `ssh [-l user] host cat <checkpoint_dir>/best/metadata.json`
205/// and parses the JSON output. Timeout: 10 seconds.
206fn read_ssh_metadata(
207    host: &str,
208    user: Option<&str>,
209    checkpoint_dir: &Path,
210) -> Result<CheckpointMetadata, String> {
211    let remote_path = checkpoint_dir.join("best").join("metadata.json");
212    let cat_cmd = format!("cat {}", remote_path.display());
213    let output = exec_ssh_command(host, user, &cat_cmd)?;
214    serde_json::from_str(&output).map_err(|e| format!("failed to parse metadata from {host}: {e}"))
215}
216
217/// Execute a command on a remote host via SSH.
218///
219/// Uses `ssh -o ConnectTimeout=5 -o BatchMode=yes` for non-interactive,
220/// timeout-bounded execution. Script is piped via stdin to avoid shell
221/// injection through arguments.
222fn exec_ssh_command(host: &str, user: Option<&str>, script: &str) -> Result<String, String> {
223    let mut cmd = std::process::Command::new("ssh");
224    cmd.args(["-o", "ConnectTimeout=5"]);
225    cmd.args(["-o", "BatchMode=yes"]);
226    cmd.args(["-o", "StrictHostKeyChecking=accept-new"]);
227
228    if let Some(u) = user {
229        cmd.args(["-l", u]);
230    }
231
232    cmd.arg(host);
233    cmd.arg("bash");
234
235    cmd.stdin(std::process::Stdio::piped());
236    cmd.stdout(std::process::Stdio::piped());
237    cmd.stderr(std::process::Stdio::piped());
238
239    let mut child = cmd.spawn().map_err(|e| format!("failed to spawn ssh to {host}: {e}"))?;
240
241    // Pipe script via stdin (safe against injection)
242    if let Some(stdin) = child.stdin.take() {
243        use std::io::Write;
244        let mut stdin = stdin;
245        let _ = stdin.write_all(script.as_bytes());
246        // stdin is dropped here, sending EOF
247    }
248
249    let output = child.wait_with_output().map_err(|e| format!("ssh to {host} failed: {e}"))?;
250
251    if !output.status.success() {
252        let stderr = String::from_utf8_lossy(&output.stderr);
253        return Err(format!(
254            "ssh to {host} exited {}: {stderr}",
255            output.status.code().unwrap_or(-1)
256        ));
257    }
258
259    String::from_utf8(output.stdout).map_err(|e| format!("invalid UTF-8 from ssh to {host}: {e}"))
260}
261
262/// Execute a training job on a remote or local node.
263///
264/// For local nodes, spawns the process directly.
265/// For SSH nodes, pipes the command via stdin to `ssh host bash`.
266///
267/// Returns the child process handle for monitoring.
268pub fn exec_launch(
269    node: &NodeConfig,
270    model_path: &Path,
271    data_path: &Path,
272    checkpoint_dir: &Path,
273    rank: u32,
274    epochs: u32,
275) -> Result<std::process::Child, String> {
276    let script = format!(
277        "apr finetune {} --task instruct --method qlora --quantize-nf4 \
278         --data {} --output {} --rank {rank} --epochs {epochs}",
279        model_path.display(),
280        data_path.display(),
281        checkpoint_dir.display(),
282    );
283
284    match node.transport {
285        Transport::Local => std::process::Command::new("bash")
286            .arg("-c")
287            .arg(&script)
288            .stdin(std::process::Stdio::null())
289            .stdout(std::process::Stdio::piped())
290            .stderr(std::process::Stdio::piped())
291            .spawn()
292            .map_err(|e| format!("failed to launch local training: {e}")),
293        Transport::Ssh => {
294            let mut cmd = std::process::Command::new("ssh");
295            cmd.args(["-o", "ConnectTimeout=5"]);
296            cmd.args(["-o", "BatchMode=yes"]);
297            cmd.args(["-o", "StrictHostKeyChecking=accept-new"]);
298            if let Some(ref u) = node.user {
299                cmd.args(["-l", u]);
300            }
301            cmd.arg(&node.host);
302            cmd.arg("bash");
303            cmd.stdin(std::process::Stdio::piped());
304            cmd.stdout(std::process::Stdio::piped());
305            cmd.stderr(std::process::Stdio::piped());
306
307            let mut child =
308                cmd.spawn().map_err(|e| format!("failed to ssh to {}: {e}", node.host))?;
309
310            if let Some(stdin) = child.stdin.take() {
311                use std::io::Write;
312                let mut stdin = stdin;
313                let _ = stdin.write_all(script.as_bytes());
314            }
315
316            Ok(child)
317        }
318    }
319}
320
321/// Build a remote launch command for an adapter job on a node.
322///
323/// Returns the shell command that would be executed on the remote node
324/// to start training.
325pub fn build_launch_command(
326    node: &NodeConfig,
327    model_path: &Path,
328    data_path: &Path,
329    checkpoint_dir: &Path,
330    rank: u32,
331    epochs: u32,
332) -> String {
333    let base = format!(
334        "apr finetune {} --task instruct --method qlora --quantize-nf4 \
335         --data {} --output {} --rank {rank} --epochs {epochs}",
336        model_path.display(),
337        data_path.display(),
338        checkpoint_dir.display(),
339    );
340
341    match node.transport {
342        Transport::Local => base,
343        Transport::Ssh => {
344            let user_prefix = node.user.as_ref().map_or_else(String::new, |u| format!("{u}@"));
345            format!("ssh {user_prefix}{} '{base}'", node.host)
346        }
347    }
348}
349
350/// Result of a node health check.
351#[derive(Debug, Clone)]
352pub struct NodeHealth {
353    /// Node name from cluster config.
354    pub node_name: String,
355    /// Whether the node is reachable (SSH or local).
356    pub reachable: bool,
357    /// apr CLI version if detected.
358    pub apr_version: Option<String>,
359    /// Error message if health check failed.
360    pub error: Option<String>,
361}
362
363/// Check health of all nodes in a cluster (GPU-SHARE §3.6).
364///
365/// For local nodes, checks that `apr --version` is available.
366/// For SSH nodes, runs `ssh host 'apr --version'` with timeout.
367pub fn check_cluster_health(cluster: &ClusterConfig) -> Vec<NodeHealth> {
368    cluster.nodes.iter().map(check_node_health).collect()
369}
370
371fn check_node_health(node: &NodeConfig) -> NodeHealth {
372    let script = "apr --version 2>/dev/null || echo 'apr: not found'";
373    let result = match node.transport {
374        Transport::Local => std::process::Command::new("bash")
375            .arg("-c")
376            .arg(script)
377            .output()
378            .map_err(|e| format!("failed to check local health: {e}"))
379            .and_then(|out| {
380                String::from_utf8(out.stdout).map_err(|e| format!("invalid UTF-8: {e}"))
381            }),
382        Transport::Ssh => exec_ssh_command(&node.host, node.user.as_deref(), script),
383    };
384
385    match result {
386        Ok(output) => {
387            let trimmed = output.trim().to_string();
388            let has_apr = !trimmed.contains("not found") && !trimmed.is_empty();
389            NodeHealth {
390                node_name: node.name.clone(),
391                reachable: true,
392                apr_version: if has_apr { Some(trimmed) } else { None },
393                error: if has_apr { None } else { Some("apr CLI not found on node".to_string()) },
394            }
395        }
396        Err(e) => NodeHealth {
397            node_name: node.name.clone(),
398            reachable: false,
399            apr_version: None,
400            error: Some(e),
401        },
402    }
403}
404
405impl CheckpointCoordinator {
406    /// Pull the best adapter's checkpoint from its node to a local directory (§3.4).
407    ///
408    /// For local nodes, copies the checkpoint directory.
409    /// For SSH nodes, uses `scp -r` to fetch the checkpoint.
410    ///
411    /// Returns the local path where the checkpoint was saved.
412    pub fn pull_best_checkpoint(&self, dest: &Path) -> Result<PathBuf, String> {
413        let best =
414            self.best_adapter().ok_or_else(|| "no adapters with checkpoint data".to_string())?;
415
416        let node = self
417            .cluster
418            .find_node(&best.node_name)
419            .ok_or_else(|| format!("node '{}' not found in cluster", best.node_name))?;
420
421        let source_dir = best.checkpoint_dir.join("best");
422        let dest_dir = dest.join(format!("adapter-{}-best", best.adapter_idx));
423
424        match node.transport {
425            Transport::Local => {
426                copy_dir_recursive(&source_dir, &dest_dir)?;
427                Ok(dest_dir)
428            }
429            Transport::Ssh => {
430                std::fs::create_dir_all(&dest_dir)
431                    .map_err(|e| format!("failed to create {}: {e}", dest_dir.display()))?;
432                let user_prefix = node.user.as_ref().map_or_else(String::new, |u| format!("{u}@"));
433                let remote = format!("{user_prefix}{}:{}/", node.host, source_dir.display());
434                let output = std::process::Command::new("scp")
435                    .args(["-r", "-o", "ConnectTimeout=10", "-o", "BatchMode=yes"])
436                    .arg(&remote)
437                    .arg(dest_dir.to_str().unwrap_or("."))
438                    .output()
439                    .map_err(|e| format!("failed to run scp: {e}"))?;
440
441                if !output.status.success() {
442                    let stderr = String::from_utf8_lossy(&output.stderr);
443                    return Err(format!("scp failed: {stderr}"));
444                }
445                Ok(dest_dir)
446            }
447        }
448    }
449}
450
451/// Recursively copy a directory tree.
452fn copy_dir_recursive(src: &Path, dst: &Path) -> Result<(), String> {
453    std::fs::create_dir_all(dst).map_err(|e| format!("failed to create {}: {e}", dst.display()))?;
454    let entries =
455        std::fs::read_dir(src).map_err(|e| format!("failed to read {}: {e}", src.display()))?;
456    for entry in entries {
457        let entry = entry.map_err(|e| format!("failed to read entry: {e}"))?;
458        let dest_path = dst.join(entry.file_name());
459        if entry.path().is_dir() {
460            copy_dir_recursive(&entry.path(), &dest_path)?;
461        } else {
462            std::fs::copy(entry.path(), &dest_path)
463                .map_err(|e| format!("failed to copy {}: {e}", entry.path().display()))?;
464        }
465    }
466    Ok(())
467}
468
469#[cfg(test)]
470mod tests {
471    #![allow(clippy::unwrap_used)]
472    use super::*;
473    use std::fs;
474
475    fn test_cluster() -> ClusterConfig {
476        ClusterConfig::from_yaml(
477            r"
478nodes:
479  - name: desktop
480    host: localhost
481    gpus:
482      - uuid: GPU-abcd-1234
483        type: rtx-4090
484        vram_mb: 24564
485    max_adapters: 3
486  - name: jetson
487    host: jetson.local
488    transport: ssh
489    gpus:
490      - uuid: GPU-efgh-5678
491        type: jetson-orin
492        vram_mb: 8192
493        memory_type: unified
494    max_adapters: 1
495",
496        )
497        .expect("valid")
498    }
499
500    fn test_placements() -> Vec<PlacementDecision> {
501        vec![
502            PlacementDecision { adapter_idx: 0, node_name: "desktop".to_string(), score: 2.5 },
503            PlacementDecision { adapter_idx: 1, node_name: "desktop".to_string(), score: 1.2 },
504            PlacementDecision { adapter_idx: 2, node_name: "jetson".to_string(), score: 0.3 },
505        ]
506    }
507
508    #[test]
509    fn test_coordinator_creation() {
510        let cluster = test_cluster();
511        let placements = test_placements();
512        let dirs = HashMap::new();
513        let coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 300);
514        assert_eq!(coord.adapters.len(), 3);
515        assert_eq!(coord.poll_interval_secs, 300);
516    }
517
518    #[test]
519    fn test_empty_leaderboard() {
520        let cluster = test_cluster();
521        let coord = CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
522        let board = coord.leaderboard();
523        assert!(board.is_empty());
524        assert!(coord.best_adapter().is_none());
525    }
526
527    #[test]
528    fn test_leaderboard_with_data() {
529        let cluster = test_cluster();
530        let mut coord =
531            CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
532
533        // Manually set latest metadata
534        coord.adapters.get_mut(&0).expect("valid").latest = Some(CheckpointMetadata {
535            adapter_idx: 0,
536            epoch: 3,
537            avg_loss: 0.5,
538            val_loss: Some(0.45),
539            node_name: Some("desktop".to_string()),
540            timestamp: None,
541        });
542        coord.adapters.get_mut(&1).expect("valid").latest = Some(CheckpointMetadata {
543            adapter_idx: 1,
544            epoch: 3,
545            avg_loss: 0.8,
546            val_loss: Some(0.75),
547            node_name: Some("desktop".to_string()),
548            timestamp: None,
549        });
550        coord.adapters.get_mut(&2).expect("valid").latest = Some(CheckpointMetadata {
551            adapter_idx: 2,
552            epoch: 2,
553            avg_loss: 0.3,
554            val_loss: Some(0.28),
555            node_name: Some("jetson".to_string()),
556            timestamp: None,
557        });
558
559        let board = coord.leaderboard();
560        assert_eq!(board.len(), 3);
561        assert_eq!(board[0].adapter_idx, 2); // 0.28 lowest
562        assert_eq!(board[0].rank, 1);
563        assert_eq!(board[1].adapter_idx, 0); // 0.45
564        assert_eq!(board[2].adapter_idx, 1); // 0.75
565
566        let best = coord.best_adapter().expect("valid");
567        assert_eq!(best.adapter_idx, 2);
568    }
569
570    #[test]
571    fn test_poll_local_checkpoint() {
572        let dir = tempfile::tempdir().expect("valid");
573        let best_dir = dir.path().join("best");
574        fs::create_dir_all(&best_dir).expect("valid");
575        let meta = CheckpointMetadata {
576            adapter_idx: 0,
577            epoch: 5,
578            avg_loss: 0.42,
579            val_loss: Some(0.39),
580            node_name: Some("desktop".to_string()),
581            timestamp: None,
582        };
583        fs::write(best_dir.join("metadata.json"), serde_json::to_string(&meta).expect("valid"))
584            .expect("valid");
585
586        let cluster = test_cluster();
587        let placements = vec![PlacementDecision {
588            adapter_idx: 0,
589            node_name: "desktop".to_string(),
590            score: 2.5,
591        }];
592        let mut dirs = HashMap::new();
593        dirs.insert(0, dir.path().to_path_buf());
594
595        let mut coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 300);
596        let results = coord.poll_all();
597
598        assert_eq!(results.len(), 1);
599        match &results[0] {
600            PollResult::Ok { adapter_idx, metadata } => {
601                assert_eq!(*adapter_idx, 0);
602                assert_eq!(metadata.epoch, 5);
603                assert!((metadata.avg_loss - 0.42).abs() < f32::EPSILON);
604            }
605            PollResult::Error { error, .. } => panic!("unexpected error: {error}"),
606        }
607    }
608
609    #[test]
610    fn test_poll_ssh_attempts_real_ssh() {
611        // SSH poll now attempts real SSH (will fail on missing host, not with stub error)
612        let cluster = test_cluster();
613        let placements =
614            vec![PlacementDecision { adapter_idx: 2, node_name: "jetson".to_string(), score: 0.3 }];
615        let mut coord = CheckpointCoordinator::new(cluster, &placements, &HashMap::new(), 300);
616        let results = coord.poll_all();
617
618        assert_eq!(results.len(), 1);
619        match &results[0] {
620            PollResult::Error { error, .. } => {
621                // Real SSH errors: connection refused, host unreachable, etc.
622                // Must NOT contain the old stub message
623                assert!(
624                    !error.contains("not yet available"),
625                    "SSH transport must not be stubbed: {error}"
626                );
627            }
628            PollResult::Ok { .. } => {
629                // If SSH host happens to be reachable (unlikely in CI), that's fine
630            }
631        }
632    }
633
634    #[test]
635    fn test_exec_ssh_command_unreachable_host() {
636        // Verify exec_ssh_command returns a real SSH error, not a stub
637        let result = exec_ssh_command("192.0.2.1", Some("nobody"), "echo test");
638        assert!(result.is_err());
639        let err = result.unwrap_err();
640        // Should be a real SSH/network error
641        assert!(
642            err.contains("ssh")
643                || err.contains("Connection")
644                || err.contains("timed out")
645                || err.contains("refused")
646                || err.contains("resolve")
647                || err.contains("No route")
648                || err.contains("exited"),
649            "expected real SSH error, got: {err}"
650        );
651    }
652
653    #[test]
654    fn test_exec_ssh_command_builds_correct_args() {
655        // Verify the SSH command with user sets -l flag
656        // This tests the Command construction path (will fail to connect but that's expected)
657        let result = exec_ssh_command("192.0.2.1", Some("testuser"), "echo hello");
658        assert!(result.is_err()); // Expected: can't connect to RFC 5737 test address
659    }
660
661    #[test]
662    fn test_format_leaderboard() {
663        let cluster = test_cluster();
664        let mut coord =
665            CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
666        coord.adapters.get_mut(&0).expect("valid").latest = Some(CheckpointMetadata {
667            adapter_idx: 0,
668            epoch: 2,
669            avg_loss: 0.5,
670            val_loss: None,
671            node_name: None,
672            timestamp: None,
673        });
674
675        let display = coord.format_leaderboard();
676        assert!(display.contains("Adapter Leaderboard"));
677        assert!(display.contains("0.5000"));
678    }
679
680    #[test]
681    fn test_build_launch_command_local() {
682        let node = NodeConfig {
683            name: "desktop".to_string(),
684            host: "localhost".to_string(),
685            transport: Transport::Local,
686            user: None,
687            gpus: vec![],
688            max_adapters: 1,
689            cpu_cores: None,
690            ram_mb: None,
691        };
692        let cmd = build_launch_command(
693            &node,
694            Path::new("model.apr"),
695            Path::new("data.jsonl"),
696            Path::new("/tmp/ckpt"),
697            16,
698            3,
699        );
700        assert!(cmd.starts_with("apr finetune model.apr"));
701        assert!(cmd.contains("--rank 16"));
702        assert!(cmd.contains("--epochs 3"));
703        assert!(!cmd.contains("ssh"));
704    }
705
706    #[test]
707    fn test_exec_launch_local() {
708        // exec_launch for local node should spawn a bash process
709        let node = NodeConfig {
710            name: "test".to_string(),
711            host: "localhost".to_string(),
712            transport: Transport::Local,
713            user: None,
714            gpus: vec![],
715            max_adapters: 1,
716            cpu_cores: None,
717            ram_mb: None,
718        };
719        // Use a command that will fail fast (no real apr binary needed)
720        let result = exec_launch(
721            &node,
722            Path::new("/nonexistent/model.apr"),
723            Path::new("/nonexistent/data.jsonl"),
724            Path::new("/tmp/test-ckpt"),
725            16,
726            1,
727        );
728        // Should successfully spawn even if the command fails
729        assert!(result.is_ok(), "local exec_launch should spawn: {:?}", result.err());
730        let mut child = result.expect("valid");
731        let _ = child.kill(); // Clean up
732        let _ = child.wait(); // Reap zombie
733    }
734
735    #[test]
736    fn test_build_launch_command_ssh() {
737        let node = NodeConfig {
738            name: "jetson".to_string(),
739            host: "jetson.local".to_string(),
740            transport: Transport::Ssh,
741            user: Some("noah".to_string()),
742            gpus: vec![],
743            max_adapters: 1,
744            cpu_cores: None,
745            ram_mb: None,
746        };
747        let cmd = build_launch_command(
748            &node,
749            Path::new("model.apr"),
750            Path::new("data.jsonl"),
751            Path::new("/tmp/ckpt"),
752            16,
753            3,
754        );
755        assert!(cmd.starts_with("ssh noah@jetson.local"));
756        assert!(cmd.contains("apr finetune model.apr"));
757    }
758
759    #[test]
760    fn test_check_node_health_local() {
761        let node = NodeConfig {
762            name: "local".to_string(),
763            host: "localhost".to_string(),
764            transport: Transport::Local,
765            user: None,
766            gpus: vec![],
767            max_adapters: 1,
768            cpu_cores: None,
769            ram_mb: None,
770        };
771        let health = check_node_health(&node);
772        assert_eq!(health.node_name, "local");
773        assert!(health.reachable);
774        // apr may or may not be installed — just verify the check ran
775    }
776
777    #[test]
778    fn test_check_cluster_health() {
779        let cluster = test_cluster();
780        let results = check_cluster_health(&cluster);
781        assert_eq!(results.len(), 2); // desktop + jetson
782        assert_eq!(results[0].node_name, "desktop");
783        assert!(results[0].reachable); // local node should be reachable
784    }
785
786    #[test]
787    fn test_pull_best_checkpoint_local() {
788        let dir = tempfile::tempdir().expect("valid");
789        let ckpt_dir = dir.path().join("adapter-0");
790        let best_dir = ckpt_dir.join("best");
791        fs::create_dir_all(&best_dir).expect("valid");
792
793        // Write metadata + a model file
794        let meta = CheckpointMetadata {
795            adapter_idx: 0,
796            epoch: 3,
797            avg_loss: 0.35,
798            val_loss: Some(0.30),
799            node_name: Some("desktop".to_string()),
800            timestamp: None,
801        };
802        fs::write(best_dir.join("metadata.json"), serde_json::to_string(&meta).expect("valid"))
803            .expect("valid");
804        fs::write(best_dir.join("adapter.safetensors"), b"fake-weights").expect("valid");
805
806        let cluster = test_cluster();
807        let placements = vec![PlacementDecision {
808            adapter_idx: 0,
809            node_name: "desktop".to_string(),
810            score: 2.5,
811        }];
812        let mut dirs = HashMap::new();
813        dirs.insert(0, ckpt_dir);
814
815        let mut coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 300);
816        // Poll to get metadata
817        let _ = coord.poll_all();
818
819        let dest = tempfile::tempdir().expect("valid");
820        let result = coord.pull_best_checkpoint(dest.path());
821        assert!(result.is_ok(), "pull should succeed: {:?}", result.err());
822
823        let pulled = result.expect("valid");
824        assert!(pulled.join("metadata.json").exists());
825        assert!(pulled.join("adapter.safetensors").exists());
826    }
827
828    #[test]
829    fn test_pull_no_checkpoints_fails() {
830        let cluster = test_cluster();
831        let coord = CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
832        let dest = tempfile::tempdir().expect("valid");
833        let result = coord.pull_best_checkpoint(dest.path());
834        assert!(result.is_err());
835        assert!(result.expect_err("should fail").contains("no adapters"));
836    }
837
838    #[test]
839    fn test_copy_dir_recursive() {
840        let src = tempfile::tempdir().expect("valid");
841        let sub = src.path().join("subdir");
842        fs::create_dir_all(&sub).expect("valid");
843        fs::write(src.path().join("a.txt"), "hello").expect("valid");
844        fs::write(sub.join("b.txt"), "world").expect("valid");
845
846        let dst = tempfile::tempdir().expect("valid");
847        let dst_path = dst.path().join("copy");
848        copy_dir_recursive(src.path(), &dst_path).expect("valid");
849
850        assert!(dst_path.join("a.txt").exists());
851        assert!(dst_path.join("subdir").join("b.txt").exists());
852        assert_eq!(fs::read_to_string(dst_path.join("a.txt")).expect("valid"), "hello");
853    }
854
855    // ── Additional coverage tests ──
856
857    #[test]
858    fn test_checkpoint_metadata_serde_roundtrip() {
859        let meta = CheckpointMetadata {
860            adapter_idx: 3,
861            epoch: 10,
862            avg_loss: 0.123,
863            val_loss: Some(0.099),
864            node_name: Some("gpu-node-1".to_string()),
865            timestamp: Some("2026-03-08T12:00:00Z".to_string()),
866        };
867        let json = serde_json::to_string(&meta).unwrap();
868        let restored: CheckpointMetadata = serde_json::from_str(&json).unwrap();
869        assert_eq!(restored.adapter_idx, 3);
870        assert_eq!(restored.epoch, 10);
871        assert!((restored.avg_loss - 0.123).abs() < f32::EPSILON);
872        assert!((restored.val_loss.unwrap() - 0.099).abs() < f32::EPSILON);
873        assert_eq!(restored.node_name.unwrap(), "gpu-node-1");
874        assert_eq!(restored.timestamp.unwrap(), "2026-03-08T12:00:00Z");
875    }
876
877    #[test]
878    fn test_checkpoint_metadata_serde_defaults() {
879        // Test deserialization with missing optional fields
880        let json = r#"{"adapter_idx":0,"epoch":1,"avg_loss":0.5}"#;
881        let meta: CheckpointMetadata = serde_json::from_str(json).unwrap();
882        assert_eq!(meta.adapter_idx, 0);
883        assert!(meta.val_loss.is_none());
884        assert!(meta.node_name.is_none());
885        assert!(meta.timestamp.is_none());
886    }
887
888    #[test]
889    fn test_coordinator_custom_checkpoint_dirs() {
890        let cluster = test_cluster();
891        let placements = test_placements();
892        let mut dirs = HashMap::new();
893        dirs.insert(0, PathBuf::from("/custom/path/adapter-0"));
894        dirs.insert(2, PathBuf::from("/custom/path/adapter-2"));
895
896        let coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 600);
897        assert_eq!(coord.adapters[&0].checkpoint_dir, PathBuf::from("/custom/path/adapter-0"));
898        assert_eq!(coord.adapters[&2].checkpoint_dir, PathBuf::from("/custom/path/adapter-2"));
899        // Adapter 1 should have auto-generated path
900        assert_eq!(coord.adapters[&1].checkpoint_dir, PathBuf::from("checkpoints/adapter-1"));
901    }
902
903    #[test]
904    fn test_coordinator_default_checkpoint_dirs() {
905        let cluster = test_cluster();
906        let placements = test_placements();
907        let dirs = HashMap::new();
908
909        let coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 300);
910        for p in &placements {
911            let expected = PathBuf::from(format!("checkpoints/adapter-{}", p.adapter_idx));
912            assert_eq!(coord.adapters[&p.adapter_idx].checkpoint_dir, expected);
913        }
914    }
915
916    #[test]
917    fn test_leaderboard_uses_val_loss_when_available() {
918        let cluster = test_cluster();
919        let mut coord =
920            CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
921
922        coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
923            adapter_idx: 0,
924            epoch: 5,
925            avg_loss: 1.0,
926            val_loss: Some(0.5), // val_loss should be used
927            node_name: None,
928            timestamp: None,
929        });
930
931        let board = coord.leaderboard();
932        assert_eq!(board.len(), 1);
933        assert!((board[0].loss - 0.5).abs() < f32::EPSILON);
934    }
935
936    #[test]
937    fn test_leaderboard_falls_back_to_avg_loss() {
938        let cluster = test_cluster();
939        let mut coord =
940            CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
941
942        coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
943            adapter_idx: 0,
944            epoch: 5,
945            avg_loss: 1.0,
946            val_loss: None, // no val_loss, should fallback to avg_loss
947            node_name: None,
948            timestamp: None,
949        });
950
951        let board = coord.leaderboard();
952        assert_eq!(board.len(), 1);
953        assert!((board[0].loss - 1.0).abs() < f32::EPSILON);
954    }
955
956    #[test]
957    fn test_leaderboard_ranking_three_adapters() {
958        let cluster = test_cluster();
959        let mut coord =
960            CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
961
962        // Set losses: adapter 0 = 0.7, adapter 1 = 0.3, adapter 2 = 0.5
963        coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
964            adapter_idx: 0,
965            epoch: 2,
966            avg_loss: 0.7,
967            val_loss: None,
968            node_name: None,
969            timestamp: None,
970        });
971        coord.adapters.get_mut(&1).unwrap().latest = Some(CheckpointMetadata {
972            adapter_idx: 1,
973            epoch: 3,
974            avg_loss: 0.3,
975            val_loss: None,
976            node_name: None,
977            timestamp: None,
978        });
979        coord.adapters.get_mut(&2).unwrap().latest = Some(CheckpointMetadata {
980            adapter_idx: 2,
981            epoch: 1,
982            avg_loss: 0.5,
983            val_loss: None,
984            node_name: None,
985            timestamp: None,
986        });
987
988        let board = coord.leaderboard();
989        assert_eq!(board.len(), 3);
990        // Ranked by loss ascending
991        assert_eq!(board[0].adapter_idx, 1); // 0.3
992        assert_eq!(board[0].rank, 1);
993        assert_eq!(board[1].adapter_idx, 2); // 0.5
994        assert_eq!(board[1].rank, 2);
995        assert_eq!(board[2].adapter_idx, 0); // 0.7
996        assert_eq!(board[2].rank, 3);
997
998        // best_adapter returns lowest loss
999        let best = coord.best_adapter().unwrap();
1000        assert_eq!(best.adapter_idx, 1);
1001    }
1002
1003    #[test]
1004    fn test_format_leaderboard_empty() {
1005        let cluster = test_cluster();
1006        let coord = CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
1007        let display = coord.format_leaderboard();
1008        assert_eq!(display, "No checkpoints available yet.");
1009    }
1010
1011    #[test]
1012    fn test_format_leaderboard_multiple_entries() {
1013        let cluster = test_cluster();
1014        let mut coord =
1015            CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
1016
1017        coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
1018            adapter_idx: 0,
1019            epoch: 3,
1020            avg_loss: 0.5,
1021            val_loss: Some(0.45),
1022            node_name: None,
1023            timestamp: None,
1024        });
1025        coord.adapters.get_mut(&1).unwrap().latest = Some(CheckpointMetadata {
1026            adapter_idx: 1,
1027            epoch: 2,
1028            avg_loss: 0.8,
1029            val_loss: Some(0.75),
1030            node_name: None,
1031            timestamp: None,
1032        });
1033
1034        let display = coord.format_leaderboard();
1035        assert!(display.contains("Adapter Leaderboard"));
1036        assert!(display.contains("Rank"));
1037        assert!(display.contains("Adapter"));
1038        assert!(display.contains("Node"));
1039        assert!(display.contains("Epoch"));
1040        assert!(display.contains("Loss"));
1041        assert!(display.contains("0.4500"));
1042        assert!(display.contains("0.7500"));
1043    }
1044
1045    #[test]
1046    fn test_adapter_status_fields() {
1047        let cluster = test_cluster();
1048        let placements = test_placements();
1049        let coord = CheckpointCoordinator::new(cluster, &placements, &HashMap::new(), 300);
1050
1051        let status = &coord.adapters[&0];
1052        assert_eq!(status.adapter_idx, 0);
1053        assert_eq!(status.node_name, "desktop");
1054        assert!(status.latest.is_none());
1055    }
1056
1057    #[test]
1058    fn test_build_launch_command_ssh_no_user() {
1059        let node = NodeConfig {
1060            name: "remote".to_string(),
1061            host: "gpu-server.example.com".to_string(),
1062            transport: Transport::Ssh,
1063            user: None,
1064            gpus: vec![],
1065            max_adapters: 1,
1066            cpu_cores: None,
1067            ram_mb: None,
1068        };
1069        let cmd = build_launch_command(
1070            &node,
1071            Path::new("/models/qwen.safetensors"),
1072            Path::new("/data/train.jsonl"),
1073            Path::new("/checkpoints"),
1074            32,
1075            5,
1076        );
1077        // Without user, no user@ prefix
1078        assert!(cmd.starts_with("ssh gpu-server.example.com"));
1079        assert!(cmd.contains("apr finetune"));
1080        assert!(cmd.contains("--rank 32"));
1081        assert!(cmd.contains("--epochs 5"));
1082    }
1083
1084    #[test]
1085    fn test_poll_result_debug_ok() {
1086        let result = PollResult::Ok {
1087            adapter_idx: 0,
1088            metadata: CheckpointMetadata {
1089                adapter_idx: 0,
1090                epoch: 1,
1091                avg_loss: 0.5,
1092                val_loss: None,
1093                node_name: None,
1094                timestamp: None,
1095            },
1096        };
1097        let debug = format!("{result:?}");
1098        assert!(debug.contains("Ok"));
1099    }
1100
1101    #[test]
1102    fn test_poll_result_debug_error() {
1103        let result = PollResult::Error {
1104            adapter_idx: 1,
1105            node_name: "node1".to_string(),
1106            error: "connection refused".to_string(),
1107        };
1108        let debug = format!("{result:?}");
1109        assert!(debug.contains("Error"));
1110        assert!(debug.contains("connection refused"));
1111    }
1112
1113    #[test]
1114    fn test_node_health_debug() {
1115        let health = NodeHealth {
1116            node_name: "test-node".to_string(),
1117            reachable: true,
1118            apr_version: Some("1.0.0".to_string()),
1119            error: None,
1120        };
1121        let debug = format!("{health:?}");
1122        assert!(debug.contains("test-node"));
1123        assert!(debug.contains("1.0.0"));
1124    }
1125
1126    #[test]
1127    fn test_node_health_unreachable() {
1128        let health = NodeHealth {
1129            node_name: "offline-node".to_string(),
1130            reachable: false,
1131            apr_version: None,
1132            error: Some("connection timeout".to_string()),
1133        };
1134        assert!(!health.reachable);
1135        assert!(health.error.is_some());
1136        assert!(health.apr_version.is_none());
1137    }
1138
1139    #[test]
1140    fn test_copy_dir_recursive_nested() {
1141        let src = tempfile::tempdir().unwrap();
1142        let deep = src.path().join("a").join("b").join("c");
1143        fs::create_dir_all(&deep).unwrap();
1144        fs::write(deep.join("deep.txt"), "deep content").unwrap();
1145        fs::write(src.path().join("root.txt"), "root").unwrap();
1146
1147        let dst = tempfile::tempdir().unwrap();
1148        let dst_path = dst.path().join("output");
1149        copy_dir_recursive(src.path(), &dst_path).unwrap();
1150
1151        assert!(dst_path.join("root.txt").exists());
1152        assert!(dst_path.join("a").join("b").join("c").join("deep.txt").exists());
1153        assert_eq!(
1154            fs::read_to_string(dst_path.join("a").join("b").join("c").join("deep.txt")).unwrap(),
1155            "deep content"
1156        );
1157    }
1158
1159    #[test]
1160    fn test_copy_dir_recursive_nonexistent_src() {
1161        let dst = tempfile::tempdir().unwrap();
1162        let result = copy_dir_recursive(Path::new("/nonexistent/path"), &dst.path().join("out"));
1163        assert!(result.is_err());
1164    }
1165
1166    #[test]
1167    fn test_read_local_metadata_missing_file() {
1168        let dir = tempfile::tempdir().unwrap();
1169        let result = read_local_metadata(dir.path());
1170        assert!(result.is_err());
1171        let err = result.unwrap_err();
1172        assert!(err.contains("failed to read"));
1173    }
1174
1175    #[test]
1176    fn test_read_local_metadata_invalid_json() {
1177        let dir = tempfile::tempdir().unwrap();
1178        let best_dir = dir.path().join("best");
1179        fs::create_dir_all(&best_dir).unwrap();
1180        fs::write(best_dir.join("metadata.json"), "not valid json").unwrap();
1181        let result = read_local_metadata(dir.path());
1182        assert!(result.is_err());
1183        let err = result.unwrap_err();
1184        assert!(err.contains("failed to parse"));
1185    }
1186
1187    #[test]
1188    fn test_read_local_metadata_valid() {
1189        let dir = tempfile::tempdir().unwrap();
1190        let best_dir = dir.path().join("best");
1191        fs::create_dir_all(&best_dir).unwrap();
1192        let meta = CheckpointMetadata {
1193            adapter_idx: 42,
1194            epoch: 7,
1195            avg_loss: 0.123,
1196            val_loss: Some(0.111),
1197            node_name: Some("gpu-0".to_string()),
1198            timestamp: Some("2026-01-01".to_string()),
1199        };
1200        fs::write(best_dir.join("metadata.json"), serde_json::to_string(&meta).unwrap()).unwrap();
1201        let result = read_local_metadata(dir.path()).unwrap();
1202        assert_eq!(result.adapter_idx, 42);
1203        assert_eq!(result.epoch, 7);
1204    }
1205
1206    #[test]
1207    fn test_poll_all_updates_adapter_status() {
1208        let dir = tempfile::tempdir().unwrap();
1209        let best_dir = dir.path().join("best");
1210        fs::create_dir_all(&best_dir).unwrap();
1211        let meta = CheckpointMetadata {
1212            adapter_idx: 0,
1213            epoch: 10,
1214            avg_loss: 0.22,
1215            val_loss: Some(0.19),
1216            node_name: None,
1217            timestamp: None,
1218        };
1219        fs::write(best_dir.join("metadata.json"), serde_json::to_string(&meta).unwrap()).unwrap();
1220
1221        let cluster = test_cluster();
1222        let placements = vec![PlacementDecision {
1223            adapter_idx: 0,
1224            node_name: "desktop".to_string(),
1225            score: 1.0,
1226        }];
1227        let mut dirs = HashMap::new();
1228        dirs.insert(0, dir.path().to_path_buf());
1229        let mut coord = CheckpointCoordinator::new(cluster, &placements, &dirs, 60);
1230
1231        // Before poll, latest is None
1232        assert!(coord.adapters[&0].latest.is_none());
1233
1234        let results = coord.poll_all();
1235        assert_eq!(results.len(), 1);
1236
1237        // After poll, latest should be populated
1238        let latest = coord.adapters[&0].latest.as_ref().unwrap();
1239        assert_eq!(latest.epoch, 10);
1240        assert!((latest.avg_loss - 0.22).abs() < f32::EPSILON);
1241    }
1242
1243    #[test]
1244    fn test_leaderboard_entry_fields() {
1245        let entry = LeaderboardEntry {
1246            rank: 1,
1247            adapter_idx: 5,
1248            node_name: "gpu-node".to_string(),
1249            epoch: 15,
1250            loss: 0.123,
1251        };
1252        assert_eq!(entry.rank, 1);
1253        assert_eq!(entry.adapter_idx, 5);
1254        assert_eq!(entry.node_name, "gpu-node");
1255        assert_eq!(entry.epoch, 15);
1256        assert!((entry.loss - 0.123).abs() < f32::EPSILON);
1257    }
1258
1259    #[test]
1260    fn test_leaderboard_nan_loss_handling() {
1261        let cluster = test_cluster();
1262        let mut coord =
1263            CheckpointCoordinator::new(cluster, &test_placements(), &HashMap::new(), 300);
1264
1265        // NaN in val_loss should still produce a leaderboard entry
1266        coord.adapters.get_mut(&0).unwrap().latest = Some(CheckpointMetadata {
1267            adapter_idx: 0,
1268            epoch: 1,
1269            avg_loss: f32::NAN,
1270            val_loss: Some(f32::NAN),
1271            node_name: None,
1272            timestamp: None,
1273        });
1274
1275        let board = coord.leaderboard();
1276        assert_eq!(board.len(), 1);
1277    }
1278}