Skip to main content

entrenar/gpu/
cluster.rs

1//! Cluster configuration for multi-node GPU training (GPU-SHARE Phase 3, §3.2).
2//!
3//! Parses `cluster.yaml` files describing heterogeneous training clusters
4//! with mixed GPU types (RTX 4090, Jetson, CPU-only nodes).
5//!
6//! # Example
7//!
8//! ```yaml
9//! nodes:
10//!   - name: desktop
11//!     host: localhost
12//!     gpus:
13//!       - uuid: GPU-abcd-1234
14//!         type: rtx-4090
15//!         vram_mb: 24564
16//!         memory_type: discrete
17//!     max_adapters: 3
18//! ```
19
20use serde::{Deserialize, Serialize};
21use std::collections::HashSet;
22use std::path::Path;
23
24/// Top-level cluster configuration.
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ClusterConfig {
27    /// Cluster nodes (at least one required).
28    pub nodes: Vec<NodeConfig>,
29}
30
31/// Configuration for a single training node.
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct NodeConfig {
34    /// Human-readable node name (must be unique within cluster).
35    pub name: String,
36    /// Hostname or IP address.
37    pub host: String,
38    /// Transport method for remote nodes.
39    #[serde(default)]
40    pub transport: Transport,
41    /// SSH user for remote nodes (defaults to current user).
42    #[serde(default)]
43    pub user: Option<String>,
44    /// GPUs available on this node (empty = CPU-only).
45    #[serde(default)]
46    pub gpus: Vec<GpuConfig>,
47    /// Maximum number of concurrent adapters on this node.
48    #[serde(default = "default_max_adapters")]
49    pub max_adapters: usize,
50    /// CPU cores available (for CPU-only nodes).
51    #[serde(default)]
52    pub cpu_cores: Option<u32>,
53    /// RAM in MB (for CPU-only nodes).
54    #[serde(default)]
55    pub ram_mb: Option<u64>,
56}
57
58/// Transport method for connecting to a node.
59#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60#[serde(rename_all = "lowercase")]
61#[derive(Default)]
62pub enum Transport {
63    /// Local node (no transport needed).
64    #[default]
65    Local,
66    /// SSH transport via forjar.
67    Ssh,
68}
69
70/// Configuration for a single GPU on a node.
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct GpuConfig {
73    /// GPU UUID from nvidia-smi (e.g., GPU-abcd-1234).
74    pub uuid: String,
75    /// GPU type identifier (e.g., rtx-4090, jetson-orin).
76    #[serde(rename = "type")]
77    pub gpu_type: String,
78    /// Total VRAM in MB.
79    pub vram_mb: u64,
80    /// Memory architecture (affects reserve factor).
81    #[serde(default)]
82    pub memory_type: MemoryType,
83}
84
85/// GPU memory architecture.
86#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
87#[serde(rename_all = "lowercase")]
88#[derive(Default)]
89pub enum MemoryType {
90    /// Discrete GPU memory (85% reserve factor).
91    #[default]
92    Discrete,
93    /// Unified memory shared with CPU (60% reserve factor).
94    Unified,
95}
96
97impl MemoryType {
98    /// Reserve factor: fraction of VRAM usable for training.
99    #[must_use]
100    pub fn reserve_factor(self) -> f32 {
101        match self {
102            Self::Discrete => 0.85,
103            Self::Unified => 0.60,
104        }
105    }
106}
107
108fn default_max_adapters() -> usize {
109    1
110}
111
112/// Cluster configuration validation errors.
113#[derive(Debug, thiserror::Error)]
114pub enum ClusterValidationError {
115    #[error("cluster must have at least one node")]
116    NoNodes,
117    #[error("duplicate node name: {0}")]
118    DuplicateNodeName(String),
119    #[error("node '{name}': max_adapters must be >= 1")]
120    ZeroMaxAdapters { name: String },
121    #[error("node '{node}': GPU '{uuid}' has zero VRAM")]
122    ZeroVram { node: String, uuid: String },
123    #[error("node '{node}': duplicate GPU UUID '{uuid}'")]
124    DuplicateGpuUuid { node: String, uuid: String },
125    #[error("node '{node}': SSH transport requires a host other than localhost")]
126    SshLocalhost { node: String },
127}
128
129impl ClusterConfig {
130    /// Load cluster config from a YAML file.
131    ///
132    /// # Errors
133    /// Returns error if file cannot be read or parsed, or if validation fails.
134    pub fn from_file(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
135        let contents = std::fs::read_to_string(path)?;
136        let config: Self = serde_yaml::from_str(&contents)?;
137        config.validate()?;
138        Ok(config)
139    }
140
141    /// Parse cluster config from a YAML string.
142    ///
143    /// # Errors
144    /// Returns error if parsing or validation fails.
145    pub fn from_yaml(yaml: &str) -> Result<Self, Box<dyn std::error::Error>> {
146        let config: Self = serde_yaml::from_str(yaml)?;
147        config.validate()?;
148        Ok(config)
149    }
150
151    /// Validate cluster configuration.
152    ///
153    /// # Errors
154    /// Returns the first validation error found.
155    pub fn validate(&self) -> Result<(), ClusterValidationError> {
156        if self.nodes.is_empty() {
157            return Err(ClusterValidationError::NoNodes);
158        }
159
160        let mut names = HashSet::new();
161        for node in &self.nodes {
162            if !names.insert(&node.name) {
163                return Err(ClusterValidationError::DuplicateNodeName(node.name.clone()));
164            }
165            if node.max_adapters == 0 {
166                return Err(ClusterValidationError::ZeroMaxAdapters { name: node.name.clone() });
167            }
168            if node.transport == Transport::Ssh
169                && (node.host == "localhost" || node.host == "127.0.0.1")
170            {
171                return Err(ClusterValidationError::SshLocalhost { node: node.name.clone() });
172            }
173            validate_node_gpus(node)?;
174        }
175        Ok(())
176    }
177
178    /// Total number of adapters the cluster can train concurrently.
179    #[must_use]
180    pub fn total_adapter_capacity(&self) -> usize {
181        self.nodes.iter().map(|n| n.max_adapters).sum()
182    }
183
184    /// Find a node by name.
185    #[must_use]
186    pub fn find_node(&self, name: &str) -> Option<&NodeConfig> {
187        self.nodes.iter().find(|n| n.name == name)
188    }
189}
190
191fn validate_node_gpus(node: &NodeConfig) -> Result<(), ClusterValidationError> {
192    let mut gpu_uuids = HashSet::new();
193    for gpu in &node.gpus {
194        if gpu.vram_mb == 0 {
195            return Err(ClusterValidationError::ZeroVram {
196                node: node.name.clone(),
197                uuid: gpu.uuid.clone(),
198            });
199        }
200        if !gpu_uuids.insert(&gpu.uuid) {
201            return Err(ClusterValidationError::DuplicateGpuUuid {
202                node: node.name.clone(),
203                uuid: gpu.uuid.clone(),
204            });
205        }
206    }
207    Ok(())
208}
209
210impl NodeConfig {
211    /// Total VRAM across all GPUs on this node (in MB).
212    #[must_use]
213    pub fn total_vram_mb(&self) -> u64 {
214        self.gpus.iter().map(|g| g.vram_mb).sum()
215    }
216
217    /// Usable VRAM (total × reserve_factor) across all GPUs.
218    #[must_use]
219    pub fn usable_vram_mb(&self) -> u64 {
220        self.gpus
221            .iter()
222            .map(|g| (g.vram_mb as f64 * f64::from(g.memory_type.reserve_factor())) as u64)
223            .sum()
224    }
225
226    /// Whether this node is local (no transport needed).
227    #[must_use]
228    pub fn is_local(&self) -> bool {
229        self.transport == Transport::Local
230    }
231
232    /// Whether this is a CPU-only node (no GPUs).
233    #[must_use]
234    pub fn is_cpu_only(&self) -> bool {
235        self.gpus.is_empty()
236    }
237}
238
239impl GpuConfig {
240    /// Usable VRAM after applying reserve factor.
241    #[must_use]
242    pub fn usable_vram_mb(&self) -> u64 {
243        (self.vram_mb as f64 * f64::from(self.memory_type.reserve_factor())) as u64
244    }
245}
246
247impl std::fmt::Display for ClusterConfig {
248    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249        writeln!(
250            f,
251            "Cluster: {} node(s), {} adapter slots",
252            self.nodes.len(),
253            self.total_adapter_capacity()
254        )?;
255        for node in &self.nodes {
256            write!(f, "  {}: {} ({})", node.name, node.host, node.transport)?;
257            if node.gpus.is_empty() {
258                write!(f, " [CPU-only]")?;
259            } else {
260                for gpu in &node.gpus {
261                    write!(f, " [{} {} MB {:?}]", gpu.gpu_type, gpu.vram_mb, gpu.memory_type)?;
262                }
263            }
264            writeln!(f, " max_adapters={}", node.max_adapters)?;
265        }
266        Ok(())
267    }
268}
269
270impl std::fmt::Display for Transport {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        match self {
273            Self::Local => write!(f, "local"),
274            Self::Ssh => write!(f, "ssh"),
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    #![allow(clippy::unwrap_used)]
282    use super::*;
283
284    fn sample_yaml() -> &'static str {
285        r"
286nodes:
287  - name: desktop
288    host: localhost
289    gpus:
290      - uuid: GPU-abcd-1234
291        type: rtx-4090
292        vram_mb: 24564
293        memory_type: discrete
294    max_adapters: 3
295
296  - name: jetson
297    host: jetson.local
298    transport: ssh
299    gpus:
300      - uuid: GPU-efgh-5678
301        type: jetson-orin
302        vram_mb: 8192
303        memory_type: unified
304    max_adapters: 1
305
306  - name: intel-box
307    host: 10.0.0.5
308    transport: ssh
309    user: noah
310    gpus: []
311    cpu_cores: 16
312    ram_mb: 65536
313    max_adapters: 1
314"
315    }
316
317    #[test]
318    fn test_parse_cluster_yaml() {
319        let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
320        assert_eq!(config.nodes.len(), 3);
321
322        let desktop = &config.nodes[0];
323        assert_eq!(desktop.name, "desktop");
324        assert_eq!(desktop.host, "localhost");
325        assert_eq!(desktop.transport, Transport::Local);
326        assert_eq!(desktop.gpus.len(), 1);
327        assert_eq!(desktop.gpus[0].uuid, "GPU-abcd-1234");
328        assert_eq!(desktop.gpus[0].gpu_type, "rtx-4090");
329        assert_eq!(desktop.gpus[0].vram_mb, 24564);
330        assert_eq!(desktop.gpus[0].memory_type, MemoryType::Discrete);
331        assert_eq!(desktop.max_adapters, 3);
332
333        let jetson = &config.nodes[1];
334        assert_eq!(jetson.transport, Transport::Ssh);
335        assert_eq!(jetson.gpus[0].memory_type, MemoryType::Unified);
336
337        let intel = &config.nodes[2];
338        assert!(intel.is_cpu_only());
339        assert_eq!(intel.user, Some("noah".to_string()));
340        assert_eq!(intel.cpu_cores, Some(16));
341        assert_eq!(intel.ram_mb, Some(65536));
342    }
343
344    #[test]
345    fn test_total_adapter_capacity() {
346        let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
347        assert_eq!(config.total_adapter_capacity(), 5); // 3 + 1 + 1
348    }
349
350    #[test]
351    fn test_node_vram_calculations() {
352        let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
353        let desktop = &config.nodes[0];
354        assert_eq!(desktop.total_vram_mb(), 24564);
355        // 24564 * 0.85 = 20879.4 → 20879
356        assert_eq!(desktop.usable_vram_mb(), 20879);
357
358        let jetson = &config.nodes[1];
359        assert_eq!(jetson.total_vram_mb(), 8192);
360        // 8192 * 0.60 = 4915.2 → 4915
361        assert_eq!(jetson.usable_vram_mb(), 4915);
362    }
363
364    #[test]
365    fn test_gpu_usable_vram() {
366        let gpu = GpuConfig {
367            uuid: "GPU-test".to_string(),
368            gpu_type: "rtx-4090".to_string(),
369            vram_mb: 24000,
370            memory_type: MemoryType::Discrete,
371        };
372        assert_eq!(gpu.usable_vram_mb(), 20400); // 24000 * 0.85
373    }
374
375    #[test]
376    fn test_find_node() {
377        let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
378        assert!(config.find_node("desktop").is_some());
379        assert!(config.find_node("jetson").is_some());
380        assert!(config.find_node("nonexistent").is_none());
381    }
382
383    #[test]
384    fn test_validation_no_nodes() {
385        let yaml = "nodes: []";
386        let result = ClusterConfig::from_yaml(yaml);
387        assert!(result.is_err());
388        assert!(result.unwrap_err().to_string().contains("at least one node"));
389    }
390
391    #[test]
392    fn test_validation_duplicate_names() {
393        let yaml = r"
394nodes:
395  - name: box1
396    host: localhost
397    max_adapters: 1
398  - name: box1
399    host: 10.0.0.2
400    transport: ssh
401    max_adapters: 1
402";
403        let result = ClusterConfig::from_yaml(yaml);
404        assert!(result.is_err());
405        assert!(result.unwrap_err().to_string().contains("duplicate node name"));
406    }
407
408    #[test]
409    fn test_validation_zero_max_adapters() {
410        let yaml = r"
411nodes:
412  - name: bad
413    host: localhost
414    max_adapters: 0
415";
416        let result = ClusterConfig::from_yaml(yaml);
417        assert!(result.is_err());
418        assert!(result.unwrap_err().to_string().contains("max_adapters"));
419    }
420
421    #[test]
422    fn test_validation_zero_vram() {
423        let yaml = r"
424nodes:
425  - name: bad
426    host: localhost
427    gpus:
428      - uuid: GPU-bad
429        type: unknown
430        vram_mb: 0
431    max_adapters: 1
432";
433        let result = ClusterConfig::from_yaml(yaml);
434        assert!(result.is_err());
435        assert!(result.unwrap_err().to_string().contains("zero VRAM"));
436    }
437
438    #[test]
439    fn test_validation_duplicate_gpu_uuid() {
440        let yaml = r"
441nodes:
442  - name: dupes
443    host: localhost
444    gpus:
445      - uuid: GPU-same
446        type: rtx-4090
447        vram_mb: 24000
448      - uuid: GPU-same
449        type: rtx-4090
450        vram_mb: 24000
451    max_adapters: 2
452";
453        let result = ClusterConfig::from_yaml(yaml);
454        assert!(result.is_err());
455        assert!(result.unwrap_err().to_string().contains("duplicate GPU UUID"));
456    }
457
458    #[test]
459    fn test_validation_ssh_localhost() {
460        let yaml = r"
461nodes:
462  - name: bad-ssh
463    host: localhost
464    transport: ssh
465    max_adapters: 1
466";
467        let result = ClusterConfig::from_yaml(yaml);
468        assert!(result.is_err());
469        assert!(result.unwrap_err().to_string().contains("SSH transport"));
470    }
471
472    #[test]
473    fn test_display() {
474        let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
475        let display = format!("{config}");
476        assert!(display.contains("3 node(s)"));
477        assert!(display.contains("5 adapter slots"));
478        assert!(display.contains("desktop"));
479        assert!(display.contains("rtx-4090"));
480        assert!(display.contains("CPU-only"));
481    }
482
483    #[test]
484    fn test_reserve_factor() {
485        assert!((MemoryType::Discrete.reserve_factor() - 0.85).abs() < f32::EPSILON);
486        assert!((MemoryType::Unified.reserve_factor() - 0.60).abs() < f32::EPSILON);
487    }
488
489    #[test]
490    fn test_minimal_config() {
491        let yaml = r"
492nodes:
493  - name: single
494    host: localhost
495";
496        let config = ClusterConfig::from_yaml(yaml).unwrap();
497        assert_eq!(config.nodes.len(), 1);
498        assert_eq!(config.nodes[0].max_adapters, 1); // default
499        assert!(config.nodes[0].gpus.is_empty()); // default
500    }
501
502    #[test]
503    fn test_serialization_roundtrip() {
504        let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
505        let yaml = serde_yaml::to_string(&config).unwrap();
506        let reparsed = ClusterConfig::from_yaml(&yaml).unwrap();
507        assert_eq!(reparsed.nodes.len(), config.nodes.len());
508        assert_eq!(reparsed.total_adapter_capacity(), config.total_adapter_capacity());
509    }
510}
511
512/// GPU dispatch cost_model (PW-01: 5× PCIe Rule)
513///
514/// Determines when GPU dispatch is beneficial based on compute-to-transfer
515/// ratio. The crossover point (dispatch_threshold) is 5× the PCIe transfer cost.
516pub struct GpuCostModel {
517    /// PCIe transfer cost per MB (microseconds)
518    pub pcie_cost_per_mb: f64,
519    /// GPU compute cost per MFLOP (microseconds)
520    pub gpu_compute_per_mflop: f64,
521    /// Dispatch threshold multiplier (default: 5×)
522    pub dispatch_threshold: f64,
523}
524
525impl Default for GpuCostModel {
526    fn default() -> Self {
527        Self {
528            pcie_cost_per_mb: 40.0,      // PCIe 4.0 ~25 GB/s → ~40 µs/MB
529            gpu_compute_per_mflop: 0.01, // RTX 4090 ~80 TFLOPS → ~0.01 µs/MFLOP
530            dispatch_threshold: 5.0,     // 5× PCIe rule
531        }
532    }
533}
534
535impl GpuCostModel {
536    /// Check if GPU dispatch is beneficial for the given workload.
537    ///
538    /// Returns true when compute time > dispatch_threshold × transfer time (crossover).
539    pub fn should_dispatch_gpu(&self, data_mb: f64, compute_mflops: f64) -> bool {
540        let transfer_cost = data_mb * self.pcie_cost_per_mb;
541        let compute_cost = compute_mflops * self.gpu_compute_per_mflop;
542        compute_cost > self.dispatch_threshold * transfer_cost
543    }
544}
545
546#[cfg(test)]
547mod cost_model_tests {
548    use super::*;
549
550    /// cost_test: small workloads stay on CPU (PW-13 prediction_accuracy)
551    #[test]
552    fn cost_test_small_workload_stays_cpu() {
553        let model = GpuCostModel::default();
554        // 1 MB data, 100 MFLOPS → transfer dominates, prediction_accuracy: CPU
555        assert!(!model.should_dispatch_gpu(1.0, 100.0));
556    }
557
558    /// cost_test: large workloads go to GPU (PW-13 prediction_accuracy)
559    #[test]
560    fn cost_test_large_workload_goes_gpu() {
561        let model = GpuCostModel::default();
562        // 1 MB data, 1_000_000 MFLOPS → compute dominates, prediction_accuracy: GPU
563        assert!(model.should_dispatch_gpu(1.0, 1_000_000.0));
564    }
565}