Skip to main content

entrenar/hf_pipeline/dataset/
cache.rs

1//! Teacher cache for distillation
2
3use ndarray::Array2;
4use std::collections::HashMap;
5
6/// Cached teacher outputs for distillation
7#[derive(Debug, Clone)]
8pub struct TeacherCache {
9    /// Cached logits by example index
10    logits: HashMap<usize, Array2<f32>>,
11    /// Cached hidden states by example index
12    hidden_states: HashMap<usize, Vec<Array2<f32>>>,
13    /// Cache hit count
14    hits: usize,
15    /// Cache miss count
16    misses: usize,
17}
18
19impl Default for TeacherCache {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl TeacherCache {
26    /// Create new empty cache
27    #[must_use]
28    pub fn new() -> Self {
29        Self { logits: HashMap::new(), hidden_states: HashMap::new(), hits: 0, misses: 0 }
30    }
31
32    /// Get cached logits
33    pub fn get_logits(&mut self, index: usize) -> Option<&Array2<f32>> {
34        if self.logits.contains_key(&index) {
35            self.hits += 1;
36            self.logits.get(&index)
37        } else {
38            self.misses += 1;
39            None
40        }
41    }
42
43    /// Cache logits
44    pub fn cache_logits(&mut self, index: usize, logits: Array2<f32>) {
45        self.logits.insert(index, logits);
46    }
47
48    /// Get cached hidden states
49    pub fn get_hidden_states(&mut self, index: usize) -> Option<&Vec<Array2<f32>>> {
50        if self.hidden_states.contains_key(&index) {
51            self.hits += 1;
52            self.hidden_states.get(&index)
53        } else {
54            self.misses += 1;
55            None
56        }
57    }
58
59    /// Cache hidden states
60    pub fn cache_hidden_states(&mut self, index: usize, states: Vec<Array2<f32>>) {
61        self.hidden_states.insert(index, states);
62    }
63
64    /// Get cache statistics
65    #[must_use]
66    pub fn stats(&self) -> CacheStats {
67        CacheStats {
68            hits: self.hits,
69            misses: self.misses,
70            logits_cached: self.logits.len(),
71            hidden_states_cached: self.hidden_states.len(),
72        }
73    }
74
75    /// Clear cache
76    pub fn clear(&mut self) {
77        self.logits.clear();
78        self.hidden_states.clear();
79        self.hits = 0;
80        self.misses = 0;
81    }
82}
83
84/// Cache statistics
85#[derive(Debug, Clone, Copy)]
86pub struct CacheStats {
87    /// Number of cache hits
88    pub hits: usize,
89    /// Number of cache misses
90    pub misses: usize,
91    /// Number of cached logits entries
92    pub logits_cached: usize,
93    /// Number of cached hidden state entries
94    pub hidden_states_cached: usize,
95}
96
97impl CacheStats {
98    /// Get hit rate
99    #[must_use]
100    pub fn hit_rate(&self) -> f32 {
101        let total = self.hits + self.misses;
102        if total > 0 {
103            self.hits as f32 / total as f32
104        } else {
105            0.0
106        }
107    }
108}