container_device_interface/
cache.rs

1use std::{
2    cell::RefCell,
3    collections::{HashMap, HashSet},
4    error::Error,
5    fmt,
6    sync::{Arc, Mutex},
7};
8
9use anyhow::Result;
10
11use oci_spec::runtime as oci;
12
13use crate::{
14    //watch::Watch,
15    container_edits::ContainerEdits,
16    device::Device,
17    spec::Spec,
18    spec_dirs::{convert_errors, scan_spec_dirs, with_spec_dirs, SpecError, DEFAULT_SPEC_DIRS},
19};
20
21// Define custom errors if not already defined
22#[derive(Debug)]
23struct ConflictError {
24    name: String,
25    dev_path: String,
26    old_path: String,
27}
28
29impl ConflictError {
30    fn new(name: &str, dev_path: &str, old_path: &str) -> Self {
31        Self {
32            name: name.to_owned(),
33            dev_path: dev_path.to_owned(),
34            old_path: old_path.to_owned(),
35        }
36    }
37}
38
39impl fmt::Display for ConflictError {
40    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41        write!(
42            f,
43            "conflicting device {} (specs {}, {})",
44            self.name, self.dev_path, self.old_path
45        )
46    }
47}
48
49impl Error for ConflictError {}
50
51// CdiOption is an option to change some aspect of default CDI behavior.
52// We define the CdiOption type using a type alias, which is a Box<dyn FnOnce(&mut Cache)>.
53// This means that CdiOption is a trait object that represents a one-time closure that takes a &mut Cache parameter.
54pub type CdiOption = Box<dyn FnOnce(&mut Cache)>;
55
56// with_auto_refresh returns an option to control automatic Cache refresh.
57// By default auto-refresh is enabled, the list of Spec directories are
58// monitored and the Cache is automatically refreshed whenever a change
59// is detected. This option can be used to disable this behavior when a
60// manually refreshed mode is preferable.
61pub fn with_auto_refresh(auto_refresh: bool) -> CdiOption {
62    Box::new(move |c: &mut Cache| {
63        c.auto_refresh = auto_refresh;
64    })
65}
66
67#[allow(dead_code)]
68#[derive(Default)]
69pub struct Cache {
70    pub spec_dirs: Vec<String>,
71    pub specs: HashMap<String, Vec<Spec>>,
72    pub devices: HashMap<String, Device>,
73    pub errors: HashMap<String, Vec<Box<dyn std::error::Error + Send + Sync + 'static>>>,
74    pub dir_errors: HashMap<String, Box<dyn std::error::Error + Send + Sync + 'static>>,
75
76    pub auto_refresh: bool,
77    //watch: Watch,
78}
79
80pub fn new_cache(options: Vec<CdiOption>) -> Arc<Mutex<Cache>> {
81    let cache = Arc::new(Mutex::new(Cache::default()));
82
83    {
84        let mut c = cache.lock().unwrap();
85
86        with_spec_dirs(&DEFAULT_SPEC_DIRS)(&mut c);
87        c.configure(options);
88        let _ = c.refresh();
89    } // MutexGuard `c` is dropped here
90
91    cache
92}
93
94impl Cache {
95    pub fn new(
96        spec_dirs: Vec<String>,
97        specs: HashMap<String, Vec<Spec>>,
98        devices: HashMap<String, Device>,
99    ) -> Self {
100        Self {
101            spec_dirs,
102            specs,
103            devices,
104            errors: HashMap::new(),
105            dir_errors: HashMap::new(),
106            auto_refresh: false,
107            //watch: Watch::new(),
108        }
109    }
110
111    pub fn configure(&mut self, options: Vec<CdiOption>) {
112        for option in options {
113            option(self);
114        }
115    }
116
117    pub fn get_device(&mut self, dev_name: &str) -> Option<&Device> {
118        let _ = self.refresh_if_required(false);
119
120        self.devices.get(dev_name)
121    }
122
123    pub fn list_devices(&mut self) -> Vec<String> {
124        let _ = self.refresh_if_required(false);
125
126        let mut devices: Vec<String> = self.devices.keys().cloned().collect();
127        devices.sort();
128        devices
129    }
130
131    pub fn list_vendors(&mut self) -> Vec<String> {
132        let mut vendors: Vec<String> = Vec::new();
133
134        let _ = self.refresh_if_required(false);
135
136        for vendor in self.specs.keys() {
137            vendors.push(vendor.clone());
138        }
139        vendors.sort();
140        vendors
141    }
142
143    pub fn get_vendor_specs(&mut self, vendor: &str) -> Vec<Spec> {
144        let _ = self.refresh_if_required(false);
145
146        match self.specs.get(vendor) {
147            Some(specs) => specs.clone(),
148            None => Vec::new(),
149        }
150    }
151
152    // refresh the Cache by rescanning CDI Spec directories and files.
153    pub fn refresh(&mut self) -> Result<(), Box<dyn Error>> {
154        let mut specs: HashMap<String, Vec<Spec>> = HashMap::new();
155        let mut devices: HashMap<String, Device> = HashMap::new();
156        let mut conflicts: HashSet<String> = HashSet::new();
157        let mut spec_errors: HashMap<String, Vec<Box<dyn Error>>> = HashMap::new();
158
159        // Wrap collect_error and resolve_conflict in RefCell
160        let collect_error = RefCell::new(|err: Box<dyn Error>, paths: Vec<String>| {
161            let err_string = err.to_string();
162            for path in paths {
163                spec_errors
164                    .entry(path.to_string())
165                    .or_default()
166                    .push(Box::new(SpecError::new(&err_string.to_string())));
167            }
168        });
169
170        let resolve_conflict = RefCell::new(|name: &str, dev: &Device, old: &Device| -> bool {
171            let dev_spec = dev.get_spec();
172            let old_spec = old.get_spec();
173            let dev_prio = dev_spec.get_priority();
174            let old_prio = old_spec.get_priority();
175
176            match dev_prio.cmp(&old_prio) {
177                std::cmp::Ordering::Greater => false,
178                std::cmp::Ordering::Equal => {
179                    let dev_path = dev_spec.get_path();
180                    let old_path = old_spec.get_path();
181                    collect_error.borrow_mut()(
182                        Box::new(ConflictError::new(name, &dev_path, &old_path)),
183                        vec![dev_path.clone(), old_path.clone()],
184                    );
185                    conflicts.insert(name.to_owned());
186                    true
187                }
188                std::cmp::Ordering::Less => true,
189            }
190        });
191
192        let mut scan_spec_fn = |s: Spec| -> Result<(), Box<dyn Error>> {
193            let vendor = s.get_vendor().to_owned();
194            specs.entry(vendor.clone()).or_default().push(s.clone());
195            let spec_devices = s.get_devices();
196            for dev in spec_devices.values() {
197                let qualified = dev.get_qualified_name();
198                if let Some(other) = devices.get(&qualified) {
199                    if resolve_conflict.borrow_mut()(&qualified, dev, other) {
200                        continue;
201                    }
202                }
203                devices.insert(qualified, dev.clone());
204            }
205
206            Ok(())
207        };
208
209        let scaned_specs: Vec<Spec> = scan_spec_dirs(&self.spec_dirs)?;
210        for spec in scaned_specs {
211            scan_spec_fn(spec)?
212        }
213
214        for conflict in conflicts.iter() {
215            self.devices.remove(conflict);
216        }
217
218        self.specs = specs;
219        self.devices = devices;
220        self.errors = convert_errors(&spec_errors);
221
222        let errs: Vec<String> = spec_errors
223            .values()
224            .flat_map(|errors| errors.iter().map(|err| err.to_string()))
225            .collect();
226
227        if !errs.is_empty() {
228            Err(errs.join(", ").into())
229        } else {
230            Ok(())
231        }
232    }
233
234    fn refresh_if_required(&mut self, force: bool) -> Result<bool, Box<dyn std::error::Error>> {
235        // We need to refresh if
236        // - it's forced by an explicit call to Refresh() in manual mode
237        // - a missing Spec dir appears (added to watch) in auto-refresh mode
238        // TODO: Here it will be recoverd if watch is completed.
239        // if force || (self.auto_refresh && self.watch.update(&mut self.dir_errors, vec![])) {
240        if force || (self.auto_refresh) {
241            self.refresh()?;
242            return Ok(true);
243        }
244
245        Ok(false)
246    }
247
248    pub fn inject_devices(
249        &mut self,
250        oci_spec: Option<&mut oci::Spec>,
251        devices: Vec<String>,
252    ) -> Result<Vec<String>, Box<dyn Error + Send + Sync + 'static>> {
253        let mut unresolved = Vec::new();
254
255        let oci_spec = match oci_spec {
256            Some(spec) => spec,
257            None => return Err("can't inject devices, OCI Spec is empty".into()),
258        };
259
260        let _ = self.refresh_if_required(false);
261
262        let edits = &mut ContainerEdits::new();
263        let mut specs: HashSet<Spec> = HashSet::new();
264
265        for device in devices {
266            if let Some(dev) = self.devices.get(&device) {
267                let mut spec = dev.get_spec();
268                if specs.insert(spec.clone()) {
269                    // spec.edits may be none when we only have dev.edits
270                    // allow dev.edits to be added even if spec.edits is None
271                    if let Some(ce) = spec.edits() {
272                        edits.append(ce)?
273                    }
274                }
275                edits.append(dev.edits())?;
276            } else {
277                unresolved.push(device);
278            }
279        }
280
281        if !unresolved.is_empty() {
282            return Err(format!("unresolvable CDI devices {}", unresolved.join(", ")).into());
283        }
284
285        if let Err(err) = edits.apply(oci_spec) {
286            return Err(format!("failed to inject devices: {}", err).into());
287        }
288
289        Ok(Vec::new())
290    }
291
292    pub fn get_errors(&self) -> HashMap<String, Vec<anyhow::Error>> {
293        // Return errors if any
294        HashMap::new()
295    }
296}