use crate::config::CacheConfig;
use crate::error::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedResult {
pub latex: String,
pub alternatives: HashMap<String, String>,
pub confidence: f32,
pub timestamp: u64,
pub access_count: usize,
pub image_hash: String,
}
#[derive(Debug, Clone)]
struct CacheEntry {
embedding: Vec<f32>,
result: CachedResult,
last_access: u64,
}
pub struct CacheManager {
config: CacheConfig,
entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
lru_order: Arc<RwLock<Vec<String>>>,
stats: Arc<RwLock<CacheStats>>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub evictions: u64,
pub avg_similarity: f32,
}
impl CacheStats {
pub fn hit_rate(&self) -> f32 {
if self.hits + self.misses == 0 {
return 0.0;
}
self.hits as f32 / (self.hits + self.misses) as f32
}
}
impl CacheManager {
pub fn new(config: CacheConfig) -> Self {
Self {
config,
entries: Arc::new(RwLock::new(HashMap::new())),
lru_order: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
fn generate_embedding(&self, image_data: &[u8]) -> Result<Vec<f32>> {
let hash = self.hash_image(image_data);
let mut embedding = vec![0.0; self.config.vector_dimension];
for (i, byte) in hash.as_bytes().iter().enumerate() {
if i < embedding.len() {
embedding[i] = *byte as f32 / 255.0;
}
}
Ok(embedding)
}
fn hash_image(&self, image_data: &[u8]) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
image_data.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
pub fn lookup(&self, image_data: &[u8]) -> Result<Option<CachedResult>> {
if !self.config.enabled {
return Ok(None);
}
let embedding = self.generate_embedding(image_data)?;
let hash = self.hash_image(image_data);
let entries = self.entries.read().unwrap();
if let Some(entry) = entries.get(&hash) {
if !self.is_expired(&entry) {
self.record_hit();
self.update_lru(&hash);
return Ok(Some(entry.result.clone()));
}
}
let mut best_match: Option<(String, f32, CachedResult)> = None;
for (key, entry) in entries.iter() {
if self.is_expired(entry) {
continue;
}
let similarity = self.cosine_similarity(&embedding, &entry.embedding);
if similarity >= self.config.similarity_threshold {
if best_match.is_none() || similarity > best_match.as_ref().unwrap().1 {
best_match = Some((key.clone(), similarity, entry.result.clone()));
}
}
}
if let Some((key, similarity, result)) = best_match {
self.record_hit_with_similarity(similarity);
self.update_lru(&key);
Ok(Some(result))
} else {
self.record_miss();
Ok(None)
}
}
pub fn store(&self, image_data: &[u8], result: CachedResult) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let embedding = self.generate_embedding(image_data)?;
let hash = self.hash_image(image_data);
let entry = CacheEntry {
embedding,
result,
last_access: self.current_timestamp(),
};
let mut entries = self.entries.write().unwrap();
if entries.len() >= self.config.capacity && !entries.contains_key(&hash) {
self.evict_lru(&mut entries);
}
entries.insert(hash.clone(), entry);
self.update_lru(&hash);
self.update_stats_entries(entries.len());
Ok(())
}
fn is_expired(&self, entry: &CacheEntry) -> bool {
let current = self.current_timestamp();
current - entry.last_access > self.config.ttl
}
fn current_timestamp(&self) -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
fn evict_lru(&self, entries: &mut HashMap<String, CacheEntry>) {
let mut lru = self.lru_order.write().unwrap();
if let Some(key) = lru.first() {
entries.remove(key);
lru.remove(0);
self.record_eviction();
}
}
fn update_lru(&self, key: &str) {
let mut lru = self.lru_order.write().unwrap();
lru.retain(|k| k != key);
lru.push(key.to_string());
}
fn record_hit(&self) {
let mut stats = self.stats.write().unwrap();
stats.hits += 1;
}
fn record_hit_with_similarity(&self, similarity: f32) {
let mut stats = self.stats.write().unwrap();
stats.hits += 1;
let total = stats.hits as f32;
stats.avg_similarity = (stats.avg_similarity * (total - 1.0) + similarity) / total;
}
fn record_miss(&self) {
let mut stats = self.stats.write().unwrap();
stats.misses += 1;
}
fn record_eviction(&self) {
let mut stats = self.stats.write().unwrap();
stats.evictions += 1;
}
fn update_stats_entries(&self, count: usize) {
let mut stats = self.stats.write().unwrap();
stats.entries = count;
}
pub fn stats(&self) -> CacheStats {
self.stats.read().unwrap().clone()
}
pub fn clear(&self) {
let mut entries = self.entries.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
entries.clear();
lru.clear();
self.update_stats_entries(0);
}
pub fn cleanup(&self) {
let mut entries = self.entries.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
let expired: Vec<String> = entries
.iter()
.filter(|(_, entry)| self.is_expired(entry))
.map(|(key, _)| key.clone())
.collect();
for key in &expired {
entries.remove(key);
lru.retain(|k| k != key);
}
self.update_stats_entries(entries.len());
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> CacheConfig {
CacheConfig {
enabled: true,
capacity: 100,
similarity_threshold: 0.95,
ttl: 3600,
vector_dimension: 128,
persistent: false,
cache_dir: ".cache/test".to_string(),
}
}
fn test_result() -> CachedResult {
CachedResult {
latex: r"\frac{x^2}{2}".to_string(),
alternatives: HashMap::new(),
confidence: 0.95,
timestamp: 0,
access_count: 0,
image_hash: "test".to_string(),
}
}
#[test]
fn test_cache_creation() {
let config = test_config();
let cache = CacheManager::new(config);
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
}
#[test]
fn test_store_and_lookup() {
let config = test_config();
let cache = CacheManager::new(config);
let image_data = b"test image data";
let result = test_result();
cache.store(image_data, result.clone()).unwrap();
let lookup_result = cache.lookup(image_data).unwrap();
assert!(lookup_result.is_some());
assert_eq!(lookup_result.unwrap().latex, result.latex);
}
#[test]
fn test_cache_miss() {
let config = test_config();
let cache = CacheManager::new(config);
let image_data = b"nonexistent image";
let lookup_result = cache.lookup(image_data).unwrap();
assert!(lookup_result.is_none());
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_cache_hit_rate() {
let config = test_config();
let cache = CacheManager::new(config);
let image_data = b"test image";
let result = test_result();
cache.store(image_data, result).unwrap();
cache.lookup(image_data).unwrap();
cache.lookup(image_data).unwrap();
cache.lookup(b"different image").unwrap();
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_cosine_similarity() {
let config = test_config();
let cache = CacheManager::new(config);
let vec_a = vec![1.0, 0.0, 0.0];
let vec_b = vec![1.0, 0.0, 0.0];
let vec_c = vec![0.0, 1.0, 0.0];
assert!((cache.cosine_similarity(&vec_a, &vec_b) - 1.0).abs() < 0.01);
assert!((cache.cosine_similarity(&vec_a, &vec_c) - 0.0).abs() < 0.01);
}
#[test]
fn test_cache_clear() {
let config = test_config();
let cache = CacheManager::new(config);
let image_data = b"test image";
let result = test_result();
cache.store(image_data, result).unwrap();
assert_eq!(cache.stats().entries, 1);
cache.clear();
assert_eq!(cache.stats().entries, 0);
}
#[test]
fn test_disabled_cache() {
let mut config = test_config();
config.enabled = false;
let cache = CacheManager::new(config);
let image_data = b"test image";
let result = test_result();
cache.store(image_data, result).unwrap();
let lookup_result = cache.lookup(image_data).unwrap();
assert!(lookup_result.is_none());
}
}