Skip to main content

zlayer_agent/
cdi.rs

1//! Container Device Interface (CDI) support.
2//!
3//! CDI is a vendor-neutral mechanism for declaring and injecting devices
4//! into OCI containers. `ZLayer` discovers CDI specs from standard locations
5//! (`/etc/cdi/`, `/var/run/cdi/`) and applies them to container specs
6//! as an alternative to manual device passthrough.
7//!
8//! See: <https://github.com/cncf-tags/container-device-interface>
9
10use std::collections::HashMap;
11use std::path::Path;
12
13use serde::{Deserialize, Serialize};
14use tracing::{debug, info, warn};
15
16/// Standard CDI spec discovery directories
17const CDI_SPEC_DIRS: &[&str] = &["/etc/cdi", "/var/run/cdi"];
18
19/// A parsed CDI specification file
20#[derive(Debug, Clone, Serialize, Deserialize)]
21#[serde(rename_all = "camelCase")]
22pub struct CdiSpec {
23    /// CDI spec version (e.g. "0.6.0")
24    pub cdi_version: String,
25    /// Device vendor and class (e.g. "nvidia.com/gpu")
26    pub kind: String,
27    /// Devices declared by this spec
28    #[serde(default)]
29    pub devices: Vec<CdiDevice>,
30    /// Container edits applied to all devices of this kind
31    #[serde(default)]
32    pub container_edits: Option<CdiContainerEdits>,
33}
34
35/// A device within a CDI spec
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(rename_all = "camelCase")]
38pub struct CdiDevice {
39    /// Device name (e.g. "0" for GPU 0)
40    pub name: String,
41    /// Container edits specific to this device
42    #[serde(default)]
43    pub container_edits: Option<CdiContainerEdits>,
44}
45
46/// Modifications to apply to the OCI container spec
47#[derive(Debug, Clone, Serialize, Deserialize, Default)]
48#[serde(rename_all = "camelCase")]
49pub struct CdiContainerEdits {
50    /// Environment variables to add
51    #[serde(default)]
52    pub env: Vec<String>,
53    /// Device nodes to create in the container
54    #[serde(default)]
55    pub device_nodes: Vec<CdiDeviceNode>,
56    /// Mounts to add
57    #[serde(default)]
58    pub mounts: Vec<CdiMount>,
59    /// Hooks to run
60    #[serde(default)]
61    pub hooks: Option<CdiHooks>,
62}
63
64/// A device node to create in the container
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct CdiDeviceNode {
68    /// Path inside the container
69    pub path: String,
70    /// Host path (defaults to container path)
71    pub host_path: Option<String>,
72    /// Device type: "b" (block) or "c" (char)
73    #[serde(rename = "type")]
74    pub device_type: Option<String>,
75    /// Major device number
76    pub major: Option<i64>,
77    /// Minor device number
78    pub minor: Option<i64>,
79    /// File mode (e.g. 0o666)
80    #[serde(default)]
81    pub file_mode: Option<u32>,
82    /// Owner UID
83    pub uid: Option<u32>,
84    /// Owner GID
85    pub gid: Option<u32>,
86    /// Device permissions ("rwm")
87    pub permissions: Option<String>,
88}
89
90/// A mount to add to the container
91#[derive(Debug, Clone, Serialize, Deserialize)]
92#[serde(rename_all = "camelCase")]
93pub struct CdiMount {
94    /// Container path
95    pub container_path: String,
96    /// Host path
97    pub host_path: String,
98    /// Mount options
99    #[serde(default)]
100    pub options: Vec<String>,
101}
102
103/// OCI lifecycle hooks
104#[derive(Debug, Clone, Serialize, Deserialize, Default)]
105#[serde(rename_all = "camelCase")]
106pub struct CdiHooks {
107    #[serde(default)]
108    pub prestart: Vec<CdiHook>,
109    #[serde(default)]
110    pub create_runtime: Vec<CdiHook>,
111    #[serde(default)]
112    pub create_container: Vec<CdiHook>,
113    #[serde(default)]
114    pub start_container: Vec<CdiHook>,
115    #[serde(default)]
116    pub poststart: Vec<CdiHook>,
117    #[serde(default)]
118    pub poststop: Vec<CdiHook>,
119}
120
121/// A single OCI hook
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct CdiHook {
124    pub path: String,
125    #[serde(default)]
126    pub args: Vec<String>,
127    #[serde(default)]
128    pub env: Vec<String>,
129}
130
131/// Registry of all discovered CDI specs, indexed by fully-qualified device name.
132///
133/// A fully-qualified CDI device name has the format `vendor.com/class=device`,
134/// e.g. `nvidia.com/gpu=0`.
135#[derive(Debug, Default)]
136pub struct CdiRegistry {
137    /// All discovered specs, keyed by kind (e.g. "nvidia.com/gpu")
138    specs: HashMap<String, CdiSpec>,
139}
140
141impl CdiRegistry {
142    /// Discover and load CDI specs from the standard directories.
143    ///
144    /// Scans `/etc/cdi/` and `/var/run/cdi/` for `*.json` and `*.yaml` files,
145    /// parses them, and indexes them by kind.
146    pub fn discover() -> Self {
147        let mut registry = Self::default();
148
149        for dir in CDI_SPEC_DIRS {
150            let dir_path = Path::new(dir);
151            if !dir_path.is_dir() {
152                debug!(dir = %dir, "CDI spec directory does not exist, skipping");
153                continue;
154            }
155
156            let entries = match std::fs::read_dir(dir_path) {
157                Ok(e) => e,
158                Err(e) => {
159                    warn!(dir = %dir, error = %e, "Failed to read CDI spec directory");
160                    continue;
161                }
162            };
163
164            for entry in entries.flatten() {
165                let path = entry.path();
166                let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
167                if ext != "json" && ext != "yaml" && ext != "yml" {
168                    continue;
169                }
170
171                match Self::load_spec(&path) {
172                    Ok(spec) => {
173                        info!(
174                            kind = %spec.kind,
175                            devices = spec.devices.len(),
176                            path = %path.display(),
177                            "Loaded CDI spec"
178                        );
179                        registry.specs.insert(spec.kind.clone(), spec);
180                    }
181                    Err(e) => {
182                        warn!(path = %path.display(), error = %e, "Failed to parse CDI spec");
183                    }
184                }
185            }
186        }
187
188        registry
189    }
190
191    /// Load a single CDI spec file.
192    fn load_spec(path: &Path) -> Result<CdiSpec, CdiError> {
193        let content = std::fs::read_to_string(path)
194            .map_err(|e| CdiError::Io(format!("{}: {e}", path.display())))?;
195
196        let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
197        if ext == "json" {
198            serde_json::from_str(&content)
199                .map_err(|e| CdiError::Parse(format!("{}: {e}", path.display())))
200        } else {
201            serde_yaml::from_str(&content)
202                .map_err(|e| CdiError::Parse(format!("{}: {e}", path.display())))
203        }
204    }
205
206    /// Look up a CDI spec by kind (e.g. "nvidia.com/gpu").
207    #[must_use]
208    pub fn get_spec(&self, kind: &str) -> Option<&CdiSpec> {
209        self.specs.get(kind)
210    }
211
212    /// Resolve a fully-qualified CDI device name to its container edits.
213    ///
214    /// Format: `vendor.com/class=device` (e.g. `nvidia.com/gpu=0`)
215    ///
216    /// Returns the merged container edits (global + device-specific) or None
217    /// if the device is not found.
218    #[must_use]
219    pub fn resolve_device(&self, qualified_name: &str) -> Option<CdiContainerEdits> {
220        let (kind, device_name) = qualified_name.split_once('=')?;
221        let spec = self.specs.get(kind)?;
222        let device = spec.devices.iter().find(|d| d.name == device_name)?;
223
224        // Merge global and device-specific edits
225        let mut merged = spec.container_edits.clone().unwrap_or_default();
226
227        if let Some(ref dev_edits) = device.container_edits {
228            merged.env.extend(dev_edits.env.clone());
229            merged.device_nodes.extend(dev_edits.device_nodes.clone());
230            merged.mounts.extend(dev_edits.mounts.clone());
231        }
232
233        Some(merged)
234    }
235
236    /// Check if any CDI specs are available.
237    #[must_use]
238    pub fn is_empty(&self) -> bool {
239        self.specs.is_empty()
240    }
241
242    /// Get all available kinds.
243    pub fn kinds(&self) -> impl Iterator<Item = &str> {
244        self.specs.keys().map(String::as_str)
245    }
246
247    /// Generate a CDI spec for NVIDIA GPUs using nvidia-ctk.
248    ///
249    /// Runs `nvidia-ctk cdi generate` and returns the resulting spec,
250    /// or None if nvidia-ctk is not available.
251    pub async fn generate_nvidia_spec() -> Option<CdiSpec> {
252        let output = tokio::process::Command::new("nvidia-ctk")
253            .args(["cdi", "generate"])
254            .output()
255            .await
256            .ok()?;
257
258        if !output.status.success() {
259            let stderr = String::from_utf8_lossy(&output.stderr);
260            warn!("nvidia-ctk cdi generate failed: {stderr}");
261            return None;
262        }
263
264        let stdout = String::from_utf8_lossy(&output.stdout);
265        match serde_yaml::from_str(&stdout) {
266            Ok(spec) => {
267                info!("Generated NVIDIA CDI spec via nvidia-ctk");
268                Some(spec)
269            }
270            Err(e) => {
271                warn!("Failed to parse nvidia-ctk output: {e}");
272                None
273            }
274        }
275    }
276}
277
278/// Errors from CDI operations
279#[derive(Debug, thiserror::Error)]
280pub enum CdiError {
281    /// I/O error reading a CDI spec file
282    #[error("CDI I/O error: {0}")]
283    Io(String),
284    /// Failed to parse a CDI spec file
285    #[error("CDI parse error: {0}")]
286    Parse(String),
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    fn sample_spec_json() -> &'static str {
294        r#"{
295            "cdiVersion": "0.6.0",
296            "kind": "nvidia.com/gpu",
297            "devices": [
298                {
299                    "name": "0",
300                    "containerEdits": {
301                        "env": ["NVIDIA_VISIBLE_DEVICES=0"],
302                        "deviceNodes": [
303                            {
304                                "path": "/dev/nvidia0",
305                                "hostPath": "/dev/nvidia0",
306                                "type": "c",
307                                "major": 195,
308                                "minor": 0
309                            }
310                        ]
311                    }
312                },
313                {
314                    "name": "all",
315                    "containerEdits": {
316                        "env": ["NVIDIA_VISIBLE_DEVICES=all"]
317                    }
318                }
319            ],
320            "containerEdits": {
321                "env": ["NVIDIA_DRIVER_CAPABILITIES=all"],
322                "deviceNodes": [
323                    {
324                        "path": "/dev/nvidiactl",
325                        "hostPath": "/dev/nvidiactl",
326                        "type": "c"
327                    }
328                ],
329                "mounts": [
330                    {
331                        "containerPath": "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
332                        "hostPath": "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
333                        "options": ["ro", "nosuid", "nodev", "bind"]
334                    }
335                ]
336            }
337        }"#
338    }
339
340    #[test]
341    fn parse_cdi_spec_json() {
342        let spec: CdiSpec = serde_json::from_str(sample_spec_json()).unwrap();
343        assert_eq!(spec.cdi_version, "0.6.0");
344        assert_eq!(spec.kind, "nvidia.com/gpu");
345        assert_eq!(spec.devices.len(), 2);
346        assert_eq!(spec.devices[0].name, "0");
347
348        let global_edits = spec.container_edits.as_ref().unwrap();
349        assert_eq!(global_edits.env, vec!["NVIDIA_DRIVER_CAPABILITIES=all"]);
350        assert_eq!(global_edits.device_nodes.len(), 1);
351        assert_eq!(global_edits.mounts.len(), 1);
352    }
353
354    #[test]
355    fn resolve_device_merges_edits() {
356        let spec: CdiSpec = serde_json::from_str(sample_spec_json()).unwrap();
357        let mut registry = CdiRegistry::default();
358        registry.specs.insert(spec.kind.clone(), spec);
359
360        let edits = registry
361            .resolve_device("nvidia.com/gpu=0")
362            .expect("should resolve gpu 0");
363
364        // Global env + device env
365        assert!(edits
366            .env
367            .contains(&"NVIDIA_DRIVER_CAPABILITIES=all".to_string()));
368        assert!(edits.env.contains(&"NVIDIA_VISIBLE_DEVICES=0".to_string()));
369
370        // Global device node + device-specific device node
371        assert_eq!(edits.device_nodes.len(), 2);
372
373        // Global mount preserved
374        assert_eq!(edits.mounts.len(), 1);
375    }
376
377    #[test]
378    fn resolve_unknown_device_returns_none() {
379        let registry = CdiRegistry::default();
380        assert!(registry.resolve_device("nvidia.com/gpu=99").is_none());
381    }
382
383    #[test]
384    fn resolve_malformed_name_returns_none() {
385        let registry = CdiRegistry::default();
386        assert!(registry.resolve_device("no-equals-sign").is_none());
387    }
388
389    #[test]
390    fn empty_registry() {
391        let registry = CdiRegistry::default();
392        assert!(registry.is_empty());
393        assert_eq!(registry.kinds().count(), 0);
394    }
395
396    #[test]
397    fn parse_cdi_spec_yaml() {
398        let yaml = r#"
399cdiVersion: "0.6.0"
400kind: "vendor.com/net"
401devices:
402  - name: "eth0"
403    containerEdits:
404      env:
405        - "NET_DEVICE=eth0"
406"#;
407        let spec: CdiSpec = serde_yaml::from_str(yaml).unwrap();
408        assert_eq!(spec.kind, "vendor.com/net");
409        assert_eq!(spec.devices.len(), 1);
410        assert_eq!(spec.devices[0].name, "eth0");
411    }
412}