Skip to main content

candle_nn/
kv_cache.rs

1//! Cache Implementations
2//!
3use candle::{DType, Device, Result, Tensor};
4
5#[derive(Debug, Clone)]
6pub struct Cache {
7    // all_data is an option on a Tensor, this makes it possible to only create the actual tensor
8    // on the first call where the batch size is easily known.
9    // Also this makes it safe to clone a KvCache that has been reset (as in it will not share
10    // its internal state with the cloned instance).
11    all_data: Option<Tensor>,
12    dim: usize,
13    current_seq_len: usize,
14    grow_by: usize,
15    max_seq_len: usize,
16}
17
18impl Cache {
19    pub fn new(dim: usize, max_seq_len: usize) -> Self {
20        Self {
21            all_data: None,
22            dim,
23            current_seq_len: 0,
24            grow_by: max_seq_len,
25            max_seq_len,
26        }
27    }
28
29    pub fn dim(&self) -> usize {
30        self.dim
31    }
32
33    pub fn current_seq_len(&self) -> usize {
34        self.current_seq_len
35    }
36
37    pub fn max_seq_len(&self) -> usize {
38        self.max_seq_len
39    }
40
41    pub fn all_data(&self) -> &Option<Tensor> {
42        &self.all_data
43    }
44
45    pub fn current_data(&self) -> Result<Option<Tensor>> {
46        let data = match self.all_data.as_ref() {
47            None => None,
48            Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
49        };
50        Ok(data)
51    }
52
53    pub fn reset(&mut self) {
54        self.current_seq_len = 0;
55        self.all_data = None;
56    }
57
58    pub fn append(&mut self, src: &Tensor) -> Result<()> {
59        let seq_len = src.dim(self.dim)?;
60        // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
61        // self.all_data.get_or_insert_with.
62        if self.all_data.is_none() {
63            let mut shape = src.dims().to_vec();
64            shape[self.dim] = self.max_seq_len;
65            let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
66            self.all_data = Some(ad)
67        };
68        let ad = self.all_data.as_mut().unwrap();
69        while self.current_seq_len + seq_len > self.max_seq_len {
70            let mut shape = src.dims().to_vec();
71            shape[self.dim] = self.grow_by;
72            let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
73            *ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
74            self.max_seq_len += self.grow_by;
75        }
76        ad.slice_set(src, self.dim, self.current_seq_len)?;
77        self.current_seq_len += seq_len;
78        Ok(())
79    }
80}
81
82#[derive(Debug, Clone)]
83pub struct KvCache {
84    k: Cache,
85    v: Cache,
86}
87
88impl KvCache {
89    pub fn new(dim: usize, max_seq_len: usize) -> Self {
90        let k = Cache::new(dim, max_seq_len);
91        let v = Cache::new(dim, max_seq_len);
92        Self { k, v }
93    }
94
95    pub fn k_cache(&self) -> &Cache {
96        &self.k
97    }
98
99    pub fn v_cache(&self) -> &Cache {
100        &self.v
101    }
102
103    pub fn k_cache_mut(&mut self) -> &mut Cache {
104        &mut self.k
105    }
106
107    pub fn v_cache_mut(&mut self) -> &mut Cache {
108        &mut self.v
109    }
110
111    pub fn k(&self) -> Result<Option<Tensor>> {
112        self.k.current_data()
113    }
114
115    pub fn v(&self) -> Result<Option<Tensor>> {
116        self.v.current_data()
117    }
118
119    pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
120        self.k.append(k)?;
121        self.v.append(v)?;
122        let out_k = self.k.current_data()?;
123        let out_v = self.v.current_data()?;
124        let k = match out_k {
125            None => {
126                let mut shape = k.dims().to_vec();
127                shape[self.k.dim] = 0;
128                Tensor::zeros(shape, k.dtype(), k.device())?
129            }
130            Some(k) => k,
131        };
132        let v = match out_v {
133            None => {
134                let mut shape = v.dims().to_vec();
135                shape[self.k.dim] = 0;
136                Tensor::zeros(shape, v.dtype(), v.device())?
137            }
138            Some(v) => v,
139        };
140        Ok((k, v))
141    }
142
143    pub fn current_seq_len(&self) -> usize {
144        self.k.current_seq_len()
145    }
146
147    pub fn reset(&mut self) {
148        self.k.reset();
149        self.v.reset();
150    }
151}
152
153#[derive(Debug, Clone)]
154pub struct RotatingCache {
155    all_data: Option<Tensor>,
156    dim: usize,
157    // `offset` is the current write index in the buffer
158    offset: usize,
159    // The total size of the sequence seen so far.
160    current_seq_len: usize,
161    // max_seq_len is the size of the rotating buffer, it is actually allowed for the full
162    // sequence to grow past this limit.
163    max_seq_len: usize,
164}
165
166impl RotatingCache {
167    pub fn new(dim: usize, max_seq_len: usize) -> Self {
168        Self {
169            all_data: None,
170            dim,
171            offset: 0,
172            current_seq_len: 0,
173            max_seq_len,
174        }
175    }
176
177    pub fn offset(&self) -> usize {
178        self.offset
179    }
180
181    pub fn dim(&self) -> usize {
182        self.dim
183    }
184
185    pub fn current_seq_len(&self) -> usize {
186        self.current_seq_len
187    }
188
189    pub fn max_seq_len(&self) -> usize {
190        self.max_seq_len
191    }
192
193    pub fn all_data(&self) -> &Option<Tensor> {
194        &self.all_data
195    }
196
197    pub fn current_data(&self) -> Result<Option<Tensor>> {
198        let data = match self.all_data.as_ref() {
199            None => None,
200            Some(d) => {
201                if self.current_seq_len >= self.max_seq_len {
202                    Some(d.clone())
203                } else {
204                    Some(d.narrow(self.dim, 0, self.current_seq_len)?)
205                }
206            }
207        };
208        Ok(data)
209    }
210
211    pub fn reset(&mut self) {
212        self.offset = 0;
213        self.current_seq_len = 0;
214        self.all_data = None;
215    }
216
217    pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
218        let seq_len = src.dim(self.dim)?;
219        // This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
220        // self.all_data.get_or_insert_with.
221        if self.all_data.is_none() {
222            let mut shape = src.dims().to_vec();
223            shape[self.dim] = self.max_seq_len;
224            let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
225            self.all_data = Some(ad)
226        };
227        let ad = self.all_data.as_mut().unwrap();
228
229        self.current_seq_len += seq_len;
230        if seq_len >= self.max_seq_len {
231            let to_copy = src
232                .narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
233                .contiguous()?;
234            ad.slice_set(&to_copy, self.dim, 0)?;
235            self.offset = 0;
236            // Here we return `src` rather than `ad` so that all the past can be used.
237            Ok(src.clone())
238        } else {
239            let rem_len = self.max_seq_len - self.offset;
240            if seq_len <= rem_len {
241                ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
242                self.offset = (self.offset + seq_len) % self.max_seq_len;
243            } else {
244                // We have to make two copies here as we go over the boundary of the cache.
245                if rem_len > 0 {
246                    let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
247                    ad.slice_set(&src1, self.dim, self.offset)?;
248                }
249                let src2 = src
250                    .narrow(self.dim, rem_len, seq_len - rem_len)?
251                    .contiguous()?;
252                ad.slice_set(&src2, self.dim, 0)?;
253                self.offset = seq_len - rem_len;
254            }
255            if self.current_seq_len >= self.max_seq_len {
256                Ok(ad.clone())
257            } else {
258                Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
259            }
260        }
261    }
262
263    fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
264        let context = self.max_seq_len;
265        let mask: Vec<_> = (0..size1)
266            .flat_map(|i| {
267                (0..size2).map(move |j| {
268                    u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
269                })
270            })
271            .collect();
272        Tensor::from_slice(&mask, (size1, size2), device)
273    }
274
275    fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
276        let context = self.max_seq_len;
277        let upd_offset = (self.offset + size1) % self.max_seq_len;
278        let mask: Vec<_> = (0..size1)
279            .flat_map(|pos_src| {
280                // The absolute position of the elements that will get added to the cache.
281                let pos_src = self.current_seq_len + pos_src;
282                (0..size2).map(move |pos_cache_rel| {
283                    // The absolute position of the cache elements after the addition.
284                    let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
285                    let pos_cache = if pos_cache_rel < upd_offset {
286                        pos_cache
287                    } else {
288                        pos_cache - self.max_seq_len
289                    };
290                    u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
291                })
292            })
293            .collect();
294        Tensor::from_slice(&mask, (size1, size2), device)
295    }
296
297    /// Returns the positions corresponding to all the elements that will be returned
298    /// *after* adding `seq_len` to the cache.
299    pub fn positions(&self, seq_len: usize) -> Vec<usize> {
300        if seq_len <= self.max_seq_len {
301            let upd_offset = (self.offset + seq_len) % self.max_seq_len;
302            let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
303            (0..cache_out_len)
304                .map(|i| {
305                    let pos_cache = self.current_seq_len + seq_len + i - upd_offset;
306                    if i < upd_offset {
307                        pos_cache
308                    } else {
309                        pos_cache - self.max_seq_len
310                    }
311                })
312                .collect()
313        } else {
314            (self.current_seq_len..(self.current_seq_len + seq_len)).collect()
315        }
316    }
317
318    /// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
319    pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
320        let mask = if seq_len == 1 {
321            None
322        } else {
323            let mask = if seq_len < self.max_seq_len {
324                let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
325                self.get_mask_rel(seq_len, cache_out_len, device)?
326            } else {
327                self.get_mask_abs(seq_len, seq_len, device)?
328            };
329            Some(mask)
330        };
331        Ok(mask)
332    }
333}
334
335#[derive(Debug, Clone)]
336pub struct RotatingKvCache {
337    k: RotatingCache,
338    v: RotatingCache,
339}
340
341impl RotatingKvCache {
342    pub fn new(dim: usize, max_seq_len: usize) -> Self {
343        let k = RotatingCache::new(dim, max_seq_len);
344        let v = RotatingCache::new(dim, max_seq_len);
345        Self { k, v }
346    }
347
348    pub fn k_cache(&self) -> &RotatingCache {
349        &self.k
350    }
351
352    pub fn v_cache(&self) -> &RotatingCache {
353        &self.v
354    }
355
356    pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
357        &mut self.k
358    }
359
360    pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
361        &mut self.v
362    }
363
364    pub fn k(&self) -> Result<Option<Tensor>> {
365        self.k.current_data()
366    }
367
368    pub fn v(&self) -> Result<Option<Tensor>> {
369        self.v.current_data()
370    }
371
372    pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
373        let out_k = self.k.append(k)?;
374        let out_v = self.v.append(v)?;
375        Ok((out_k, out_v))
376    }
377
378    pub fn offset(&self) -> usize {
379        self.k.offset()
380    }
381
382    pub fn current_seq_len(&self) -> usize {
383        self.k.current_seq_len()
384    }
385
386    /// Returns the attn_mask to be applied *after* adding `seq_len` to the cache.
387    pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
388        self.k.attn_mask(seq_len, device)
389    }
390
391    /// Returns the positions corresponding to all the elements that will be returned
392    /// *after* adding `seq_len` to the cache.
393    pub fn positions(&self, seq_len: usize) -> Vec<usize> {
394        self.k.positions(seq_len)
395    }
396
397    pub fn reset(&mut self) {
398        self.k.reset();
399        self.v.reset();
400    }
401}
402
403#[derive(Debug, Clone)]
404pub struct IndicesAndMask {
405    indices: Tensor,
406    mask: Tensor,
407}
408
409impl IndicesAndMask {
410    pub fn mask(&self) -> &Tensor {
411        &self.mask
412    }
413}
414
415#[derive(Debug, Clone)]
416pub struct ScatteredKvCache {
417    k: Tensor,
418    v: Tensor,
419    context: usize,
420}
421
422impl ScatteredKvCache {
423    pub fn append(
424        &mut self,
425        k: &Tensor,
426        v: &Tensor,
427        iam: &IndicesAndMask,
428    ) -> Result<(Tensor, Tensor)> {
429        if self.context <= k.dim(2)? {
430            return Ok((k.clone(), v.clone()));
431        }
432        let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
433        let indices = indices.broadcast_as(k.shape())?.contiguous()?;
434        self.k.scatter_set(&indices, k, 2)?;
435        self.v.scatter_set(&indices, v, 2)?;
436        Ok((self.k.clone(), self.v.clone()))
437    }
438
439    pub fn k(&self) -> &Tensor {
440        &self.k
441    }
442
443    pub fn v(&self) -> &Tensor {
444        &self.v
445    }
446}
447
448#[derive(Debug, Clone)]
449pub struct ScatteredCacheBuilder {
450    context: usize,
451    // The current position in the stream, this can be larger than context.
452    positions: Vec<usize>,
453    // The index where the next element will be stored.
454    indices: Vec<usize>,
455    dtype: DType,
456    device: Device,
457}
458
459impl ScatteredCacheBuilder {
460    pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
461        let positions = vec![0; batch_size];
462        let indices = vec![0; batch_size];
463        Ok(Self {
464            positions,
465            indices,
466            context,
467            dtype,
468            device: device.clone(),
469        })
470    }
471
472    pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
473        let batch_size = self.batch_size();
474        let shape = (batch_size, num_heads, self.context, head_dim);
475        let k = Tensor::zeros(shape, self.dtype, self.device())?;
476        let v = Tensor::zeros(shape, self.dtype, self.device())?;
477        Ok(ScatteredKvCache {
478            k,
479            v,
480            context: self.context,
481        })
482    }
483
484    pub fn positions(&self) -> &[usize] {
485        &self.positions
486    }
487
488    pub fn reset(&mut self) {
489        self.positions.fill(0);
490        self.indices.fill(0);
491    }
492
493    pub fn batch_size(&self) -> usize {
494        self.positions.len()
495    }
496
497    pub fn reset_batch_index(&mut self, batch_index: usize) {
498        self.positions[batch_index] = 0;
499        self.indices[batch_index] = 0;
500    }
501
502    #[allow(clippy::needless_range_loop)]
503    pub fn indices_and_mask(
504        &mut self,
505        seq_len: usize,
506        batch_mask: &[bool],
507    ) -> Result<IndicesAndMask> {
508        // mask shape is (b, h, t, k)
509        let context = self.context;
510        if self.context <= seq_len {
511            return self.indices_and_mask_abs(seq_len, batch_mask);
512        }
513        let mut attention_masks = Vec::with_capacity(self.batch_size());
514        let mut cache_indices = Vec::with_capacity(self.batch_size());
515        for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
516            if !batch_mask {
517                let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
518                let indices = vec![self.indices[batch_i] as u32; seq_len];
519                attention_masks.push(masks);
520                cache_indices.push(indices);
521            } else {
522                let start_index = self.indices[batch_i];
523                let start_pos = self.positions[batch_i];
524                let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
525                let mut indices = Vec::with_capacity(seq_len);
526                let mut all_pos = vec![usize::MAX; context];
527                if start_pos < context {
528                    for i in 0..start_pos {
529                        all_pos[i] = i;
530                    }
531                } else {
532                    let offset = start_pos - start_index;
533                    for i in 0..context {
534                        all_pos[i] = if i < start_index {
535                            i + offset
536                        } else {
537                            i + offset - context
538                        };
539                    }
540                }
541                for seq_i in 0..seq_len {
542                    let index = self.indices[batch_i];
543                    all_pos[index] = seq_i + start_pos;
544                    indices.push(index as u32);
545                    self.indices[batch_i] += 1;
546                    self.positions[batch_i] += 1;
547                    if self.indices[batch_i] >= self.context {
548                        self.indices[batch_i] = 0;
549                    }
550                }
551
552                for seq_i in 0..seq_len {
553                    let my_pos = seq_i + start_pos;
554                    let mask = all_pos
555                        .iter()
556                        .map(|&pos| {
557                            if pos <= my_pos {
558                                0.0
559                            } else {
560                                f32::NEG_INFINITY
561                            }
562                        })
563                        .collect::<Vec<f32>>();
564                    masks.push(mask);
565                }
566
567                attention_masks.push(masks);
568                cache_indices.push(indices);
569            }
570        }
571        // Flattening the attention mask then using Tensor::from_vec rather using Tensor::new ends
572        // up being almost 10x faster with candle 0.9.0. This has been fixed in candle 0.9.1.
573        let attention_masks = attention_masks
574            .into_iter()
575            .flat_map(|m| m.into_iter().flatten())
576            .collect::<Vec<f32>>();
577        let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
578            .to_dtype(self.dtype)?;
579        let indices = Tensor::new(cache_indices, self.device())?;
580        Ok(IndicesAndMask { indices, mask })
581    }
582
583    pub fn device(&self) -> &Device {
584        &self.device
585    }
586
587    #[allow(clippy::needless_range_loop)]
588    fn indices_and_mask_abs(
589        &mut self,
590        seq_len: usize,
591        batch_mask: &[bool],
592    ) -> Result<IndicesAndMask> {
593        let mask = self.get_mask_abs(seq_len, seq_len)?;
594        let mut cache_indices = Vec::with_capacity(self.batch_size());
595        for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
596            if !batch_mask {
597                let indices = vec![self.indices[batch_i] as u32; seq_len];
598                cache_indices.push(indices);
599            } else {
600                let mut indices = Vec::with_capacity(seq_len);
601                for _ in 0..seq_len {
602                    let index = self.indices[batch_i];
603                    indices.push(index as u32);
604                    self.indices[batch_i] += 1;
605                    self.positions[batch_i] += 1;
606                    if self.indices[batch_i] >= self.context {
607                        self.indices[batch_i] = 0;
608                    }
609                }
610                cache_indices.push(indices);
611            }
612        }
613        let indices = Tensor::new(cache_indices, self.device())?;
614        Ok(IndicesAndMask { indices, mask })
615    }
616
617    fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
618        let context = self.context;
619        let mask: Vec<_> = (0..size1)
620            .flat_map(|i| {
621                (0..size2).map(move |j| {
622                    if size1 + j > size2 + i || size1 + j + context < size2 + i {
623                        f32::NEG_INFINITY
624                    } else {
625                        0.0
626                    }
627                })
628            })
629            .collect();
630        Tensor::from_slice(&mask, (size1, size2), self.device())
631    }
632}
633
634/// KV-Cache using concatenation for append operations
635///
636/// This implementation uses `Tensor::cat` instead of `slice_set` for updates,
637/// providing significant GPU performance improvements for autoregressive generation.
638///
639/// # When to Use
640///
641/// **Recommended for:**
642/// - GPU inference (CUDA, Metal)
643/// - Autoregressive generation (token-by-token decoding)
644///
645/// **Use `KvCache` instead for:**
646/// - CPU-only inference
647/// - When you need fixed memory allocation upfront
648///
649/// # Example
650///
651/// ```ignore
652/// use candle_nn::kv_cache::ConcatKvCache;
653///
654/// let mut cache = ConcatKvCache::new(2); // dim=2 for sequence dimension
655///
656/// // First token (prefill)
657/// let k1 = Tensor::randn(0f32, 1., (1, 8, 10, 64), &device)?;
658/// let v1 = Tensor::randn(0f32, 1., (1, 8, 10, 64), &device)?;
659/// let (k, v) = cache.append(&k1, &v1)?;
660///
661/// // Subsequent tokens (decode)
662/// let k_new = Tensor::randn(0f32, 1., (1, 8, 1, 64), &device)?;
663/// let v_new = Tensor::randn(0f32, 1., (1, 8, 1, 64), &device)?;
664/// let (k, v) = cache.append(&k_new, &v_new)?;
665/// ```
666#[derive(Debug, Clone)]
667pub struct ConcatKvCache {
668    k: Option<Tensor>,
669    v: Option<Tensor>,
670    dim: usize,
671}
672
673impl ConcatKvCache {
674    /// Create a new empty concatenation-based KV-cache
675    ///
676    /// # Arguments
677    /// * `dim` - The dimension along which to concatenate
678    ///   - For attention with shape `[batch, heads, seq, head_dim]`, use `dim=2`
679    ///   - For attention with shape `[batch, seq, heads, head_dim]`, use `dim=1`
680    ///
681    /// # Example
682    /// ```ignore
683    /// // For standard transformer attention: [B, H, S, D]
684    /// let cache = ConcatKvCache::new(2);
685    /// ```
686    pub fn new(dim: usize) -> Self {
687        Self {
688            k: None,
689            v: None,
690            dim,
691        }
692    }
693
694    /// Get current sequence length in the cache
695    ///
696    /// Returns 0 if the cache is empty.
697    pub fn current_seq_len(&self) -> usize {
698        self.k
699            .as_ref()
700            .and_then(|k| k.dims().get(self.dim).copied())
701            .unwrap_or(0)
702    }
703
704    /// Check if cache is empty
705    pub fn is_empty(&self) -> bool {
706        self.k.is_none()
707    }
708
709    /// Get the concatenation dimension
710    pub fn dim(&self) -> usize {
711        self.dim
712    }
713
714    /// Append key and value tensors to the cache
715    ///
716    /// This is the core operation that uses optimized concatenation kernels.
717    ///
718    /// # Arguments
719    /// * `k` - Key tensor to append (shape: [..., seq_len, ...])
720    /// * `v` - Value tensor to append (shape: [..., seq_len, ...])
721    ///
722    /// # Returns
723    /// Tuple of `(full_k, full_v)` containing all cached keys and values,
724    /// including the newly appended data.
725    pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
726        // Ensure inputs are contiguous for optimal concatenation performance
727        let k = k.contiguous()?;
728        let v = v.contiguous()?;
729        // Update K cache using concatenation
730        self.k = Some(match &self.k {
731            None => k.clone(),
732            Some(k_cache) => {
733                // Concatenate along the sequence dimension
734                // GPU kernel for cat is highly optimized:
735                // - Fused allocation + copy
736                // - Coalesced memory access
737                // - Single kernel launch
738                Tensor::cat(&[k_cache, &k], self.dim)?
739            }
740        });
741
742        // Update V cache using concatenation
743        self.v = Some(match &self.v {
744            None => v.clone(),
745            Some(v_cache) => Tensor::cat(&[v_cache, &v], self.dim)?,
746        });
747
748        Ok((
749            self.k.as_ref().unwrap().clone(),
750            self.v.as_ref().unwrap().clone(),
751        ))
752    }
753
754    /// Reset the cache (clear all stored keys and values)
755    ///
756    /// After calling this, `is_empty()` will return `true` and
757    /// `current_seq_len()` will return 0.
758    pub fn reset(&mut self) {
759        self.k = None;
760        self.v = None;
761    }
762
763    /// Get reference to current K cache data
764    ///
765    /// Returns `None` if the cache is empty.
766    pub fn k(&self) -> Option<&Tensor> {
767        self.k.as_ref()
768    }
769
770    /// Get reference to current V cache data
771    ///
772    /// Returns `None` if the cache is empty.
773    pub fn v(&self) -> Option<&Tensor> {
774        self.v.as_ref()
775    }
776
777    /// Get mutable reference to K cache data
778    ///
779    /// Returns `None` if the cache is empty.
780    pub fn k_mut(&mut self) -> Option<&mut Tensor> {
781        self.k.as_mut()
782    }
783
784    /// Get mutable reference to V cache data
785    ///
786    /// Returns `None` if the cache is empty.
787    pub fn v_mut(&mut self) -> Option<&mut Tensor> {
788        self.v.as_mut()
789    }
790
791    /// Get owned K and V tensors, consuming the cache
792    ///
793    /// Returns `None` if the cache is empty.
794    pub fn into_inner(self) -> Option<(Tensor, Tensor)> {
795        match (self.k, self.v) {
796            (Some(k), Some(v)) => Some((k, v)),
797            _ => None,
798        }
799    }
800}
801
802#[cfg(test)]
803mod tests {
804    use super::*;
805    use candle::IndexOp;
806
807    #[test]
808    fn test_scattered_kv_cache() -> Result<()> {
809        let device = Device::Cpu;
810        let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
811        let inf = f32::INFINITY;
812
813        let iam = cache.indices_and_mask(1, &[true, false])?;
814        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
815        assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
816        assert_eq!(
817            mask,
818            [[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
819        );
820
821        let iam = cache.indices_and_mask(1, &[true, false])?;
822        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
823        assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
824        assert_eq!(
825            mask,
826            [[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
827        );
828
829        let iam = cache.indices_and_mask(3, &[false, true])?;
830        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
831        assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
832        assert_eq!(
833            mask,
834            [
835                [
836                    [0.0, 0.0, 0.0, 0.0, 0.0],
837                    [0.0, 0.0, 0.0, 0.0, 0.0],
838                    [0.0, 0.0, 0.0, 0.0, 0.0]
839                ],
840                [
841                    [0.0, -inf, -inf, -inf, -inf],
842                    [0.0, 0.0, -inf, -inf, -inf],
843                    [0.0, 0.0, 0.0, -inf, -inf]
844                ]
845            ]
846        );
847
848        let iam = cache.indices_and_mask(3, &[true, true])?;
849        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
850        assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
851        assert_eq!(
852            mask,
853            [
854                [
855                    [0.0, 0.0, 0.0, -inf, -inf],
856                    [0.0, 0.0, 0.0, 0.0, -inf],
857                    [0.0, 0.0, 0.0, 0.0, 0.0]
858                ],
859                [
860                    [-inf, 0.0, 0.0, 0.0, -inf],
861                    [-inf, 0.0, 0.0, 0.0, 0.0],
862                    [0.0, 0.0, 0.0, 0.0, 0.0]
863                ]
864            ]
865        );
866
867        let iam = cache.indices_and_mask(1, &[true, false])?;
868        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
869        assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
870        assert_eq!(
871            mask,
872            [[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
873        );
874
875        let iam = cache.indices_and_mask(2, &[true, false])?;
876        let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
877        assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
878        assert_eq!(
879            mask,
880            [
881                [[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
882                [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
883            ]
884        );
885
886        Ok(())
887    }
888
889    #[test]
890    fn test_concat_cache_basic() -> Result<()> {
891        let device = Device::Cpu;
892        let mut cache = ConcatKvCache::new(2);
893
894        assert!(cache.is_empty());
895        assert_eq!(cache.current_seq_len(), 0);
896
897        // First append
898        let k1 = Tensor::zeros((1, 8, 3, 64), DType::F32, &device)?;
899        let v1 = Tensor::zeros((1, 8, 3, 64), DType::F32, &device)?;
900        let (k, v) = cache.append(&k1, &v1)?;
901
902        assert_eq!(k.dims(), &[1, 8, 3, 64]);
903        assert_eq!(v.dims(), &[1, 8, 3, 64]);
904        assert_eq!(cache.current_seq_len(), 3);
905        assert!(!cache.is_empty());
906
907        // Second append
908        let k2 = Tensor::zeros((1, 8, 2, 64), DType::F32, &device)?;
909        let v2 = Tensor::zeros((1, 8, 2, 64), DType::F32, &device)?;
910        let (k, v) = cache.append(&k2, &v2)?;
911
912        assert_eq!(k.dims(), &[1, 8, 5, 64]); // 3 + 2
913        assert_eq!(v.dims(), &[1, 8, 5, 64]);
914        assert_eq!(cache.current_seq_len(), 5);
915
916        Ok(())
917    }
918
919    #[test]
920    fn test_concat_cache_reset() -> Result<()> {
921        let device = Device::Cpu;
922        let mut cache = ConcatKvCache::new(2);
923
924        let k = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
925        let v = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
926        cache.append(&k, &v)?;
927
928        assert_eq!(cache.current_seq_len(), 10);
929
930        cache.reset();
931
932        assert!(cache.is_empty());
933        assert_eq!(cache.current_seq_len(), 0);
934        assert!(cache.k().is_none());
935        assert!(cache.v().is_none());
936
937        Ok(())
938    }
939
940    #[test]
941    fn test_concat_cache_multiple_appends() -> Result<()> {
942        let device = Device::Cpu;
943        let mut cache = ConcatKvCache::new(2);
944
945        // Simulate autoregressive generation
946        let k_prefill = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
947        let v_prefill = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
948        cache.append(&k_prefill, &v_prefill)?;
949
950        assert_eq!(cache.current_seq_len(), 10);
951
952        // Decode phase: append one token at a time
953        for i in 1..=5 {
954            let k_token = Tensor::zeros((1, 8, 1, 64), DType::F32, &device)?;
955            let v_token = Tensor::zeros((1, 8, 1, 64), DType::F32, &device)?;
956            let (k, v) = cache.append(&k_token, &v_token)?;
957            assert_eq!(k.dims()[2], 10 + i);
958            assert_eq!(v.dims()[2], 10 + i);
959        }
960
961        assert_eq!(cache.current_seq_len(), 15);
962
963        Ok(())
964    }
965
966    #[test]
967    fn test_concat_cache_different_dim() -> Result<()> {
968        let device = Device::Cpu;
969        let mut cache = ConcatKvCache::new(1); // Concatenate on dim 1 instead of 2
970
971        let k1 = Tensor::zeros((1, 3, 8, 64), DType::F32, &device)?;
972        let v1 = Tensor::zeros((1, 3, 8, 64), DType::F32, &device)?;
973        let (k, _v) = cache.append(&k1, &v1)?;
974
975        assert_eq!(k.dims(), &[1, 3, 8, 64]);
976
977        let k2 = Tensor::zeros((1, 2, 8, 64), DType::F32, &device)?;
978        let v2 = Tensor::zeros((1, 2, 8, 64), DType::F32, &device)?;
979        let (k, _v) = cache.append(&k2, &v2)?;
980
981        assert_eq!(k.dims(), &[1, 5, 8, 64]); // Concatenated on dim 1
982        assert_eq!(cache.current_seq_len(), 5);
983
984        Ok(())
985    }
986}