1use std::collections::HashMap;
11use std::path::{Path, PathBuf};
12
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info, warn};
15
16const CDI_SPEC_DIRS: &[&str] = &["/etc/cdi", "/var/run/cdi"];
18
19pub const CDI_SPEC_DIRS_ENV: &str = "CDI_SPEC_DIRS";
25
26#[must_use]
32pub fn vendor_to_cdi_kind(vendor: &str) -> Option<&'static str> {
33 match vendor {
34 "nvidia" => Some("nvidia.com/gpu"),
35 "amd" => Some("amd.com/gpu"),
36 "intel" => Some("intel.com/gpu"),
37 _ => None,
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43#[serde(rename_all = "camelCase")]
44pub struct CdiSpec {
45 pub cdi_version: String,
47 pub kind: String,
49 #[serde(default)]
51 pub devices: Vec<CdiDevice>,
52 #[serde(default)]
54 pub container_edits: Option<CdiContainerEdits>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
59#[serde(rename_all = "camelCase")]
60pub struct CdiDevice {
61 pub name: String,
63 #[serde(default)]
65 pub container_edits: Option<CdiContainerEdits>,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize, Default)]
70#[serde(rename_all = "camelCase")]
71pub struct CdiContainerEdits {
72 #[serde(default)]
74 pub env: Vec<String>,
75 #[serde(default)]
77 pub device_nodes: Vec<CdiDeviceNode>,
78 #[serde(default)]
80 pub mounts: Vec<CdiMount>,
81 #[serde(default)]
83 pub hooks: Option<CdiHooks>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
88#[serde(rename_all = "camelCase")]
89pub struct CdiDeviceNode {
90 pub path: String,
92 pub host_path: Option<String>,
94 #[serde(rename = "type")]
96 pub device_type: Option<String>,
97 pub major: Option<i64>,
99 pub minor: Option<i64>,
101 #[serde(default)]
103 pub file_mode: Option<u32>,
104 pub uid: Option<u32>,
106 pub gid: Option<u32>,
108 pub permissions: Option<String>,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114#[serde(rename_all = "camelCase")]
115pub struct CdiMount {
116 pub container_path: String,
118 pub host_path: String,
120 #[serde(default)]
122 pub options: Vec<String>,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize, Default)]
127#[serde(rename_all = "camelCase")]
128pub struct CdiHooks {
129 #[serde(default)]
130 pub prestart: Vec<CdiHook>,
131 #[serde(default)]
132 pub create_runtime: Vec<CdiHook>,
133 #[serde(default)]
134 pub create_container: Vec<CdiHook>,
135 #[serde(default)]
136 pub start_container: Vec<CdiHook>,
137 #[serde(default)]
138 pub poststart: Vec<CdiHook>,
139 #[serde(default)]
140 pub poststop: Vec<CdiHook>,
141}
142
143#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct CdiHook {
146 pub path: String,
147 #[serde(default)]
148 pub args: Vec<String>,
149 #[serde(default)]
150 pub env: Vec<String>,
151}
152
153#[derive(Debug, Default)]
158pub struct CdiRegistry {
159 specs: HashMap<String, CdiSpec>,
161}
162
163impl CdiRegistry {
164 pub fn discover() -> Self {
170 let mut dirs: Vec<PathBuf> = CDI_SPEC_DIRS.iter().map(PathBuf::from).collect();
171 if let Ok(env_dirs) = std::env::var(CDI_SPEC_DIRS_ENV) {
172 for entry in std::env::split_paths(&env_dirs) {
173 if !entry.as_os_str().is_empty() {
174 dirs.push(entry);
175 }
176 }
177 }
178 Self::discover_from(&dirs)
179 }
180
181 pub fn discover_from<P: AsRef<Path>>(dirs: &[P]) -> Self {
186 let mut registry = Self::default();
187
188 for dir in dirs {
189 let dir_path = dir.as_ref();
190 if !dir_path.is_dir() {
191 debug!(dir = %dir_path.display(), "CDI spec directory does not exist, skipping");
192 continue;
193 }
194
195 let entries = match std::fs::read_dir(dir_path) {
196 Ok(e) => e,
197 Err(e) => {
198 warn!(dir = %dir_path.display(), error = %e, "Failed to read CDI spec directory");
199 continue;
200 }
201 };
202
203 for entry in entries.flatten() {
204 let path = entry.path();
205 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
206 if ext != "json" && ext != "yaml" && ext != "yml" {
207 continue;
208 }
209
210 match Self::load_spec(&path) {
211 Ok(spec) => {
212 info!(
213 kind = %spec.kind,
214 devices = spec.devices.len(),
215 path = %path.display(),
216 "Loaded CDI spec"
217 );
218 registry.specs.insert(spec.kind.clone(), spec);
219 }
220 Err(e) => {
221 warn!(path = %path.display(), error = %e, "Failed to parse CDI spec");
222 }
223 }
224 }
225 }
226
227 registry
228 }
229
230 fn load_spec(path: &Path) -> Result<CdiSpec, CdiError> {
232 let content = std::fs::read_to_string(path)
233 .map_err(|e| CdiError::Io(format!("{}: {e}", path.display())))?;
234
235 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
236 if ext == "json" {
237 serde_json::from_str(&content)
238 .map_err(|e| CdiError::Parse(format!("{}: {e}", path.display())))
239 } else {
240 serde_yaml::from_str(&content)
241 .map_err(|e| CdiError::Parse(format!("{}: {e}", path.display())))
242 }
243 }
244
245 #[must_use]
247 pub fn get_spec(&self, kind: &str) -> Option<&CdiSpec> {
248 self.specs.get(kind)
249 }
250
251 #[must_use]
258 pub fn resolve_device(&self, qualified_name: &str) -> Option<CdiContainerEdits> {
259 let (kind, device_name) = qualified_name.split_once('=')?;
260 let spec = self.specs.get(kind)?;
261 let device = spec.devices.iter().find(|d| d.name == device_name)?;
262
263 let mut merged = spec.container_edits.clone().unwrap_or_default();
265
266 if let Some(ref dev_edits) = device.container_edits {
267 merged.env.extend(dev_edits.env.clone());
268 merged.device_nodes.extend(dev_edits.device_nodes.clone());
269 merged.mounts.extend(dev_edits.mounts.clone());
270 if let Some(ref dev_hooks) = dev_edits.hooks {
271 let merged_hooks = merged.hooks.get_or_insert_with(CdiHooks::default);
272 merged_hooks.prestart.extend(dev_hooks.prestart.clone());
273 merged_hooks
274 .create_runtime
275 .extend(dev_hooks.create_runtime.clone());
276 merged_hooks
277 .create_container
278 .extend(dev_hooks.create_container.clone());
279 merged_hooks
280 .start_container
281 .extend(dev_hooks.start_container.clone());
282 merged_hooks.poststart.extend(dev_hooks.poststart.clone());
283 merged_hooks.poststop.extend(dev_hooks.poststop.clone());
284 }
285 }
286
287 Some(merged)
288 }
289
290 pub fn resolve_for_kind(
306 &self,
307 kind: &str,
308 device_names: &[String],
309 ) -> std::result::Result<Vec<CdiContainerEdits>, CdiError> {
310 let spec = self
311 .specs
312 .get(kind)
313 .ok_or_else(|| CdiError::SpecMissing(kind.to_string()))?;
314
315 let expanded: Vec<String> = if device_names.iter().any(|n| n == "all") {
319 let names: Vec<String> = spec
320 .devices
321 .iter()
322 .filter(|d| d.name != "all")
323 .map(|d| d.name.clone())
324 .collect();
325 if names.is_empty() {
326 return Err(CdiError::NoDevices(kind.to_string()));
327 }
328 names
329 } else {
330 device_names.to_vec()
331 };
332
333 let mut out = Vec::with_capacity(expanded.len());
334 for name in &expanded {
335 let qualified = format!("{kind}={name}");
336 let edits = self
337 .resolve_device(&qualified)
338 .ok_or_else(|| CdiError::DeviceMissing {
339 kind: kind.to_string(),
340 device: name.clone(),
341 })?;
342 out.push(edits);
343 }
344 Ok(out)
345 }
346
347 #[must_use]
349 pub fn is_empty(&self) -> bool {
350 self.specs.is_empty()
351 }
352
353 pub fn kinds(&self) -> impl Iterator<Item = &str> {
355 self.specs.keys().map(String::as_str)
356 }
357
358 pub async fn generate_nvidia_spec() -> Option<CdiSpec> {
363 let output = tokio::process::Command::new("nvidia-ctk")
364 .args(["cdi", "generate"])
365 .output()
366 .await
367 .ok()?;
368
369 if !output.status.success() {
370 let stderr = String::from_utf8_lossy(&output.stderr);
371 warn!("nvidia-ctk cdi generate failed: {stderr}");
372 return None;
373 }
374
375 let stdout = String::from_utf8_lossy(&output.stdout);
376 match serde_yaml::from_str(&stdout) {
377 Ok(spec) => {
378 info!("Generated NVIDIA CDI spec via nvidia-ctk");
379 Some(spec)
380 }
381 Err(e) => {
382 warn!("Failed to parse nvidia-ctk output: {e}");
383 None
384 }
385 }
386 }
387}
388
389#[derive(Debug, thiserror::Error)]
391pub enum CdiError {
392 #[error("CDI I/O error: {0}")]
394 Io(String),
395 #[error("CDI parse error: {0}")]
397 Parse(String),
398 #[error("no CDI spec installed for kind '{0}' (run the vendor's CDI generator)")]
403 SpecMissing(String),
404 #[error("CDI device '{device}' not declared in spec for kind '{kind}'")]
406 DeviceMissing {
407 kind: String,
409 device: String,
411 },
412 #[error("CDI spec for kind '{0}' declares no devices (host has no compatible hardware)")]
415 NoDevices(String),
416}
417
418#[cfg(test)]
419mod tests {
420 use super::*;
421
422 fn sample_spec_json() -> &'static str {
423 r#"{
424 "cdiVersion": "0.6.0",
425 "kind": "nvidia.com/gpu",
426 "devices": [
427 {
428 "name": "0",
429 "containerEdits": {
430 "env": ["NVIDIA_VISIBLE_DEVICES=0"],
431 "deviceNodes": [
432 {
433 "path": "/dev/nvidia0",
434 "hostPath": "/dev/nvidia0",
435 "type": "c",
436 "major": 195,
437 "minor": 0
438 }
439 ]
440 }
441 },
442 {
443 "name": "all",
444 "containerEdits": {
445 "env": ["NVIDIA_VISIBLE_DEVICES=all"]
446 }
447 }
448 ],
449 "containerEdits": {
450 "env": ["NVIDIA_DRIVER_CAPABILITIES=all"],
451 "deviceNodes": [
452 {
453 "path": "/dev/nvidiactl",
454 "hostPath": "/dev/nvidiactl",
455 "type": "c"
456 }
457 ],
458 "mounts": [
459 {
460 "containerPath": "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
461 "hostPath": "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
462 "options": ["ro", "nosuid", "nodev", "bind"]
463 }
464 ]
465 }
466 }"#
467 }
468
469 #[test]
470 fn parse_cdi_spec_json() {
471 let spec: CdiSpec = serde_json::from_str(sample_spec_json()).unwrap();
472 assert_eq!(spec.cdi_version, "0.6.0");
473 assert_eq!(spec.kind, "nvidia.com/gpu");
474 assert_eq!(spec.devices.len(), 2);
475 assert_eq!(spec.devices[0].name, "0");
476
477 let global_edits = spec.container_edits.as_ref().unwrap();
478 assert_eq!(global_edits.env, vec!["NVIDIA_DRIVER_CAPABILITIES=all"]);
479 assert_eq!(global_edits.device_nodes.len(), 1);
480 assert_eq!(global_edits.mounts.len(), 1);
481 }
482
483 #[test]
484 fn resolve_device_merges_edits() {
485 let spec: CdiSpec = serde_json::from_str(sample_spec_json()).unwrap();
486 let mut registry = CdiRegistry::default();
487 registry.specs.insert(spec.kind.clone(), spec);
488
489 let edits = registry
490 .resolve_device("nvidia.com/gpu=0")
491 .expect("should resolve gpu 0");
492
493 assert!(edits
495 .env
496 .contains(&"NVIDIA_DRIVER_CAPABILITIES=all".to_string()));
497 assert!(edits.env.contains(&"NVIDIA_VISIBLE_DEVICES=0".to_string()));
498
499 assert_eq!(edits.device_nodes.len(), 2);
501
502 assert_eq!(edits.mounts.len(), 1);
504 }
505
506 #[test]
507 fn resolve_unknown_device_returns_none() {
508 let registry = CdiRegistry::default();
509 assert!(registry.resolve_device("nvidia.com/gpu=99").is_none());
510 }
511
512 #[test]
513 fn resolve_malformed_name_returns_none() {
514 let registry = CdiRegistry::default();
515 assert!(registry.resolve_device("no-equals-sign").is_none());
516 }
517
518 #[test]
519 fn empty_registry() {
520 let registry = CdiRegistry::default();
521 assert!(registry.is_empty());
522 assert_eq!(registry.kinds().count(), 0);
523 }
524
525 #[test]
526 fn parse_cdi_spec_yaml() {
527 let yaml = r#"
528cdiVersion: "0.6.0"
529kind: "vendor.com/net"
530devices:
531 - name: "eth0"
532 containerEdits:
533 env:
534 - "NET_DEVICE=eth0"
535"#;
536 let spec: CdiSpec = serde_yaml::from_str(yaml).unwrap();
537 assert_eq!(spec.kind, "vendor.com/net");
538 assert_eq!(spec.devices.len(), 1);
539 assert_eq!(spec.devices[0].name, "eth0");
540 }
541
542 fn fixture_spec_with_hooks() -> &'static str {
543 r#"{
544 "cdiVersion": "0.6.0",
545 "kind": "nvidia.com/gpu",
546 "devices": [
547 {
548 "name": "0",
549 "containerEdits": {
550 "env": ["NVIDIA_VISIBLE_DEVICES=0"],
551 "deviceNodes": [
552 {"path": "/dev/nvidia0", "type": "c", "major": 195, "minor": 0}
553 ],
554 "hooks": {
555 "createContainer": [{
556 "path": "/usr/bin/nvidia-container-runtime-hook",
557 "args": ["nvidia-container-runtime-hook", "prestart"]
558 }]
559 }
560 }
561 },
562 {
563 "name": "1",
564 "containerEdits": {
565 "env": ["NVIDIA_VISIBLE_DEVICES=1"],
566 "deviceNodes": [
567 {"path": "/dev/nvidia1", "type": "c", "major": 195, "minor": 1}
568 ]
569 }
570 }
571 ]
572 }"#
573 }
574
575 fn registry_with_fixture_dir() -> (tempfile::TempDir, CdiRegistry) {
576 let dir = tempfile::tempdir().unwrap();
577 let path = dir.path().join("nvidia.json");
578 std::fs::write(&path, fixture_spec_with_hooks()).unwrap();
579 let registry = CdiRegistry::discover_from(&[dir.path()]);
580 (dir, registry)
581 }
582
583 #[test]
584 fn discover_from_loads_specs() {
585 let (_keep, registry) = registry_with_fixture_dir();
586 assert_eq!(registry.kinds().count(), 1);
587 assert!(registry.get_spec("nvidia.com/gpu").is_some());
588 }
589
590 #[test]
591 fn discover_from_empty_dir_is_empty() {
592 let dir = tempfile::tempdir().unwrap();
593 let registry = CdiRegistry::discover_from(&[dir.path()]);
594 assert!(registry.is_empty());
595 }
596
597 #[test]
598 fn resolve_for_kind_returns_edits_per_device() {
599 let (_keep, registry) = registry_with_fixture_dir();
600 let edits = registry
601 .resolve_for_kind("nvidia.com/gpu", &["0".to_string()])
602 .expect("resolve gpu 0");
603 assert_eq!(edits.len(), 1);
604 assert!(edits[0].env.iter().any(|e| e == "NVIDIA_VISIBLE_DEVICES=0"));
605 assert!(edits[0]
606 .device_nodes
607 .iter()
608 .any(|d| d.path == "/dev/nvidia0"));
609 let hooks = edits[0].hooks.as_ref().expect("hooks merged");
610 assert_eq!(hooks.create_container.len(), 1);
611 }
612
613 #[test]
614 fn resolve_for_kind_all_expands_to_every_device() {
615 let (_keep, registry) = registry_with_fixture_dir();
616 let edits = registry
617 .resolve_for_kind("nvidia.com/gpu", &["all".to_string()])
618 .expect("resolve all");
619 assert_eq!(edits.len(), 2, "should expand to both '0' and '1'");
620 let names: Vec<&str> = edits
621 .iter()
622 .flat_map(|e| e.env.iter())
623 .filter(|s| s.starts_with("NVIDIA_VISIBLE_DEVICES="))
624 .map(String::as_str)
625 .collect();
626 assert!(names.contains(&"NVIDIA_VISIBLE_DEVICES=0"));
627 assert!(names.contains(&"NVIDIA_VISIBLE_DEVICES=1"));
628 }
629
630 #[test]
631 fn resolve_for_kind_missing_spec_errors() {
632 let registry = CdiRegistry::default();
633 let err = registry
634 .resolve_for_kind("nvidia.com/gpu", &["0".to_string()])
635 .unwrap_err();
636 assert!(matches!(err, CdiError::SpecMissing(ref k) if k == "nvidia.com/gpu"));
637 }
638
639 #[test]
640 fn resolve_for_kind_unknown_device_errors() {
641 let (_keep, registry) = registry_with_fixture_dir();
642 let err = registry
643 .resolve_for_kind("nvidia.com/gpu", &["99".to_string()])
644 .unwrap_err();
645 assert!(matches!(
646 err,
647 CdiError::DeviceMissing { ref device, .. } if device == "99"
648 ));
649 }
650
651 #[test]
652 fn vendor_to_cdi_kind_maps_known_vendors() {
653 assert_eq!(vendor_to_cdi_kind("nvidia"), Some("nvidia.com/gpu"));
654 assert_eq!(vendor_to_cdi_kind("amd"), Some("amd.com/gpu"));
655 assert_eq!(vendor_to_cdi_kind("intel"), Some("intel.com/gpu"));
656 assert_eq!(vendor_to_cdi_kind("apple"), None);
657 }
658}