Skip to main content

oxibonsai_model/
kv_cache.rs

1//! KV Cache for autoregressive generation.
2//!
3//! Stores key and value tensors for each layer to avoid recomputation
4//! during token-by-token generation. Provides both a standard contiguous
5//! cache and a page-based cache for memory-efficient allocation.
6
7/// Policy for KV cache storage format.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum KvCachePolicy {
10    /// Standard FP32 cache (contiguous allocation).
11    #[default]
12    Standard,
13    /// FP16 cache (half the memory of Standard).
14    Fp16,
15    /// Sliding window cache: only retain the most recent N positions.
16    SlidingWindow(usize),
17}
18
19/// Per-layer KV cache storing FP32 key and value vectors.
20#[derive(Debug)]
21pub struct KvCache {
22    /// Number of Transformer layers.
23    num_layers: usize,
24    /// Number of KV heads per layer.
25    num_kv_heads: usize,
26    /// Dimension per head.
27    head_dim: usize,
28    /// Maximum sequence length.
29    max_seq_len: usize,
30    /// Current sequence length (number of tokens cached).
31    seq_len: usize,
32    /// Key cache: [num_layers × num_kv_heads × max_seq_len × head_dim].
33    keys: Vec<f32>,
34    /// Value cache: [num_layers × num_kv_heads × max_seq_len × head_dim].
35    values: Vec<f32>,
36}
37
38impl KvCache {
39    /// Create a new KV cache.
40    pub fn new(
41        num_layers: usize,
42        num_kv_heads: usize,
43        head_dim: usize,
44        max_seq_len: usize,
45    ) -> Self {
46        let total = num_layers * num_kv_heads * max_seq_len * head_dim;
47        Self {
48            num_layers,
49            num_kv_heads,
50            head_dim,
51            max_seq_len,
52            seq_len: 0,
53            keys: vec![0.0; total],
54            values: vec![0.0; total],
55        }
56    }
57
58    /// Current number of cached tokens.
59    pub fn seq_len(&self) -> usize {
60        self.seq_len
61    }
62
63    /// Maximum sequence length.
64    pub fn max_seq_len(&self) -> usize {
65        self.max_seq_len
66    }
67
68    /// Store a key vector for a specific layer, head, and position.
69    pub fn store_key(&mut self, layer: usize, head: usize, pos: usize, key: &[f32]) {
70        debug_assert!(layer < self.num_layers);
71        debug_assert!(head < self.num_kv_heads);
72        debug_assert!(pos < self.max_seq_len);
73        debug_assert_eq!(key.len(), self.head_dim);
74
75        let offset = self.cache_offset(layer, head, pos);
76        self.keys[offset..offset + self.head_dim].copy_from_slice(key);
77    }
78
79    /// Store a value vector for a specific layer, head, and position.
80    pub fn store_value(&mut self, layer: usize, head: usize, pos: usize, value: &[f32]) {
81        debug_assert!(layer < self.num_layers);
82        debug_assert!(head < self.num_kv_heads);
83        debug_assert!(pos < self.max_seq_len);
84        debug_assert_eq!(value.len(), self.head_dim);
85
86        let offset = self.cache_offset(layer, head, pos);
87        self.values[offset..offset + self.head_dim].copy_from_slice(value);
88    }
89
90    /// Get all cached keys for a layer and head up to `seq_len`.
91    ///
92    /// Returns a slice of [seq_len × head_dim] in row-major order.
93    pub fn keys_for(&self, layer: usize, head: usize, seq_len: usize) -> &[f32] {
94        let start = self.cache_offset(layer, head, 0);
95        let end = start + seq_len * self.head_dim;
96        &self.keys[start..end]
97    }
98
99    /// Get all cached values for a layer and head up to `seq_len`.
100    pub fn values_for(&self, layer: usize, head: usize, seq_len: usize) -> &[f32] {
101        let start = self.cache_offset(layer, head, 0);
102        let end = start + seq_len * self.head_dim;
103        &self.values[start..end]
104    }
105
106    /// Advance the sequence position by one token.
107    pub fn advance(&mut self) {
108        self.seq_len += 1;
109    }
110
111    /// Reset the cache (clear all stored KV pairs).
112    pub fn clear(&mut self) {
113        self.seq_len = 0;
114        // Optionally zero out for security, but not required for correctness
115    }
116
117    /// Compute flat offset into cache arrays.
118    fn cache_offset(&self, layer: usize, head: usize, pos: usize) -> usize {
119        ((layer * self.num_kv_heads + head) * self.max_seq_len + pos) * self.head_dim
120    }
121
122    /// Total memory used by this cache in bytes.
123    pub fn memory_bytes(&self) -> usize {
124        (self.keys.len() + self.values.len()) * std::mem::size_of::<f32>()
125    }
126
127    /// Utilization ratio: fraction of cache capacity currently used.
128    ///
129    /// Returns a value in [0.0, 1.0].
130    pub fn utilization_ratio(&self) -> f64 {
131        if self.max_seq_len == 0 {
132            return 0.0;
133        }
134        self.seq_len as f64 / self.max_seq_len as f64
135    }
136
137    /// Number of layers in this cache.
138    pub fn num_layers(&self) -> usize {
139        self.num_layers
140    }
141
142    /// Number of KV heads per layer.
143    pub fn num_kv_heads(&self) -> usize {
144        self.num_kv_heads
145    }
146
147    /// Head dimension.
148    pub fn head_dim(&self) -> usize {
149        self.head_dim
150    }
151
152    /// Manually set the cached sequence length.
153    ///
154    /// Used by the prefix-cache integration when restoring previously
155    /// computed KV blocks: after [`inject_block`](Self::inject_block) writes
156    /// the block contents, the consumer must call this to advertise the
157    /// number of valid positions to subsequent attention computations.
158    ///
159    /// `n` is clamped to `max_seq_len`.
160    pub fn set_seq_len(&mut self, n: usize) {
161        self.seq_len = n.min(self.max_seq_len);
162    }
163
164    /// Extract one prefix-cache block worth of KV for a single layer.
165    ///
166    /// Reads `block_size` consecutive positions starting at `start_pos` for
167    /// every KV head in `layer` and returns them in `[head][pos_in_block][dim]`
168    /// order, packed as a flat `Vec<f32>` of length
169    /// `num_kv_heads * block_size * head_dim`.
170    ///
171    /// Mirrors the layout used by [`crate::prefix_cache::CacheBlock`].
172    ///
173    /// Returns `(keys, values)`. If the requested range exceeds
174    /// `max_seq_len`, the trailing positions are returned as zeros.
175    pub fn extract_block(
176        &self,
177        layer: usize,
178        start_pos: usize,
179        block_size: usize,
180    ) -> (Vec<f32>, Vec<f32>) {
181        debug_assert!(layer < self.num_layers);
182        let per_layer = self.num_kv_heads * block_size * self.head_dim;
183        let mut keys = vec![0.0f32; per_layer];
184        let mut values = vec![0.0f32; per_layer];
185
186        for head in 0..self.num_kv_heads {
187            for off in 0..block_size {
188                let pos = start_pos + off;
189                if pos >= self.max_seq_len {
190                    continue;
191                }
192                let src = self.cache_offset(layer, head, pos);
193                let dst = (head * block_size + off) * self.head_dim;
194                keys[dst..dst + self.head_dim]
195                    .copy_from_slice(&self.keys[src..src + self.head_dim]);
196                values[dst..dst + self.head_dim]
197                    .copy_from_slice(&self.values[src..src + self.head_dim]);
198            }
199        }
200
201        (keys, values)
202    }
203
204    /// Inject a previously extracted block back into the cache for a single layer.
205    ///
206    /// `keys` and `values` must have the same `[head][pos_in_block][dim]`
207    /// layout produced by [`extract_block`](Self::extract_block); they are
208    /// expected to be of length `num_kv_heads * block_size * head_dim`.
209    /// Positions outside `max_seq_len` are silently skipped.
210    pub fn inject_block(
211        &mut self,
212        layer: usize,
213        start_pos: usize,
214        block_size: usize,
215        keys: &[f32],
216        values: &[f32],
217    ) {
218        debug_assert!(layer < self.num_layers);
219        let per_layer = self.num_kv_heads * block_size * self.head_dim;
220        debug_assert_eq!(keys.len(), per_layer);
221        debug_assert_eq!(values.len(), per_layer);
222
223        for head in 0..self.num_kv_heads {
224            for off in 0..block_size {
225                let pos = start_pos + off;
226                if pos >= self.max_seq_len {
227                    continue;
228                }
229                let src = (head * block_size + off) * self.head_dim;
230                let dst = self.cache_offset(layer, head, pos);
231                self.keys[dst..dst + self.head_dim]
232                    .copy_from_slice(&keys[src..src + self.head_dim]);
233                self.values[dst..dst + self.head_dim]
234                    .copy_from_slice(&values[src..src + self.head_dim]);
235            }
236        }
237    }
238}
239
240// ──────────────────────────────────────────────────────────────────
241// Paged KV Cache
242// ──────────────────────────────────────────────────────────────────
243
244/// Default number of positions per page.
245const DEFAULT_PAGE_SIZE: usize = 256;
246
247/// A single page in the paged KV cache.
248///
249/// Each page holds `page_size` positions worth of key and value data
250/// for a single layer and head.
251#[derive(Debug, Clone)]
252struct KvPage {
253    /// Key data: [page_size * head_dim] floats.
254    keys: Vec<f32>,
255    /// Value data: [page_size * head_dim] floats.
256    values: Vec<f32>,
257    /// Number of positions actually used in this page.
258    used: usize,
259}
260
261impl KvPage {
262    fn new(page_size: usize, head_dim: usize) -> Self {
263        Self {
264            keys: vec![0.0; page_size * head_dim],
265            values: vec![0.0; page_size * head_dim],
266            used: 0,
267        }
268    }
269}
270
271/// Page-based KV cache for memory-efficient allocation.
272///
273/// Instead of pre-allocating the full `max_seq_len` contiguously,
274/// pages of `page_size` positions are allocated on demand. This is
275/// beneficial when the actual sequence length is much shorter than
276/// `max_seq_len`.
277#[derive(Debug)]
278pub struct PagedKvCache {
279    /// Pages indexed as [layer][head][page_index].
280    pages: Vec<Vec<Vec<KvPage>>>,
281    /// Number of transformer layers.
282    num_layers: usize,
283    /// Number of KV heads per layer.
284    num_kv_heads: usize,
285    /// Dimension per head.
286    head_dim: usize,
287    /// Positions per page.
288    page_size: usize,
289    /// Maximum sequence length (total capacity).
290    max_seq_len: usize,
291    /// Current sequence length.
292    seq_len: usize,
293}
294
295impl PagedKvCache {
296    /// Create a new paged KV cache.
297    ///
298    /// Pages are allocated lazily as positions are stored.
299    pub fn new(
300        num_layers: usize,
301        num_kv_heads: usize,
302        head_dim: usize,
303        max_seq_len: usize,
304    ) -> Self {
305        Self::with_page_size(
306            num_layers,
307            num_kv_heads,
308            head_dim,
309            max_seq_len,
310            DEFAULT_PAGE_SIZE,
311        )
312    }
313
314    /// Create a new paged KV cache with a custom page size.
315    pub fn with_page_size(
316        num_layers: usize,
317        num_kv_heads: usize,
318        head_dim: usize,
319        max_seq_len: usize,
320        page_size: usize,
321    ) -> Self {
322        let pages = (0..num_layers)
323            .map(|_| (0..num_kv_heads).map(|_| Vec::new()).collect())
324            .collect();
325
326        Self {
327            pages,
328            num_layers,
329            num_kv_heads,
330            head_dim,
331            page_size,
332            max_seq_len,
333            seq_len: 0,
334        }
335    }
336
337    /// Store a key vector for a specific layer, head, and position.
338    pub fn store_key(&mut self, layer: usize, head: usize, pos: usize, key: &[f32]) {
339        debug_assert!(layer < self.num_layers);
340        debug_assert!(head < self.num_kv_heads);
341        debug_assert!(pos < self.max_seq_len);
342        debug_assert_eq!(key.len(), self.head_dim);
343
344        let page_idx = pos / self.page_size;
345        let offset_in_page = pos % self.page_size;
346
347        self.ensure_page(layer, head, page_idx);
348
349        let page = &mut self.pages[layer][head][page_idx];
350        let start = offset_in_page * self.head_dim;
351        page.keys[start..start + self.head_dim].copy_from_slice(key);
352        if offset_in_page >= page.used {
353            page.used = offset_in_page + 1;
354        }
355    }
356
357    /// Store a value vector for a specific layer, head, and position.
358    pub fn store_value(&mut self, layer: usize, head: usize, pos: usize, value: &[f32]) {
359        debug_assert!(layer < self.num_layers);
360        debug_assert!(head < self.num_kv_heads);
361        debug_assert!(pos < self.max_seq_len);
362        debug_assert_eq!(value.len(), self.head_dim);
363
364        let page_idx = pos / self.page_size;
365        let offset_in_page = pos % self.page_size;
366
367        self.ensure_page(layer, head, page_idx);
368
369        let page = &mut self.pages[layer][head][page_idx];
370        let start = offset_in_page * self.head_dim;
371        page.values[start..start + self.head_dim].copy_from_slice(value);
372        if offset_in_page >= page.used {
373            page.used = offset_in_page + 1;
374        }
375    }
376
377    /// Get all cached keys for a layer and head up to `seq_len`, assembled into a contiguous buffer.
378    pub fn keys_for(&self, layer: usize, head: usize, seq_len: usize) -> Vec<f32> {
379        let mut result = Vec::with_capacity(seq_len * self.head_dim);
380        let head_pages = &self.pages[layer][head];
381
382        for pos in 0..seq_len {
383            let page_idx = pos / self.page_size;
384            let offset_in_page = pos % self.page_size;
385
386            if page_idx < head_pages.len() {
387                let page = &head_pages[page_idx];
388                let start = offset_in_page * self.head_dim;
389                result.extend_from_slice(&page.keys[start..start + self.head_dim]);
390            } else {
391                // Page not yet allocated; fill with zeros
392                result.extend(std::iter::repeat_n(0.0f32, self.head_dim));
393            }
394        }
395
396        result
397    }
398
399    /// Get all cached values for a layer and head up to `seq_len`.
400    pub fn values_for(&self, layer: usize, head: usize, seq_len: usize) -> Vec<f32> {
401        let mut result = Vec::with_capacity(seq_len * self.head_dim);
402        let head_pages = &self.pages[layer][head];
403
404        for pos in 0..seq_len {
405            let page_idx = pos / self.page_size;
406            let offset_in_page = pos % self.page_size;
407
408            if page_idx < head_pages.len() {
409                let page = &head_pages[page_idx];
410                let start = offset_in_page * self.head_dim;
411                result.extend_from_slice(&page.values[start..start + self.head_dim]);
412            } else {
413                result.extend(std::iter::repeat_n(0.0f32, self.head_dim));
414            }
415        }
416
417        result
418    }
419
420    /// Current sequence length.
421    pub fn seq_len(&self) -> usize {
422        self.seq_len
423    }
424
425    /// Advance the sequence position by one token.
426    pub fn advance(&mut self) {
427        self.seq_len += 1;
428    }
429
430    /// Reset the cache (deallocate all pages).
431    pub fn clear(&mut self) {
432        self.seq_len = 0;
433        for layer_pages in &mut self.pages {
434            for head_pages in layer_pages.iter_mut() {
435                head_pages.clear();
436            }
437        }
438    }
439
440    /// Total memory currently allocated by this cache in bytes.
441    ///
442    /// Only counts allocated pages, not the full capacity.
443    pub fn memory_usage_bytes(&self) -> usize {
444        let mut total_pages = 0usize;
445        for layer_pages in &self.pages {
446            for head_pages in layer_pages {
447                total_pages += head_pages.len();
448            }
449        }
450        // Each page has keys + values, each of page_size * head_dim floats
451        total_pages * self.page_size * self.head_dim * std::mem::size_of::<f32>() * 2
452    }
453
454    /// Utilization ratio: fraction of allocated pages that are used.
455    pub fn utilization_ratio(&self) -> f64 {
456        let mut total_slots = 0usize;
457        let mut used_slots = 0usize;
458        for layer_pages in &self.pages {
459            for head_pages in layer_pages {
460                for page in head_pages {
461                    total_slots += self.page_size;
462                    used_slots += page.used;
463                }
464            }
465        }
466        if total_slots == 0 {
467            return 0.0;
468        }
469        used_slots as f64 / total_slots as f64
470    }
471
472    /// Total number of pages allocated.
473    pub fn total_pages(&self) -> usize {
474        let mut count = 0usize;
475        for layer_pages in &self.pages {
476            for head_pages in layer_pages {
477                count += head_pages.len();
478            }
479        }
480        count
481    }
482
483    /// Page size (positions per page).
484    pub fn page_size(&self) -> usize {
485        self.page_size
486    }
487
488    /// Ensure a page exists at the given index, allocating it if needed.
489    fn ensure_page(&mut self, layer: usize, head: usize, page_idx: usize) {
490        let head_pages = &mut self.pages[layer][head];
491        while head_pages.len() <= page_idx {
492            head_pages.push(KvPage::new(self.page_size, self.head_dim));
493        }
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500
501    #[test]
502    fn kv_cache_store_and_retrieve() {
503        let mut cache = KvCache::new(2, 8, 128, 16);
504
505        let key = vec![1.0f32; 128];
506        let value = vec![2.0f32; 128];
507
508        cache.store_key(0, 0, 0, &key);
509        cache.store_value(0, 0, 0, &value);
510        cache.advance();
511
512        let keys = cache.keys_for(0, 0, 1);
513        let values = cache.values_for(0, 0, 1);
514
515        assert_eq!(keys.len(), 128);
516        assert_eq!(values.len(), 128);
517        assert!((keys[0] - 1.0).abs() < 1e-5);
518        assert!((values[0] - 2.0).abs() < 1e-5);
519    }
520
521    #[test]
522    fn kv_cache_multiple_positions() {
523        let mut cache = KvCache::new(1, 1, 4, 8);
524
525        cache.store_key(0, 0, 0, &[1.0, 2.0, 3.0, 4.0]);
526        cache.advance();
527        cache.store_key(0, 0, 1, &[5.0, 6.0, 7.0, 8.0]);
528        cache.advance();
529
530        let keys = cache.keys_for(0, 0, 2);
531        assert_eq!(keys.len(), 8);
532        assert!((keys[0] - 1.0).abs() < 1e-5);
533        assert!((keys[4] - 5.0).abs() < 1e-5);
534    }
535
536    #[test]
537    fn kv_cache_memory_size() {
538        let cache = KvCache::new(36, 8, 128, 4096);
539        // 36 layers * 8 heads * 4096 seq * 128 dim * 4 bytes * 2 (K+V)
540        let expected = 36 * 8 * 4096 * 128 * 4 * 2;
541        assert_eq!(cache.memory_bytes(), expected);
542    }
543
544    #[test]
545    fn kv_cache_utilization() {
546        let mut cache = KvCache::new(1, 1, 4, 10);
547        assert!((cache.utilization_ratio() - 0.0).abs() < 1e-10);
548
549        cache.advance();
550        cache.advance();
551        cache.advance();
552        assert!((cache.utilization_ratio() - 0.3).abs() < 1e-10);
553    }
554
555    #[test]
556    fn kv_cache_policy_default() {
557        let policy = KvCachePolicy::default();
558        assert_eq!(policy, KvCachePolicy::Standard);
559    }
560
561    #[test]
562    fn kv_cache_set_seq_len_clamps_to_max() {
563        let mut cache = KvCache::new(1, 1, 4, 8);
564        cache.set_seq_len(4);
565        assert_eq!(cache.seq_len(), 4);
566        cache.set_seq_len(100);
567        assert_eq!(cache.seq_len(), 8); // clamped
568    }
569
570    #[test]
571    fn kv_cache_extract_inject_roundtrip() {
572        // Two layers, two KV heads, head_dim=4, block_size=4 → per_layer = 32 floats.
573        let num_layers = 2;
574        let num_kv_heads = 2;
575        let head_dim = 4;
576        let block_size = 4;
577        let max_seq = 16;
578        let mut cache = KvCache::new(num_layers, num_kv_heads, head_dim, max_seq);
579
580        // Populate layer 1 at positions 0..4 with deterministic key/value patterns.
581        for head in 0..num_kv_heads {
582            for pos in 0..block_size {
583                let key: Vec<f32> = (0..head_dim)
584                    .map(|d| (head as f32 + 1.0) * 100.0 + pos as f32 * 10.0 + d as f32)
585                    .collect();
586                let value: Vec<f32> = (0..head_dim)
587                    .map(|d| (head as f32 + 1.0) * 1000.0 + pos as f32 * 10.0 + d as f32)
588                    .collect();
589                cache.store_key(1, head, pos, &key);
590                cache.store_value(1, head, pos, &value);
591            }
592        }
593
594        // Extract, then inject into a fresh cache and re-extract.
595        let (k_block, v_block) = cache.extract_block(1, 0, block_size);
596        let per_layer = num_kv_heads * block_size * head_dim;
597        assert_eq!(k_block.len(), per_layer);
598        assert_eq!(v_block.len(), per_layer);
599
600        let mut fresh = KvCache::new(num_layers, num_kv_heads, head_dim, max_seq);
601        fresh.inject_block(1, 0, block_size, &k_block, &v_block);
602        fresh.set_seq_len(block_size);
603
604        let (k_block_2, v_block_2) = fresh.extract_block(1, 0, block_size);
605        assert_eq!(k_block_2, k_block);
606        assert_eq!(v_block_2, v_block);
607
608        // Re-read via keys_for / values_for to verify the position-major layout.
609        for head in 0..num_kv_heads {
610            let original_keys = cache.keys_for(1, head, block_size);
611            let restored_keys = fresh.keys_for(1, head, block_size);
612            assert_eq!(
613                original_keys, restored_keys,
614                "head {head} keys must round-trip"
615            );
616            let original_values = cache.values_for(1, head, block_size);
617            let restored_values = fresh.values_for(1, head, block_size);
618            assert_eq!(
619                original_values, restored_values,
620                "head {head} values must round-trip"
621            );
622        }
623    }
624
625    #[test]
626    fn kv_cache_extract_inject_at_offset() {
627        // Verify extract/inject behave correctly for non-zero start_pos.
628        let mut cache = KvCache::new(1, 1, 2, 16);
629        // Write a recognisable pattern at positions 4..8.
630        for pos in 0..4 {
631            let key = vec![pos as f32, pos as f32 + 0.5];
632            let value = vec![-(pos as f32), -(pos as f32) - 0.5];
633            cache.store_key(0, 0, 4 + pos, &key);
634            cache.store_value(0, 0, 4 + pos, &value);
635        }
636        let (k, v) = cache.extract_block(0, 4, 4);
637        let mut other = KvCache::new(1, 1, 2, 16);
638        other.inject_block(0, 4, 4, &k, &v);
639        for pos in 0..4 {
640            let original_k = cache.keys_for(0, 0, 8);
641            let restored_k = other.keys_for(0, 0, 8);
642            // positions 0..4 are zeros in both; positions 4..8 must match.
643            let off = (4 + pos) * 2;
644            assert!((restored_k[off] - original_k[off]).abs() < 1e-6);
645            assert!((restored_k[off + 1] - original_k[off + 1]).abs() < 1e-6);
646        }
647    }
648
649    // ── Paged KV Cache tests ──
650
651    #[test]
652    fn paged_kv_cache_store_and_retrieve() {
653        let mut cache = PagedKvCache::with_page_size(2, 1, 4, 16, 4);
654
655        let key = vec![1.0, 2.0, 3.0, 4.0];
656        let value = vec![5.0, 6.0, 7.0, 8.0];
657
658        cache.store_key(0, 0, 0, &key);
659        cache.store_value(0, 0, 0, &value);
660        cache.advance();
661
662        let keys = cache.keys_for(0, 0, 1);
663        let values = cache.values_for(0, 0, 1);
664
665        assert_eq!(keys.len(), 4);
666        assert_eq!(values.len(), 4);
667        assert!((keys[0] - 1.0).abs() < 1e-5);
668        assert!((values[0] - 5.0).abs() < 1e-5);
669    }
670
671    #[test]
672    fn paged_kv_cache_cross_page_boundary() {
673        let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 2);
674
675        // Store in page 0 (positions 0, 1)
676        cache.store_key(0, 0, 0, &[1.0, 2.0, 3.0, 4.0]);
677        cache.store_key(0, 0, 1, &[5.0, 6.0, 7.0, 8.0]);
678        // Store in page 1 (positions 2, 3)
679        cache.store_key(0, 0, 2, &[9.0, 10.0, 11.0, 12.0]);
680
681        let keys = cache.keys_for(0, 0, 3);
682        assert_eq!(keys.len(), 12);
683        assert!((keys[0] - 1.0).abs() < 1e-5);
684        assert!((keys[4] - 5.0).abs() < 1e-5);
685        assert!((keys[8] - 9.0).abs() < 1e-5);
686    }
687
688    #[test]
689    fn paged_kv_cache_lazy_allocation() {
690        let cache = PagedKvCache::with_page_size(1, 1, 4, 1024, 256);
691        assert_eq!(cache.total_pages(), 0);
692        assert_eq!(cache.memory_usage_bytes(), 0);
693    }
694
695    #[test]
696    fn paged_kv_cache_memory_grows() {
697        let mut cache = PagedKvCache::with_page_size(1, 1, 4, 1024, 4);
698
699        assert_eq!(cache.memory_usage_bytes(), 0);
700
701        cache.store_key(0, 0, 0, &[1.0; 4]);
702        // 1 page allocated: 4 positions * 4 dims * 4 bytes * 2 (K+V)
703        let one_page_bytes = 4 * 4 * 4 * 2;
704        assert_eq!(cache.memory_usage_bytes(), one_page_bytes);
705
706        // Trigger second page allocation
707        cache.store_key(0, 0, 4, &[1.0; 4]);
708        assert_eq!(cache.memory_usage_bytes(), one_page_bytes * 2);
709    }
710
711    #[test]
712    fn paged_kv_cache_clear() {
713        let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 4);
714        cache.store_key(0, 0, 0, &[1.0; 4]);
715        cache.advance();
716
717        assert!(cache.total_pages() > 0);
718        cache.clear();
719        assert_eq!(cache.total_pages(), 0);
720        assert_eq!(cache.seq_len(), 0);
721    }
722
723    #[test]
724    fn paged_kv_cache_utilization() {
725        let mut cache = PagedKvCache::with_page_size(1, 1, 4, 16, 4);
726        assert!((cache.utilization_ratio() - 0.0).abs() < 1e-10);
727
728        cache.store_key(0, 0, 0, &[1.0; 4]);
729        // 1 used out of 4 slots in 1 page = 0.25
730        assert!((cache.utilization_ratio() - 0.25).abs() < 1e-10);
731    }
732}