1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fs;
4use std::path::Path;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct Accelerator {
9 pub name: String,
10 pub memory: u32,
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct AcceleratorsConfig {
16 pub supported: Vec<Accelerator>,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct PlatformConfig {
22 pub name: String,
23 pub accelerator_map: HashMap<String, String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct Config {
29 pub accelerators: AcceleratorsConfig,
30}
31
32pub trait AcceleratorProvider {
34 fn name(&self) -> &str;
36
37 fn accelerator_map(&self) -> &HashMap<String, String>;
39
40 fn get_platform_name(&self, internal_name: &str) -> Option<&String> {
42 self.accelerator_map().get(internal_name)
43 }
44}
45
46impl Config {
47 pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
49 let config_content =
50 fs::read_to_string(path).map_err(|e| ConfigError::IoError(e.to_string()))?;
51
52 serde_yaml::from_str(&config_content).map_err(|e| ConfigError::ParseError(e.to_string()))
53 }
54
55 pub fn get_accelerator_by_name(&self, name: &str) -> Option<&Accelerator> {
57 self.accelerators
58 .supported
59 .iter()
60 .find(|acc| acc.name == name)
61 }
62
63 pub fn default() -> Self {
65 let supported = vec![
66 Accelerator {
67 name: "A100_PCIe".to_string(),
68 memory: 80,
69 },
70 Accelerator {
71 name: "A100_SXM".to_string(),
72 memory: 80,
73 },
74 Accelerator {
75 name: "A30".to_string(),
76 memory: 24,
77 },
78 Accelerator {
79 name: "A40".to_string(),
80 memory: 48,
81 },
82 Accelerator {
83 name: "H100_NVL".to_string(),
84 memory: 94,
85 },
86 Accelerator {
87 name: "H100_PCIe".to_string(),
88 memory: 80,
89 },
90 Accelerator {
91 name: "H100_SXM".to_string(),
92 memory: 80,
93 },
94 Accelerator {
95 name: "H200_SXM".to_string(),
96 memory: 143,
97 },
98 Accelerator {
99 name: "L4".to_string(),
100 memory: 24,
101 },
102 Accelerator {
103 name: "L40".to_string(),
104 memory: 48,
105 },
106 Accelerator {
107 name: "L40S".to_string(),
108 memory: 48,
109 },
110 Accelerator {
111 name: "MI300X".to_string(),
112 memory: 192,
113 },
114 Accelerator {
115 name: "RTX_2000_Ada".to_string(),
116 memory: 16,
117 },
118 Accelerator {
119 name: "RTX_3070".to_string(),
120 memory: 8,
121 },
122 Accelerator {
123 name: "RTX_3080".to_string(),
124 memory: 10,
125 },
126 Accelerator {
127 name: "RTX_3080_Ti".to_string(),
128 memory: 12,
129 },
130 Accelerator {
131 name: "RTX_3090".to_string(),
132 memory: 24,
133 },
134 Accelerator {
135 name: "RTX_3090_Ti".to_string(),
136 memory: 24,
137 },
138 Accelerator {
139 name: "RTX_4000_Ada".to_string(),
140 memory: 20,
141 },
142 Accelerator {
143 name: "RTX_4070_Ti".to_string(),
144 memory: 12,
145 },
146 Accelerator {
147 name: "RTX_4080".to_string(),
148 memory: 16,
149 },
150 Accelerator {
151 name: "RTX_4080_SUPER".to_string(),
152 memory: 16,
153 },
154 Accelerator {
155 name: "RTX_4090".to_string(),
156 memory: 24,
157 },
158 Accelerator {
159 name: "RTX_5000_Ada".to_string(),
160 memory: 32,
161 },
162 Accelerator {
163 name: "RTX_6000_Ada".to_string(),
164 memory: 48,
165 },
166 Accelerator {
167 name: "RTX_A2000".to_string(),
168 memory: 6,
169 },
170 Accelerator {
171 name: "RTX_A4000".to_string(),
172 memory: 16,
173 },
174 Accelerator {
175 name: "RTX_A4500".to_string(),
176 memory: 20,
177 },
178 Accelerator {
179 name: "RTX_A5000".to_string(),
180 memory: 24,
181 },
182 Accelerator {
183 name: "RTX_A6000".to_string(),
184 memory: 48,
185 },
186 Accelerator {
187 name: "V100".to_string(),
188 memory: 16,
189 },
190 Accelerator {
191 name: "V100_FHHL".to_string(),
192 memory: 16,
193 },
194 Accelerator {
195 name: "V100_SXM2".to_string(),
196 memory: 16,
197 },
198 Accelerator {
199 name: "V100_SXM2_32GB".to_string(),
200 memory: 32,
201 },
202 ];
203 Config {
204 accelerators: AcceleratorsConfig { supported },
205 }
206 }
207}
208
209#[derive(Debug, thiserror::Error)]
211pub enum ConfigError {
212 #[error("IO error: {0}")]
213 IoError(String),
214
215 #[error("Parse error: {0}")]
216 ParseError(String),
217}