cake_core/models/common/
cache.rs1use std::collections::HashMap;
2
3use candle_core::{DType, Device, Result, Tensor};
4
5use super::Config;
6
7#[derive(Debug, Clone)]
9pub struct Cache {
10 cos: Tensor,
11 sin: Tensor,
12
13 masks: HashMap<usize, Tensor>,
14 use_kv_cache: bool,
15 kvs: Vec<Option<(Tensor, Tensor)>>,
16 max_seq_len: usize,
17
18 recurrent_states: Vec<Option<Tensor>>,
21 conv_states: Vec<Option<Tensor>>,
24
25 device: Device,
26}
27
28impl Cache {
29 pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
32 let head_dim = config
34 .head_dim
35 .unwrap_or(config.hidden_size / config.num_attention_heads);
36 let rotary_dim = (head_dim as f32 * config.partial_rotary_factor) as usize;
37 let max_seq_len = config.max_seq_len;
38
39 log::debug!("cache::head_dim = {head_dim}");
40 log::debug!("cache::rotary_dim = {rotary_dim}");
41 log::debug!("cache::max_seq_len = {max_seq_len}");
42
43 let mut theta: Vec<_> = (0..rotary_dim)
44 .step_by(2)
45 .map(|i| 1f32 / config.rope_theta.powf(i as f32 / rotary_dim as f32))
46 .collect();
47
48 if let Some(ref rope_scaling) = config.rope_scaling {
50 let is_llama3 = rope_scaling
51 .rope_type
52 .as_deref()
53 .map_or(false, |t| t == "llama3");
54
55 if is_llama3 && rope_scaling.original_max_position_embeddings > 0 {
56 let factor = rope_scaling.factor;
57 let low_freq_factor = rope_scaling.low_freq_factor;
58 let high_freq_factor = rope_scaling.high_freq_factor;
59 let old_context_len = rope_scaling.original_max_position_embeddings as f32;
60
61 let low_freq_wavelen = old_context_len / low_freq_factor;
62 let high_freq_wavelen = old_context_len / high_freq_factor;
63
64 for freq in theta.iter_mut() {
65 let wavelen = 2.0 * std::f32::consts::PI / *freq;
66 if wavelen < high_freq_wavelen {
67 } else if wavelen > low_freq_wavelen {
69 *freq /= factor;
71 } else {
72 let smooth = (old_context_len / wavelen - low_freq_factor)
74 / (high_freq_factor - low_freq_factor);
75 *freq = (1.0 - smooth) * (*freq / factor) + smooth * *freq;
76 }
77 }
78
79 log::debug!("cache: applied llama3 rope scaling (factor={factor}, low_freq={low_freq_factor}, high_freq={high_freq_factor})");
80 }
81 }
82
83 let theta = Tensor::new(theta.as_slice(), device)?;
84
85 log::debug!("cache::theta = {}", &theta);
86
87 let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
88 .to_dtype(DType::F32)?
89 .reshape((max_seq_len, 1))?
90 .matmul(&theta.reshape((1, theta.elem_count()))?)?;
91
92 log::debug!("cache::idx_theta = {}", &idx_theta);
93
94 let cos = idx_theta.cos()?.to_dtype(dtype)?;
97 let sin = idx_theta.sin()?.to_dtype(dtype)?;
98
99 log::debug!("cache::cos = {}", &cos);
100 log::debug!("cache::sin = {}", &sin);
101
102 let num_layers = config.num_hidden_layers;
103
104 Ok(Self {
105 masks: HashMap::new(),
106 use_kv_cache,
107 kvs: vec![None; num_layers],
108 max_seq_len,
109 recurrent_states: vec![None; num_layers],
110 conv_states: vec![None; num_layers],
111 device: device.clone(),
112 cos,
113 sin,
114 })
115 }
116
117 pub fn with_kv_cache(&self) -> bool {
119 self.use_kv_cache
120 }
121
122 pub fn cosine(&self, index_pos: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
126 self.cos.narrow(0, index_pos, seq_len)?.to_device(device)
127 }
128
129 pub fn sine(&self, index_pos: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
131 self.sin.narrow(0, index_pos, seq_len)?.to_device(device)
132 }
133
134 pub fn mask(&mut self, seq_len: usize, device: &Device) -> Result<Tensor> {
136 if !self.masks.contains_key(&seq_len) {
138 let mask: Vec<_> = (0..seq_len)
139 .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
140 .collect();
141 let mask = Tensor::from_slice(&mask, (seq_len, seq_len), &self.device)?;
142 self.masks.insert(seq_len, mask);
143 }
144 self.masks.get(&seq_len).unwrap().clone().to_device(device)
145 }
146
147 pub fn process_kv(
149 &mut self,
150 block_idx: usize,
151 mut k: Tensor,
152 mut v: Tensor,
153 ) -> Result<(Tensor, Tensor)> {
154 if self.use_kv_cache {
155 if let Some((cache_k, cache_v)) = &self.kvs[block_idx] {
157 k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
160 v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
161
162 let k_seq_len = k.dims()[2];
164 if k_seq_len > self.max_seq_len {
165 k = k
166 .narrow(2, k_seq_len - self.max_seq_len, self.max_seq_len)?
167 .contiguous()?;
168 }
169 let v_seq_len = v.dims()[2];
170 if v_seq_len > self.max_seq_len {
171 v = v
172 .narrow(2, v_seq_len - self.max_seq_len, self.max_seq_len)?
173 .contiguous()?;
174 }
175 }
176 self.kvs[block_idx] = Some((k.clone(), v.clone()))
178 }
179 Ok((k, v))
180 }
181
182 pub fn get_recurrent_state(&self, block_idx: usize) -> Option<&Tensor> {
184 self.recurrent_states[block_idx].as_ref()
185 }
186
187 pub fn set_recurrent_state(&mut self, block_idx: usize, state: Tensor) {
189 self.recurrent_states[block_idx] = Some(state);
190 }
191
192 pub fn get_conv_state(&self, block_idx: usize) -> Option<&Tensor> {
194 self.conv_states[block_idx].as_ref()
195 }
196
197 pub fn set_conv_state(&mut self, block_idx: usize, state: Tensor) {
199 self.conv_states[block_idx] = Some(state);
200 }
201
202 pub fn as_new(&self) -> Self {
204 let mut copy = self.clone();
205 copy.clear();
206 copy
207 }
208
209 pub fn clear(&mut self) {
211 self.masks.clear();
212 self.kvs = vec![None; self.kvs.len()];
213 self.recurrent_states = vec![None; self.recurrent_states.len()];
214 self.conv_states = vec![None; self.conv_states.len()];
215 }
216}