container_device_interface/
cache.rs1use std::{
2 cell::RefCell,
3 collections::{HashMap, HashSet},
4 error::Error,
5 fmt,
6 sync::{Arc, Mutex},
7};
8
9use anyhow::Result;
10
11use oci_spec::runtime as oci;
12
13use crate::{
14 container_edits::ContainerEdits,
16 device::Device,
17 spec::Spec,
18 spec_dirs::{convert_errors, scan_spec_dirs, with_spec_dirs, SpecError, DEFAULT_SPEC_DIRS},
19};
20
21#[derive(Debug)]
23struct ConflictError {
24 name: String,
25 dev_path: String,
26 old_path: String,
27}
28
29impl ConflictError {
30 fn new(name: &str, dev_path: &str, old_path: &str) -> Self {
31 Self {
32 name: name.to_owned(),
33 dev_path: dev_path.to_owned(),
34 old_path: old_path.to_owned(),
35 }
36 }
37}
38
39impl fmt::Display for ConflictError {
40 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
41 write!(
42 f,
43 "conflicting device {} (specs {}, {})",
44 self.name, self.dev_path, self.old_path
45 )
46 }
47}
48
49impl Error for ConflictError {}
50
51pub type CdiOption = Box<dyn FnOnce(&mut Cache)>;
55
56pub fn with_auto_refresh(auto_refresh: bool) -> CdiOption {
62 Box::new(move |c: &mut Cache| {
63 c.auto_refresh = auto_refresh;
64 })
65}
66
67#[allow(dead_code)]
68#[derive(Default)]
69pub struct Cache {
70 pub spec_dirs: Vec<String>,
71 pub specs: HashMap<String, Vec<Spec>>,
72 pub devices: HashMap<String, Device>,
73 pub errors: HashMap<String, Vec<Box<dyn std::error::Error + Send + Sync + 'static>>>,
74 pub dir_errors: HashMap<String, Box<dyn std::error::Error + Send + Sync + 'static>>,
75
76 pub auto_refresh: bool,
77 }
79
80pub fn new_cache(options: Vec<CdiOption>) -> Arc<Mutex<Cache>> {
81 let cache = Arc::new(Mutex::new(Cache::default()));
82
83 {
84 let mut c = cache.lock().unwrap();
85
86 with_spec_dirs(&DEFAULT_SPEC_DIRS)(&mut c);
87 c.configure(options);
88 let _ = c.refresh();
89 } cache
92}
93
94impl Cache {
95 pub fn new(
96 spec_dirs: Vec<String>,
97 specs: HashMap<String, Vec<Spec>>,
98 devices: HashMap<String, Device>,
99 ) -> Self {
100 Self {
101 spec_dirs,
102 specs,
103 devices,
104 errors: HashMap::new(),
105 dir_errors: HashMap::new(),
106 auto_refresh: false,
107 }
109 }
110
111 pub fn configure(&mut self, options: Vec<CdiOption>) {
112 for option in options {
113 option(self);
114 }
115 }
116
117 pub fn get_device(&mut self, dev_name: &str) -> Option<&Device> {
118 let _ = self.refresh_if_required(false);
119
120 self.devices.get(dev_name)
121 }
122
123 pub fn list_devices(&mut self) -> Vec<String> {
124 let _ = self.refresh_if_required(false);
125
126 let mut devices: Vec<String> = self.devices.keys().cloned().collect();
127 devices.sort();
128 devices
129 }
130
131 pub fn list_vendors(&mut self) -> Vec<String> {
132 let mut vendors: Vec<String> = Vec::new();
133
134 let _ = self.refresh_if_required(false);
135
136 for vendor in self.specs.keys() {
137 vendors.push(vendor.clone());
138 }
139 vendors.sort();
140 vendors
141 }
142
143 pub fn get_vendor_specs(&mut self, vendor: &str) -> Vec<Spec> {
144 let _ = self.refresh_if_required(false);
145
146 match self.specs.get(vendor) {
147 Some(specs) => specs.clone(),
148 None => Vec::new(),
149 }
150 }
151
152 pub fn refresh(&mut self) -> Result<(), Box<dyn Error>> {
154 let mut specs: HashMap<String, Vec<Spec>> = HashMap::new();
155 let mut devices: HashMap<String, Device> = HashMap::new();
156 let mut conflicts: HashSet<String> = HashSet::new();
157 let mut spec_errors: HashMap<String, Vec<Box<dyn Error>>> = HashMap::new();
158
159 let collect_error = RefCell::new(|err: Box<dyn Error>, paths: Vec<String>| {
161 let err_string = err.to_string();
162 for path in paths {
163 spec_errors
164 .entry(path.to_string())
165 .or_default()
166 .push(Box::new(SpecError::new(&err_string.to_string())));
167 }
168 });
169
170 let resolve_conflict = RefCell::new(|name: &str, dev: &Device, old: &Device| -> bool {
171 let dev_spec = dev.get_spec();
172 let old_spec = old.get_spec();
173 let dev_prio = dev_spec.get_priority();
174 let old_prio = old_spec.get_priority();
175
176 match dev_prio.cmp(&old_prio) {
177 std::cmp::Ordering::Greater => false,
178 std::cmp::Ordering::Equal => {
179 let dev_path = dev_spec.get_path();
180 let old_path = old_spec.get_path();
181 collect_error.borrow_mut()(
182 Box::new(ConflictError::new(name, &dev_path, &old_path)),
183 vec![dev_path.clone(), old_path.clone()],
184 );
185 conflicts.insert(name.to_owned());
186 true
187 }
188 std::cmp::Ordering::Less => true,
189 }
190 });
191
192 let mut scan_spec_fn = |s: Spec| -> Result<(), Box<dyn Error>> {
193 let vendor = s.get_vendor().to_owned();
194 specs.entry(vendor.clone()).or_default().push(s.clone());
195 let spec_devices = s.get_devices();
196 for dev in spec_devices.values() {
197 let qualified = dev.get_qualified_name();
198 if let Some(other) = devices.get(&qualified) {
199 if resolve_conflict.borrow_mut()(&qualified, dev, other) {
200 continue;
201 }
202 }
203 devices.insert(qualified, dev.clone());
204 }
205
206 Ok(())
207 };
208
209 let scaned_specs: Vec<Spec> = scan_spec_dirs(&self.spec_dirs)?;
210 for spec in scaned_specs {
211 scan_spec_fn(spec)?
212 }
213
214 for conflict in conflicts.iter() {
215 self.devices.remove(conflict);
216 }
217
218 self.specs = specs;
219 self.devices = devices;
220 self.errors = convert_errors(&spec_errors);
221
222 let errs: Vec<String> = spec_errors
223 .values()
224 .flat_map(|errors| errors.iter().map(|err| err.to_string()))
225 .collect();
226
227 if !errs.is_empty() {
228 Err(errs.join(", ").into())
229 } else {
230 Ok(())
231 }
232 }
233
234 fn refresh_if_required(&mut self, force: bool) -> Result<bool, Box<dyn std::error::Error>> {
235 if force || (self.auto_refresh) {
241 self.refresh()?;
242 return Ok(true);
243 }
244
245 Ok(false)
246 }
247
248 pub fn inject_devices(
249 &mut self,
250 oci_spec: Option<&mut oci::Spec>,
251 devices: Vec<String>,
252 ) -> Result<Vec<String>, Box<dyn Error + Send + Sync + 'static>> {
253 let mut unresolved = Vec::new();
254
255 let oci_spec = match oci_spec {
256 Some(spec) => spec,
257 None => return Err("can't inject devices, OCI Spec is empty".into()),
258 };
259
260 let _ = self.refresh_if_required(false);
261
262 let edits = &mut ContainerEdits::new();
263 let mut specs: HashSet<Spec> = HashSet::new();
264
265 for device in devices {
266 if let Some(dev) = self.devices.get(&device) {
267 let mut spec = dev.get_spec();
268 if specs.insert(spec.clone()) {
269 if let Some(ce) = spec.edits() {
272 edits.append(ce)?
273 }
274 }
275 edits.append(dev.edits())?;
276 } else {
277 unresolved.push(device);
278 }
279 }
280
281 if !unresolved.is_empty() {
282 return Err(format!("unresolvable CDI devices {}", unresolved.join(", ")).into());
283 }
284
285 if let Err(err) = edits.apply(oci_spec) {
286 return Err(format!("failed to inject devices: {}", err).into());
287 }
288
289 Ok(Vec::new())
290 }
291
292 pub fn get_errors(&self) -> HashMap<String, Vec<anyhow::Error>> {
293 HashMap::new()
295 }
296}