Skip to main content

oxillama_runtime/kv_cache/
mod.rs

1//! Key-Value cache for transformer attention.
2//!
3//! Stores the key and value tensors from previous tokens so they don't
4//! need to be recomputed during autoregressive generation.
5//!
6//! Three implementations are provided:
7//! - [`KvCache`]: Simple contiguous pre-allocated buffers (fast, simple)
8//! - [`PagedKvCache`]: Page-based allocation (memory-efficient, supports variable lengths)
9//! - [`PrefixKvCache`]: Radix-tree prefix sharing (reuse cached prefixes across requests)
10
11pub mod paged;
12pub mod prefix;
13
14use oxicode::{Decode, Encode};
15use oxillama_arch::traits::KvCacheAccess;
16use oxillama_arch::ArchResult;
17
18pub use oxillama_arch::traits::{BatchedKvView, KvSlot};
19pub use paged::PagedKvCache;
20pub use prefix::{PrefixCacheConfig, PrefixKvCache};
21
22/// A point-in-time snapshot of a [`KvCache`] state.
23///
24/// Created by [`KvCache::snapshot`] and restored via
25/// [`KvCache::restore_from_snapshot`].  Used by
26/// [`crate::speculative::SpeculativeDeltaSync`] to roll back the KV cache
27/// after a draft token is rejected.
28#[derive(Debug, Clone, Encode, Decode)]
29pub struct KvCacheSnapshot {
30    /// Per-layer key vectors, each of length `seq_len * kv_dim`.
31    pub keys: Vec<Vec<f32>>,
32    /// Per-layer value vectors, each of length `seq_len * kv_dim`.
33    pub values: Vec<Vec<f32>>,
34    /// The sequence length at snapshot time.
35    pub seq_len: usize,
36}
37
38/// Simple contiguous `BatchedKvView` backed by a `Vec<KvSlot>` paired with
39/// a pool of flat key/value buffers.
40///
41/// Each entry `i` refers to slot `slots[i]`.  The key buffer for slot `i`
42/// has length `position * kv_dim` and the value buffer likewise.
43pub struct VecBatchedKvView {
44    slots: Vec<KvSlot>,
45    /// Flat key buffers, one per slot, length `position * kv_dim`.
46    keys: Vec<Vec<f32>>,
47    /// Flat value buffers, one per slot, length `position * kv_dim`.
48    values: Vec<Vec<f32>>,
49}
50
51impl VecBatchedKvView {
52    /// Construct a new `VecBatchedKvView` from parallel vecs.
53    ///
54    /// # Panics
55    ///
56    /// Panics if `slots.len() != keys.len()` or `slots.len() != values.len()`.
57    pub fn new(slots: Vec<KvSlot>, keys: Vec<Vec<f32>>, values: Vec<Vec<f32>>) -> Self {
58        assert_eq!(
59            slots.len(),
60            keys.len(),
61            "slots and keys vecs must have equal length"
62        );
63        assert_eq!(
64            slots.len(),
65            values.len(),
66            "slots and values vecs must have equal length"
67        );
68        Self {
69            slots,
70            keys,
71            values,
72        }
73    }
74}
75
76impl BatchedKvView for VecBatchedKvView {
77    fn slot_count(&self) -> usize {
78        self.slots.len()
79    }
80
81    fn kv_for_slot(&self, slot: usize) -> (&[f32], &[f32]) {
82        (&self.keys[slot], &self.values[slot])
83    }
84
85    fn position(&self, slot: usize) -> usize {
86        self.slots[slot].position
87    }
88}
89
90/// Simple contiguous KV cache implementation.
91///
92/// Stores key and value tensors for all layers in contiguous FP32 buffers.
93/// Each layer has a separate key buffer and value buffer, sized for the
94/// maximum context length.
95pub struct KvCache {
96    /// Key buffers: one per layer, each of size [max_seq_len * kv_dim].
97    keys: Vec<Vec<f32>>,
98    /// Value buffers: one per layer, each of size [max_seq_len * kv_dim].
99    values: Vec<Vec<f32>>,
100    /// Current sequence length (number of fully-committed tokens).
101    seq_len: usize,
102    /// Number of token positions that have had K/V data written.
103    ///
104    /// Invariant: `stored_len >= seq_len`.  Between a `store_kv` call at
105    /// position `seq_len` and the subsequent `advance()`, `stored_len ==
106    /// seq_len + 1` so that attention can immediately read the just-written
107    /// entry without requiring `advance()` to have been called first.
108    stored_len: usize,
109    /// Maximum sequence length.
110    max_seq_len: usize,
111    /// KV dimension per token (num_kv_heads * head_dim).
112    kv_dim: usize,
113    /// Number of layers.
114    num_layers: usize,
115}
116
117impl KvCache {
118    /// Allocate a new KV cache.
119    ///
120    /// # Arguments
121    /// * `num_layers` - Number of transformer layers.
122    /// * `max_seq_len` - Maximum context length.
123    /// * `kv_dim` - KV dimension per token (num_kv_heads * head_dim).
124    pub fn new(num_layers: usize, max_seq_len: usize, kv_dim: usize) -> Self {
125        let keys = (0..num_layers)
126            .map(|_| vec![0.0f32; max_seq_len * kv_dim])
127            .collect();
128        let values = (0..num_layers)
129            .map(|_| vec![0.0f32; max_seq_len * kv_dim])
130            .collect();
131
132        Self {
133            keys,
134            values,
135            seq_len: 0,
136            stored_len: 0,
137            max_seq_len,
138            kv_dim,
139            num_layers,
140        }
141    }
142
143    /// Reset the cache, clearing all stored KV pairs.
144    pub fn clear(&mut self) {
145        self.seq_len = 0;
146        self.stored_len = 0;
147        for k in &mut self.keys {
148            k.fill(0.0);
149        }
150        for v in &mut self.values {
151            v.fill(0.0);
152        }
153    }
154
155    /// Returns the maximum sequence length.
156    pub fn max_seq_len(&self) -> usize {
157        self.max_seq_len
158    }
159
160    /// Returns the KV dimension per token.
161    pub fn kv_dim(&self) -> usize {
162        self.kv_dim
163    }
164
165    /// Returns the number of layers.
166    pub fn num_layers(&self) -> usize {
167        self.num_layers
168    }
169
170    /// Advance the sequence position by one token.
171    pub fn advance(&mut self) {
172        if self.seq_len < self.max_seq_len {
173            self.seq_len += 1;
174            if self.stored_len < self.seq_len {
175                self.stored_len = self.seq_len;
176            }
177        }
178    }
179
180    /// Restore from a prefix cache snapshot.
181    ///
182    /// Copies the provided per-layer key/value data into internal buffers
183    /// and sets `seq_len` to the snapshot's length.  The caller must ensure
184    /// that `keys.len() == values.len() == num_layers` and that each inner
185    /// vec has `seq_len * kv_dim` elements.
186    pub fn restore_from_snapshot(
187        &mut self,
188        keys: &[Vec<f32>],
189        values: &[Vec<f32>],
190        seq_len: usize,
191    ) {
192        let layers = keys.len().min(values.len()).min(self.num_layers);
193        let copy_len = seq_len * self.kv_dim;
194
195        for layer in 0..layers {
196            let src_k = &keys[layer];
197            let src_v = &values[layer];
198            let n = copy_len.min(src_k.len()).min(self.keys[layer].len());
199            self.keys[layer][..n].copy_from_slice(&src_k[..n]);
200            let n = copy_len.min(src_v.len()).min(self.values[layer].len());
201            self.values[layer][..n].copy_from_slice(&src_v[..n]);
202        }
203
204        self.seq_len = seq_len.min(self.max_seq_len);
205        self.stored_len = self.seq_len;
206    }
207
208    /// Truncate the KV cache to `n` tokens.
209    ///
210    /// After this call `seq_len()` returns `n` (clamped to the current
211    /// `seq_len` if `n` is already beyond it — truncate never extends the
212    /// cache).  The underlying buffers are **not** zeroed; the truncated
213    /// region is simply considered invalid and will be overwritten on the
214    /// next `store_kv` call.
215    ///
216    /// This is the low-level primitive for speculative-decoding rollback: the
217    /// target engine calls `truncate(divergence_pos)` after rejecting a draft
218    /// token, then continues generating from `divergence_pos`.
219    pub fn truncate(&mut self, n: usize) {
220        let n = n.min(self.seq_len);
221        self.seq_len = n;
222        self.stored_len = n;
223    }
224
225    /// Capture a snapshot of the current KV state.
226    ///
227    /// Only the data up to `seq_len * kv_dim` is copied per layer, keeping
228    /// the snapshot compact.
229    pub fn snapshot(&self) -> KvCacheSnapshot {
230        let copy_len = self.seq_len * self.kv_dim;
231        let keys = self
232            .keys
233            .iter()
234            .map(|k| k[..copy_len.min(k.len())].to_vec())
235            .collect();
236        let values = self
237            .values
238            .iter()
239            .map(|v| v[..copy_len.min(v.len())].to_vec())
240            .collect();
241        KvCacheSnapshot {
242            keys,
243            values,
244            seq_len: self.seq_len,
245        }
246    }
247
248    /// Build a serializable [`crate::snapshot::KvStatePayload`] from the current state.
249    pub fn to_payload(&self) -> crate::snapshot::KvStatePayload {
250        let copy_len = self.seq_len * self.kv_dim;
251        let keys = self
252            .keys
253            .iter()
254            .map(|k| k[..copy_len.min(k.len())].to_vec())
255            .collect();
256        let values = self
257            .values
258            .iter()
259            .map(|v| v[..copy_len.min(v.len())].to_vec())
260            .collect();
261        crate::snapshot::KvStatePayload {
262            keys,
263            values,
264            seq_len: self.seq_len,
265            num_layers: self.num_layers,
266            max_seq_len: self.max_seq_len,
267            kv_dim: self.kv_dim,
268        }
269    }
270
271    /// Restore cache state from a [`crate::snapshot::KvStatePayload`].
272    ///
273    /// Validates that layer count and dimensions match the cache configuration,
274    /// then restores the key/value buffers and sequence length.
275    pub fn restore_from_payload(
276        &mut self,
277        payload: &crate::snapshot::KvStatePayload,
278    ) -> crate::error::RuntimeResult<()> {
279        use crate::error::RuntimeError;
280        if payload.num_layers != self.num_layers {
281            return Err(RuntimeError::SnapshotIncompatible {
282                detail: format!(
283                    "layer count mismatch: snapshot has {}, cache has {}",
284                    payload.num_layers, self.num_layers
285                ),
286            });
287        }
288        if payload.kv_dim != self.kv_dim {
289            return Err(RuntimeError::SnapshotIncompatible {
290                detail: format!(
291                    "kv_dim mismatch: snapshot has {}, cache has {}",
292                    payload.kv_dim, self.kv_dim
293                ),
294            });
295        }
296        self.restore_from_snapshot(&payload.keys, &payload.values, payload.seq_len);
297        Ok(())
298    }
299}
300
301impl KvCacheAccess for KvCache {
302    fn seq_len(&self) -> usize {
303        self.seq_len
304    }
305
306    fn store_kv(&mut self, layer: usize, key: &[f32], value: &[f32]) -> ArchResult<()> {
307        if layer >= self.num_layers {
308            return Err(oxillama_arch::ArchError::ForwardPassError {
309                layer,
310                message: format!("layer index {layer} out of range (max {})", self.num_layers),
311            });
312        }
313
314        let offset = self.seq_len * self.kv_dim;
315        let end = offset + self.kv_dim;
316
317        if end <= self.keys[layer].len() {
318            self.keys[layer][offset..end].copy_from_slice(&key[..self.kv_dim]);
319            self.values[layer][offset..end].copy_from_slice(&value[..self.kv_dim]);
320            // Ensure get_keys/get_values can see the entry we just wrote even
321            // before advance() is called (advance is called once per token
322            // after ALL layers have written their K/V, but attention reads
323            // back during the same forward pass).
324            if self.stored_len <= self.seq_len {
325                self.stored_len = self.seq_len + 1;
326            }
327        }
328
329        Ok(())
330    }
331
332    fn get_keys(&self, layer: usize) -> ArchResult<&[f32]> {
333        if layer >= self.num_layers {
334            return Err(oxillama_arch::ArchError::ForwardPassError {
335                layer,
336                message: format!("layer index {layer} out of range (max {})", self.num_layers),
337            });
338        }
339        let end = self.stored_len * self.kv_dim;
340        Ok(&self.keys[layer][..end])
341    }
342
343    fn get_values(&self, layer: usize) -> ArchResult<&[f32]> {
344        if layer >= self.num_layers {
345            return Err(oxillama_arch::ArchError::ForwardPassError {
346                layer,
347                message: format!("layer index {layer} out of range (max {})", self.num_layers),
348            });
349        }
350        let end = self.stored_len * self.kv_dim;
351        Ok(&self.values[layer][..end])
352    }
353
354    fn advance(&mut self) {
355        if self.seq_len < self.max_seq_len {
356            self.seq_len += 1;
357            // stored_len must always be >= seq_len.
358            if self.stored_len < self.seq_len {
359                self.stored_len = self.seq_len;
360            }
361        }
362    }
363
364    fn kv_dim(&self) -> usize {
365        self.kv_dim
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    // ── Construction ─────────────────────────────────────────────────────────
374
375    #[test]
376    fn test_new_starts_at_zero_seq_len() {
377        let cache = KvCache::new(4, 128, 64);
378        assert_eq!(cache.seq_len(), 0);
379    }
380
381    #[test]
382    fn test_new_stores_dimensions() {
383        let cache = KvCache::new(8, 512, 128);
384        assert_eq!(cache.num_layers(), 8);
385        assert_eq!(cache.max_seq_len(), 512);
386        assert_eq!(cache.kv_dim(), 128);
387    }
388
389    // ── advance ──────────────────────────────────────────────────────────────
390
391    #[test]
392    fn test_advance_increments_seq_len() {
393        let mut cache = KvCache::new(2, 8, 4);
394        assert_eq!(cache.seq_len(), 0);
395        cache.advance();
396        assert_eq!(cache.seq_len(), 1);
397        cache.advance();
398        assert_eq!(cache.seq_len(), 2);
399    }
400
401    #[test]
402    fn test_advance_capped_at_max_seq_len() {
403        let max = 3;
404        let mut cache = KvCache::new(1, max, 4);
405        for _ in 0..max + 5 {
406            cache.advance();
407        }
408        assert_eq!(cache.seq_len(), max, "seq_len must not exceed max_seq_len");
409    }
410
411    #[test]
412    fn test_kvcache_access_advance_also_increments() {
413        let mut cache = KvCache::new(2, 8, 4);
414        // KvCacheAccess::advance should behave identically.
415        <KvCache as KvCacheAccess>::advance(&mut cache);
416        assert_eq!(cache.seq_len(), 1);
417    }
418
419    // ── clear ────────────────────────────────────────────────────────────────
420
421    #[test]
422    fn test_clear_resets_seq_len_to_zero() {
423        let mut cache = KvCache::new(2, 8, 4);
424        cache.advance();
425        cache.advance();
426        assert_eq!(cache.seq_len(), 2);
427        cache.clear();
428        assert_eq!(cache.seq_len(), 0);
429    }
430
431    #[test]
432    fn test_clear_zeros_stored_data() {
433        let kv_dim = 4;
434        let mut cache = KvCache::new(1, 8, kv_dim);
435
436        // Write some data and advance.
437        let key = vec![1.0f32, 2.0, 3.0, 4.0];
438        let val = vec![5.0f32, 6.0, 7.0, 8.0];
439        cache
440            .store_kv(0, &key, &val)
441            .expect("store_kv must succeed");
442        cache.advance();
443
444        cache.clear();
445
446        // After clear the seq_len is 0, so get_keys returns empty slice.
447        let keys = cache.get_keys(0).expect("get_keys must succeed");
448        assert!(
449            keys.is_empty(),
450            "after clear, get_keys should return empty slice"
451        );
452    }
453
454    // ── store_kv / get_keys / get_values round-trip ───────────────────────
455
456    #[test]
457    fn test_store_kv_and_get_keys_round_trip() {
458        let kv_dim = 8;
459        let mut cache = KvCache::new(2, 16, kv_dim);
460
461        let key: Vec<f32> = (0..kv_dim as i32).map(|i| i as f32 * 0.1).collect();
462        let val: Vec<f32> = (0..kv_dim as i32).map(|i| i as f32 * -0.1).collect();
463
464        cache.store_kv(0, &key, &val).expect("store_kv layer 0");
465        cache.advance();
466
467        let stored_keys = cache.get_keys(0).expect("get_keys layer 0");
468        assert_eq!(stored_keys.len(), kv_dim, "should have kv_dim floats");
469        for (i, (&got, &expected)) in stored_keys.iter().zip(key.iter()).enumerate() {
470            assert!(
471                (got - expected).abs() < 1e-7,
472                "key[{i}]: got {got}, expected {expected}"
473            );
474        }
475    }
476
477    #[test]
478    fn test_store_kv_and_get_values_round_trip() {
479        let kv_dim = 4;
480        let mut cache = KvCache::new(1, 8, kv_dim);
481
482        let key = vec![0.0f32; kv_dim];
483        let val = vec![1.1f32, 2.2, 3.3, 4.4];
484
485        cache.store_kv(0, &key, &val).expect("store_kv");
486        cache.advance();
487
488        let stored_vals = cache.get_values(0).expect("get_values");
489        assert_eq!(stored_vals.len(), kv_dim);
490        for (i, (&got, &expected)) in stored_vals.iter().zip(val.iter()).enumerate() {
491            assert!(
492                (got - expected).abs() < 1e-6,
493                "value[{i}]: got {got}, expected {expected}"
494            );
495        }
496    }
497
498    #[test]
499    fn test_store_kv_accumulates_across_tokens() {
500        let kv_dim = 2;
501        let mut cache = KvCache::new(1, 8, kv_dim);
502
503        for t in 0..3u32 {
504            let key = vec![t as f32, t as f32 + 0.5];
505            let val = vec![0.0f32; kv_dim];
506            cache.store_kv(0, &key, &val).expect("store_kv");
507            cache.advance();
508        }
509
510        let keys = cache.get_keys(0).expect("get_keys");
511        assert_eq!(
512            keys.len(),
513            3 * kv_dim,
514            "should have 3 tokens × kv_dim floats"
515        );
516        // Verify first token keys.
517        assert!((keys[0] - 0.0).abs() < 1e-7);
518        assert!((keys[1] - 0.5).abs() < 1e-7);
519        // Verify second token keys.
520        assert!((keys[2] - 1.0).abs() < 1e-7);
521    }
522
523    // ── out-of-range layer errors ────────────────────────────────────────────
524
525    #[test]
526    fn test_store_kv_out_of_range_layer_returns_error() {
527        let mut cache = KvCache::new(2, 8, 4);
528        let key = vec![0.0f32; 4];
529        let val = vec![0.0f32; 4];
530        let result = cache.store_kv(99, &key, &val);
531        assert!(result.is_err(), "out-of-range layer should return error");
532    }
533
534    #[test]
535    fn test_get_keys_out_of_range_layer_returns_error() {
536        let cache = KvCache::new(2, 8, 4);
537        let result = cache.get_keys(99);
538        assert!(result.is_err(), "out-of-range layer should return error");
539    }
540
541    #[test]
542    fn test_get_values_out_of_range_layer_returns_error() {
543        let cache = KvCache::new(2, 8, 4);
544        let result = cache.get_values(99);
545        assert!(result.is_err(), "out-of-range layer should return error");
546    }
547
548    // ── multi-layer independence ─────────────────────────────────────────────
549
550    #[test]
551    fn test_store_kv_different_layers_independent() {
552        let kv_dim = 4;
553        let mut cache = KvCache::new(2, 8, kv_dim);
554
555        let key0 = vec![1.0f32; kv_dim];
556        let key1 = vec![2.0f32; kv_dim];
557        let val0 = vec![3.0f32; kv_dim];
558        let val1 = vec![4.0f32; kv_dim];
559
560        cache.store_kv(0, &key0, &val0).expect("layer 0 store");
561        cache.store_kv(1, &key1, &val1).expect("layer 1 store");
562        cache.advance();
563
564        let stored0 = cache.get_keys(0).expect("layer 0 keys");
565        let stored1 = cache.get_keys(1).expect("layer 1 keys");
566
567        for &v in stored0 {
568            assert!((v - 1.0).abs() < 1e-7, "layer 0 key should be 1.0");
569        }
570        for &v in stored1 {
571            assert!((v - 2.0).abs() < 1e-7, "layer 1 key should be 2.0");
572        }
573    }
574
575    // ── for_each_key / for_each_value iteration ─────────────────────────────
576
577    #[test]
578    fn kv_cache_for_each_key_contiguous() {
579        use oxillama_arch::traits::KvCacheAccess;
580
581        let kv_dim = 4usize;
582        let mut cache = KvCache::new(1, 16, kv_dim);
583
584        // Store 4 tokens in layer 0.
585        for t in 0..4u32 {
586            let key: Vec<f32> = (0..kv_dim).map(|d| t as f32 * 10.0 + d as f32).collect();
587            let val: Vec<f32> = (0..kv_dim).map(|d| t as f32 * 100.0 + d as f32).collect();
588            cache.store_kv(0, &key, &val).expect("store_kv");
589            cache.advance();
590        }
591
592        // Collect callbacks via for_each_key.
593        let mut positions_seen: Vec<usize> = Vec::new();
594        let mut keys_seen: Vec<Vec<f32>> = Vec::new();
595        cache
596            .for_each_key(0, &mut |pos, slice| {
597                positions_seen.push(pos);
598                keys_seen.push(slice.to_vec());
599            })
600            .expect("for_each_key must succeed");
601
602        assert_eq!(positions_seen.len(), 4, "must visit all 4 positions");
603        assert_eq!(
604            positions_seen,
605            vec![0, 1, 2, 3],
606            "positions must be in order"
607        );
608
609        // Check key data for each position.
610        for (t, key_row) in keys_seen.iter().enumerate() {
611            assert_eq!(key_row.len(), kv_dim, "key row must have kv_dim elements");
612            for (d, &v) in key_row.iter().enumerate() {
613                let expected = t as f32 * 10.0 + d as f32;
614                assert!(
615                    (v - expected).abs() < 1e-6,
616                    "token {t} dim {d}: expected {expected}, got {v}"
617                );
618            }
619        }
620    }
621
622    #[test]
623    fn kv_cache_for_each_value_contiguous() {
624        use oxillama_arch::traits::KvCacheAccess;
625
626        let kv_dim = 3usize;
627        let mut cache = KvCache::new(1, 8, kv_dim);
628
629        for t in 0..3u32 {
630            let key = vec![0.0f32; kv_dim];
631            let val: Vec<f32> = (0..kv_dim).map(|d| t as f32 + d as f32 * 0.1).collect();
632            cache.store_kv(0, &key, &val).expect("store_kv");
633            cache.advance();
634        }
635
636        let mut count = 0usize;
637        cache
638            .for_each_value(0, &mut |_pos, slice| {
639                assert_eq!(slice.len(), kv_dim);
640                count += 1;
641            })
642            .expect("for_each_value must succeed");
643        assert_eq!(count, 3, "must visit 3 value rows");
644    }
645}