use crate::error::Result;
use crate::factor::Factor;
use std::collections::HashMap;
use std::hash::Hash;
use std::sync::{Arc, Mutex};
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
enum CacheKey {
Product(String, String),
Marginalize(String, String),
Divide(String, String),
Reduce(String, String, usize),
}
pub struct FactorCache {
cache: Arc<Mutex<HashMap<CacheKey, Factor>>>,
max_size: usize,
hits: Arc<Mutex<usize>>,
misses: Arc<Mutex<usize>>,
}
impl Default for FactorCache {
fn default() -> Self {
Self::new(1000)
}
}
impl FactorCache {
pub fn new(max_size: usize) -> Self {
Self {
cache: Arc::new(Mutex::new(HashMap::new())),
max_size,
hits: Arc::new(Mutex::new(0)),
misses: Arc::new(Mutex::new(0)),
}
}
pub fn get_product(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
self.get(&key)
}
pub fn put_product(&self, f1_name: &str, f2_name: &str, result: Factor) {
let key = CacheKey::Product(f1_name.to_string(), f2_name.to_string());
self.put(key, result);
}
pub fn get_marginalize(&self, factor_name: &str, var: &str) -> Option<Factor> {
let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
self.get(&key)
}
pub fn put_marginalize(&self, factor_name: &str, var: &str, result: Factor) {
let key = CacheKey::Marginalize(factor_name.to_string(), var.to_string());
self.put(key, result);
}
pub fn get_divide(&self, f1_name: &str, f2_name: &str) -> Option<Factor> {
let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
self.get(&key)
}
pub fn put_divide(&self, f1_name: &str, f2_name: &str, result: Factor) {
let key = CacheKey::Divide(f1_name.to_string(), f2_name.to_string());
self.put(key, result);
}
pub fn get_reduce(&self, factor_name: &str, var: &str, value: usize) -> Option<Factor> {
let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
self.get(&key)
}
pub fn put_reduce(&self, factor_name: &str, var: &str, value: usize, result: Factor) {
let key = CacheKey::Reduce(factor_name.to_string(), var.to_string(), value);
self.put(key, result);
}
fn get(&self, key: &CacheKey) -> Option<Factor> {
let cache = self.cache.lock().expect("lock should not be poisoned");
if let Some(factor) = cache.get(key) {
*self.hits.lock().expect("lock should not be poisoned") += 1;
Some(factor.clone())
} else {
*self.misses.lock().expect("lock should not be poisoned") += 1;
None
}
}
fn put(&self, key: CacheKey, factor: Factor) {
let mut cache = self.cache.lock().expect("lock should not be poisoned");
if cache.len() >= self.max_size {
if let Some(first_key) = cache.keys().next().cloned() {
cache.remove(&first_key);
}
}
cache.insert(key, factor);
}
pub fn clear(&self) {
self.cache
.lock()
.expect("lock should not be poisoned")
.clear();
*self.hits.lock().expect("lock should not be poisoned") = 0;
*self.misses.lock().expect("lock should not be poisoned") = 0;
}
pub fn stats(&self) -> CacheStats {
let hits = *self.hits.lock().expect("lock should not be poisoned");
let misses = *self.misses.lock().expect("lock should not be poisoned");
let size = self
.cache
.lock()
.expect("lock should not be poisoned")
.len();
CacheStats {
hits,
misses,
size,
hit_rate: if hits + misses > 0 {
hits as f64 / (hits + misses) as f64
} else {
0.0
},
}
}
pub fn size(&self) -> usize {
self.cache
.lock()
.expect("lock should not be poisoned")
.len()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub size: usize,
pub hit_rate: f64,
}
pub struct CachedFactor {
pub factor: Factor,
cache: Arc<FactorCache>,
}
impl CachedFactor {
pub fn new(factor: Factor, cache: Arc<FactorCache>) -> Self {
Self { factor, cache }
}
pub fn product_cached(&self, other: &CachedFactor) -> Result<Factor> {
if let Some(cached) = self
.cache
.get_product(&self.factor.name, &other.factor.name)
{
return Ok(cached);
}
let result = self.factor.product(&other.factor)?;
self.cache
.put_product(&self.factor.name, &other.factor.name, result.clone());
Ok(result)
}
pub fn marginalize_out_cached(&self, var: &str) -> Result<Factor> {
if let Some(cached) = self.cache.get_marginalize(&self.factor.name, var) {
return Ok(cached);
}
let result = self.factor.marginalize_out(var)?;
self.cache
.put_marginalize(&self.factor.name, var, result.clone());
Ok(result)
}
pub fn divide_cached(&self, other: &CachedFactor) -> Result<Factor> {
if let Some(cached) = self.cache.get_divide(&self.factor.name, &other.factor.name) {
return Ok(cached);
}
let result = self.factor.divide(&other.factor)?;
self.cache
.put_divide(&self.factor.name, &other.factor.name, result.clone());
Ok(result)
}
pub fn reduce_cached(&self, var: &str, value: usize) -> Result<Factor> {
if let Some(cached) = self.cache.get_reduce(&self.factor.name, var, value) {
return Ok(cached);
}
let result = self.factor.reduce(var, value)?;
self.cache
.put_reduce(&self.factor.name, var, value, result.clone());
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array;
fn create_test_factor(name: &str) -> Factor {
let values = vec![0.1, 0.2, 0.3, 0.4];
let array = Array::from_shape_vec(vec![2, 2], values)
.expect("unwrap")
.into_dyn();
Factor::new(
name.to_string(),
vec!["X".to_string(), "Y".to_string()],
array,
)
.expect("unwrap")
}
#[test]
fn test_cache_product() {
let cache = Arc::new(FactorCache::new(100));
let f1 = CachedFactor::new(create_test_factor("f1"), cache.clone());
let f2 = CachedFactor::new(create_test_factor("f2"), cache.clone());
let result1 = f1.product_cached(&f2).expect("unwrap");
let stats1 = cache.stats();
assert_eq!(stats1.misses, 1);
assert_eq!(stats1.hits, 0);
let result2 = f1.product_cached(&f2).expect("unwrap");
let stats2 = cache.stats();
assert_eq!(stats2.misses, 1);
assert_eq!(stats2.hits, 1);
assert_eq!(result1.name, result2.name);
}
#[test]
fn test_cache_marginalize() {
let cache = Arc::new(FactorCache::new(100));
let f = CachedFactor::new(create_test_factor("f"), cache.clone());
let _result1 = f.marginalize_out_cached("Y").expect("unwrap");
let stats1 = cache.stats();
assert_eq!(stats1.misses, 1);
let _result2 = f.marginalize_out_cached("Y").expect("unwrap");
let stats2 = cache.stats();
assert_eq!(stats2.hits, 1);
}
#[test]
fn test_cache_stats() {
let cache = FactorCache::new(100);
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.hit_rate, 0.0);
}
#[test]
fn test_cache_clear() {
let cache = Arc::new(FactorCache::new(100));
let f = CachedFactor::new(create_test_factor("f"), cache.clone());
let _ = f.marginalize_out_cached("Y").expect("unwrap");
assert_eq!(cache.size(), 1);
cache.clear();
assert_eq!(cache.size(), 0);
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
}
#[test]
fn test_cache_eviction() {
let cache = Arc::new(FactorCache::new(2));
cache.put_marginalize("f1", "X", create_test_factor("result1"));
cache.put_marginalize("f2", "Y", create_test_factor("result2"));
cache.put_marginalize("f3", "Z", create_test_factor("result3"));
assert!(cache.size() <= 2);
}
}