use crate::{Result, VisionError};
use image::{DynamicImage, GenericImageView};
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::{Duration, Instant};
use super::image_conversion::image_to_tensor;
use super::image_processing::resize_image;
use torsh_tensor::Tensor;
#[derive(Clone)]
pub struct CacheEntry {
pub image: DynamicImage,
pub access_time: Instant,
pub access_count: usize,
pub size_bytes: usize,
}
pub struct ImageCache {
cache: Arc<Mutex<HashMap<String, CacheEntry>>>,
max_size_bytes: usize,
current_size_bytes: Arc<Mutex<usize>>,
hit_count: Arc<Mutex<usize>>,
miss_count: Arc<Mutex<usize>>,
}
impl ImageCache {
pub fn new(max_size_mb: usize) -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
max_size_bytes: max_size_mb * 1024 * 1024, current_size_bytes: Arc::new(Mutex::new(0)),
hit_count: Arc::new(Mutex::new(0)),
miss_count: Arc::new(Mutex::new(0)),
}
}
pub fn get_or_load<P: AsRef<Path>>(&self, path: P) -> Result<DynamicImage> {
let path_str = path.as_ref().to_string_lossy().to_string();
{
let mut cache = self.cache.lock().expect("lock should not be poisoned");
if let Some(entry) = cache.get_mut(&path_str) {
entry.access_time = Instant::now();
entry.access_count += 1;
*self.hit_count.lock().expect("lock should not be poisoned") += 1;
return Ok(entry.image.clone());
}
}
*self.miss_count.lock().expect("lock should not be poisoned") += 1;
let image = crate::io::global::load_image(path)?;
let estimated_size = (image.width() * image.height() * 4) as usize;
self.insert(path_str, image.clone(), estimated_size);
Ok(image)
}
fn insert(&self, key: String, image: DynamicImage, size_bytes: usize) {
let entry = CacheEntry {
image: image.clone(),
access_time: Instant::now(),
access_count: 1,
size_bytes,
};
let mut cache = self.cache.lock().expect("lock should not be poisoned");
let mut current_size = self
.current_size_bytes
.lock()
.expect("lock should not be poisoned");
if let Some(old_entry) = cache.remove(&key) {
*current_size -= old_entry.size_bytes;
}
while *current_size + size_bytes > self.max_size_bytes && !cache.is_empty() {
let lru_key = cache
.iter()
.min_by_key(|(_, entry)| entry.access_time)
.map(|(k, _)| k.clone());
if let Some(lru_key) = lru_key {
if let Some(lru_entry) = cache.remove(&lru_key) {
*current_size -= lru_entry.size_bytes;
}
} else {
break;
}
}
cache.insert(key, entry);
*current_size += size_bytes;
}
pub fn stats(&self) -> CacheStats {
let hit_count = *self.hit_count.lock().expect("lock should not be poisoned");
let miss_count = *self.miss_count.lock().expect("lock should not be poisoned");
let total_requests = hit_count + miss_count;
let hit_rate = if total_requests > 0 {
hit_count as f64 / total_requests as f64
} else {
0.0
};
CacheStats {
hit_count,
miss_count,
hit_rate,
current_size_bytes: *self
.current_size_bytes
.lock()
.expect("lock should not be poisoned"),
max_size_bytes: self.max_size_bytes,
entry_count: self
.cache
.lock()
.expect("lock should not be poisoned")
.len(),
}
}
pub fn clear(&self) {
self.cache
.lock()
.expect("lock should not be poisoned")
.clear();
*self
.current_size_bytes
.lock()
.expect("lock should not be poisoned") = 0;
*self.hit_count.lock().expect("lock should not be poisoned") = 0;
*self.miss_count.lock().expect("lock should not be poisoned") = 0;
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hit_count: usize,
pub miss_count: usize,
pub hit_rate: f64,
pub current_size_bytes: usize,
pub max_size_bytes: usize,
pub entry_count: usize,
}
pub struct ImagePrefetcher {
cache: Arc<ImageCache>,
prefetch_queue: Arc<Mutex<Vec<String>>>,
worker_handle: Option<thread::JoinHandle<()>>,
shutdown_signal: Arc<Mutex<bool>>,
}
impl ImagePrefetcher {
pub fn new(cache: Arc<ImageCache>) -> Self {
let prefetch_queue = Arc::new(Mutex::new(Vec::new()));
let shutdown_signal = Arc::new(Mutex::new(false));
let queue_clone = Arc::clone(&prefetch_queue);
let cache_clone = Arc::clone(&cache);
let shutdown_clone = Arc::clone(&shutdown_signal);
let worker_handle = thread::spawn(move || {
Self::worker_thread(queue_clone, cache_clone, shutdown_clone);
});
Self {
cache,
prefetch_queue,
worker_handle: Some(worker_handle),
shutdown_signal,
}
}
pub fn prefetch_paths<P: AsRef<Path>>(&self, paths: &[P]) {
let mut queue = self
.prefetch_queue
.lock()
.expect("lock should not be poisoned");
for path in paths {
queue.push(path.as_ref().to_string_lossy().to_string());
}
}
fn worker_thread(
queue: Arc<Mutex<Vec<String>>>,
cache: Arc<ImageCache>,
shutdown: Arc<Mutex<bool>>,
) {
loop {
if *shutdown.lock().expect("lock should not be poisoned") {
break;
}
let path = {
let mut queue_guard = queue.lock().expect("lock should not be poisoned");
queue_guard.pop()
};
if let Some(path) = path {
if let Err(_) = cache.get_or_load(&path) {
}
} else {
thread::sleep(Duration::from_millis(10));
}
}
}
pub fn get_image<P: AsRef<Path>>(&self, path: P) -> Result<DynamicImage> {
self.cache.get_or_load(path)
}
pub fn shutdown(&mut self) {
*self
.shutdown_signal
.lock()
.expect("lock should not be poisoned") = true;
if let Some(handle) = self.worker_handle.take() {
let _ = handle.join();
}
}
}
impl Drop for ImagePrefetcher {
fn drop(&mut self) {
self.shutdown();
}
}
pub struct BatchImageLoader {
cache: Arc<ImageCache>,
prefetcher: ImagePrefetcher,
target_size: Option<(u32, u32)>,
normalize: bool,
}
impl BatchImageLoader {
pub fn new(cache_size_mb: usize) -> Self {
let cache = Arc::new(ImageCache::new(cache_size_mb));
let prefetcher = ImagePrefetcher::new(Arc::clone(&cache));
Self {
cache,
prefetcher,
target_size: None,
normalize: false,
}
}
pub fn with_target_size(mut self, width: u32, height: u32) -> Self {
self.target_size = Some((width, height));
self
}
pub fn with_normalization(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn load_batch<P: AsRef<Path>>(&self, paths: &[P]) -> Result<Vec<Tensor<f32>>> {
if paths.len() > 1 {
let prefetch_paths: Vec<String> = paths
.iter()
.skip(1)
.map(|p| p.as_ref().to_string_lossy().to_string())
.collect();
self.prefetcher.prefetch_paths(
&prefetch_paths
.iter()
.map(|s| s.as_str())
.collect::<Vec<_>>(),
);
}
let mut tensors = Vec::with_capacity(paths.len());
for path in paths {
let mut image = self.prefetcher.get_image(path)?;
if let Some((width, height)) = self.target_size {
image = resize_image(&image, width, height, image::imageops::FilterType::Lanczos3);
}
let tensor = image_to_tensor(&image)?;
let final_tensor = if self.normalize {
let mut normalized = tensor.clone();
normalized.div_scalar_(255.0)?;
normalized
} else {
tensor
};
tensors.push(final_tensor);
}
Ok(tensors)
}
pub fn cache_stats(&self) -> CacheStats {
self.cache.stats()
}
pub fn clear_cache(&self) {
self.cache.clear();
}
}
pub struct MemoryMappedLoader {
file_handles: HashMap<String, std::fs::File>,
mmap_cache: HashMap<String, memmap2::Mmap>,
}
impl MemoryMappedLoader {
pub fn new() -> Self {
Self {
file_handles: HashMap::new(),
mmap_cache: HashMap::new(),
}
}
pub fn load_image_mmap<P: AsRef<Path>>(&mut self, path: P) -> Result<DynamicImage> {
let path_str = path.as_ref().to_string_lossy().to_string();
if let Some(mmap) = self.mmap_cache.get(&path_str) {
return self.decode_from_mmap(mmap);
}
let file = std::fs::File::open(&path)?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
let image = self.decode_from_mmap(&mmap)?;
self.file_handles.insert(path_str.clone(), file);
self.mmap_cache.insert(path_str, mmap);
Ok(image)
}
fn decode_from_mmap(&self, mmap: &memmap2::Mmap) -> Result<DynamicImage> {
let cursor = std::io::Cursor::new(&mmap[..]);
let image = image::load(
cursor,
image::ImageFormat::from_path("dummy.jpg").unwrap_or(image::ImageFormat::Jpeg),
)?;
Ok(image)
}
pub fn clear(&mut self) {
self.mmap_cache.clear();
self.file_handles.clear();
}
}
impl Default for MemoryMappedLoader {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default)]
pub struct LoadingMetrics {
pub total_images_loaded: usize,
pub total_loading_time: Duration,
pub cache_hits: usize,
pub cache_misses: usize,
pub average_loading_time: Duration,
}
impl LoadingMetrics {
pub fn record_load(&mut self, duration: Duration, cache_hit: bool) {
self.total_images_loaded += 1;
self.total_loading_time += duration;
if cache_hit {
self.cache_hits += 1;
} else {
self.cache_misses += 1;
}
self.average_loading_time = self.total_loading_time / self.total_images_loaded as u32;
}
pub fn cache_hit_rate(&self) -> f64 {
if self.total_images_loaded == 0 {
0.0
} else {
self.cache_hits as f64 / self.total_images_loaded as f64
}
}
pub fn throughput_ips(&self) -> f64 {
if self.total_loading_time.is_zero() {
0.0
} else {
self.total_images_loaded as f64 / self.total_loading_time.as_secs_f64()
}
}
pub fn reset(&mut self) {
*self = Self::default();
}
}