1use serde::{Deserialize, Serialize};
21use std::collections::HashSet;
22use std::path::Path;
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct ClusterConfig {
27 pub nodes: Vec<NodeConfig>,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct NodeConfig {
34 pub name: String,
36 pub host: String,
38 #[serde(default)]
40 pub transport: Transport,
41 #[serde(default)]
43 pub user: Option<String>,
44 #[serde(default)]
46 pub gpus: Vec<GpuConfig>,
47 #[serde(default = "default_max_adapters")]
49 pub max_adapters: usize,
50 #[serde(default)]
52 pub cpu_cores: Option<u32>,
53 #[serde(default)]
55 pub ram_mb: Option<u64>,
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
60#[serde(rename_all = "lowercase")]
61#[derive(Default)]
62pub enum Transport {
63 #[default]
65 Local,
66 Ssh,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct GpuConfig {
73 pub uuid: String,
75 #[serde(rename = "type")]
77 pub gpu_type: String,
78 pub vram_mb: u64,
80 #[serde(default)]
82 pub memory_type: MemoryType,
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
87#[serde(rename_all = "lowercase")]
88#[derive(Default)]
89pub enum MemoryType {
90 #[default]
92 Discrete,
93 Unified,
95}
96
97impl MemoryType {
98 #[must_use]
100 pub fn reserve_factor(self) -> f32 {
101 match self {
102 Self::Discrete => 0.85,
103 Self::Unified => 0.60,
104 }
105 }
106}
107
108fn default_max_adapters() -> usize {
109 1
110}
111
112#[derive(Debug, thiserror::Error)]
114pub enum ClusterValidationError {
115 #[error("cluster must have at least one node")]
116 NoNodes,
117 #[error("duplicate node name: {0}")]
118 DuplicateNodeName(String),
119 #[error("node '{name}': max_adapters must be >= 1")]
120 ZeroMaxAdapters { name: String },
121 #[error("node '{node}': GPU '{uuid}' has zero VRAM")]
122 ZeroVram { node: String, uuid: String },
123 #[error("node '{node}': duplicate GPU UUID '{uuid}'")]
124 DuplicateGpuUuid { node: String, uuid: String },
125 #[error("node '{node}': SSH transport requires a host other than localhost")]
126 SshLocalhost { node: String },
127}
128
129impl ClusterConfig {
130 pub fn from_file(path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
135 let contents = std::fs::read_to_string(path)?;
136 let config: Self = serde_yaml::from_str(&contents)?;
137 config.validate()?;
138 Ok(config)
139 }
140
141 pub fn from_yaml(yaml: &str) -> Result<Self, Box<dyn std::error::Error>> {
146 let config: Self = serde_yaml::from_str(yaml)?;
147 config.validate()?;
148 Ok(config)
149 }
150
151 pub fn validate(&self) -> Result<(), ClusterValidationError> {
156 if self.nodes.is_empty() {
157 return Err(ClusterValidationError::NoNodes);
158 }
159
160 let mut names = HashSet::new();
161 for node in &self.nodes {
162 if !names.insert(&node.name) {
163 return Err(ClusterValidationError::DuplicateNodeName(node.name.clone()));
164 }
165 if node.max_adapters == 0 {
166 return Err(ClusterValidationError::ZeroMaxAdapters { name: node.name.clone() });
167 }
168 if node.transport == Transport::Ssh
169 && (node.host == "localhost" || node.host == "127.0.0.1")
170 {
171 return Err(ClusterValidationError::SshLocalhost { node: node.name.clone() });
172 }
173 validate_node_gpus(node)?;
174 }
175 Ok(())
176 }
177
178 #[must_use]
180 pub fn total_adapter_capacity(&self) -> usize {
181 self.nodes.iter().map(|n| n.max_adapters).sum()
182 }
183
184 #[must_use]
186 pub fn find_node(&self, name: &str) -> Option<&NodeConfig> {
187 self.nodes.iter().find(|n| n.name == name)
188 }
189}
190
191fn validate_node_gpus(node: &NodeConfig) -> Result<(), ClusterValidationError> {
192 let mut gpu_uuids = HashSet::new();
193 for gpu in &node.gpus {
194 if gpu.vram_mb == 0 {
195 return Err(ClusterValidationError::ZeroVram {
196 node: node.name.clone(),
197 uuid: gpu.uuid.clone(),
198 });
199 }
200 if !gpu_uuids.insert(&gpu.uuid) {
201 return Err(ClusterValidationError::DuplicateGpuUuid {
202 node: node.name.clone(),
203 uuid: gpu.uuid.clone(),
204 });
205 }
206 }
207 Ok(())
208}
209
210impl NodeConfig {
211 #[must_use]
213 pub fn total_vram_mb(&self) -> u64 {
214 self.gpus.iter().map(|g| g.vram_mb).sum()
215 }
216
217 #[must_use]
219 pub fn usable_vram_mb(&self) -> u64 {
220 self.gpus
221 .iter()
222 .map(|g| (g.vram_mb as f64 * f64::from(g.memory_type.reserve_factor())) as u64)
223 .sum()
224 }
225
226 #[must_use]
228 pub fn is_local(&self) -> bool {
229 self.transport == Transport::Local
230 }
231
232 #[must_use]
234 pub fn is_cpu_only(&self) -> bool {
235 self.gpus.is_empty()
236 }
237}
238
239impl GpuConfig {
240 #[must_use]
242 pub fn usable_vram_mb(&self) -> u64 {
243 (self.vram_mb as f64 * f64::from(self.memory_type.reserve_factor())) as u64
244 }
245}
246
247impl std::fmt::Display for ClusterConfig {
248 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
249 writeln!(
250 f,
251 "Cluster: {} node(s), {} adapter slots",
252 self.nodes.len(),
253 self.total_adapter_capacity()
254 )?;
255 for node in &self.nodes {
256 write!(f, " {}: {} ({})", node.name, node.host, node.transport)?;
257 if node.gpus.is_empty() {
258 write!(f, " [CPU-only]")?;
259 } else {
260 for gpu in &node.gpus {
261 write!(f, " [{} {} MB {:?}]", gpu.gpu_type, gpu.vram_mb, gpu.memory_type)?;
262 }
263 }
264 writeln!(f, " max_adapters={}", node.max_adapters)?;
265 }
266 Ok(())
267 }
268}
269
270impl std::fmt::Display for Transport {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 match self {
273 Self::Local => write!(f, "local"),
274 Self::Ssh => write!(f, "ssh"),
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 #![allow(clippy::unwrap_used)]
282 use super::*;
283
284 fn sample_yaml() -> &'static str {
285 r"
286nodes:
287 - name: desktop
288 host: localhost
289 gpus:
290 - uuid: GPU-abcd-1234
291 type: rtx-4090
292 vram_mb: 24564
293 memory_type: discrete
294 max_adapters: 3
295
296 - name: jetson
297 host: jetson.local
298 transport: ssh
299 gpus:
300 - uuid: GPU-efgh-5678
301 type: jetson-orin
302 vram_mb: 8192
303 memory_type: unified
304 max_adapters: 1
305
306 - name: intel-box
307 host: 10.0.0.5
308 transport: ssh
309 user: noah
310 gpus: []
311 cpu_cores: 16
312 ram_mb: 65536
313 max_adapters: 1
314"
315 }
316
317 #[test]
318 fn test_parse_cluster_yaml() {
319 let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
320 assert_eq!(config.nodes.len(), 3);
321
322 let desktop = &config.nodes[0];
323 assert_eq!(desktop.name, "desktop");
324 assert_eq!(desktop.host, "localhost");
325 assert_eq!(desktop.transport, Transport::Local);
326 assert_eq!(desktop.gpus.len(), 1);
327 assert_eq!(desktop.gpus[0].uuid, "GPU-abcd-1234");
328 assert_eq!(desktop.gpus[0].gpu_type, "rtx-4090");
329 assert_eq!(desktop.gpus[0].vram_mb, 24564);
330 assert_eq!(desktop.gpus[0].memory_type, MemoryType::Discrete);
331 assert_eq!(desktop.max_adapters, 3);
332
333 let jetson = &config.nodes[1];
334 assert_eq!(jetson.transport, Transport::Ssh);
335 assert_eq!(jetson.gpus[0].memory_type, MemoryType::Unified);
336
337 let intel = &config.nodes[2];
338 assert!(intel.is_cpu_only());
339 assert_eq!(intel.user, Some("noah".to_string()));
340 assert_eq!(intel.cpu_cores, Some(16));
341 assert_eq!(intel.ram_mb, Some(65536));
342 }
343
344 #[test]
345 fn test_total_adapter_capacity() {
346 let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
347 assert_eq!(config.total_adapter_capacity(), 5); }
349
350 #[test]
351 fn test_node_vram_calculations() {
352 let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
353 let desktop = &config.nodes[0];
354 assert_eq!(desktop.total_vram_mb(), 24564);
355 assert_eq!(desktop.usable_vram_mb(), 20879);
357
358 let jetson = &config.nodes[1];
359 assert_eq!(jetson.total_vram_mb(), 8192);
360 assert_eq!(jetson.usable_vram_mb(), 4915);
362 }
363
364 #[test]
365 fn test_gpu_usable_vram() {
366 let gpu = GpuConfig {
367 uuid: "GPU-test".to_string(),
368 gpu_type: "rtx-4090".to_string(),
369 vram_mb: 24000,
370 memory_type: MemoryType::Discrete,
371 };
372 assert_eq!(gpu.usable_vram_mb(), 20400); }
374
375 #[test]
376 fn test_find_node() {
377 let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
378 assert!(config.find_node("desktop").is_some());
379 assert!(config.find_node("jetson").is_some());
380 assert!(config.find_node("nonexistent").is_none());
381 }
382
383 #[test]
384 fn test_validation_no_nodes() {
385 let yaml = "nodes: []";
386 let result = ClusterConfig::from_yaml(yaml);
387 assert!(result.is_err());
388 assert!(result.unwrap_err().to_string().contains("at least one node"));
389 }
390
391 #[test]
392 fn test_validation_duplicate_names() {
393 let yaml = r"
394nodes:
395 - name: box1
396 host: localhost
397 max_adapters: 1
398 - name: box1
399 host: 10.0.0.2
400 transport: ssh
401 max_adapters: 1
402";
403 let result = ClusterConfig::from_yaml(yaml);
404 assert!(result.is_err());
405 assert!(result.unwrap_err().to_string().contains("duplicate node name"));
406 }
407
408 #[test]
409 fn test_validation_zero_max_adapters() {
410 let yaml = r"
411nodes:
412 - name: bad
413 host: localhost
414 max_adapters: 0
415";
416 let result = ClusterConfig::from_yaml(yaml);
417 assert!(result.is_err());
418 assert!(result.unwrap_err().to_string().contains("max_adapters"));
419 }
420
421 #[test]
422 fn test_validation_zero_vram() {
423 let yaml = r"
424nodes:
425 - name: bad
426 host: localhost
427 gpus:
428 - uuid: GPU-bad
429 type: unknown
430 vram_mb: 0
431 max_adapters: 1
432";
433 let result = ClusterConfig::from_yaml(yaml);
434 assert!(result.is_err());
435 assert!(result.unwrap_err().to_string().contains("zero VRAM"));
436 }
437
438 #[test]
439 fn test_validation_duplicate_gpu_uuid() {
440 let yaml = r"
441nodes:
442 - name: dupes
443 host: localhost
444 gpus:
445 - uuid: GPU-same
446 type: rtx-4090
447 vram_mb: 24000
448 - uuid: GPU-same
449 type: rtx-4090
450 vram_mb: 24000
451 max_adapters: 2
452";
453 let result = ClusterConfig::from_yaml(yaml);
454 assert!(result.is_err());
455 assert!(result.unwrap_err().to_string().contains("duplicate GPU UUID"));
456 }
457
458 #[test]
459 fn test_validation_ssh_localhost() {
460 let yaml = r"
461nodes:
462 - name: bad-ssh
463 host: localhost
464 transport: ssh
465 max_adapters: 1
466";
467 let result = ClusterConfig::from_yaml(yaml);
468 assert!(result.is_err());
469 assert!(result.unwrap_err().to_string().contains("SSH transport"));
470 }
471
472 #[test]
473 fn test_display() {
474 let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
475 let display = format!("{config}");
476 assert!(display.contains("3 node(s)"));
477 assert!(display.contains("5 adapter slots"));
478 assert!(display.contains("desktop"));
479 assert!(display.contains("rtx-4090"));
480 assert!(display.contains("CPU-only"));
481 }
482
483 #[test]
484 fn test_reserve_factor() {
485 assert!((MemoryType::Discrete.reserve_factor() - 0.85).abs() < f32::EPSILON);
486 assert!((MemoryType::Unified.reserve_factor() - 0.60).abs() < f32::EPSILON);
487 }
488
489 #[test]
490 fn test_minimal_config() {
491 let yaml = r"
492nodes:
493 - name: single
494 host: localhost
495";
496 let config = ClusterConfig::from_yaml(yaml).unwrap();
497 assert_eq!(config.nodes.len(), 1);
498 assert_eq!(config.nodes[0].max_adapters, 1); assert!(config.nodes[0].gpus.is_empty()); }
501
502 #[test]
503 fn test_serialization_roundtrip() {
504 let config = ClusterConfig::from_yaml(sample_yaml()).unwrap();
505 let yaml = serde_yaml::to_string(&config).unwrap();
506 let reparsed = ClusterConfig::from_yaml(&yaml).unwrap();
507 assert_eq!(reparsed.nodes.len(), config.nodes.len());
508 assert_eq!(reparsed.total_adapter_capacity(), config.total_adapter_capacity());
509 }
510}
511
512pub struct GpuCostModel {
517 pub pcie_cost_per_mb: f64,
519 pub gpu_compute_per_mflop: f64,
521 pub dispatch_threshold: f64,
523}
524
525impl Default for GpuCostModel {
526 fn default() -> Self {
527 Self {
528 pcie_cost_per_mb: 40.0, gpu_compute_per_mflop: 0.01, dispatch_threshold: 5.0, }
532 }
533}
534
535impl GpuCostModel {
536 pub fn should_dispatch_gpu(&self, data_mb: f64, compute_mflops: f64) -> bool {
540 let transfer_cost = data_mb * self.pcie_cost_per_mb;
541 let compute_cost = compute_mflops * self.gpu_compute_per_mflop;
542 compute_cost > self.dispatch_threshold * transfer_cost
543 }
544}
545
546#[cfg(test)]
547mod cost_model_tests {
548 use super::*;
549
550 #[test]
552 fn cost_test_small_workload_stays_cpu() {
553 let model = GpuCostModel::default();
554 assert!(!model.should_dispatch_gpu(1.0, 100.0));
556 }
557
558 #[test]
560 fn cost_test_large_workload_goes_gpu() {
561 let model = GpuCostModel::default();
562 assert!(model.should_dispatch_gpu(1.0, 1_000_000.0));
564 }
565}