kalosm_common/
kv_cache.rs1use candle_core::Tensor;
2
3#[derive(Debug, Clone)]
5pub struct KvCache {
6 cache: candle_nn::kv_cache::KvCache,
7 concat_dim: usize,
8 max_seq_len: usize,
9}
10
11impl KvCache {
12 pub fn new(concat_dim: usize, max_seq_len: usize) -> Self {
14 Self {
15 cache: candle_nn::kv_cache::KvCache::new(concat_dim, 8),
16 concat_dim,
17 max_seq_len,
18 }
19 }
20
21 pub fn cache(&self) -> &candle_nn::kv_cache::KvCache {
23 &self.cache
24 }
25
26 pub fn cache_mut(&mut self) -> &mut candle_nn::kv_cache::KvCache {
28 &mut self.cache
29 }
30
31 pub fn reset(&mut self) {
33 self.cache.reset()
34 }
35
36 pub fn append(&mut self, k: &Tensor, v: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
38 let k = k.contiguous()?;
39 let v = v.contiguous()?;
40 let seq_len = k.dim(self.concat_dim)?;
41 debug_assert_eq!(seq_len, v.dim(self.concat_dim)?);
43
44 let current_allocated_size = self.cache.k_cache().max_seq_len();
45 let size_required_for_append = self.cache.current_seq_len() + seq_len;
46
47 if size_required_for_append > current_allocated_size {
49 let next_power_of_two = size_required_for_append.next_power_of_two();
52 let new_cache_max_seq_len = next_power_of_two.min(self.max_seq_len);
53
54 let mut new_cache =
56 candle_nn::kv_cache::KvCache::new(self.concat_dim, new_cache_max_seq_len);
57 if let (Ok(Some(k)), Ok(Some(v))) = (self.cache.k(), self.cache.v()) {
59 new_cache.k_cache_mut().append(&k.contiguous()?)?;
60 new_cache.v_cache_mut().append(&v.contiguous()?)?;
61 }
62 self.cache = new_cache;
64 }
65
66 self.cache.append(&k, &v)
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct TensorCache {
73 cache: candle_nn::kv_cache::Cache,
74 concat_dim: usize,
75 max_seq_len: usize,
76}
77
78impl TensorCache {
79 pub fn new(concat_dim: usize, max_seq_len: usize) -> Self {
81 Self {
82 cache: candle_nn::kv_cache::Cache::new(concat_dim, 8),
83 concat_dim,
84 max_seq_len,
85 }
86 }
87
88 pub fn cache(&self) -> &candle_nn::kv_cache::Cache {
90 &self.cache
91 }
92
93 pub fn cache_mut(&mut self) -> &mut candle_nn::kv_cache::Cache {
95 &mut self.cache
96 }
97
98 pub fn all_data(&self) -> &Option<Tensor> {
100 self.cache.all_data()
101 }
102
103 pub fn reset(&mut self) {
105 self.cache.reset()
106 }
107
108 pub fn append(&mut self, v: &Tensor) -> candle_core::Result<()> {
110 let v = v.contiguous()?;
111 let seq_len = v.dim(self.concat_dim)?;
112 debug_assert_eq!(seq_len, v.dim(self.concat_dim)?);
114
115 let current_allocated_size = self.cache.max_seq_len();
116 let size_required_for_append = self.cache.current_seq_len() + seq_len;
117
118 if size_required_for_append > current_allocated_size {
120 let next_power_of_two = size_required_for_append.next_power_of_two();
123 let new_cache_max_seq_len = next_power_of_two.min(self.max_seq_len);
124
125 let mut new_cache =
127 candle_nn::kv_cache::Cache::new(self.concat_dim, new_cache_max_seq_len);
128 if let Some(v) = self.cache.all_data() {
130 new_cache.append(&v.contiguous()?)?;
131 }
132 self.cache = new_cache;
134 }
135
136 self.cache.append(&v)
138 }
139}