use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use candle_core::Tensor;
use indexmap::IndexMap;
pub struct EncoderCacheManager {
cache: IndexMap<u64, Vec<Tensor>>,
max_entries: usize,
hits: Arc<AtomicUsize>,
misses: Arc<AtomicUsize>,
}
impl EncoderCacheManager {
pub fn new(max_entries: usize) -> Self {
Self {
cache: IndexMap::with_capacity(max_entries),
max_entries,
hits: Arc::new(AtomicUsize::new(0)),
misses: Arc::new(AtomicUsize::new(0)),
}
}
pub fn counters(&self) -> (Arc<AtomicUsize>, Arc<AtomicUsize>) {
(self.hits.clone(), self.misses.clone())
}
pub fn get(&mut self, content_hash: u64) -> Option<Vec<Tensor>> {
if let Some(entry) = self.cache.shift_remove(&content_hash) {
let cloned = entry.clone();
self.cache.insert(content_hash, entry);
self.hits.fetch_add(1, Ordering::Relaxed);
Some(cloned)
} else {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub fn insert(&mut self, content_hash: u64, outputs: Vec<Tensor>) {
if self.cache.contains_key(&content_hash) {
self.cache.shift_remove(&content_hash);
self.cache.insert(content_hash, outputs);
return;
}
if self.cache.len() >= self.max_entries && self.max_entries > 0 {
self.cache.shift_remove_index(0);
}
self.cache.insert(content_hash, outputs);
}
}
pub fn cached_encode_images(
image_hashes: &[u64],
pixel_values: &Tensor,
cache: &Mutex<EncoderCacheManager>,
encode_fn: impl FnOnce(&Tensor) -> candle_core::Result<Vec<Tensor>>,
) -> candle_core::Result<Vec<Tensor>> {
let n_images = image_hashes.len();
if n_images == 0 {
return encode_fn(pixel_values);
}
debug_assert_eq!(
n_images,
pixel_values.dim(0)?,
"image_hashes length must match pixel_values dim-0"
);
let mut hits: Vec<Option<Vec<Tensor>>> = vec![None; n_images];
let mut miss_indices: Vec<usize> = Vec::new();
{
let mut guard = cache.lock().expect("encoder cache lock poisoned");
for (i, &hash) in image_hashes.iter().enumerate() {
if let Some(cached) = guard.get(hash) {
hits[i] = Some(cached);
} else {
miss_indices.push(i);
}
}
}
if miss_indices.is_empty() {
return assemble(hits, n_images);
}
let miss_pixels = if miss_indices.len() == n_images {
pixel_values.clone()
} else {
let slices: Vec<Tensor> = miss_indices
.iter()
.map(|&i| pixel_values.get(i))
.collect::<candle_core::Result<Vec<_>>>()?;
Tensor::stack(&slices, 0)?
};
let encoded = encode_fn(&miss_pixels)?;
{
let mut guard = cache.lock().expect("encoder cache lock poisoned");
for (batch_idx, &orig_idx) in miss_indices.iter().enumerate() {
let per_image: Vec<Tensor> = encoded
.iter()
.map(|t| t.get(batch_idx))
.collect::<candle_core::Result<Vec<_>>>()?;
guard.insert(image_hashes[orig_idx], per_image.clone());
hits[orig_idx] = Some(per_image);
}
}
assemble(hits, n_images)
}
fn assemble(hits: Vec<Option<Vec<Tensor>>>, n_images: usize) -> candle_core::Result<Vec<Tensor>> {
let n_outputs = hits[0].as_ref().map(|v| v.len()).unwrap_or(1);
let mut result = Vec::with_capacity(n_outputs);
for out_idx in 0..n_outputs {
let slices: Vec<Tensor> = (0..n_images)
.map(|i| hits[i].as_ref().expect("all images should be resolved")[out_idx].clone())
.collect();
result.push(Tensor::stack(&slices, 0)?);
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{Device, Tensor};
fn dummy_tensor(val: f32) -> Tensor {
Tensor::new(&[val], &Device::Cpu).unwrap()
}
#[test]
fn test_insert_and_get() {
let mut cache = EncoderCacheManager::new(4);
let t = dummy_tensor(1.0);
cache.insert(100, vec![t.clone()]);
let result = cache.get(100);
assert!(result.is_some());
let result = result.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(
result[0].to_vec1::<f32>().unwrap(),
t.to_vec1::<f32>().unwrap()
);
}
#[test]
fn test_get_miss() {
let mut cache = EncoderCacheManager::new(4);
assert!(cache.get(999).is_none());
}
#[test]
fn test_lru_eviction() {
let mut cache = EncoderCacheManager::new(3);
cache.insert(1, vec![dummy_tensor(1.0)]);
cache.insert(2, vec![dummy_tensor(2.0)]);
cache.insert(3, vec![dummy_tensor(3.0)]);
cache.insert(4, vec![dummy_tensor(4.0)]);
assert!(cache.get(1).is_none(), "key 1 should have been evicted");
assert!(cache.get(2).is_some());
assert!(cache.get(3).is_some());
assert!(cache.get(4).is_some());
}
#[test]
fn test_get_bumps_lru_order() {
let mut cache = EncoderCacheManager::new(3);
cache.insert(1, vec![dummy_tensor(1.0)]);
cache.insert(2, vec![dummy_tensor(2.0)]);
cache.insert(3, vec![dummy_tensor(3.0)]);
let _ = cache.get(1);
cache.insert(4, vec![dummy_tensor(4.0)]);
assert!(cache.get(1).is_some(), "key 1 was accessed, should survive");
assert!(cache.get(2).is_none(), "key 2 should have been evicted");
assert!(cache.get(3).is_some());
assert!(cache.get(4).is_some());
}
#[test]
fn test_insert_duplicate_updates_lru() {
let mut cache = EncoderCacheManager::new(3);
cache.insert(1, vec![dummy_tensor(1.0)]);
cache.insert(2, vec![dummy_tensor(2.0)]);
cache.insert(3, vec![dummy_tensor(3.0)]);
cache.insert(1, vec![dummy_tensor(10.0)]);
cache.insert(4, vec![dummy_tensor(4.0)]);
assert!(
cache.get(1).is_some(),
"key 1 was re-inserted, should survive"
);
assert!(cache.get(2).is_none(), "key 2 should have been evicted");
let val = cache.get(1).unwrap()[0].to_vec1::<f32>().unwrap();
assert_eq!(val, vec![10.0]);
}
#[test]
fn test_multi_tensor_entries() {
let mut cache = EncoderCacheManager::new(4);
let t1 = dummy_tensor(1.0);
let t2 = dummy_tensor(2.0);
cache.insert(42, vec![t1, t2]);
let result = cache.get(42).unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].to_vec1::<f32>().unwrap(), vec![1.0]);
assert_eq!(result[1].to_vec1::<f32>().unwrap(), vec![2.0]);
}
fn make_pixels(vals: &[f32]) -> Tensor {
Tensor::from_slice(vals, (vals.len(), 1), &Device::Cpu).unwrap()
}
#[test]
fn test_cached_encode_all_miss() {
let cache = Mutex::new(EncoderCacheManager::new(32));
let pixels = make_pixels(&[10.0, 20.0, 30.0]);
let hashes = [1u64, 2, 3];
let result = cached_encode_images(&hashes, &pixels, &cache, |pv| {
Ok(vec![pv.clone()])
})
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].dims(), &[3, 1]);
assert_eq!(
result[0].to_vec2::<f32>().unwrap(),
vec![vec![10.0], vec![20.0], vec![30.0]]
);
let mut guard = cache.lock().unwrap();
assert!(guard.get(1).is_some());
assert!(guard.get(2).is_some());
assert!(guard.get(3).is_some());
}
#[test]
fn test_cached_encode_all_hit() {
let cache = Mutex::new(EncoderCacheManager::new(32));
{
let mut guard = cache.lock().unwrap();
guard.insert(1, vec![Tensor::new(&[100.0f32], &Device::Cpu).unwrap()]);
guard.insert(2, vec![Tensor::new(&[200.0f32], &Device::Cpu).unwrap()]);
}
let pixels = make_pixels(&[10.0, 20.0]);
let hashes = [1u64, 2];
let encode_called = std::sync::atomic::AtomicBool::new(false);
let result = cached_encode_images(&hashes, &pixels, &cache, |pv| {
encode_called.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(vec![pv.clone()])
})
.unwrap();
assert!(
!encode_called.load(std::sync::atomic::Ordering::SeqCst),
"encode_fn should NOT be called when everything is cached"
);
assert_eq!(
result[0].to_vec2::<f32>().unwrap(),
vec![vec![100.0], vec![200.0]]
);
}
#[test]
fn test_cached_encode_partial_hit() {
let cache = Mutex::new(EncoderCacheManager::new(32));
{
let mut guard = cache.lock().unwrap();
guard.insert(2, vec![Tensor::new(&[200.0f32], &Device::Cpu).unwrap()]);
}
let pixels = make_pixels(&[10.0, 20.0, 30.0]);
let hashes = [1u64, 2, 3];
let result = cached_encode_images(&hashes, &pixels, &cache, |pv| {
Ok(vec![(pv * 2.0)?])
})
.unwrap();
let output = result[0].to_vec2::<f32>().unwrap();
assert_eq!(output[0], vec![20.0]);
assert_eq!(output[1], vec![200.0]);
assert_eq!(output[2], vec![60.0]);
}
#[test]
fn test_cached_encode_multi_output() {
let cache = Mutex::new(EncoderCacheManager::new(32));
let pixels = make_pixels(&[5.0, 6.0]);
let hashes = [10u64, 20];
let result = cached_encode_images(&hashes, &pixels, &cache, |pv| {
let main = pv.clone();
let aux = (pv * 10.0)?;
Ok(vec![main, aux])
})
.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(
result[0].to_vec2::<f32>().unwrap(),
vec![vec![5.0], vec![6.0]]
);
assert_eq!(
result[1].to_vec2::<f32>().unwrap(),
vec![vec![50.0], vec![60.0]]
);
let result2 = cached_encode_images(&hashes, &pixels, &cache, |_| {
panic!("should not be called on full cache hit");
})
.unwrap();
assert_eq!(
result2[0].to_vec2::<f32>().unwrap(),
vec![vec![5.0], vec![6.0]]
);
assert_eq!(
result2[1].to_vec2::<f32>().unwrap(),
vec![vec![50.0], vec![60.0]]
);
}
}