use ndarray::Array2;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TeacherCache {
logits: HashMap<usize, Array2<f32>>,
hidden_states: HashMap<usize, Vec<Array2<f32>>>,
hits: usize,
misses: usize,
}
impl Default for TeacherCache {
fn default() -> Self {
Self::new()
}
}
impl TeacherCache {
#[must_use]
pub fn new() -> Self {
Self { logits: HashMap::new(), hidden_states: HashMap::new(), hits: 0, misses: 0 }
}
pub fn get_logits(&mut self, index: usize) -> Option<&Array2<f32>> {
if self.logits.contains_key(&index) {
self.hits += 1;
self.logits.get(&index)
} else {
self.misses += 1;
None
}
}
pub fn cache_logits(&mut self, index: usize, logits: Array2<f32>) {
self.logits.insert(index, logits);
}
pub fn get_hidden_states(&mut self, index: usize) -> Option<&Vec<Array2<f32>>> {
if self.hidden_states.contains_key(&index) {
self.hits += 1;
self.hidden_states.get(&index)
} else {
self.misses += 1;
None
}
}
pub fn cache_hidden_states(&mut self, index: usize, states: Vec<Array2<f32>>) {
self.hidden_states.insert(index, states);
}
#[must_use]
pub fn stats(&self) -> CacheStats {
CacheStats {
hits: self.hits,
misses: self.misses,
logits_cached: self.logits.len(),
hidden_states_cached: self.hidden_states.len(),
}
}
pub fn clear(&mut self) {
self.logits.clear();
self.hidden_states.clear();
self.hits = 0;
self.misses = 0;
}
}
#[derive(Debug, Clone, Copy)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub logits_cached: usize,
pub hidden_states_cached: usize,
}
impl CacheStats {
#[must_use]
pub fn hit_rate(&self) -> f32 {
let total = self.hits + self.misses;
if total > 0 {
self.hits as f32 / total as f32
} else {
0.0
}
}
}