entrenar/hf_pipeline/dataset/
cache.rs1use ndarray::Array2;
4use std::collections::HashMap;
5
6#[derive(Debug, Clone)]
8pub struct TeacherCache {
9 logits: HashMap<usize, Array2<f32>>,
11 hidden_states: HashMap<usize, Vec<Array2<f32>>>,
13 hits: usize,
15 misses: usize,
17}
18
19impl Default for TeacherCache {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl TeacherCache {
26 #[must_use]
28 pub fn new() -> Self {
29 Self { logits: HashMap::new(), hidden_states: HashMap::new(), hits: 0, misses: 0 }
30 }
31
32 pub fn get_logits(&mut self, index: usize) -> Option<&Array2<f32>> {
34 if self.logits.contains_key(&index) {
35 self.hits += 1;
36 self.logits.get(&index)
37 } else {
38 self.misses += 1;
39 None
40 }
41 }
42
43 pub fn cache_logits(&mut self, index: usize, logits: Array2<f32>) {
45 self.logits.insert(index, logits);
46 }
47
48 pub fn get_hidden_states(&mut self, index: usize) -> Option<&Vec<Array2<f32>>> {
50 if self.hidden_states.contains_key(&index) {
51 self.hits += 1;
52 self.hidden_states.get(&index)
53 } else {
54 self.misses += 1;
55 None
56 }
57 }
58
59 pub fn cache_hidden_states(&mut self, index: usize, states: Vec<Array2<f32>>) {
61 self.hidden_states.insert(index, states);
62 }
63
64 #[must_use]
66 pub fn stats(&self) -> CacheStats {
67 CacheStats {
68 hits: self.hits,
69 misses: self.misses,
70 logits_cached: self.logits.len(),
71 hidden_states_cached: self.hidden_states.len(),
72 }
73 }
74
75 pub fn clear(&mut self) {
77 self.logits.clear();
78 self.hidden_states.clear();
79 self.hits = 0;
80 self.misses = 0;
81 }
82}
83
84#[derive(Debug, Clone, Copy)]
86pub struct CacheStats {
87 pub hits: usize,
89 pub misses: usize,
91 pub logits_cached: usize,
93 pub hidden_states_cached: usize,
95}
96
97impl CacheStats {
98 #[must_use]
100 pub fn hit_rate(&self) -> f32 {
101 let total = self.hits + self.misses;
102 if total > 0 {
103 self.hits as f32 / total as f32
104 } else {
105 0.0
106 }
107 }
108}