1use oxillama_arch::traits::KvCacheAccess;
14use oxillama_arch::ArchResult;
15
16const PAGE_SIZE: usize = 16;
19
20struct Page {
25 data: Vec<f32>,
26}
27
28impl Page {
29 fn new(kv_dim: usize) -> Self {
30 Self {
31 data: vec![0.0f32; PAGE_SIZE * kv_dim],
32 }
33 }
34
35 fn write_token(&mut self, slot: usize, kv_dim: usize, src: &[f32]) {
37 let offset = slot * kv_dim;
38 self.data[offset..offset + kv_dim].copy_from_slice(&src[..kv_dim]);
39 }
40
41 fn read_token(&self, slot: usize, kv_dim: usize) -> &[f32] {
43 let offset = slot * kv_dim;
44 &self.data[offset..offset + kv_dim]
45 }
46}
47
48struct LayerCache {
50 key_pages: Vec<Page>,
51 value_pages: Vec<Page>,
52}
53
54impl LayerCache {
55 fn new() -> Self {
56 Self {
57 key_pages: Vec::new(),
58 value_pages: Vec::new(),
59 }
60 }
61
62 fn ensure_capacity(&mut self, token_pos: usize, kv_dim: usize) {
64 let needed_pages = token_pos / PAGE_SIZE + 1;
65 while self.key_pages.len() < needed_pages {
66 self.key_pages.push(Page::new(kv_dim));
67 self.value_pages.push(Page::new(kv_dim));
68 }
69 }
70
71 fn store(&mut self, token_pos: usize, kv_dim: usize, key: &[f32], value: &[f32]) {
73 self.ensure_capacity(token_pos, kv_dim);
74 let page_idx = token_pos / PAGE_SIZE;
75 let slot = token_pos % PAGE_SIZE;
76 self.key_pages[page_idx].write_token(slot, kv_dim, key);
77 self.value_pages[page_idx].write_token(slot, kv_dim, value);
78 }
79
80 fn num_pages(&self) -> usize {
82 self.key_pages.len()
83 }
84
85 fn shrink_to(&mut self, seq_len: usize) {
87 let needed = if seq_len == 0 {
88 0
89 } else {
90 seq_len / PAGE_SIZE + 1
91 };
92 self.key_pages.truncate(needed);
93 self.value_pages.truncate(needed);
94 }
95}
96
97pub struct PagedKvCache {
104 layers: Vec<LayerCache>,
106 seq_len: usize,
108 max_seq_len: usize,
110 kv_dim: usize,
112 num_layers: usize,
114}
115
116impl PagedKvCache {
117 pub fn new(num_layers: usize, max_seq_len: usize, kv_dim: usize) -> Self {
122 let layers = (0..num_layers).map(|_| LayerCache::new()).collect();
123
124 Self {
125 layers,
126 seq_len: 0,
127 max_seq_len,
128 kv_dim,
129 num_layers,
130 }
131 }
132
133 pub fn page_size(&self) -> usize {
135 PAGE_SIZE
136 }
137
138 pub fn max_seq_len(&self) -> usize {
140 self.max_seq_len
141 }
142
143 pub fn kv_dim(&self) -> usize {
145 self.kv_dim
146 }
147
148 pub fn num_layers(&self) -> usize {
150 self.num_layers
151 }
152
153 pub fn total_pages(&self) -> usize {
155 self.layers.iter().map(|l| l.num_pages()).sum()
156 }
157
158 pub fn memory_bytes(&self) -> usize {
160 self.total_pages() * PAGE_SIZE * self.kv_dim * 4 * 2 }
162
163 pub fn clear(&mut self) {
165 self.seq_len = 0;
166 for layer in &mut self.layers {
167 layer.key_pages.clear();
168 layer.value_pages.clear();
169 }
170 }
171
172 pub fn shrink_to_fit(&mut self) {
175 for layer in &mut self.layers {
176 layer.shrink_to(self.seq_len);
177 }
178 }
179
180 fn assemble_keys(&self, layer: usize, buf: &mut Vec<f32>) {
184 let total = self.seq_len * self.kv_dim;
185 buf.clear();
186 buf.reserve(total);
187
188 let layer_cache = &self.layers[layer];
189 for pos in 0..self.seq_len {
190 let page_idx = pos / PAGE_SIZE;
191 let slot = pos % PAGE_SIZE;
192 let token_data = layer_cache.key_pages[page_idx].read_token(slot, self.kv_dim);
193 buf.extend_from_slice(token_data);
194 }
195 }
196
197 fn assemble_values(&self, layer: usize, buf: &mut Vec<f32>) {
199 let total = self.seq_len * self.kv_dim;
200 buf.clear();
201 buf.reserve(total);
202
203 let layer_cache = &self.layers[layer];
204 for pos in 0..self.seq_len {
205 let page_idx = pos / PAGE_SIZE;
206 let slot = pos % PAGE_SIZE;
207 let token_data = layer_cache.value_pages[page_idx].read_token(slot, self.kv_dim);
208 buf.extend_from_slice(token_data);
209 }
210 }
211}
212
213impl KvCacheAccess for PagedKvCache {
214 fn seq_len(&self) -> usize {
215 self.seq_len
216 }
217
218 fn store_kv(&mut self, layer: usize, key: &[f32], value: &[f32]) -> ArchResult<()> {
219 if layer >= self.num_layers {
220 return Err(oxillama_arch::ArchError::ForwardPassError {
221 layer,
222 message: format!("layer index {layer} out of range (max {})", self.num_layers),
223 });
224 }
225
226 if self.seq_len >= self.max_seq_len {
227 return Err(oxillama_arch::ArchError::ForwardPassError {
228 layer,
229 message: format!(
230 "sequence length {} exceeds max {}",
231 self.seq_len, self.max_seq_len
232 ),
233 });
234 }
235
236 self.layers[layer].store(self.seq_len, self.kv_dim, key, value);
237 Ok(())
238 }
239
240 fn get_keys(&self, layer: usize) -> ArchResult<&[f32]> {
241 if layer >= self.num_layers {
242 return Err(oxillama_arch::ArchError::ForwardPassError {
243 layer,
244 message: format!("layer index {layer} out of range (max {})", self.num_layers),
245 });
246 }
247
248 if self.seq_len == 0 {
257 return Ok(&[]);
258 }
259
260 let pages_used = (self.seq_len - 1) / PAGE_SIZE + 1;
262 if pages_used == 1 {
263 let end = self.seq_len * self.kv_dim;
264 return Ok(&self.layers[layer].key_pages[0].data[..end]);
265 }
266
267 Err(oxillama_arch::ArchError::ForwardPassError {
272 layer,
273 message: format!(
274 "paged KV cache: sequence length {} spans {} pages; \
275 use get_keys_into() for multi-page access",
276 self.seq_len, pages_used
277 ),
278 })
279 }
280
281 fn get_values(&self, layer: usize) -> ArchResult<&[f32]> {
282 if layer >= self.num_layers {
283 return Err(oxillama_arch::ArchError::ForwardPassError {
284 layer,
285 message: format!("layer index {layer} out of range (max {})", self.num_layers),
286 });
287 }
288
289 if self.seq_len == 0 {
290 return Ok(&[]);
291 }
292
293 let pages_used = (self.seq_len - 1) / PAGE_SIZE + 1;
294 if pages_used == 1 {
295 let end = self.seq_len * self.kv_dim;
296 return Ok(&self.layers[layer].value_pages[0].data[..end]);
297 }
298
299 Err(oxillama_arch::ArchError::ForwardPassError {
300 layer,
301 message: format!(
302 "paged KV cache: sequence length {} spans {} pages; \
303 use get_values_into() for multi-page access",
304 self.seq_len, pages_used
305 ),
306 })
307 }
308
309 fn advance(&mut self) {
310 if self.seq_len < self.max_seq_len {
311 self.seq_len += 1;
312 }
313 }
314
315 fn kv_dim(&self) -> usize {
316 self.kv_dim
317 }
318
319 fn for_each_key(&self, layer: usize, f: &mut dyn FnMut(usize, &[f32])) -> ArchResult<()> {
320 self.iter_keys(layer, |pos, slice| f(pos, slice))
321 }
322
323 fn for_each_value(&self, layer: usize, f: &mut dyn FnMut(usize, &[f32])) -> ArchResult<()> {
324 self.iter_values(layer, |pos, slice| f(pos, slice))
325 }
326}
327
328impl PagedKvCache {
330 pub fn get_keys_into(&self, layer: usize, buf: &mut Vec<f32>) -> ArchResult<()> {
335 if layer >= self.num_layers {
336 return Err(oxillama_arch::ArchError::ForwardPassError {
337 layer,
338 message: format!("layer index {layer} out of range (max {})", self.num_layers),
339 });
340 }
341 self.assemble_keys(layer, buf);
342 Ok(())
343 }
344
345 pub fn get_values_into(&self, layer: usize, buf: &mut Vec<f32>) -> ArchResult<()> {
347 if layer >= self.num_layers {
348 return Err(oxillama_arch::ArchError::ForwardPassError {
349 layer,
350 message: format!("layer index {layer} out of range (max {})", self.num_layers),
351 });
352 }
353 self.assemble_values(layer, buf);
354 Ok(())
355 }
356
357 pub fn get_key_token(&self, layer: usize, pos: usize) -> ArchResult<&[f32]> {
359 if layer >= self.num_layers {
360 return Err(oxillama_arch::ArchError::ForwardPassError {
361 layer,
362 message: format!("layer index {layer} out of range (max {})", self.num_layers),
363 });
364 }
365 if pos >= self.seq_len {
366 return Err(oxillama_arch::ArchError::ForwardPassError {
367 layer,
368 message: format!("position {pos} out of range (seq_len {})", self.seq_len),
369 });
370 }
371 let page_idx = pos / PAGE_SIZE;
372 let slot = pos % PAGE_SIZE;
373 Ok(self.layers[layer].key_pages[page_idx].read_token(slot, self.kv_dim))
374 }
375
376 pub fn get_value_token(&self, layer: usize, pos: usize) -> ArchResult<&[f32]> {
378 if layer >= self.num_layers {
379 return Err(oxillama_arch::ArchError::ForwardPassError {
380 layer,
381 message: format!("layer index {layer} out of range (max {})", self.num_layers),
382 });
383 }
384 if pos >= self.seq_len {
385 return Err(oxillama_arch::ArchError::ForwardPassError {
386 layer,
387 message: format!("position {pos} out of range (seq_len {})", self.seq_len),
388 });
389 }
390 let page_idx = pos / PAGE_SIZE;
391 let slot = pos % PAGE_SIZE;
392 Ok(self.layers[layer].value_pages[page_idx].read_token(slot, self.kv_dim))
393 }
394
395 pub fn iter_keys<F>(&self, layer: usize, mut f: F) -> ArchResult<()>
397 where
398 F: FnMut(usize, &[f32]),
399 {
400 if layer >= self.num_layers {
401 return Err(oxillama_arch::ArchError::ForwardPassError {
402 layer,
403 message: format!("layer index {layer} out of range (max {})", self.num_layers),
404 });
405 }
406 let layer_cache = &self.layers[layer];
407 for pos in 0..self.seq_len {
408 let page_idx = pos / PAGE_SIZE;
409 let slot = pos % PAGE_SIZE;
410 let data = layer_cache.key_pages[page_idx].read_token(slot, self.kv_dim);
411 f(pos, data);
412 }
413 Ok(())
414 }
415
416 pub fn iter_values<F>(&self, layer: usize, mut f: F) -> ArchResult<()>
418 where
419 F: FnMut(usize, &[f32]),
420 {
421 if layer >= self.num_layers {
422 return Err(oxillama_arch::ArchError::ForwardPassError {
423 layer,
424 message: format!("layer index {layer} out of range (max {})", self.num_layers),
425 });
426 }
427 let layer_cache = &self.layers[layer];
428 for pos in 0..self.seq_len {
429 let page_idx = pos / PAGE_SIZE;
430 let slot = pos % PAGE_SIZE;
431 let data = layer_cache.value_pages[page_idx].read_token(slot, self.kv_dim);
432 f(pos, data);
433 }
434 Ok(())
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_paged_basic_store_retrieve() {
444 let mut cache = PagedKvCache::new(2, 64, 4);
445 assert_eq!(cache.seq_len(), 0);
446 assert_eq!(cache.total_pages(), 0);
447
448 let key = [1.0, 2.0, 3.0, 4.0];
450 let val = [5.0, 6.0, 7.0, 8.0];
451 cache.store_kv(0, &key, &val).unwrap();
452 cache.advance();
453
454 assert_eq!(cache.seq_len(), 1);
455 assert_eq!(cache.layers[0].num_pages(), 1);
457 assert_eq!(cache.layers[1].num_pages(), 0);
458
459 let keys = cache.get_keys(0).unwrap();
461 assert_eq!(keys, &[1.0, 2.0, 3.0, 4.0]);
462
463 let vals = cache.get_values(0).unwrap();
464 assert_eq!(vals, &[5.0, 6.0, 7.0, 8.0]);
465 }
466
467 #[test]
468 fn test_paged_multi_token_single_page() {
469 let mut cache = PagedKvCache::new(1, 64, 2);
470
471 for i in 0..PAGE_SIZE {
473 let key = [i as f32, (i * 10) as f32];
474 let val = [(i + 100) as f32, (i + 200) as f32];
475 cache.store_kv(0, &key, &val).unwrap();
476 cache.advance();
477 }
478
479 assert_eq!(cache.seq_len(), PAGE_SIZE);
480 assert_eq!(cache.layers[0].num_pages(), 1);
481
482 let keys = cache.get_keys(0).unwrap();
484 assert_eq!(keys.len(), PAGE_SIZE * 2);
485 assert_eq!(keys[0], 0.0);
486 assert_eq!(keys[1], 0.0);
487 assert_eq!(keys[2], 1.0);
488 assert_eq!(keys[3], 10.0);
489 }
490
491 #[test]
492 fn test_paged_multi_page_assembly() {
493 let mut cache = PagedKvCache::new(1, 64, 2);
494
495 for i in 0..=PAGE_SIZE {
497 let key = [i as f32, (i * 10) as f32];
498 let val = [(i + 100) as f32, (i + 200) as f32];
499 cache.store_kv(0, &key, &val).unwrap();
500 cache.advance();
501 }
502
503 assert_eq!(cache.seq_len(), PAGE_SIZE + 1);
504 assert_eq!(cache.layers[0].num_pages(), 2);
505
506 assert!(cache.get_keys(0).is_err());
508
509 let mut buf = Vec::new();
511 cache.get_keys_into(0, &mut buf).unwrap();
512 assert_eq!(buf.len(), (PAGE_SIZE + 1) * 2);
513
514 assert_eq!(buf[0], 0.0);
516 assert_eq!(buf[1], 0.0);
517 let last_off = PAGE_SIZE * 2;
518 assert_eq!(buf[last_off], PAGE_SIZE as f32);
519 assert_eq!(buf[last_off + 1], (PAGE_SIZE * 10) as f32);
520 }
521
522 #[test]
523 fn test_paged_per_token_access() {
524 let mut cache = PagedKvCache::new(1, 64, 3);
525
526 for i in 0..20 {
527 let key = [i as f32, (i * 2) as f32, (i * 3) as f32];
528 let val = [(i + 50) as f32, (i + 60) as f32, (i + 70) as f32];
529 cache.store_kv(0, &key, &val).unwrap();
530 cache.advance();
531 }
532
533 let k5 = cache.get_key_token(0, 5).unwrap();
535 assert_eq!(k5, &[5.0, 10.0, 15.0]);
536
537 let v17 = cache.get_value_token(0, 17).unwrap();
538 assert_eq!(v17, &[67.0, 77.0, 87.0]);
539
540 assert!(cache.get_key_token(0, 20).is_err());
542 }
543
544 #[test]
545 fn test_paged_iteration() {
546 let mut cache = PagedKvCache::new(1, 64, 2);
547
548 for i in 0..20 {
549 let key = [i as f32, (i + 1) as f32];
550 let val = [(i + 100) as f32, (i + 101) as f32];
551 cache.store_kv(0, &key, &val).unwrap();
552 cache.advance();
553 }
554
555 let mut count = 0;
556 cache
557 .iter_keys(0, |pos, data| {
558 assert_eq!(data[0], pos as f32);
559 assert_eq!(data[1], (pos + 1) as f32);
560 count += 1;
561 })
562 .unwrap();
563 assert_eq!(count, 20);
564 }
565
566 #[test]
567 fn test_paged_clear() {
568 let mut cache = PagedKvCache::new(2, 64, 4);
569
570 for i in 0..20 {
571 let key = [i as f32; 4];
572 let val = [i as f32; 4];
573 cache.store_kv(0, &key, &val).unwrap();
574 cache.store_kv(1, &key, &val).unwrap();
575 cache.advance();
576 }
577
578 assert!(cache.total_pages() > 0);
579 cache.clear();
580 assert_eq!(cache.seq_len(), 0);
581 assert_eq!(cache.total_pages(), 0);
582 }
583
584 #[test]
585 fn test_paged_shrink_to_fit() {
586 let mut cache = PagedKvCache::new(1, 128, 4);
587
588 for i in 0..40 {
590 cache.store_kv(0, &[i as f32; 4], &[i as f32; 4]).unwrap();
591 cache.advance();
592 }
593 assert_eq!(cache.layers[0].num_pages(), 3);
594
595 cache.seq_len = 10;
597 cache.shrink_to_fit();
598 assert_eq!(cache.layers[0].num_pages(), 1);
600 }
601
602 #[test]
603 fn test_paged_memory_efficiency() {
604 let num_layers = 32;
607 let max_seq = 4096;
608 let kv_dim = 128;
609
610 let contiguous_bytes = num_layers * max_seq * kv_dim * 4 * 2; let mut cache = PagedKvCache::new(num_layers, max_seq, kv_dim);
613 for i in 0..10 {
615 for layer in 0..num_layers {
616 cache
617 .store_kv(layer, &vec![i as f32; kv_dim], &vec![i as f32; kv_dim])
618 .unwrap();
619 }
620 cache.advance();
621 }
622
623 let paged_bytes = cache.memory_bytes();
624 assert!(
626 paged_bytes < contiguous_bytes / 10,
627 "paged={paged_bytes} should be << contiguous={contiguous_bytes}"
628 );
629 }
630
631 #[test]
632 fn test_paged_max_seq_len_error() {
633 let mut cache = PagedKvCache::new(1, 2, 2);
634
635 cache.store_kv(0, &[1.0, 2.0], &[3.0, 4.0]).unwrap();
636 cache.advance();
637 cache.store_kv(0, &[5.0, 6.0], &[7.0, 8.0]).unwrap();
638 cache.advance();
639
640 let result = cache.store_kv(0, &[9.0, 10.0], &[11.0, 12.0]);
642 assert!(result.is_err());
643 }
644
645 #[test]
648 fn paged_for_each_key_multi_page() {
649 use oxillama_arch::traits::KvCacheAccess;
650
651 let kv_dim = 2usize;
652 let seq_len = PAGE_SIZE + 4;
654 let mut cache = PagedKvCache::new(1, seq_len + 10, kv_dim);
655
656 for t in 0..seq_len {
657 let key = [t as f32, t as f32 * 2.0];
658 let val = [0.0f32; 2];
659 cache.store_kv(0, &key, &val).expect("store_kv");
660 cache.advance();
661 }
662
663 assert_eq!(cache.seq_len(), seq_len);
664 assert_eq!(cache.layers[0].num_pages(), 2, "must span two pages");
665
666 assert!(cache.get_keys(0).is_err());
668
669 let mut positions_seen: Vec<usize> = Vec::new();
671 let mut first_elements: Vec<f32> = Vec::new();
672 cache
673 .for_each_key(0, &mut |pos, slice| {
674 positions_seen.push(pos);
675 first_elements.push(slice[0]);
676 })
677 .expect("for_each_key must succeed on paged cache");
678
679 assert_eq!(
680 positions_seen.len(),
681 seq_len,
682 "must visit all {} positions",
683 seq_len
684 );
685 assert_eq!(positions_seen, (0..seq_len).collect::<Vec<_>>());
686
687 for (t, &first) in first_elements.iter().enumerate() {
689 let expected = t as f32;
690 assert!(
691 (first - expected).abs() < 1e-6,
692 "token {t}: expected first element {expected}, got {first}"
693 );
694 }
695 }
696
697 #[test]
698 fn paged_for_each_value_multi_page() {
699 use oxillama_arch::traits::KvCacheAccess;
700
701 let kv_dim = 3usize;
702 let seq_len = PAGE_SIZE + 2;
703 let mut cache = PagedKvCache::new(1, seq_len + 10, kv_dim);
704
705 for t in 0..seq_len {
706 let key = [0.0f32; 3];
707 let val = [t as f32, t as f32 * 10.0, t as f32 * 100.0];
708 cache.store_kv(0, &key, &val).expect("store_kv");
709 cache.advance();
710 }
711
712 let mut count = 0usize;
713 let mut sum_first: f32 = 0.0;
714 cache
715 .for_each_value(0, &mut |_pos, slice| {
716 count += 1;
717 sum_first += slice[0];
718 })
719 .expect("for_each_value must succeed");
720
721 assert_eq!(count, seq_len, "must visit all tokens");
722 let expected_sum = (seq_len * (seq_len - 1) / 2) as f32;
724 assert!(
725 (sum_first - expected_sum).abs() < 1e-4,
726 "sum of first value elements: expected {expected_sum}, got {sum_first}"
727 );
728 }
729}