Skip to main content

cake_core/models/common/
cache.rs

1use std::collections::HashMap;
2
3use candle_core::{DType, Device, Result, Tensor};
4
5use super::Config;
6
7/// Abstraction over cosine and sine tables, kv-caching and attention masking.
8#[derive(Debug, Clone)]
9pub struct Cache {
10    cos: Tensor,
11    sin: Tensor,
12
13    masks: HashMap<usize, Tensor>,
14    use_kv_cache: bool,
15    kvs: Vec<Option<(Tensor, Tensor)>>,
16    max_seq_len: usize,
17
18    /// Recurrent state matrices for linear attention layers (Gated DeltaNet).
19    /// Shape per entry: (batch=1, num_heads, key_dim, value_dim).
20    recurrent_states: Vec<Option<Tensor>>,
21    /// Conv1d history for linear attention layers.
22    /// Shape per entry: (batch=1, channels, kernel_size-1).
23    conv_states: Vec<Option<Tensor>>,
24
25    device: Device,
26}
27
28impl Cache {
29    /// Creates a new cache instance with the provided configuration.
30    /// Set `use_kv_cache` to false to disable kv-caching.
31    pub fn new(use_kv_cache: bool, dtype: DType, config: &Config, device: &Device) -> Result<Self> {
32        // Compute rotary dimension, respecting partial_rotary_factor
33        let head_dim = config
34            .head_dim
35            .unwrap_or(config.hidden_size / config.num_attention_heads);
36        let rotary_dim = (head_dim as f32 * config.partial_rotary_factor) as usize;
37        let max_seq_len = config.max_seq_len;
38
39        log::debug!("cache::head_dim = {head_dim}");
40        log::debug!("cache::rotary_dim = {rotary_dim}");
41        log::debug!("cache::max_seq_len = {max_seq_len}");
42
43        let mut theta: Vec<_> = (0..rotary_dim)
44            .step_by(2)
45            .map(|i| 1f32 / config.rope_theta.powf(i as f32 / rotary_dim as f32))
46            .collect();
47
48        // Apply LLaMA3 RoPE frequency scaling if configured
49        if let Some(ref rope_scaling) = config.rope_scaling {
50            let is_llama3 = rope_scaling
51                .rope_type
52                .as_deref()
53                .map_or(false, |t| t == "llama3");
54
55            if is_llama3 && rope_scaling.original_max_position_embeddings > 0 {
56                let factor = rope_scaling.factor;
57                let low_freq_factor = rope_scaling.low_freq_factor;
58                let high_freq_factor = rope_scaling.high_freq_factor;
59                let old_context_len = rope_scaling.original_max_position_embeddings as f32;
60
61                let low_freq_wavelen = old_context_len / low_freq_factor;
62                let high_freq_wavelen = old_context_len / high_freq_factor;
63
64                for freq in theta.iter_mut() {
65                    let wavelen = 2.0 * std::f32::consts::PI / *freq;
66                    if wavelen < high_freq_wavelen {
67                        // High frequency: keep as-is
68                    } else if wavelen > low_freq_wavelen {
69                        // Low frequency: scale down by factor
70                        *freq /= factor;
71                    } else {
72                        // Medium frequency: smooth interpolation
73                        let smooth = (old_context_len / wavelen - low_freq_factor)
74                            / (high_freq_factor - low_freq_factor);
75                        *freq = (1.0 - smooth) * (*freq / factor) + smooth * *freq;
76                    }
77                }
78
79                log::debug!("cache: applied llama3 rope scaling (factor={factor}, low_freq={low_freq_factor}, high_freq={high_freq_factor})");
80            }
81        }
82
83        let theta = Tensor::new(theta.as_slice(), device)?;
84
85        log::debug!("cache::theta = {}", &theta);
86
87        let idx_theta = Tensor::arange(0, max_seq_len as u32, device)?
88            .to_dtype(DType::F32)?
89            .reshape((max_seq_len, 1))?
90            .matmul(&theta.reshape((1, theta.elem_count()))?)?;
91
92        log::debug!("cache::idx_theta = {}", &idx_theta);
93
94        // This is different from the paper, see:
95        // https://github.com/huggingface/transformers/blob/6112b1c6442aaf7affd2b0676a1cd4eee30c45cf/src/transformers/models/llama/modeling_llama.py#L112
96        let cos = idx_theta.cos()?.to_dtype(dtype)?;
97        let sin = idx_theta.sin()?.to_dtype(dtype)?;
98
99        log::debug!("cache::cos = {}", &cos);
100        log::debug!("cache::sin = {}", &sin);
101
102        let num_layers = config.num_hidden_layers;
103
104        Ok(Self {
105            masks: HashMap::new(),
106            use_kv_cache,
107            kvs: vec![None; num_layers],
108            max_seq_len,
109            recurrent_states: vec![None; num_layers],
110            conv_states: vec![None; num_layers],
111            device: device.clone(),
112            cos,
113            sin,
114        })
115    }
116
117    /// Return true if kv-caching is enabled.
118    pub fn with_kv_cache(&self) -> bool {
119        self.use_kv_cache
120    }
121
122    /// Return the cached cosine value for the given position and sequence length.
123    /// When `device` differs from the cache's own device, the result is copied
124    /// to that device (enables multi-GPU workers).
125    pub fn cosine(&self, index_pos: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
126        self.cos.narrow(0, index_pos, seq_len)?.to_device(device)
127    }
128
129    /// Return the cached sine value for the given position and sequence length.
130    pub fn sine(&self, index_pos: usize, seq_len: usize, device: &Device) -> Result<Tensor> {
131        self.sin.narrow(0, index_pos, seq_len)?.to_device(device)
132    }
133
134    /// Get the attention mask for the given sequence length.
135    pub fn mask(&mut self, seq_len: usize, device: &Device) -> Result<Tensor> {
136        // Always create/cache on self.device, then copy to target if needed
137        if !self.masks.contains_key(&seq_len) {
138            let mask: Vec<_> = (0..seq_len)
139                .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
140                .collect();
141            let mask = Tensor::from_slice(&mask, (seq_len, seq_len), &self.device)?;
142            self.masks.insert(seq_len, mask);
143        }
144        self.masks.get(&seq_len).unwrap().clone().to_device(device)
145    }
146
147    /// Process the input k and v by either generating their cache entry or applying a previously cached one.
148    pub fn process_kv(
149        &mut self,
150        block_idx: usize,
151        mut k: Tensor,
152        mut v: Tensor,
153    ) -> Result<(Tensor, Tensor)> {
154        if self.use_kv_cache {
155            // if this block_idx in cache
156            if let Some((cache_k, cache_v)) = &self.kvs[block_idx] {
157                // update cache entry: concatenate on dim 2 (seq_len)
158                // tensor shape is (batch, num_heads, seq_len, head_dim)
159                k = Tensor::cat(&[cache_k, &k], 2)?.contiguous()?;
160                v = Tensor::cat(&[cache_v, &v], 2)?.contiguous()?;
161
162                // truncate on dim 2 (seq_len) if over limit
163                let k_seq_len = k.dims()[2];
164                if k_seq_len > self.max_seq_len {
165                    k = k
166                        .narrow(2, k_seq_len - self.max_seq_len, self.max_seq_len)?
167                        .contiguous()?;
168                }
169                let v_seq_len = v.dims()[2];
170                if v_seq_len > self.max_seq_len {
171                    v = v
172                        .narrow(2, v_seq_len - self.max_seq_len, self.max_seq_len)?
173                        .contiguous()?;
174                }
175            }
176            // set entry for this block
177            self.kvs[block_idx] = Some((k.clone(), v.clone()))
178        }
179        Ok((k, v))
180    }
181
182    /// Get the recurrent state for a linear attention layer.
183    pub fn get_recurrent_state(&self, block_idx: usize) -> Option<&Tensor> {
184        self.recurrent_states[block_idx].as_ref()
185    }
186
187    /// Set the recurrent state for a linear attention layer.
188    pub fn set_recurrent_state(&mut self, block_idx: usize, state: Tensor) {
189        self.recurrent_states[block_idx] = Some(state);
190    }
191
192    /// Get the conv state for a linear attention layer.
193    pub fn get_conv_state(&self, block_idx: usize) -> Option<&Tensor> {
194        self.conv_states[block_idx].as_ref()
195    }
196
197    /// Set the conv state for a linear attention layer.
198    pub fn set_conv_state(&mut self, block_idx: usize, state: Tensor) {
199        self.conv_states[block_idx] = Some(state);
200    }
201
202    /// Return a copy of this cache with the same state but new kv table.
203    pub fn as_new(&self) -> Self {
204        let mut copy = self.clone();
205        copy.clear();
206        copy
207    }
208
209    /// Clear the cache.
210    pub fn clear(&mut self) {
211        self.masks.clear();
212        self.kvs = vec![None; self.kvs.len()];
213        self.recurrent_states = vec![None; self.recurrent_states.len()];
214        self.conv_states = vec![None; self.conv_states.len()];
215    }
216}