use crate::error::{CvError, CvResult};
use oxionnx::Session;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct ModelCache {
cache: Arc<Mutex<HashMap<PathBuf, Arc<Mutex<Session>>>>>,
}
impl ModelCache {
#[must_use]
pub fn new() -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::<PathBuf, Arc<Mutex<Session>>>::new())),
}
}
pub fn get_or_load(&self, path: impl AsRef<Path>) -> CvResult<Arc<Mutex<Session>>> {
let path = path.as_ref().to_path_buf();
let mut cache = self
.cache
.lock()
.map_err(|e| CvError::model_load(format!("Cache lock error: {e}")))?;
if let Some(session) = cache.get(&path) {
return Ok(Arc::clone(session));
}
let session = Session::builder()
.with_optimization_level(oxionnx::OptLevel::All)
.load(&path)
.map_err(|e| CvError::model_load(format!("Failed to load model: {e}")))?;
let session = Arc::new(Mutex::new(session));
cache.insert(path, Arc::clone(&session));
Ok(session)
}
pub fn clear(&self) {
if let Ok(mut cache) = self.cache.lock() {
cache.clear();
}
}
#[must_use]
pub fn len(&self) -> usize {
self.cache.lock().map_or(0, |cache| cache.len())
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl Default for ModelCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_cache() {
let cache = ModelCache::new();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
let cache2 = ModelCache::default();
assert!(cache2.is_empty());
}
}