1use std::collections::HashMap;
34use std::path::{Path, PathBuf};
35use std::sync::{Arc, Mutex};
36
37use crate::device::DeviceType;
38use crate::error::{MlError, MlResult};
39use crate::model::{canonical_path, OnnxModel};
40
41pub const DEFAULT_CAPACITY: usize = 8;
43
44pub struct ModelCache {
50 inner: Mutex<Inner>,
51 capacity: usize,
52}
53
54struct Inner {
55 map: HashMap<PathBuf, Arc<OnnxModel>>,
56 order: Vec<PathBuf>,
57}
58
59impl ModelCache {
60 #[must_use]
62 pub fn new() -> Self {
63 match Self::with_capacity(DEFAULT_CAPACITY) {
65 Ok(c) => c,
66 Err(_) => unreachable!("DEFAULT_CAPACITY is non-zero"),
67 }
68 }
69
70 pub fn with_capacity(capacity: usize) -> MlResult<Self> {
76 if capacity == 0 {
77 return Err(MlError::CacheCapacityZero);
78 }
79 Ok(Self {
80 inner: Mutex::new(Inner {
81 map: HashMap::with_capacity(capacity),
82 order: Vec::with_capacity(capacity),
83 }),
84 capacity,
85 })
86 }
87
88 #[must_use]
90 pub fn capacity(&self) -> usize {
91 self.capacity
92 }
93
94 pub fn len(&self) -> MlResult<usize> {
96 let g = self
97 .inner
98 .lock()
99 .map_err(|_| MlError::pipeline("cache", "cache mutex poisoned"))?;
100 Ok(g.map.len())
101 }
102
103 pub fn is_empty(&self) -> MlResult<bool> {
105 self.len().map(|n| n == 0)
106 }
107
108 pub fn get_or_load(
120 &self,
121 path: impl AsRef<Path>,
122 device: DeviceType,
123 ) -> MlResult<Arc<OnnxModel>> {
124 let key = canonical_path(path.as_ref());
125 let mut guard = self
126 .inner
127 .lock()
128 .map_err(|_| MlError::pipeline("cache", "cache mutex poisoned"))?;
129
130 if let Some(existing) = guard.map.get(&key) {
131 let arc = Arc::clone(existing);
132 Self::touch(&mut guard.order, &key);
133 return Ok(arc);
134 }
135
136 let model = Arc::new(OnnxModel::load(&key, device)?);
137 guard.map.insert(key.clone(), Arc::clone(&model));
138 guard.order.push(key);
139 Self::evict_if_needed(&mut guard, self.capacity);
140 Ok(model)
141 }
142
143 pub fn remove(&self, path: impl AsRef<Path>) -> MlResult<Option<Arc<OnnxModel>>> {
145 let key = canonical_path(path.as_ref());
146 let mut guard = self
147 .inner
148 .lock()
149 .map_err(|_| MlError::pipeline("cache", "cache mutex poisoned"))?;
150 let removed = guard.map.remove(&key);
151 guard.order.retain(|p| p != &key);
152 Ok(removed)
153 }
154
155 pub fn clear(&self) -> MlResult<()> {
157 let mut guard = self
158 .inner
159 .lock()
160 .map_err(|_| MlError::pipeline("cache", "cache mutex poisoned"))?;
161 guard.map.clear();
162 guard.order.clear();
163 Ok(())
164 }
165
166 fn touch(order: &mut Vec<PathBuf>, key: &Path) {
167 if let Some(pos) = order.iter().position(|p| p == key) {
168 let entry = order.remove(pos);
169 order.push(entry);
170 }
171 }
172
173 fn evict_if_needed(inner: &mut Inner, capacity: usize) {
174 while inner.map.len() > capacity {
175 if inner.order.is_empty() {
176 break;
177 }
178 let oldest = inner.order.remove(0);
179 inner.map.remove(&oldest);
180 }
181 }
182}
183
184impl Default for ModelCache {
185 fn default() -> Self {
186 Self::new()
187 }
188}
189
190impl std::fmt::Debug for ModelCache {
191 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192 let len = self.len().unwrap_or(0);
193 f.debug_struct("ModelCache")
194 .field("capacity", &self.capacity)
195 .field("len", &len)
196 .finish()
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
205 fn zero_capacity_is_rejected() {
206 let err = ModelCache::with_capacity(0).expect_err("expected error");
207 assert!(matches!(err, MlError::CacheCapacityZero));
208 }
209
210 #[test]
211 fn default_capacity_is_non_empty() {
212 let cache = ModelCache::new();
213 assert_eq!(cache.capacity(), DEFAULT_CAPACITY);
214 assert!(cache.is_empty().unwrap_or(false));
215 }
216}