Skip to main content

oximedia_ml/
cache.rs

1//! Bounded LRU cache for loaded ONNX models.
2//!
3//! [`ModelCache`] wraps a capacity-bounded `HashMap` keyed on the
4//! canonicalised model path. Each entry is an `Arc<OnnxModel>` so that
5//! callers get cheap cloning and the underlying session is shared across
6//! concurrent pipelines.
7//!
8//! The LRU policy is tracked with a simple `Vec<PathBuf>` that records
9//! insertion / access order. When capacity is exceeded, the front of the
10//! vector (least-recently-used) is evicted.
11//!
12//! ## Example
13//!
14//! ```no_run
15//! # #[cfg(feature = "onnx")]
16//! # fn demo() -> oximedia_ml::MlResult<()> {
17//! use oximedia_ml::{DeviceType, ModelCache};
18//!
19//! let cache = ModelCache::with_capacity(4)?;
20//!
21//! // First call loads from disk; second call reuses the cached Arc.
22//! let a = cache.get_or_load("scene.onnx", DeviceType::auto())?;
23//! let b = cache.get_or_load("scene.onnx", DeviceType::auto())?;
24//! assert!(std::sync::Arc::ptr_eq(&a, &b));
25//! # Ok(())
26//! # }
27//! ```
28//!
29//! The cache is thread-safe (internal `Mutex`), so a single cache can be
30//! wrapped in `Arc<ModelCache>` and shared across pipelines / async
31//! tasks without additional synchronisation.
32
33use std::collections::HashMap;
34use std::path::{Path, PathBuf};
35use std::sync::{Arc, Mutex};
36
37use crate::device::DeviceType;
38use crate::error::{MlError, MlResult};
39use crate::model::{canonical_path, OnnxModel};
40
41/// Default cache capacity.
42pub const DEFAULT_CAPACITY: usize = 8;
43
44/// LRU cache of loaded [`OnnxModel`] handles.
45///
46/// See the [module-level docs][self] for an end-to-end example. Keys are
47/// canonicalised via [`canonical_path`] so equivalent relative/absolute
48/// paths resolve to the same slot.
49pub struct ModelCache {
50    inner: Mutex<Inner>,
51    capacity: usize,
52}
53
54struct Inner {
55    map: HashMap<PathBuf, Arc<OnnxModel>>,
56    order: Vec<PathBuf>,
57}
58
59impl ModelCache {
60    /// Create a new cache with [`DEFAULT_CAPACITY`] slots.
61    #[must_use]
62    pub fn new() -> Self {
63        // SAFETY: DEFAULT_CAPACITY is a non-zero const, so with_capacity cannot fail.
64        match Self::with_capacity(DEFAULT_CAPACITY) {
65            Ok(c) => c,
66            Err(_) => unreachable!("DEFAULT_CAPACITY is non-zero"),
67        }
68    }
69
70    /// Create a new cache with the given positive capacity.
71    ///
72    /// # Errors
73    ///
74    /// Returns [`MlError::CacheCapacityZero`] if `capacity == 0`.
75    pub fn with_capacity(capacity: usize) -> MlResult<Self> {
76        if capacity == 0 {
77            return Err(MlError::CacheCapacityZero);
78        }
79        Ok(Self {
80            inner: Mutex::new(Inner {
81                map: HashMap::with_capacity(capacity),
82                order: Vec::with_capacity(capacity),
83            }),
84            capacity,
85        })
86    }
87
88    /// Capacity reported by this cache.
89    #[must_use]
90    pub fn capacity(&self) -> usize {
91        self.capacity
92    }
93
94    /// Current number of cached models.
95    pub fn len(&self) -> MlResult<usize> {
96        let g = self
97            .inner
98            .lock()
99            .map_err(|_| MlError::pipeline("cache", "cache mutex poisoned"))?;
100        Ok(g.map.len())
101    }
102
103    /// Return whether the cache is empty.
104    pub fn is_empty(&self) -> MlResult<bool> {
105        self.len().map(|n| n == 0)
106    }
107
108    /// Load `path` or return the cached handle.
109    ///
110    /// Under the hood the entire operation is serialised by the internal
111    /// mutex, which avoids the classic double-checked-locking race where
112    /// two threads both miss and both load the same model.
113    ///
114    /// # Errors
115    ///
116    /// Returns [`MlError::Pipeline`] with stage `"cache"` if the internal
117    /// mutex is poisoned, or any error produced by
118    /// [`OnnxModel::load`][crate::OnnxModel::load] on a cache miss.
119    pub fn get_or_load(
120        &self,
121        path: impl AsRef<Path>,
122        device: DeviceType,
123    ) -> MlResult<Arc<OnnxModel>> {
124        let key = canonical_path(path.as_ref());
125        let mut guard = self
126            .inner
127            .lock()
128            .map_err(|_| MlError::pipeline("cache", "cache mutex poisoned"))?;
129
130        if let Some(existing) = guard.map.get(&key) {
131            let arc = Arc::clone(existing);
132            Self::touch(&mut guard.order, &key);
133            return Ok(arc);
134        }
135
136        let model = Arc::new(OnnxModel::load(&key, device)?);
137        guard.map.insert(key.clone(), Arc::clone(&model));
138        guard.order.push(key);
139        Self::evict_if_needed(&mut guard, self.capacity);
140        Ok(model)
141    }
142
143    /// Remove the cached entry for `path` if present and return it.
144    pub fn remove(&self, path: impl AsRef<Path>) -> MlResult<Option<Arc<OnnxModel>>> {
145        let key = canonical_path(path.as_ref());
146        let mut guard = self
147            .inner
148            .lock()
149            .map_err(|_| MlError::pipeline("cache", "cache mutex poisoned"))?;
150        let removed = guard.map.remove(&key);
151        guard.order.retain(|p| p != &key);
152        Ok(removed)
153    }
154
155    /// Drop all cached entries.
156    pub fn clear(&self) -> MlResult<()> {
157        let mut guard = self
158            .inner
159            .lock()
160            .map_err(|_| MlError::pipeline("cache", "cache mutex poisoned"))?;
161        guard.map.clear();
162        guard.order.clear();
163        Ok(())
164    }
165
166    fn touch(order: &mut Vec<PathBuf>, key: &Path) {
167        if let Some(pos) = order.iter().position(|p| p == key) {
168            let entry = order.remove(pos);
169            order.push(entry);
170        }
171    }
172
173    fn evict_if_needed(inner: &mut Inner, capacity: usize) {
174        while inner.map.len() > capacity {
175            if inner.order.is_empty() {
176                break;
177            }
178            let oldest = inner.order.remove(0);
179            inner.map.remove(&oldest);
180        }
181    }
182}
183
184impl Default for ModelCache {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189
190impl std::fmt::Debug for ModelCache {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        let len = self.len().unwrap_or(0);
193        f.debug_struct("ModelCache")
194            .field("capacity", &self.capacity)
195            .field("len", &len)
196            .finish()
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    #[test]
205    fn zero_capacity_is_rejected() {
206        let err = ModelCache::with_capacity(0).expect_err("expected error");
207        assert!(matches!(err, MlError::CacheCapacityZero));
208    }
209
210    #[test]
211    fn default_capacity_is_non_empty() {
212        let cache = ModelCache::new();
213        assert_eq!(cache.capacity(), DEFAULT_CAPACITY);
214        assert!(cache.is_empty().unwrap_or(false));
215    }
216}