1use std::collections::HashMap;
11use std::path::Path;
12
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info, warn};
15
16const CDI_SPEC_DIRS: &[&str] = &["/etc/cdi", "/var/run/cdi"];
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(rename_all = "camelCase")]
22pub struct CdiSpec {
23 pub cdi_version: String,
25 pub kind: String,
27 #[serde(default)]
29 pub devices: Vec<CdiDevice>,
30 #[serde(default)]
32 pub container_edits: Option<CdiContainerEdits>,
33}
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct CdiDevice {
39 pub name: String,
41 #[serde(default)]
43 pub container_edits: Option<CdiContainerEdits>,
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize, Default)]
48#[serde(rename_all = "camelCase")]
49pub struct CdiContainerEdits {
50 #[serde(default)]
52 pub env: Vec<String>,
53 #[serde(default)]
55 pub device_nodes: Vec<CdiDeviceNode>,
56 #[serde(default)]
58 pub mounts: Vec<CdiMount>,
59 #[serde(default)]
61 pub hooks: Option<CdiHooks>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct CdiDeviceNode {
68 pub path: String,
70 pub host_path: Option<String>,
72 #[serde(rename = "type")]
74 pub device_type: Option<String>,
75 pub major: Option<i64>,
77 pub minor: Option<i64>,
79 #[serde(default)]
81 pub file_mode: Option<u32>,
82 pub uid: Option<u32>,
84 pub gid: Option<u32>,
86 pub permissions: Option<String>,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize)]
92#[serde(rename_all = "camelCase")]
93pub struct CdiMount {
94 pub container_path: String,
96 pub host_path: String,
98 #[serde(default)]
100 pub options: Vec<String>,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize, Default)]
105#[serde(rename_all = "camelCase")]
106pub struct CdiHooks {
107 #[serde(default)]
108 pub prestart: Vec<CdiHook>,
109 #[serde(default)]
110 pub create_runtime: Vec<CdiHook>,
111 #[serde(default)]
112 pub create_container: Vec<CdiHook>,
113 #[serde(default)]
114 pub start_container: Vec<CdiHook>,
115 #[serde(default)]
116 pub poststart: Vec<CdiHook>,
117 #[serde(default)]
118 pub poststop: Vec<CdiHook>,
119}
120
121#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct CdiHook {
124 pub path: String,
125 #[serde(default)]
126 pub args: Vec<String>,
127 #[serde(default)]
128 pub env: Vec<String>,
129}
130
131#[derive(Debug, Default)]
136pub struct CdiRegistry {
137 specs: HashMap<String, CdiSpec>,
139}
140
141impl CdiRegistry {
142 pub fn discover() -> Self {
147 let mut registry = Self::default();
148
149 for dir in CDI_SPEC_DIRS {
150 let dir_path = Path::new(dir);
151 if !dir_path.is_dir() {
152 debug!(dir = %dir, "CDI spec directory does not exist, skipping");
153 continue;
154 }
155
156 let entries = match std::fs::read_dir(dir_path) {
157 Ok(e) => e,
158 Err(e) => {
159 warn!(dir = %dir, error = %e, "Failed to read CDI spec directory");
160 continue;
161 }
162 };
163
164 for entry in entries.flatten() {
165 let path = entry.path();
166 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
167 if ext != "json" && ext != "yaml" && ext != "yml" {
168 continue;
169 }
170
171 match Self::load_spec(&path) {
172 Ok(spec) => {
173 info!(
174 kind = %spec.kind,
175 devices = spec.devices.len(),
176 path = %path.display(),
177 "Loaded CDI spec"
178 );
179 registry.specs.insert(spec.kind.clone(), spec);
180 }
181 Err(e) => {
182 warn!(path = %path.display(), error = %e, "Failed to parse CDI spec");
183 }
184 }
185 }
186 }
187
188 registry
189 }
190
191 fn load_spec(path: &Path) -> Result<CdiSpec, CdiError> {
193 let content = std::fs::read_to_string(path)
194 .map_err(|e| CdiError::Io(format!("{}: {e}", path.display())))?;
195
196 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
197 if ext == "json" {
198 serde_json::from_str(&content)
199 .map_err(|e| CdiError::Parse(format!("{}: {e}", path.display())))
200 } else {
201 serde_yaml::from_str(&content)
202 .map_err(|e| CdiError::Parse(format!("{}: {e}", path.display())))
203 }
204 }
205
206 #[must_use]
208 pub fn get_spec(&self, kind: &str) -> Option<&CdiSpec> {
209 self.specs.get(kind)
210 }
211
212 #[must_use]
219 pub fn resolve_device(&self, qualified_name: &str) -> Option<CdiContainerEdits> {
220 let (kind, device_name) = qualified_name.split_once('=')?;
221 let spec = self.specs.get(kind)?;
222 let device = spec.devices.iter().find(|d| d.name == device_name)?;
223
224 let mut merged = spec.container_edits.clone().unwrap_or_default();
226
227 if let Some(ref dev_edits) = device.container_edits {
228 merged.env.extend(dev_edits.env.clone());
229 merged.device_nodes.extend(dev_edits.device_nodes.clone());
230 merged.mounts.extend(dev_edits.mounts.clone());
231 }
232
233 Some(merged)
234 }
235
236 #[must_use]
238 pub fn is_empty(&self) -> bool {
239 self.specs.is_empty()
240 }
241
242 pub fn kinds(&self) -> impl Iterator<Item = &str> {
244 self.specs.keys().map(String::as_str)
245 }
246
247 pub async fn generate_nvidia_spec() -> Option<CdiSpec> {
252 let output = tokio::process::Command::new("nvidia-ctk")
253 .args(["cdi", "generate"])
254 .output()
255 .await
256 .ok()?;
257
258 if !output.status.success() {
259 let stderr = String::from_utf8_lossy(&output.stderr);
260 warn!("nvidia-ctk cdi generate failed: {stderr}");
261 return None;
262 }
263
264 let stdout = String::from_utf8_lossy(&output.stdout);
265 match serde_yaml::from_str(&stdout) {
266 Ok(spec) => {
267 info!("Generated NVIDIA CDI spec via nvidia-ctk");
268 Some(spec)
269 }
270 Err(e) => {
271 warn!("Failed to parse nvidia-ctk output: {e}");
272 None
273 }
274 }
275 }
276}
277
278#[derive(Debug, thiserror::Error)]
280pub enum CdiError {
281 #[error("CDI I/O error: {0}")]
283 Io(String),
284 #[error("CDI parse error: {0}")]
286 Parse(String),
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 fn sample_spec_json() -> &'static str {
294 r#"{
295 "cdiVersion": "0.6.0",
296 "kind": "nvidia.com/gpu",
297 "devices": [
298 {
299 "name": "0",
300 "containerEdits": {
301 "env": ["NVIDIA_VISIBLE_DEVICES=0"],
302 "deviceNodes": [
303 {
304 "path": "/dev/nvidia0",
305 "hostPath": "/dev/nvidia0",
306 "type": "c",
307 "major": 195,
308 "minor": 0
309 }
310 ]
311 }
312 },
313 {
314 "name": "all",
315 "containerEdits": {
316 "env": ["NVIDIA_VISIBLE_DEVICES=all"]
317 }
318 }
319 ],
320 "containerEdits": {
321 "env": ["NVIDIA_DRIVER_CAPABILITIES=all"],
322 "deviceNodes": [
323 {
324 "path": "/dev/nvidiactl",
325 "hostPath": "/dev/nvidiactl",
326 "type": "c"
327 }
328 ],
329 "mounts": [
330 {
331 "containerPath": "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
332 "hostPath": "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
333 "options": ["ro", "nosuid", "nodev", "bind"]
334 }
335 ]
336 }
337 }"#
338 }
339
340 #[test]
341 fn parse_cdi_spec_json() {
342 let spec: CdiSpec = serde_json::from_str(sample_spec_json()).unwrap();
343 assert_eq!(spec.cdi_version, "0.6.0");
344 assert_eq!(spec.kind, "nvidia.com/gpu");
345 assert_eq!(spec.devices.len(), 2);
346 assert_eq!(spec.devices[0].name, "0");
347
348 let global_edits = spec.container_edits.as_ref().unwrap();
349 assert_eq!(global_edits.env, vec!["NVIDIA_DRIVER_CAPABILITIES=all"]);
350 assert_eq!(global_edits.device_nodes.len(), 1);
351 assert_eq!(global_edits.mounts.len(), 1);
352 }
353
354 #[test]
355 fn resolve_device_merges_edits() {
356 let spec: CdiSpec = serde_json::from_str(sample_spec_json()).unwrap();
357 let mut registry = CdiRegistry::default();
358 registry.specs.insert(spec.kind.clone(), spec);
359
360 let edits = registry
361 .resolve_device("nvidia.com/gpu=0")
362 .expect("should resolve gpu 0");
363
364 assert!(edits
366 .env
367 .contains(&"NVIDIA_DRIVER_CAPABILITIES=all".to_string()));
368 assert!(edits.env.contains(&"NVIDIA_VISIBLE_DEVICES=0".to_string()));
369
370 assert_eq!(edits.device_nodes.len(), 2);
372
373 assert_eq!(edits.mounts.len(), 1);
375 }
376
377 #[test]
378 fn resolve_unknown_device_returns_none() {
379 let registry = CdiRegistry::default();
380 assert!(registry.resolve_device("nvidia.com/gpu=99").is_none());
381 }
382
383 #[test]
384 fn resolve_malformed_name_returns_none() {
385 let registry = CdiRegistry::default();
386 assert!(registry.resolve_device("no-equals-sign").is_none());
387 }
388
389 #[test]
390 fn empty_registry() {
391 let registry = CdiRegistry::default();
392 assert!(registry.is_empty());
393 assert_eq!(registry.kinds().count(), 0);
394 }
395
396 #[test]
397 fn parse_cdi_spec_yaml() {
398 let yaml = r#"
399cdiVersion: "0.6.0"
400kind: "vendor.com/net"
401devices:
402 - name: "eth0"
403 containerEdits:
404 env:
405 - "NET_DEVICE=eth0"
406"#;
407 let spec: CdiSpec = serde_yaml::from_str(yaml).unwrap();
408 assert_eq!(spec.kind, "vendor.com/net");
409 assert_eq!(spec.devices.len(), 1);
410 assert_eq!(spec.devices[0].name, "eth0");
411 }
412}