use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex};
use crate::error::Result;
use crate::types::Kernel;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
struct CacheKey {
x_hash: u64,
y_hash: u64,
}
impl CacheKey {
fn from_inputs(x: &[f64], y: &[f64]) -> Self {
Self {
x_hash: Self::hash_vector(x),
y_hash: Self::hash_vector(y),
}
}
fn hash_vector(v: &[f64]) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
for &val in v {
val.to_bits().hash(&mut hasher);
}
hasher.finish()
}
}
pub struct CachedKernel {
inner: Box<dyn Kernel>,
cache: Arc<Mutex<HashMap<CacheKey, f64>>>,
stats: Arc<Mutex<CacheStats>>,
}
#[derive(Clone, Debug, Default)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub size: usize,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
impl CachedKernel {
pub fn new(inner: Box<dyn Kernel>) -> Self {
Self {
inner,
cache: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(Mutex::new(CacheStats::default())),
}
}
pub fn stats(&self) -> CacheStats {
self.stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
pub fn clear(&mut self) {
self.cache
.lock()
.expect("lock should not be poisoned")
.clear();
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.hits = 0;
stats.misses = 0;
stats.size = 0;
}
pub fn cache_size(&self) -> usize {
self.cache
.lock()
.expect("lock should not be poisoned")
.len()
}
}
impl Kernel for CachedKernel {
fn compute(&self, x: &[f64], y: &[f64]) -> Result<f64> {
let key = CacheKey::from_inputs(x, y);
{
let cache = self.cache.lock().expect("lock should not be poisoned");
if let Some(&value) = cache.get(&key) {
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.hits += 1;
return Ok(value);
}
}
let value = self.inner.compute(x, y)?;
{
let mut cache = self.cache.lock().expect("lock should not be poisoned");
cache.insert(key, value);
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.misses += 1;
stats.size = cache.len();
}
Ok(value)
}
fn name(&self) -> &str {
self.inner.name()
}
fn is_psd(&self) -> bool {
self.inner.is_psd()
}
}
pub struct KernelMatrixCache {
cache: HashMap<u64, Vec<Vec<f64>>>,
}
impl KernelMatrixCache {
pub fn new() -> Self {
Self {
cache: HashMap::new(),
}
}
fn hash_data(data: &[Vec<f64>]) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
for row in data {
for &val in row {
val.to_bits().hash(&mut hasher);
}
}
hasher.finish()
}
pub fn get_or_compute(
&mut self,
data: &[Vec<f64>],
kernel: &dyn Kernel,
) -> Result<Vec<Vec<f64>>> {
let key = Self::hash_data(data);
if let Some(matrix) = self.cache.get(&key) {
return Ok(matrix.clone());
}
let matrix = kernel.compute_matrix(data)?;
self.cache.insert(key, matrix.clone());
Ok(matrix)
}
pub fn clear(&mut self) {
self.cache.clear();
}
pub fn size(&self) -> usize {
self.cache.len()
}
}
impl Default for KernelMatrixCache {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor_kernels::LinearKernel;
#[test]
fn test_cached_kernel() {
let base = LinearKernel::new();
let cached = CachedKernel::new(Box::new(base));
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
let result1 = cached.compute(&x, &y).expect("unwrap");
let stats1 = cached.stats();
assert_eq!(stats1.misses, 1);
assert_eq!(stats1.hits, 0);
let result2 = cached.compute(&x, &y).expect("unwrap");
let stats2 = cached.stats();
assert_eq!(stats2.misses, 1);
assert_eq!(stats2.hits, 1);
assert_eq!(result1, result2);
}
#[test]
fn test_cached_kernel_clear() {
let base = LinearKernel::new();
let mut cached = CachedKernel::new(Box::new(base));
let x = vec![1.0, 2.0, 3.0];
let y = vec![4.0, 5.0, 6.0];
cached.compute(&x, &y).expect("unwrap");
assert_eq!(cached.cache_size(), 1);
cached.clear();
assert_eq!(cached.cache_size(), 0);
let stats = cached.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_cache_stats_hit_rate() {
let stats = CacheStats {
hits: 7,
misses: 3,
size: 10,
};
assert!((stats.hit_rate() - 0.7).abs() < 1e-10);
}
#[test]
fn test_cache_stats_empty() {
let stats = CacheStats::default();
assert_eq!(stats.hit_rate(), 0.0);
}
#[test]
fn test_kernel_matrix_cache() {
let kernel = LinearKernel::new();
let mut cache = KernelMatrixCache::new();
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
let matrix1 = cache.get_or_compute(&data, &kernel).expect("unwrap");
assert_eq!(cache.size(), 1);
let matrix2 = cache.get_or_compute(&data, &kernel).expect("unwrap");
assert_eq!(cache.size(), 1);
assert_eq!(matrix1.len(), matrix2.len());
for i in 0..matrix1.len() {
for j in 0..matrix1[i].len() {
assert_eq!(matrix1[i][j], matrix2[i][j]);
}
}
}
#[test]
fn test_kernel_matrix_cache_clear() {
let kernel = LinearKernel::new();
let mut cache = KernelMatrixCache::new();
let data = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
cache.get_or_compute(&data, &kernel).expect("unwrap");
assert_eq!(cache.size(), 1);
cache.clear();
assert_eq!(cache.size(), 0);
}
#[test]
fn test_cached_kernel_name() {
let base = LinearKernel::new();
let cached = CachedKernel::new(Box::new(base));
assert_eq!(cached.name(), "Linear");
}
#[test]
fn test_cached_kernel_psd() {
let base = LinearKernel::new();
let cached = CachedKernel::new(Box::new(base));
assert!(cached.is_psd());
}
}