1use candle::{DType, Device, Result, Tensor};
4
5#[derive(Debug, Clone)]
6pub struct Cache {
7 all_data: Option<Tensor>,
12 dim: usize,
13 current_seq_len: usize,
14 grow_by: usize,
15 max_seq_len: usize,
16}
17
18impl Cache {
19 pub fn new(dim: usize, max_seq_len: usize) -> Self {
20 Self {
21 all_data: None,
22 dim,
23 current_seq_len: 0,
24 grow_by: max_seq_len,
25 max_seq_len,
26 }
27 }
28
29 pub fn dim(&self) -> usize {
30 self.dim
31 }
32
33 pub fn current_seq_len(&self) -> usize {
34 self.current_seq_len
35 }
36
37 pub fn max_seq_len(&self) -> usize {
38 self.max_seq_len
39 }
40
41 pub fn all_data(&self) -> &Option<Tensor> {
42 &self.all_data
43 }
44
45 pub fn current_data(&self) -> Result<Option<Tensor>> {
46 let data = match self.all_data.as_ref() {
47 None => None,
48 Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
49 };
50 Ok(data)
51 }
52
53 pub fn reset(&mut self) {
54 self.current_seq_len = 0;
55 self.all_data = None;
56 }
57
58 pub fn append(&mut self, src: &Tensor) -> Result<()> {
59 let seq_len = src.dim(self.dim)?;
60 if self.all_data.is_none() {
63 let mut shape = src.dims().to_vec();
64 shape[self.dim] = self.max_seq_len;
65 let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
66 self.all_data = Some(ad)
67 };
68 let ad = self.all_data.as_mut().unwrap();
69 while self.current_seq_len + seq_len > self.max_seq_len {
70 let mut shape = src.dims().to_vec();
71 shape[self.dim] = self.grow_by;
72 let next_ad = Tensor::zeros(shape, src.dtype(), src.device())?;
73 *ad = Tensor::cat(&[&*ad, &next_ad], self.dim)?;
74 self.max_seq_len += self.grow_by;
75 }
76 ad.slice_set(src, self.dim, self.current_seq_len)?;
77 self.current_seq_len += seq_len;
78 Ok(())
79 }
80}
81
82#[derive(Debug, Clone)]
83pub struct KvCache {
84 k: Cache,
85 v: Cache,
86}
87
88impl KvCache {
89 pub fn new(dim: usize, max_seq_len: usize) -> Self {
90 let k = Cache::new(dim, max_seq_len);
91 let v = Cache::new(dim, max_seq_len);
92 Self { k, v }
93 }
94
95 pub fn k_cache(&self) -> &Cache {
96 &self.k
97 }
98
99 pub fn v_cache(&self) -> &Cache {
100 &self.v
101 }
102
103 pub fn k_cache_mut(&mut self) -> &mut Cache {
104 &mut self.k
105 }
106
107 pub fn v_cache_mut(&mut self) -> &mut Cache {
108 &mut self.v
109 }
110
111 pub fn k(&self) -> Result<Option<Tensor>> {
112 self.k.current_data()
113 }
114
115 pub fn v(&self) -> Result<Option<Tensor>> {
116 self.v.current_data()
117 }
118
119 pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
120 self.k.append(k)?;
121 self.v.append(v)?;
122 let out_k = self.k.current_data()?;
123 let out_v = self.v.current_data()?;
124 let k = match out_k {
125 None => {
126 let mut shape = k.dims().to_vec();
127 shape[self.k.dim] = 0;
128 Tensor::zeros(shape, k.dtype(), k.device())?
129 }
130 Some(k) => k,
131 };
132 let v = match out_v {
133 None => {
134 let mut shape = v.dims().to_vec();
135 shape[self.k.dim] = 0;
136 Tensor::zeros(shape, v.dtype(), v.device())?
137 }
138 Some(v) => v,
139 };
140 Ok((k, v))
141 }
142
143 pub fn current_seq_len(&self) -> usize {
144 self.k.current_seq_len()
145 }
146
147 pub fn reset(&mut self) {
148 self.k.reset();
149 self.v.reset();
150 }
151}
152
153#[derive(Debug, Clone)]
154pub struct RotatingCache {
155 all_data: Option<Tensor>,
156 dim: usize,
157 offset: usize,
159 current_seq_len: usize,
161 max_seq_len: usize,
164}
165
166impl RotatingCache {
167 pub fn new(dim: usize, max_seq_len: usize) -> Self {
168 Self {
169 all_data: None,
170 dim,
171 offset: 0,
172 current_seq_len: 0,
173 max_seq_len,
174 }
175 }
176
177 pub fn offset(&self) -> usize {
178 self.offset
179 }
180
181 pub fn dim(&self) -> usize {
182 self.dim
183 }
184
185 pub fn current_seq_len(&self) -> usize {
186 self.current_seq_len
187 }
188
189 pub fn max_seq_len(&self) -> usize {
190 self.max_seq_len
191 }
192
193 pub fn all_data(&self) -> &Option<Tensor> {
194 &self.all_data
195 }
196
197 pub fn current_data(&self) -> Result<Option<Tensor>> {
198 let data = match self.all_data.as_ref() {
199 None => None,
200 Some(d) => {
201 if self.current_seq_len >= self.max_seq_len {
202 Some(d.clone())
203 } else {
204 Some(d.narrow(self.dim, 0, self.current_seq_len)?)
205 }
206 }
207 };
208 Ok(data)
209 }
210
211 pub fn reset(&mut self) {
212 self.offset = 0;
213 self.current_seq_len = 0;
214 self.all_data = None;
215 }
216
217 pub fn append(&mut self, src: &Tensor) -> Result<Tensor> {
218 let seq_len = src.dim(self.dim)?;
219 if self.all_data.is_none() {
222 let mut shape = src.dims().to_vec();
223 shape[self.dim] = self.max_seq_len;
224 let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
225 self.all_data = Some(ad)
226 };
227 let ad = self.all_data.as_mut().unwrap();
228
229 self.current_seq_len += seq_len;
230 if seq_len >= self.max_seq_len {
231 let to_copy = src
232 .narrow(self.dim, seq_len - self.max_seq_len, self.max_seq_len)?
233 .contiguous()?;
234 ad.slice_set(&to_copy, self.dim, 0)?;
235 self.offset = 0;
236 Ok(src.clone())
238 } else {
239 let rem_len = self.max_seq_len - self.offset;
240 if seq_len <= rem_len {
241 ad.slice_set(&src.contiguous()?, self.dim, self.offset)?;
242 self.offset = (self.offset + seq_len) % self.max_seq_len;
243 } else {
244 if rem_len > 0 {
246 let src1 = src.narrow(self.dim, 0, rem_len)?.contiguous()?;
247 ad.slice_set(&src1, self.dim, self.offset)?;
248 }
249 let src2 = src
250 .narrow(self.dim, rem_len, seq_len - rem_len)?
251 .contiguous()?;
252 ad.slice_set(&src2, self.dim, 0)?;
253 self.offset = seq_len - rem_len;
254 }
255 if self.current_seq_len >= self.max_seq_len {
256 Ok(ad.clone())
257 } else {
258 Ok(ad.narrow(self.dim, 0, self.current_seq_len)?)
259 }
260 }
261 }
262
263 fn get_mask_abs(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
264 let context = self.max_seq_len;
265 let mask: Vec<_> = (0..size1)
266 .flat_map(|i| {
267 (0..size2).map(move |j| {
268 u8::from(size1 + j > size2 + i || size1 + j + context < size2 + i)
269 })
270 })
271 .collect();
272 Tensor::from_slice(&mask, (size1, size2), device)
273 }
274
275 fn get_mask_rel(&self, size1: usize, size2: usize, device: &Device) -> Result<Tensor> {
276 let context = self.max_seq_len;
277 let upd_offset = (self.offset + size1) % self.max_seq_len;
278 let mask: Vec<_> = (0..size1)
279 .flat_map(|pos_src| {
280 let pos_src = self.current_seq_len + pos_src;
282 (0..size2).map(move |pos_cache_rel| {
283 let pos_cache = self.current_seq_len + size1 + pos_cache_rel - upd_offset;
285 let pos_cache = if pos_cache_rel < upd_offset {
286 pos_cache
287 } else {
288 pos_cache - self.max_seq_len
289 };
290 u8::from(pos_cache > pos_src || pos_cache + context < pos_src)
291 })
292 })
293 .collect();
294 Tensor::from_slice(&mask, (size1, size2), device)
295 }
296
297 pub fn positions(&self, seq_len: usize) -> Vec<usize> {
300 if seq_len <= self.max_seq_len {
301 let upd_offset = (self.offset + seq_len) % self.max_seq_len;
302 let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
303 (0..cache_out_len)
304 .map(|i| {
305 let pos_cache = self.current_seq_len + seq_len + i - upd_offset;
306 if i < upd_offset {
307 pos_cache
308 } else {
309 pos_cache - self.max_seq_len
310 }
311 })
312 .collect()
313 } else {
314 (self.current_seq_len..(self.current_seq_len + seq_len)).collect()
315 }
316 }
317
318 pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
320 let mask = if seq_len == 1 {
321 None
322 } else {
323 let mask = if seq_len < self.max_seq_len {
324 let cache_out_len = (self.current_seq_len + seq_len).min(self.max_seq_len);
325 self.get_mask_rel(seq_len, cache_out_len, device)?
326 } else {
327 self.get_mask_abs(seq_len, seq_len, device)?
328 };
329 Some(mask)
330 };
331 Ok(mask)
332 }
333}
334
335#[derive(Debug, Clone)]
336pub struct RotatingKvCache {
337 k: RotatingCache,
338 v: RotatingCache,
339}
340
341impl RotatingKvCache {
342 pub fn new(dim: usize, max_seq_len: usize) -> Self {
343 let k = RotatingCache::new(dim, max_seq_len);
344 let v = RotatingCache::new(dim, max_seq_len);
345 Self { k, v }
346 }
347
348 pub fn k_cache(&self) -> &RotatingCache {
349 &self.k
350 }
351
352 pub fn v_cache(&self) -> &RotatingCache {
353 &self.v
354 }
355
356 pub fn k_cache_mut(&mut self) -> &mut RotatingCache {
357 &mut self.k
358 }
359
360 pub fn v_cache_mut(&mut self) -> &mut RotatingCache {
361 &mut self.v
362 }
363
364 pub fn k(&self) -> Result<Option<Tensor>> {
365 self.k.current_data()
366 }
367
368 pub fn v(&self) -> Result<Option<Tensor>> {
369 self.v.current_data()
370 }
371
372 pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
373 let out_k = self.k.append(k)?;
374 let out_v = self.v.append(v)?;
375 Ok((out_k, out_v))
376 }
377
378 pub fn offset(&self) -> usize {
379 self.k.offset()
380 }
381
382 pub fn current_seq_len(&self) -> usize {
383 self.k.current_seq_len()
384 }
385
386 pub fn attn_mask(&self, seq_len: usize, device: &Device) -> Result<Option<Tensor>> {
388 self.k.attn_mask(seq_len, device)
389 }
390
391 pub fn positions(&self, seq_len: usize) -> Vec<usize> {
394 self.k.positions(seq_len)
395 }
396
397 pub fn reset(&mut self) {
398 self.k.reset();
399 self.v.reset();
400 }
401}
402
403#[derive(Debug, Clone)]
404pub struct IndicesAndMask {
405 indices: Tensor,
406 mask: Tensor,
407}
408
409impl IndicesAndMask {
410 pub fn mask(&self) -> &Tensor {
411 &self.mask
412 }
413}
414
415#[derive(Debug, Clone)]
416pub struct ScatteredKvCache {
417 k: Tensor,
418 v: Tensor,
419 context: usize,
420}
421
422impl ScatteredKvCache {
423 pub fn append(
424 &mut self,
425 k: &Tensor,
426 v: &Tensor,
427 iam: &IndicesAndMask,
428 ) -> Result<(Tensor, Tensor)> {
429 if self.context <= k.dim(2)? {
430 return Ok((k.clone(), v.clone()));
431 }
432 let indices = iam.indices.unsqueeze(2)?.unsqueeze(1)?;
433 let indices = indices.broadcast_as(k.shape())?.contiguous()?;
434 self.k.scatter_set(&indices, k, 2)?;
435 self.v.scatter_set(&indices, v, 2)?;
436 Ok((self.k.clone(), self.v.clone()))
437 }
438
439 pub fn k(&self) -> &Tensor {
440 &self.k
441 }
442
443 pub fn v(&self) -> &Tensor {
444 &self.v
445 }
446}
447
448#[derive(Debug, Clone)]
449pub struct ScatteredCacheBuilder {
450 context: usize,
451 positions: Vec<usize>,
453 indices: Vec<usize>,
455 dtype: DType,
456 device: Device,
457}
458
459impl ScatteredCacheBuilder {
460 pub fn new(batch_size: usize, context: usize, dtype: DType, device: &Device) -> Result<Self> {
461 let positions = vec![0; batch_size];
462 let indices = vec![0; batch_size];
463 Ok(Self {
464 positions,
465 indices,
466 context,
467 dtype,
468 device: device.clone(),
469 })
470 }
471
472 pub fn make_cache(&self, num_heads: usize, head_dim: usize) -> Result<ScatteredKvCache> {
473 let batch_size = self.batch_size();
474 let shape = (batch_size, num_heads, self.context, head_dim);
475 let k = Tensor::zeros(shape, self.dtype, self.device())?;
476 let v = Tensor::zeros(shape, self.dtype, self.device())?;
477 Ok(ScatteredKvCache {
478 k,
479 v,
480 context: self.context,
481 })
482 }
483
484 pub fn positions(&self) -> &[usize] {
485 &self.positions
486 }
487
488 pub fn reset(&mut self) {
489 self.positions.fill(0);
490 self.indices.fill(0);
491 }
492
493 pub fn batch_size(&self) -> usize {
494 self.positions.len()
495 }
496
497 pub fn reset_batch_index(&mut self, batch_index: usize) {
498 self.positions[batch_index] = 0;
499 self.indices[batch_index] = 0;
500 }
501
502 #[allow(clippy::needless_range_loop)]
503 pub fn indices_and_mask(
504 &mut self,
505 seq_len: usize,
506 batch_mask: &[bool],
507 ) -> Result<IndicesAndMask> {
508 let context = self.context;
510 if self.context <= seq_len {
511 return self.indices_and_mask_abs(seq_len, batch_mask);
512 }
513 let mut attention_masks = Vec::with_capacity(self.batch_size());
514 let mut cache_indices = Vec::with_capacity(self.batch_size());
515 for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
516 if !batch_mask {
517 let masks: Vec<Vec<f32>> = vec![vec![0.0; context]; seq_len];
518 let indices = vec![self.indices[batch_i] as u32; seq_len];
519 attention_masks.push(masks);
520 cache_indices.push(indices);
521 } else {
522 let start_index = self.indices[batch_i];
523 let start_pos = self.positions[batch_i];
524 let mut masks: Vec<Vec<f32>> = Vec::with_capacity(seq_len);
525 let mut indices = Vec::with_capacity(seq_len);
526 let mut all_pos = vec![usize::MAX; context];
527 if start_pos < context {
528 for i in 0..start_pos {
529 all_pos[i] = i;
530 }
531 } else {
532 let offset = start_pos - start_index;
533 for i in 0..context {
534 all_pos[i] = if i < start_index {
535 i + offset
536 } else {
537 i + offset - context
538 };
539 }
540 }
541 for seq_i in 0..seq_len {
542 let index = self.indices[batch_i];
543 all_pos[index] = seq_i + start_pos;
544 indices.push(index as u32);
545 self.indices[batch_i] += 1;
546 self.positions[batch_i] += 1;
547 if self.indices[batch_i] >= self.context {
548 self.indices[batch_i] = 0;
549 }
550 }
551
552 for seq_i in 0..seq_len {
553 let my_pos = seq_i + start_pos;
554 let mask = all_pos
555 .iter()
556 .map(|&pos| {
557 if pos <= my_pos {
558 0.0
559 } else {
560 f32::NEG_INFINITY
561 }
562 })
563 .collect::<Vec<f32>>();
564 masks.push(mask);
565 }
566
567 attention_masks.push(masks);
568 cache_indices.push(indices);
569 }
570 }
571 let attention_masks = attention_masks
574 .into_iter()
575 .flat_map(|m| m.into_iter().flatten())
576 .collect::<Vec<f32>>();
577 let mask = Tensor::from_vec(attention_masks, ((), 1, seq_len, context), self.device())?
578 .to_dtype(self.dtype)?;
579 let indices = Tensor::new(cache_indices, self.device())?;
580 Ok(IndicesAndMask { indices, mask })
581 }
582
583 pub fn device(&self) -> &Device {
584 &self.device
585 }
586
587 #[allow(clippy::needless_range_loop)]
588 fn indices_and_mask_abs(
589 &mut self,
590 seq_len: usize,
591 batch_mask: &[bool],
592 ) -> Result<IndicesAndMask> {
593 let mask = self.get_mask_abs(seq_len, seq_len)?;
594 let mut cache_indices = Vec::with_capacity(self.batch_size());
595 for (batch_i, &batch_mask) in batch_mask.iter().enumerate() {
596 if !batch_mask {
597 let indices = vec![self.indices[batch_i] as u32; seq_len];
598 cache_indices.push(indices);
599 } else {
600 let mut indices = Vec::with_capacity(seq_len);
601 for _ in 0..seq_len {
602 let index = self.indices[batch_i];
603 indices.push(index as u32);
604 self.indices[batch_i] += 1;
605 self.positions[batch_i] += 1;
606 if self.indices[batch_i] >= self.context {
607 self.indices[batch_i] = 0;
608 }
609 }
610 cache_indices.push(indices);
611 }
612 }
613 let indices = Tensor::new(cache_indices, self.device())?;
614 Ok(IndicesAndMask { indices, mask })
615 }
616
617 fn get_mask_abs(&self, size1: usize, size2: usize) -> Result<Tensor> {
618 let context = self.context;
619 let mask: Vec<_> = (0..size1)
620 .flat_map(|i| {
621 (0..size2).map(move |j| {
622 if size1 + j > size2 + i || size1 + j + context < size2 + i {
623 f32::NEG_INFINITY
624 } else {
625 0.0
626 }
627 })
628 })
629 .collect();
630 Tensor::from_slice(&mask, (size1, size2), self.device())
631 }
632}
633
634#[derive(Debug, Clone)]
667pub struct ConcatKvCache {
668 k: Option<Tensor>,
669 v: Option<Tensor>,
670 dim: usize,
671}
672
673impl ConcatKvCache {
674 pub fn new(dim: usize) -> Self {
687 Self {
688 k: None,
689 v: None,
690 dim,
691 }
692 }
693
694 pub fn current_seq_len(&self) -> usize {
698 self.k
699 .as_ref()
700 .and_then(|k| k.dims().get(self.dim).copied())
701 .unwrap_or(0)
702 }
703
704 pub fn is_empty(&self) -> bool {
706 self.k.is_none()
707 }
708
709 pub fn dim(&self) -> usize {
711 self.dim
712 }
713
714 pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
726 let k = k.contiguous()?;
728 let v = v.contiguous()?;
729 self.k = Some(match &self.k {
731 None => k.clone(),
732 Some(k_cache) => {
733 Tensor::cat(&[k_cache, &k], self.dim)?
739 }
740 });
741
742 self.v = Some(match &self.v {
744 None => v.clone(),
745 Some(v_cache) => Tensor::cat(&[v_cache, &v], self.dim)?,
746 });
747
748 Ok((
749 self.k.as_ref().unwrap().clone(),
750 self.v.as_ref().unwrap().clone(),
751 ))
752 }
753
754 pub fn reset(&mut self) {
759 self.k = None;
760 self.v = None;
761 }
762
763 pub fn k(&self) -> Option<&Tensor> {
767 self.k.as_ref()
768 }
769
770 pub fn v(&self) -> Option<&Tensor> {
774 self.v.as_ref()
775 }
776
777 pub fn k_mut(&mut self) -> Option<&mut Tensor> {
781 self.k.as_mut()
782 }
783
784 pub fn v_mut(&mut self) -> Option<&mut Tensor> {
788 self.v.as_mut()
789 }
790
791 pub fn into_inner(self) -> Option<(Tensor, Tensor)> {
795 match (self.k, self.v) {
796 (Some(k), Some(v)) => Some((k, v)),
797 _ => None,
798 }
799 }
800}
801
802#[cfg(test)]
803mod tests {
804 use super::*;
805 use candle::IndexOp;
806
807 #[test]
808 fn test_scattered_kv_cache() -> Result<()> {
809 let device = Device::Cpu;
810 let mut cache = ScatteredCacheBuilder::new(2, 5, DType::F32, &device)?;
811 let inf = f32::INFINITY;
812
813 let iam = cache.indices_and_mask(1, &[true, false])?;
814 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
815 assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [0]]);
816 assert_eq!(
817 mask,
818 [[[0.0, -inf, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
819 );
820
821 let iam = cache.indices_and_mask(1, &[true, false])?;
822 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
823 assert_eq!(iam.indices.to_vec2::<u32>()?, [[1], [0]]);
824 assert_eq!(
825 mask,
826 [[[0.0, 0.0, -inf, -inf, -inf]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
827 );
828
829 let iam = cache.indices_and_mask(3, &[false, true])?;
830 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
831 assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 2, 2], [0, 1, 2]]);
832 assert_eq!(
833 mask,
834 [
835 [
836 [0.0, 0.0, 0.0, 0.0, 0.0],
837 [0.0, 0.0, 0.0, 0.0, 0.0],
838 [0.0, 0.0, 0.0, 0.0, 0.0]
839 ],
840 [
841 [0.0, -inf, -inf, -inf, -inf],
842 [0.0, 0.0, -inf, -inf, -inf],
843 [0.0, 0.0, 0.0, -inf, -inf]
844 ]
845 ]
846 );
847
848 let iam = cache.indices_and_mask(3, &[true, true])?;
849 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
850 assert_eq!(iam.indices.to_vec2::<u32>()?, [[2, 3, 4], [3, 4, 0]]);
851 assert_eq!(
852 mask,
853 [
854 [
855 [0.0, 0.0, 0.0, -inf, -inf],
856 [0.0, 0.0, 0.0, 0.0, -inf],
857 [0.0, 0.0, 0.0, 0.0, 0.0]
858 ],
859 [
860 [-inf, 0.0, 0.0, 0.0, -inf],
861 [-inf, 0.0, 0.0, 0.0, 0.0],
862 [0.0, 0.0, 0.0, 0.0, 0.0]
863 ]
864 ]
865 );
866
867 let iam = cache.indices_and_mask(1, &[true, false])?;
868 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
869 assert_eq!(iam.indices.to_vec2::<u32>()?, [[0], [1]]);
870 assert_eq!(
871 mask,
872 [[[0.0, 0.0, 0.0, 0.0, 0.0]], [[0.0, 0.0, 0.0, 0.0, 0.0]]]
873 );
874
875 let iam = cache.indices_and_mask(2, &[true, false])?;
876 let mask = iam.mask.i((.., 0))?.to_vec3::<f32>()?;
877 assert_eq!(iam.indices.to_vec2::<u32>()?, [[1, 2], [1, 1]]);
878 assert_eq!(
879 mask,
880 [
881 [[0.0, 0.0, -inf, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]],
882 [[0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0]]
883 ]
884 );
885
886 Ok(())
887 }
888
889 #[test]
890 fn test_concat_cache_basic() -> Result<()> {
891 let device = Device::Cpu;
892 let mut cache = ConcatKvCache::new(2);
893
894 assert!(cache.is_empty());
895 assert_eq!(cache.current_seq_len(), 0);
896
897 let k1 = Tensor::zeros((1, 8, 3, 64), DType::F32, &device)?;
899 let v1 = Tensor::zeros((1, 8, 3, 64), DType::F32, &device)?;
900 let (k, v) = cache.append(&k1, &v1)?;
901
902 assert_eq!(k.dims(), &[1, 8, 3, 64]);
903 assert_eq!(v.dims(), &[1, 8, 3, 64]);
904 assert_eq!(cache.current_seq_len(), 3);
905 assert!(!cache.is_empty());
906
907 let k2 = Tensor::zeros((1, 8, 2, 64), DType::F32, &device)?;
909 let v2 = Tensor::zeros((1, 8, 2, 64), DType::F32, &device)?;
910 let (k, v) = cache.append(&k2, &v2)?;
911
912 assert_eq!(k.dims(), &[1, 8, 5, 64]); assert_eq!(v.dims(), &[1, 8, 5, 64]);
914 assert_eq!(cache.current_seq_len(), 5);
915
916 Ok(())
917 }
918
919 #[test]
920 fn test_concat_cache_reset() -> Result<()> {
921 let device = Device::Cpu;
922 let mut cache = ConcatKvCache::new(2);
923
924 let k = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
925 let v = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
926 cache.append(&k, &v)?;
927
928 assert_eq!(cache.current_seq_len(), 10);
929
930 cache.reset();
931
932 assert!(cache.is_empty());
933 assert_eq!(cache.current_seq_len(), 0);
934 assert!(cache.k().is_none());
935 assert!(cache.v().is_none());
936
937 Ok(())
938 }
939
940 #[test]
941 fn test_concat_cache_multiple_appends() -> Result<()> {
942 let device = Device::Cpu;
943 let mut cache = ConcatKvCache::new(2);
944
945 let k_prefill = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
947 let v_prefill = Tensor::zeros((1, 8, 10, 64), DType::F32, &device)?;
948 cache.append(&k_prefill, &v_prefill)?;
949
950 assert_eq!(cache.current_seq_len(), 10);
951
952 for i in 1..=5 {
954 let k_token = Tensor::zeros((1, 8, 1, 64), DType::F32, &device)?;
955 let v_token = Tensor::zeros((1, 8, 1, 64), DType::F32, &device)?;
956 let (k, v) = cache.append(&k_token, &v_token)?;
957 assert_eq!(k.dims()[2], 10 + i);
958 assert_eq!(v.dims()[2], 10 + i);
959 }
960
961 assert_eq!(cache.current_seq_len(), 15);
962
963 Ok(())
964 }
965
966 #[test]
967 fn test_concat_cache_different_dim() -> Result<()> {
968 let device = Device::Cpu;
969 let mut cache = ConcatKvCache::new(1); let k1 = Tensor::zeros((1, 3, 8, 64), DType::F32, &device)?;
972 let v1 = Tensor::zeros((1, 3, 8, 64), DType::F32, &device)?;
973 let (k, _v) = cache.append(&k1, &v1)?;
974
975 assert_eq!(k.dims(), &[1, 3, 8, 64]);
976
977 let k2 = Tensor::zeros((1, 2, 8, 64), DType::F32, &device)?;
978 let v2 = Tensor::zeros((1, 2, 8, 64), DType::F32, &device)?;
979 let (k, _v) = cache.append(&k2, &v2)?;
980
981 assert_eq!(k.dims(), &[1, 5, 8, 64]); assert_eq!(cache.current_seq_len(), 5);
983
984 Ok(())
985 }
986}