nebulous/accelerator/
base.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fs;
4use std::path::Path;
5
6/// Represents an accelerator with its name and memory capacity in GB
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Accelerator {
9    pub name: String,
10    pub memory: u32,
11}
12
13/// Configuration for accelerators
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AcceleratorsConfig {
16    pub supported: Vec<Accelerator>,
17}
18
19/// Platform-specific configuration
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct PlatformConfig {
22    pub name: String,
23    pub accelerator_map: HashMap<String, String>,
24}
25
26/// Root configuration structure
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct Config {
29    pub accelerators: AcceleratorsConfig,
30}
31
32/// Trait for platform-specific accelerator providers
33pub trait AcceleratorProvider {
34    /// Get the platform name
35    fn name(&self) -> &str;
36
37    /// Get the mapping from internal accelerator names to platform-specific names
38    fn accelerator_map(&self) -> &HashMap<String, String>;
39
40    /// Get the platform-specific name for an accelerator
41    fn get_platform_name(&self, internal_name: &str) -> Option<&String> {
42        self.accelerator_map().get(internal_name)
43    }
44}
45
46impl Config {
47    /// Load configuration from a specified path
48    pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
49        let config_content =
50            fs::read_to_string(path).map_err(|e| ConfigError::IoError(e.to_string()))?;
51
52        serde_yaml::from_str(&config_content).map_err(|e| ConfigError::ParseError(e.to_string()))
53    }
54
55    /// Get an accelerator by its name
56    pub fn get_accelerator_by_name(&self, name: &str) -> Option<&Accelerator> {
57        self.accelerators
58            .supported
59            .iter()
60            .find(|acc| acc.name == name)
61    }
62
63    /// Create a default configuration with predefined accelerators
64    pub fn default() -> Self {
65        let supported = vec![
66            Accelerator {
67                name: "A100_PCIe".to_string(),
68                memory: 80,
69            },
70            Accelerator {
71                name: "A100_SXM".to_string(),
72                memory: 80,
73            },
74            Accelerator {
75                name: "A30".to_string(),
76                memory: 24,
77            },
78            Accelerator {
79                name: "A40".to_string(),
80                memory: 48,
81            },
82            Accelerator {
83                name: "H100_NVL".to_string(),
84                memory: 94,
85            },
86            Accelerator {
87                name: "H100_PCIe".to_string(),
88                memory: 80,
89            },
90            Accelerator {
91                name: "H100_SXM".to_string(),
92                memory: 80,
93            },
94            Accelerator {
95                name: "H200_SXM".to_string(),
96                memory: 143,
97            },
98            Accelerator {
99                name: "L4".to_string(),
100                memory: 24,
101            },
102            Accelerator {
103                name: "L40".to_string(),
104                memory: 48,
105            },
106            Accelerator {
107                name: "L40S".to_string(),
108                memory: 48,
109            },
110            Accelerator {
111                name: "MI300X".to_string(),
112                memory: 192,
113            },
114            Accelerator {
115                name: "RTX_2000_Ada".to_string(),
116                memory: 16,
117            },
118            Accelerator {
119                name: "RTX_3070".to_string(),
120                memory: 8,
121            },
122            Accelerator {
123                name: "RTX_3080".to_string(),
124                memory: 10,
125            },
126            Accelerator {
127                name: "RTX_3080_Ti".to_string(),
128                memory: 12,
129            },
130            Accelerator {
131                name: "RTX_3090".to_string(),
132                memory: 24,
133            },
134            Accelerator {
135                name: "RTX_3090_Ti".to_string(),
136                memory: 24,
137            },
138            Accelerator {
139                name: "RTX_4000_Ada".to_string(),
140                memory: 20,
141            },
142            Accelerator {
143                name: "RTX_4070_Ti".to_string(),
144                memory: 12,
145            },
146            Accelerator {
147                name: "RTX_4080".to_string(),
148                memory: 16,
149            },
150            Accelerator {
151                name: "RTX_4080_SUPER".to_string(),
152                memory: 16,
153            },
154            Accelerator {
155                name: "RTX_4090".to_string(),
156                memory: 24,
157            },
158            Accelerator {
159                name: "RTX_5000_Ada".to_string(),
160                memory: 32,
161            },
162            Accelerator {
163                name: "RTX_6000_Ada".to_string(),
164                memory: 48,
165            },
166            Accelerator {
167                name: "RTX_A2000".to_string(),
168                memory: 6,
169            },
170            Accelerator {
171                name: "RTX_A4000".to_string(),
172                memory: 16,
173            },
174            Accelerator {
175                name: "RTX_A4500".to_string(),
176                memory: 20,
177            },
178            Accelerator {
179                name: "RTX_A5000".to_string(),
180                memory: 24,
181            },
182            Accelerator {
183                name: "RTX_A6000".to_string(),
184                memory: 48,
185            },
186            Accelerator {
187                name: "V100".to_string(),
188                memory: 16,
189            },
190            Accelerator {
191                name: "V100_FHHL".to_string(),
192                memory: 16,
193            },
194            Accelerator {
195                name: "V100_SXM2".to_string(),
196                memory: 16,
197            },
198            Accelerator {
199                name: "V100_SXM2_32GB".to_string(),
200                memory: 32,
201            },
202        ];
203        Config {
204            accelerators: AcceleratorsConfig { supported },
205        }
206    }
207}
208
209/// Error types for configuration operations
210#[derive(Debug, thiserror::Error)]
211pub enum ConfigError {
212    #[error("IO error: {0}")]
213    IoError(String),
214
215    #[error("Parse error: {0}")]
216    ParseError(String),
217}