moshi_db/
kv_cache.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use candle::{DType, Device, Result, Tensor};
6use candle_nn::kv_cache::RotatingKvCache;
7
8#[derive(Debug, Clone)]
9pub struct IndicesAndMask {
10    indices: Tensor,
11    mask: Tensor,
12}
13
14impl IndicesAndMask {
15    pub fn mask(&self) -> &Tensor {
16        &self.mask
17    }
18}
19
20#[derive(Debug, Clone)]
21pub struct ScatteredKvCache {
22    k: Tensor,
23    v: Tensor,
24    context: usize,
25}
26
27impl ScatteredKvCache {
28    pub fn append(
29        &mut self,
30        k: &Tensor,
31        v: &Tensor,
32        iam: &IndicesAndMask,
33    ) -> Result<(Tensor, Tensor)> {
34        if self.context <= k.dim(2)? {
35            return Ok((k.clone(), v.clone()));
36        }
37        let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
38        let indices = indices.broadcast_as(k.shape())?.contiguous()?;
39        self.k.scatter_set(&indices, k, 2)?;
40        self.v.scatter_set(&indices, v, 2)?;
41        Ok((self.k.clone(), self.v.clone()))
42    }
43
44    pub fn k(&self) -> &Tensor {
45        &self.k
46    }
47
48    pub fn v(&self) -> &Tensor {
49        &self.v
50    }
51}
52
53#[derive(Debug, Clone)]
54pub struct ScatteredCacheBuilder {
55    context: usize,
56    // The current position in the stream, this can be larger than context.
57    positions: Vec<usize>,
58    // The index where the next element will be stored.
59    indices: Vec<usize>,
60    dtype: DType,
61    device: Device,
62}
63
64impl ScatteredCacheBuilder {
65    pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
66        let positions = vec![0; batch_size];
67        let indices = vec![0; batch_size];
68        Ok(Self { positions, indices, context, dtype, device: device.clone() })
69    }
70
71    pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
72        let batch_size = self.batch_size();
73        let shape = (batch_size, num_heads, self.context, head_dim);
74        let k = Tensor::zeros(shape, self.dtype, self.device())?;
75        let v = Tensor::zeros(shape, self.dtype, self.device())?;
76        Ok(ScatteredKvCache { k, v, context: self.context })
77    }
78
79    pub fn positions(&self) -> &[usize] {
80        &self.positions
81    }
82
83    pub fn reset(&mut self) {
84        self.positions.fill(0);
85        self.indices.fill(0);
86    }
87
88    pub fn batch_size(&self) -> usize {
89        self.positions.len()
90    }
91
92    pub fn reset_batch_index(&mut self, batch_index: usize) {
93        self.positions[batch_index] = 0;
94        self.indices[batch_index] = 0;
95    }
96
97    #[allow(clippy::needless_range_loop)]
98    pub fn indices_and_mask(
99        &mut self,
100        seq_len: usize,
101        batch_mask: &[bool],
102    ) -> Result<IndicesAndMask> {
103        // mask shape is (b, h, t, k)
104        let context = self.context;
105        if self.context <= seq_len {
106            return self.indices_and_mask_abs(seq_len, batch_mask);
107        }
108        let mut attention_masks = Vec::with_capacity(self.batch_size());
109        let mut cache_indices = Vec::with_capacity(self.batch_size());
110        for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
111            if !batch_mask {
112                let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
113                let indices = vec![self.indices[batch_i] as u32; seq_len];
114                attention_masks.push(masks);
115                cache_indices.push(indices);
116            } else {
117                let start_index = self.indices[batch_i];
118                let start_pos = self.positions[batch_i];
119                let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
120                let mut indices = Vec::with_capacity(seq_len);
121                let mut all_pos = vec![usize::MAX; context];
122                if start_pos < context {
123                    for i in 0..start_pos {
124                        all_pos[i] = i;
125                    }
126                } else {
127                    let offset = start_pos - start_index;
128                    for i in 0..context {
129                        all_pos[i] =
130                            if i < start_index { i + offset } else { i + offset - context };
131                    }
132                }
133                for seq_i in 0..seq_len {
134                    let index = self.indices[batch_i];
135                    all_pos[index] = seq_i + start_pos;
136                    indices.push(index as u32);
137                    self.indices[batch_i] += 1;
138                    self.positions[batch_i] += 1;
139                    if self.indices[batch_i] >= self.context {
140                        self.indices[batch_i] = 0;
141                    }
142                }
143
144                for seq_i in 0..seq_len {
145                    let my_pos = seq_i + start_pos;
146                    let mask = all_pos
147                        .iter()
148                        .map(|&pos| if pos <= my_pos { 0.0 } else { f32::NEG_INFINITY })
149                        .collect::<Vec<f32>>();
150                    masks.push(mask);
151                }
152
153                attention_masks.push(masks);
154                cache_indices.push(indices);
155            }
156        }
157        // Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
158        // up being almost 10x faster with candle 0.9.0. The slowness seems to be on the CPU
159        // copies, to be further investigated.
160        let attention_masks =
161            attention_masks.into_iter().flat_map(|m| m.into_iter().flatten()).collect::<Vec<f32>>();
162        let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
163            .to_dtype(self.dtype)?;
164        let indices = Tensor::new(cache_indices, self.device())?;
165        Ok(IndicesAndMask { indices, mask })
166    }
167
168    pub fn device(&self) -> &Device {
169        &self.device
170    }
171
172    #[allow(clippy::needless_range_loop)]
173    fn indices_and_mask_abs(
174        &mut self,
175        seq_len: usize,
176        batch_mask: &[bool],
177    ) -> Result<IndicesAndMask> {
178        let mask = self.get_mask_abs(seq_len, seq_len)?;
179        let mut cache_indices = Vec::with_capacity(self.batch_size());
180        for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
181            if !batch_mask {
182                let indices = vec![self.indices[batch_i] as u32; seq_len];
183                cache_indices.push(indices);
184            } else {
185                let mut indices = Vec::with_capacity(seq_len);
186                for _ in 0..seq_len {
187                    let index = self.indices[batch_i];
188                    indices.push(index as u32);
189                    self.indices[batch_i] += 1;
190                    self.positions[batch_i] += 1;
191                    if self.indices[batch_i] >= self.context {
192                        self.indices[batch_i] = 0;
193                    }
194                }
195                cache_indices.push(indices);
196            }
197        }
198        let indices = Tensor::new(cache_indices, self.device())?;
199        Ok(IndicesAndMask { indices, mask })
200    }
201
202    fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
203        let context = self.context;
204        let mask: Vec<_> = (0..size1)
205            .flat_map(|i| {
206                (0..size2).map(move |j| {
207                    if size1 + j > size2 + i || size1 + j + context < size2 + i {
208                        f32::NEG_INFINITY
209                    } else {
210                        0.0
211                    }
212                })
213            })
214            .collect();
215        Tensor::from_slice(&mask, (size1, size2), self.device())
216    }
217}
218
219#[derive(Debug, Clone)]
220pub enum KvCache {
221    Rotating(RotatingKvCache),
222}
223
224impl KvCache {
225    pub fn new(dim: usize, max_seq_len: usize) -> Self {
226        let cache = RotatingKvCache::new(dim, max_seq_len);
227        Self::Rotating(cache)
228    }
229
230    pub fn current_seq_len(&self) -> usize {
231        match self {
232            KvCache::Rotating(cache) => cache.current_seq_len(),
233        }
234    }
235
236    pub fn reset(&mut self) {
237        match self {
238            KvCache::Rotating(cache) => cache.reset(),
239        }
240    }
241
242    pub fn append(&mut self, key: &Tensor, value: &Tensor) -> Result<(Tensor, Tensor)> {
243        match self {
244            KvCache::Rotating(cache) => cache.append(key, value),
245        }
246    }
247
248    pub fn positions(&self, seq_len: usize) -> Vec<usize> {
249        match self {
250            KvCache::Rotating(cache) => cache.positions(seq_len),
251        }
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use candle::IndexOp;
259
260    #[test]
261    fn test_scattered_kv_cache() -> Result<()> {
262        let device = Device::Cpu;
263        let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
264        let inf = f32::INFINITY;
265
266        let iam = cache.indices_and_mask(1, &[true, false])?;
267        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
268        assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
269        assert_eq!(mask, [[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]);
270
271        let iam = cache.indices_and_mask(1, &[true, false])?;
272        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
273        assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
274        assert_eq!(mask, [[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]);
275
276        let iam = cache.indices_and_mask(3, &[false, true])?;
277        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
278        assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
279        assert_eq!(
280            mask,
281            [
282                [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
283                [
284                    [0.0, -inf, -inf, -inf, -inf],
285                    [0.0, 0.0, -inf, -inf, -inf],
286                    [0.0, 0.0, 0.0, -inf, -inf]
287                ]
288            ]
289        );
290
291        let iam = cache.indices_and_mask(3, &[true, true])?;
292        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
293        assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
294        assert_eq!(
295            mask,
296            [
297                [
298                    [0.0, 0.0, 0.0, -inf, -inf],
299                    [0.0, 0.0, 0.0, 0.0, -inf],
300                    [0.0, 0.0, 0.0, 0.0, 0.0]
301                ],
302                [
303                    [-inf, 0.0, 0.0, 0.0, -inf],
304                    [-inf, 0.0, 0.0, 0.0, 0.0],
305                    [0.0, 0.0, 0.0, 0.0, 0.0]
306                ]
307            ]
308        );
309
310        let iam = cache.indices_and_mask(1, &[true, false])?;
311        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
312        assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
313        assert_eq!(mask, [[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]);
314
315        let iam = cache.indices_and_mask(2, &[true, false])?;
316        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
317        assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
318        assert_eq!(
319            mask,
320            [
321                [[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
322                [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
323            ]
324        );
325
326        Ok(())
327    }
328}