Skip to main content

oxillama_runtime/kv_cache/
paged.rs

1//! Paged KV cache implementation.
2//!
3//! Uses fixed-size blocks (pages) for memory-efficient KV storage.
4//! Each page holds `PAGE_SIZE` tokens worth of KV data. Pages are
5//! allocated on demand from a shared pool and can be freed when
6//! a sequence is discarded.
7//!
8//! Benefits over contiguous cache:
9//! - Memory-efficient for variable-length sequences
10//! - No wasted memory for sequences shorter than max context
11//! - Foundation for continuous batching (share pool across sequences)
12
13use oxillama_arch::traits::KvCacheAccess;
14use oxillama_arch::ArchResult;
15
16/// Number of tokens per page. Chosen to balance allocation overhead
17/// (fewer, larger pages) vs memory waste (smaller pages = less waste).
18const PAGE_SIZE: usize = 16;
19
20/// A single page of KV data for one layer.
21///
22/// Stores `PAGE_SIZE` tokens worth of key or value data.
23/// Each token occupies `kv_dim` floats.
24struct Page {
25    data: Vec<f32>,
26}
27
28impl Page {
29    fn new(kv_dim: usize) -> Self {
30        Self {
31            data: vec![0.0f32; PAGE_SIZE * kv_dim],
32        }
33    }
34
35    /// Write one token's data at the given slot within this page.
36    fn write_token(&mut self, slot: usize, kv_dim: usize, src: &[f32]) {
37        let offset = slot * kv_dim;
38        self.data[offset..offset + kv_dim].copy_from_slice(&src[..kv_dim]);
39    }
40
41    /// Read one token's data at the given slot within this page.
42    fn read_token(&self, slot: usize, kv_dim: usize) -> &[f32] {
43        let offset = slot * kv_dim;
44        &self.data[offset..offset + kv_dim]
45    }
46}
47
48/// Per-layer page table: maps logical page index → physical page.
49struct LayerCache {
50    key_pages: Vec<Page>,
51    value_pages: Vec<Page>,
52}
53
54impl LayerCache {
55    fn new() -> Self {
56        Self {
57            key_pages: Vec::new(),
58            value_pages: Vec::new(),
59        }
60    }
61
62    /// Ensure enough pages are allocated to cover `token_pos` (0-based).
63    fn ensure_capacity(&mut self, token_pos: usize, kv_dim: usize) {
64        let needed_pages = token_pos / PAGE_SIZE + 1;
65        while self.key_pages.len() < needed_pages {
66            self.key_pages.push(Page::new(kv_dim));
67            self.value_pages.push(Page::new(kv_dim));
68        }
69    }
70
71    /// Store a KV pair at the given token position.
72    fn store(&mut self, token_pos: usize, kv_dim: usize, key: &[f32], value: &[f32]) {
73        self.ensure_capacity(token_pos, kv_dim);
74        let page_idx = token_pos / PAGE_SIZE;
75        let slot = token_pos % PAGE_SIZE;
76        self.key_pages[page_idx].write_token(slot, kv_dim, key);
77        self.value_pages[page_idx].write_token(slot, kv_dim, value);
78    }
79
80    /// Get the number of allocated pages.
81    fn num_pages(&self) -> usize {
82        self.key_pages.len()
83    }
84
85    /// Free all pages beyond what's needed for `seq_len` tokens.
86    fn shrink_to(&mut self, seq_len: usize) {
87        let needed = if seq_len == 0 {
88            0
89        } else {
90            seq_len / PAGE_SIZE + 1
91        };
92        self.key_pages.truncate(needed);
93        self.value_pages.truncate(needed);
94    }
95}
96
97/// Paged KV cache.
98///
99/// Memory is allocated in fixed-size pages (blocks) of `PAGE_SIZE` tokens.
100/// Pages grow on demand — short sequences don't waste memory for unused
101/// context positions. The `assemble_*` methods reconstruct contiguous
102/// slices for attention computation.
103pub struct PagedKvCache {
104    /// Per-layer caches.
105    layers: Vec<LayerCache>,
106    /// Current sequence length.
107    seq_len: usize,
108    /// Maximum sequence length.
109    max_seq_len: usize,
110    /// KV dimension per token (num_kv_heads * head_dim).
111    kv_dim: usize,
112    /// Number of layers.
113    num_layers: usize,
114}
115
116impl PagedKvCache {
117    /// Create a new paged KV cache.
118    ///
119    /// Unlike the contiguous cache, this does NOT pre-allocate all memory.
120    /// Pages are allocated on demand as tokens are processed.
121    pub fn new(num_layers: usize, max_seq_len: usize, kv_dim: usize) -> Self {
122        let layers = (0..num_layers).map(|_| LayerCache::new()).collect();
123
124        Self {
125            layers,
126            seq_len: 0,
127            max_seq_len,
128            kv_dim,
129            num_layers,
130        }
131    }
132
133    /// Returns the page size (tokens per page).
134    pub fn page_size(&self) -> usize {
135        PAGE_SIZE
136    }
137
138    /// Returns the maximum sequence length.
139    pub fn max_seq_len(&self) -> usize {
140        self.max_seq_len
141    }
142
143    /// Returns the KV dimension per token.
144    pub fn kv_dim(&self) -> usize {
145        self.kv_dim
146    }
147
148    /// Returns the number of layers.
149    pub fn num_layers(&self) -> usize {
150        self.num_layers
151    }
152
153    /// Returns total number of allocated pages across all layers.
154    pub fn total_pages(&self) -> usize {
155        self.layers.iter().map(|l| l.num_pages()).sum()
156    }
157
158    /// Returns total memory usage in bytes (approximate).
159    pub fn memory_bytes(&self) -> usize {
160        self.total_pages() * PAGE_SIZE * self.kv_dim * 4 * 2 // *2 for K+V, *4 for f32
161    }
162
163    /// Reset the cache, freeing all pages.
164    pub fn clear(&mut self) {
165        self.seq_len = 0;
166        for layer in &mut self.layers {
167            layer.key_pages.clear();
168            layer.value_pages.clear();
169        }
170    }
171
172    /// Shrink allocated pages to fit the current sequence length.
173    /// Useful after trimming context.
174    pub fn shrink_to_fit(&mut self) {
175        for layer in &mut self.layers {
176            layer.shrink_to(self.seq_len);
177        }
178    }
179
180    /// Assemble contiguous key data for a layer into the provided buffer.
181    ///
182    /// Copies from paged storage into a flat `[seq_len * kv_dim]` buffer.
183    fn assemble_keys(&self, layer: usize, buf: &mut Vec<f32>) {
184        let total = self.seq_len * self.kv_dim;
185        buf.clear();
186        buf.reserve(total);
187
188        let layer_cache = &self.layers[layer];
189        for pos in 0..self.seq_len {
190            let page_idx = pos / PAGE_SIZE;
191            let slot = pos % PAGE_SIZE;
192            let token_data = layer_cache.key_pages[page_idx].read_token(slot, self.kv_dim);
193            buf.extend_from_slice(token_data);
194        }
195    }
196
197    /// Assemble contiguous value data for a layer into the provided buffer.
198    fn assemble_values(&self, layer: usize, buf: &mut Vec<f32>) {
199        let total = self.seq_len * self.kv_dim;
200        buf.clear();
201        buf.reserve(total);
202
203        let layer_cache = &self.layers[layer];
204        for pos in 0..self.seq_len {
205            let page_idx = pos / PAGE_SIZE;
206            let slot = pos % PAGE_SIZE;
207            let token_data = layer_cache.value_pages[page_idx].read_token(slot, self.kv_dim);
208            buf.extend_from_slice(token_data);
209        }
210    }
211}
212
213impl KvCacheAccess for PagedKvCache {
214    fn seq_len(&self) -> usize {
215        self.seq_len
216    }
217
218    fn store_kv(&mut self, layer: usize, key: &[f32], value: &[f32]) -> ArchResult<()> {
219        if layer >= self.num_layers {
220            return Err(oxillama_arch::ArchError::ForwardPassError {
221                layer,
222                message: format!("layer index {layer} out of range (max {})", self.num_layers),
223            });
224        }
225
226        if self.seq_len >= self.max_seq_len {
227            return Err(oxillama_arch::ArchError::ForwardPassError {
228                layer,
229                message: format!(
230                    "sequence length {} exceeds max {}",
231                    self.seq_len, self.max_seq_len
232                ),
233            });
234        }
235
236        self.layers[layer].store(self.seq_len, self.kv_dim, key, value);
237        Ok(())
238    }
239
240    fn get_keys(&self, layer: usize) -> ArchResult<&[f32]> {
241        if layer >= self.num_layers {
242            return Err(oxillama_arch::ArchError::ForwardPassError {
243                layer,
244                message: format!("layer index {layer} out of range (max {})", self.num_layers),
245            });
246        }
247
248        // For now, we need to return a contiguous slice. The paged layout
249        // means we can't return a zero-copy slice if data spans multiple pages.
250        // This is a known limitation — the trait will need to evolve for
251        // truly zero-copy paged access (page-aware attention kernels).
252        //
253        // SAFETY: We use interior mutability via the assemble buffer approach.
254        // Since we can't mutate &self, we return a reference to assembled data
255        // that lives in the pages themselves when seq_len fits in one page.
256        if self.seq_len == 0 {
257            return Ok(&[]);
258        }
259
260        // Fast path: all data fits in a single page — return zero-copy slice
261        let pages_used = (self.seq_len - 1) / PAGE_SIZE + 1;
262        if pages_used == 1 {
263            let end = self.seq_len * self.kv_dim;
264            return Ok(&self.layers[layer].key_pages[0].data[..end]);
265        }
266
267        // Multi-page: we can't return a contiguous &[f32] without copying.
268        // This is a fundamental limitation of returning &[f32] from paged storage.
269        // For now, panic with a message pointing to the solution.
270        // TODO: Change trait to support page-aware iteration or accept a callback.
271        Err(oxillama_arch::ArchError::ForwardPassError {
272            layer,
273            message: format!(
274                "paged KV cache: sequence length {} spans {} pages; \
275                 use get_keys_into() for multi-page access",
276                self.seq_len, pages_used
277            ),
278        })
279    }
280
281    fn get_values(&self, layer: usize) -> ArchResult<&[f32]> {
282        if layer >= self.num_layers {
283            return Err(oxillama_arch::ArchError::ForwardPassError {
284                layer,
285                message: format!("layer index {layer} out of range (max {})", self.num_layers),
286            });
287        }
288
289        if self.seq_len == 0 {
290            return Ok(&[]);
291        }
292
293        let pages_used = (self.seq_len - 1) / PAGE_SIZE + 1;
294        if pages_used == 1 {
295            let end = self.seq_len * self.kv_dim;
296            return Ok(&self.layers[layer].value_pages[0].data[..end]);
297        }
298
299        Err(oxillama_arch::ArchError::ForwardPassError {
300            layer,
301            message: format!(
302                "paged KV cache: sequence length {} spans {} pages; \
303                 use get_values_into() for multi-page access",
304                self.seq_len, pages_used
305            ),
306        })
307    }
308
309    fn advance(&mut self) {
310        if self.seq_len < self.max_seq_len {
311            self.seq_len += 1;
312        }
313    }
314
315    fn kv_dim(&self) -> usize {
316        self.kv_dim
317    }
318
319    fn for_each_key(&self, layer: usize, f: &mut dyn FnMut(usize, &[f32])) -> ArchResult<()> {
320        self.iter_keys(layer, |pos, slice| f(pos, slice))
321    }
322
323    fn for_each_value(&self, layer: usize, f: &mut dyn FnMut(usize, &[f32])) -> ArchResult<()> {
324        self.iter_values(layer, |pos, slice| f(pos, slice))
325    }
326}
327
328/// Extended paged KV cache operations (not part of the base trait).
329impl PagedKvCache {
330    /// Copy all cached keys for a layer into a contiguous buffer.
331    ///
332    /// This is the multi-page alternative to `get_keys()`. The caller
333    /// provides a reusable buffer to avoid repeated allocation.
334    pub fn get_keys_into(&self, layer: usize, buf: &mut Vec<f32>) -> ArchResult<()> {
335        if layer >= self.num_layers {
336            return Err(oxillama_arch::ArchError::ForwardPassError {
337                layer,
338                message: format!("layer index {layer} out of range (max {})", self.num_layers),
339            });
340        }
341        self.assemble_keys(layer, buf);
342        Ok(())
343    }
344
345    /// Copy all cached values for a layer into a contiguous buffer.
346    pub fn get_values_into(&self, layer: usize, buf: &mut Vec<f32>) -> ArchResult<()> {
347        if layer >= self.num_layers {
348            return Err(oxillama_arch::ArchError::ForwardPassError {
349                layer,
350                message: format!("layer index {layer} out of range (max {})", self.num_layers),
351            });
352        }
353        self.assemble_values(layer, buf);
354        Ok(())
355    }
356
357    /// Read a specific token's key data from the cache.
358    pub fn get_key_token(&self, layer: usize, pos: usize) -> ArchResult<&[f32]> {
359        if layer >= self.num_layers {
360            return Err(oxillama_arch::ArchError::ForwardPassError {
361                layer,
362                message: format!("layer index {layer} out of range (max {})", self.num_layers),
363            });
364        }
365        if pos >= self.seq_len {
366            return Err(oxillama_arch::ArchError::ForwardPassError {
367                layer,
368                message: format!("position {pos} out of range (seq_len {})", self.seq_len),
369            });
370        }
371        let page_idx = pos / PAGE_SIZE;
372        let slot = pos % PAGE_SIZE;
373        Ok(self.layers[layer].key_pages[page_idx].read_token(slot, self.kv_dim))
374    }
375
376    /// Read a specific token's value data from the cache.
377    pub fn get_value_token(&self, layer: usize, pos: usize) -> ArchResult<&[f32]> {
378        if layer >= self.num_layers {
379            return Err(oxillama_arch::ArchError::ForwardPassError {
380                layer,
381                message: format!("layer index {layer} out of range (max {})", self.num_layers),
382            });
383        }
384        if pos >= self.seq_len {
385            return Err(oxillama_arch::ArchError::ForwardPassError {
386                layer,
387                message: format!("position {pos} out of range (seq_len {})", self.seq_len),
388            });
389        }
390        let page_idx = pos / PAGE_SIZE;
391        let slot = pos % PAGE_SIZE;
392        Ok(self.layers[layer].value_pages[page_idx].read_token(slot, self.kv_dim))
393    }
394
395    /// Iterate over key tokens for a layer, calling `f` for each (pos, key_data).
396    pub fn iter_keys<F>(&self, layer: usize, mut f: F) -> ArchResult<()>
397    where
398        F: FnMut(usize, &[f32]),
399    {
400        if layer >= self.num_layers {
401            return Err(oxillama_arch::ArchError::ForwardPassError {
402                layer,
403                message: format!("layer index {layer} out of range (max {})", self.num_layers),
404            });
405        }
406        let layer_cache = &self.layers[layer];
407        for pos in 0..self.seq_len {
408            let page_idx = pos / PAGE_SIZE;
409            let slot = pos % PAGE_SIZE;
410            let data = layer_cache.key_pages[page_idx].read_token(slot, self.kv_dim);
411            f(pos, data);
412        }
413        Ok(())
414    }
415
416    /// Iterate over value tokens for a layer.
417    pub fn iter_values<F>(&self, layer: usize, mut f: F) -> ArchResult<()>
418    where
419        F: FnMut(usize, &[f32]),
420    {
421        if layer >= self.num_layers {
422            return Err(oxillama_arch::ArchError::ForwardPassError {
423                layer,
424                message: format!("layer index {layer} out of range (max {})", self.num_layers),
425            });
426        }
427        let layer_cache = &self.layers[layer];
428        for pos in 0..self.seq_len {
429            let page_idx = pos / PAGE_SIZE;
430            let slot = pos % PAGE_SIZE;
431            let data = layer_cache.value_pages[page_idx].read_token(slot, self.kv_dim);
432            f(pos, data);
433        }
434        Ok(())
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_paged_basic_store_retrieve() {
444        let mut cache = PagedKvCache::new(2, 64, 4);
445        assert_eq!(cache.seq_len(), 0);
446        assert_eq!(cache.total_pages(), 0);
447
448        // Store first token in layer 0
449        let key = [1.0, 2.0, 3.0, 4.0];
450        let val = [5.0, 6.0, 7.0, 8.0];
451        cache.store_kv(0, &key, &val).unwrap();
452        cache.advance();
453
454        assert_eq!(cache.seq_len(), 1);
455        // Should have allocated 1 page per layer for layer 0 only
456        assert_eq!(cache.layers[0].num_pages(), 1);
457        assert_eq!(cache.layers[1].num_pages(), 0);
458
459        // Retrieve via single-page fast path
460        let keys = cache.get_keys(0).unwrap();
461        assert_eq!(keys, &[1.0, 2.0, 3.0, 4.0]);
462
463        let vals = cache.get_values(0).unwrap();
464        assert_eq!(vals, &[5.0, 6.0, 7.0, 8.0]);
465    }
466
467    #[test]
468    fn test_paged_multi_token_single_page() {
469        let mut cache = PagedKvCache::new(1, 64, 2);
470
471        // Store PAGE_SIZE tokens (all fit in one page)
472        for i in 0..PAGE_SIZE {
473            let key = [i as f32, (i * 10) as f32];
474            let val = [(i + 100) as f32, (i + 200) as f32];
475            cache.store_kv(0, &key, &val).unwrap();
476            cache.advance();
477        }
478
479        assert_eq!(cache.seq_len(), PAGE_SIZE);
480        assert_eq!(cache.layers[0].num_pages(), 1);
481
482        // Should still work via fast path (single page)
483        let keys = cache.get_keys(0).unwrap();
484        assert_eq!(keys.len(), PAGE_SIZE * 2);
485        assert_eq!(keys[0], 0.0);
486        assert_eq!(keys[1], 0.0);
487        assert_eq!(keys[2], 1.0);
488        assert_eq!(keys[3], 10.0);
489    }
490
491    #[test]
492    fn test_paged_multi_page_assembly() {
493        let mut cache = PagedKvCache::new(1, 64, 2);
494
495        // Store PAGE_SIZE + 1 tokens (spans 2 pages)
496        for i in 0..=PAGE_SIZE {
497            let key = [i as f32, (i * 10) as f32];
498            let val = [(i + 100) as f32, (i + 200) as f32];
499            cache.store_kv(0, &key, &val).unwrap();
500            cache.advance();
501        }
502
503        assert_eq!(cache.seq_len(), PAGE_SIZE + 1);
504        assert_eq!(cache.layers[0].num_pages(), 2);
505
506        // get_keys returns error for multi-page
507        assert!(cache.get_keys(0).is_err());
508
509        // Use get_keys_into for multi-page access
510        let mut buf = Vec::new();
511        cache.get_keys_into(0, &mut buf).unwrap();
512        assert_eq!(buf.len(), (PAGE_SIZE + 1) * 2);
513
514        // Verify first and last tokens
515        assert_eq!(buf[0], 0.0);
516        assert_eq!(buf[1], 0.0);
517        let last_off = PAGE_SIZE * 2;
518        assert_eq!(buf[last_off], PAGE_SIZE as f32);
519        assert_eq!(buf[last_off + 1], (PAGE_SIZE * 10) as f32);
520    }
521
522    #[test]
523    fn test_paged_per_token_access() {
524        let mut cache = PagedKvCache::new(1, 64, 3);
525
526        for i in 0..20 {
527            let key = [i as f32, (i * 2) as f32, (i * 3) as f32];
528            let val = [(i + 50) as f32, (i + 60) as f32, (i + 70) as f32];
529            cache.store_kv(0, &key, &val).unwrap();
530            cache.advance();
531        }
532
533        // Access specific tokens
534        let k5 = cache.get_key_token(0, 5).unwrap();
535        assert_eq!(k5, &[5.0, 10.0, 15.0]);
536
537        let v17 = cache.get_value_token(0, 17).unwrap();
538        assert_eq!(v17, &[67.0, 77.0, 87.0]);
539
540        // Out of range
541        assert!(cache.get_key_token(0, 20).is_err());
542    }
543
544    #[test]
545    fn test_paged_iteration() {
546        let mut cache = PagedKvCache::new(1, 64, 2);
547
548        for i in 0..20 {
549            let key = [i as f32, (i + 1) as f32];
550            let val = [(i + 100) as f32, (i + 101) as f32];
551            cache.store_kv(0, &key, &val).unwrap();
552            cache.advance();
553        }
554
555        let mut count = 0;
556        cache
557            .iter_keys(0, |pos, data| {
558                assert_eq!(data[0], pos as f32);
559                assert_eq!(data[1], (pos + 1) as f32);
560                count += 1;
561            })
562            .unwrap();
563        assert_eq!(count, 20);
564    }
565
566    #[test]
567    fn test_paged_clear() {
568        let mut cache = PagedKvCache::new(2, 64, 4);
569
570        for i in 0..20 {
571            let key = [i as f32; 4];
572            let val = [i as f32; 4];
573            cache.store_kv(0, &key, &val).unwrap();
574            cache.store_kv(1, &key, &val).unwrap();
575            cache.advance();
576        }
577
578        assert!(cache.total_pages() > 0);
579        cache.clear();
580        assert_eq!(cache.seq_len(), 0);
581        assert_eq!(cache.total_pages(), 0);
582    }
583
584    #[test]
585    fn test_paged_shrink_to_fit() {
586        let mut cache = PagedKvCache::new(1, 128, 4);
587
588        // Fill 40 tokens (3 pages)
589        for i in 0..40 {
590            cache.store_kv(0, &[i as f32; 4], &[i as f32; 4]).unwrap();
591            cache.advance();
592        }
593        assert_eq!(cache.layers[0].num_pages(), 3);
594
595        // Manually reduce seq_len to simulate context trimming
596        cache.seq_len = 10;
597        cache.shrink_to_fit();
598        // 10 tokens → 1 page needed
599        assert_eq!(cache.layers[0].num_pages(), 1);
600    }
601
602    #[test]
603    fn test_paged_memory_efficiency() {
604        // Compare memory: contiguous pre-allocates everything,
605        // paged only allocates what's used.
606        let num_layers = 32;
607        let max_seq = 4096;
608        let kv_dim = 128;
609
610        let contiguous_bytes = num_layers * max_seq * kv_dim * 4 * 2; // K+V
611
612        let mut cache = PagedKvCache::new(num_layers, max_seq, kv_dim);
613        // Store just 10 tokens
614        for i in 0..10 {
615            for layer in 0..num_layers {
616                cache
617                    .store_kv(layer, &vec![i as f32; kv_dim], &vec![i as f32; kv_dim])
618                    .unwrap();
619            }
620            cache.advance();
621        }
622
623        let paged_bytes = cache.memory_bytes();
624        // Paged should use much less memory than contiguous
625        assert!(
626            paged_bytes < contiguous_bytes / 10,
627            "paged={paged_bytes} should be << contiguous={contiguous_bytes}"
628        );
629    }
630
631    #[test]
632    fn test_paged_max_seq_len_error() {
633        let mut cache = PagedKvCache::new(1, 2, 2);
634
635        cache.store_kv(0, &[1.0, 2.0], &[3.0, 4.0]).unwrap();
636        cache.advance();
637        cache.store_kv(0, &[5.0, 6.0], &[7.0, 8.0]).unwrap();
638        cache.advance();
639
640        // Should fail — at max
641        let result = cache.store_kv(0, &[9.0, 10.0], &[11.0, 12.0]);
642        assert!(result.is_err());
643    }
644
645    // ── for_each_key / for_each_value override ───────────────────────────────
646
647    #[test]
648    fn paged_for_each_key_multi_page() {
649        use oxillama_arch::traits::KvCacheAccess;
650
651        let kv_dim = 2usize;
652        // Fill PAGE_SIZE + 4 tokens so we span two pages.
653        let seq_len = PAGE_SIZE + 4;
654        let mut cache = PagedKvCache::new(1, seq_len + 10, kv_dim);
655
656        for t in 0..seq_len {
657            let key = [t as f32, t as f32 * 2.0];
658            let val = [0.0f32; 2];
659            cache.store_kv(0, &key, &val).expect("store_kv");
660            cache.advance();
661        }
662
663        assert_eq!(cache.seq_len(), seq_len);
664        assert_eq!(cache.layers[0].num_pages(), 2, "must span two pages");
665
666        // get_keys() returns error for multi-page
667        assert!(cache.get_keys(0).is_err());
668
669        // for_each_key must visit all tokens via page-aware path
670        let mut positions_seen: Vec<usize> = Vec::new();
671        let mut first_elements: Vec<f32> = Vec::new();
672        cache
673            .for_each_key(0, &mut |pos, slice| {
674                positions_seen.push(pos);
675                first_elements.push(slice[0]);
676            })
677            .expect("for_each_key must succeed on paged cache");
678
679        assert_eq!(
680            positions_seen.len(),
681            seq_len,
682            "must visit all {} positions",
683            seq_len
684        );
685        assert_eq!(positions_seen, (0..seq_len).collect::<Vec<_>>());
686
687        // Verify key data at a few positions
688        for (t, &first) in first_elements.iter().enumerate() {
689            let expected = t as f32;
690            assert!(
691                (first - expected).abs() < 1e-6,
692                "token {t}: expected first element {expected}, got {first}"
693            );
694        }
695    }
696
697    #[test]
698    fn paged_for_each_value_multi_page() {
699        use oxillama_arch::traits::KvCacheAccess;
700
701        let kv_dim = 3usize;
702        let seq_len = PAGE_SIZE + 2;
703        let mut cache = PagedKvCache::new(1, seq_len + 10, kv_dim);
704
705        for t in 0..seq_len {
706            let key = [0.0f32; 3];
707            let val = [t as f32, t as f32 * 10.0, t as f32 * 100.0];
708            cache.store_kv(0, &key, &val).expect("store_kv");
709            cache.advance();
710        }
711
712        let mut count = 0usize;
713        let mut sum_first: f32 = 0.0;
714        cache
715            .for_each_value(0, &mut |_pos, slice| {
716                count += 1;
717                sum_first += slice[0];
718            })
719            .expect("for_each_value must succeed");
720
721        assert_eq!(count, seq_len, "must visit all tokens");
722        // sum of 0..seq_len = seq_len*(seq_len-1)/2
723        let expected_sum = (seq_len * (seq_len - 1) / 2) as f32;
724        assert!(
725            (sum_first - expected_sum).abs() < 1e-4,
726            "sum of first value elements: expected {expected_sum}, got {sum_first}"
727        );
728    }
729}