1pub mod paged;
12pub mod prefix;
13
14use oxicode::{Decode, Encode};
15use oxillama_arch::traits::KvCacheAccess;
16use oxillama_arch::ArchResult;
17
18pub use oxillama_arch::traits::{BatchedKvView, KvSlot};
19pub use paged::PagedKvCache;
20pub use prefix::{PrefixCacheConfig, PrefixKvCache};
21
22#[derive(Debug, Clone, Encode, Decode)]
29pub struct KvCacheSnapshot {
30 pub keys: Vec<Vec<f32>>,
32 pub values: Vec<Vec<f32>>,
34 pub seq_len: usize,
36}
37
38pub struct VecBatchedKvView {
44 slots: Vec<KvSlot>,
45 keys: Vec<Vec<f32>>,
47 values: Vec<Vec<f32>>,
49}
50
51impl VecBatchedKvView {
52 pub fn new(slots: Vec<KvSlot>, keys: Vec<Vec<f32>>, values: Vec<Vec<f32>>) -> Self {
58 assert_eq!(
59 slots.len(),
60 keys.len(),
61 "slots and keys vecs must have equal length"
62 );
63 assert_eq!(
64 slots.len(),
65 values.len(),
66 "slots and values vecs must have equal length"
67 );
68 Self {
69 slots,
70 keys,
71 values,
72 }
73 }
74}
75
76impl BatchedKvView for VecBatchedKvView {
77 fn slot_count(&self) -> usize {
78 self.slots.len()
79 }
80
81 fn kv_for_slot(&self, slot: usize) -> (&[f32], &[f32]) {
82 (&self.keys[slot], &self.values[slot])
83 }
84
85 fn position(&self, slot: usize) -> usize {
86 self.slots[slot].position
87 }
88}
89
90pub struct KvCache {
96 keys: Vec<Vec<f32>>,
98 values: Vec<Vec<f32>>,
100 seq_len: usize,
102 stored_len: usize,
109 max_seq_len: usize,
111 kv_dim: usize,
113 num_layers: usize,
115}
116
117impl KvCache {
118 pub fn new(num_layers: usize, max_seq_len: usize, kv_dim: usize) -> Self {
125 let keys = (0..num_layers)
126 .map(|_| vec![0.0f32; max_seq_len * kv_dim])
127 .collect();
128 let values = (0..num_layers)
129 .map(|_| vec![0.0f32; max_seq_len * kv_dim])
130 .collect();
131
132 Self {
133 keys,
134 values,
135 seq_len: 0,
136 stored_len: 0,
137 max_seq_len,
138 kv_dim,
139 num_layers,
140 }
141 }
142
143 pub fn clear(&mut self) {
145 self.seq_len = 0;
146 self.stored_len = 0;
147 for k in &mut self.keys {
148 k.fill(0.0);
149 }
150 for v in &mut self.values {
151 v.fill(0.0);
152 }
153 }
154
155 pub fn max_seq_len(&self) -> usize {
157 self.max_seq_len
158 }
159
160 pub fn kv_dim(&self) -> usize {
162 self.kv_dim
163 }
164
165 pub fn num_layers(&self) -> usize {
167 self.num_layers
168 }
169
170 pub fn advance(&mut self) {
172 if self.seq_len < self.max_seq_len {
173 self.seq_len += 1;
174 if self.stored_len < self.seq_len {
175 self.stored_len = self.seq_len;
176 }
177 }
178 }
179
180 pub fn restore_from_snapshot(
187 &mut self,
188 keys: &[Vec<f32>],
189 values: &[Vec<f32>],
190 seq_len: usize,
191 ) {
192 let layers = keys.len().min(values.len()).min(self.num_layers);
193 let copy_len = seq_len * self.kv_dim;
194
195 for layer in 0..layers {
196 let src_k = &keys[layer];
197 let src_v = &values[layer];
198 let n = copy_len.min(src_k.len()).min(self.keys[layer].len());
199 self.keys[layer][..n].copy_from_slice(&src_k[..n]);
200 let n = copy_len.min(src_v.len()).min(self.values[layer].len());
201 self.values[layer][..n].copy_from_slice(&src_v[..n]);
202 }
203
204 self.seq_len = seq_len.min(self.max_seq_len);
205 self.stored_len = self.seq_len;
206 }
207
208 pub fn truncate(&mut self, n: usize) {
220 let n = n.min(self.seq_len);
221 self.seq_len = n;
222 self.stored_len = n;
223 }
224
225 pub fn snapshot(&self) -> KvCacheSnapshot {
230 let copy_len = self.seq_len * self.kv_dim;
231 let keys = self
232 .keys
233 .iter()
234 .map(|k| k[..copy_len.min(k.len())].to_vec())
235 .collect();
236 let values = self
237 .values
238 .iter()
239 .map(|v| v[..copy_len.min(v.len())].to_vec())
240 .collect();
241 KvCacheSnapshot {
242 keys,
243 values,
244 seq_len: self.seq_len,
245 }
246 }
247
248 pub fn to_payload(&self) -> crate::snapshot::KvStatePayload {
250 let copy_len = self.seq_len * self.kv_dim;
251 let keys = self
252 .keys
253 .iter()
254 .map(|k| k[..copy_len.min(k.len())].to_vec())
255 .collect();
256 let values = self
257 .values
258 .iter()
259 .map(|v| v[..copy_len.min(v.len())].to_vec())
260 .collect();
261 crate::snapshot::KvStatePayload {
262 keys,
263 values,
264 seq_len: self.seq_len,
265 num_layers: self.num_layers,
266 max_seq_len: self.max_seq_len,
267 kv_dim: self.kv_dim,
268 }
269 }
270
271 pub fn restore_from_payload(
276 &mut self,
277 payload: &crate::snapshot::KvStatePayload,
278 ) -> crate::error::RuntimeResult<()> {
279 use crate::error::RuntimeError;
280 if payload.num_layers != self.num_layers {
281 return Err(RuntimeError::SnapshotIncompatible {
282 detail: format!(
283 "layer count mismatch: snapshot has {}, cache has {}",
284 payload.num_layers, self.num_layers
285 ),
286 });
287 }
288 if payload.kv_dim != self.kv_dim {
289 return Err(RuntimeError::SnapshotIncompatible {
290 detail: format!(
291 "kv_dim mismatch: snapshot has {}, cache has {}",
292 payload.kv_dim, self.kv_dim
293 ),
294 });
295 }
296 self.restore_from_snapshot(&payload.keys, &payload.values, payload.seq_len);
297 Ok(())
298 }
299}
300
301impl KvCacheAccess for KvCache {
302 fn seq_len(&self) -> usize {
303 self.seq_len
304 }
305
306 fn store_kv(&mut self, layer: usize, key: &[f32], value: &[f32]) -> ArchResult<()> {
307 if layer >= self.num_layers {
308 return Err(oxillama_arch::ArchError::ForwardPassError {
309 layer,
310 message: format!("layer index {layer} out of range (max {})", self.num_layers),
311 });
312 }
313
314 let offset = self.seq_len * self.kv_dim;
315 let end = offset + self.kv_dim;
316
317 if end <= self.keys[layer].len() {
318 self.keys[layer][offset..end].copy_from_slice(&key[..self.kv_dim]);
319 self.values[layer][offset..end].copy_from_slice(&value[..self.kv_dim]);
320 if self.stored_len <= self.seq_len {
325 self.stored_len = self.seq_len + 1;
326 }
327 }
328
329 Ok(())
330 }
331
332 fn get_keys(&self, layer: usize) -> ArchResult<&[f32]> {
333 if layer >= self.num_layers {
334 return Err(oxillama_arch::ArchError::ForwardPassError {
335 layer,
336 message: format!("layer index {layer} out of range (max {})", self.num_layers),
337 });
338 }
339 let end = self.stored_len * self.kv_dim;
340 Ok(&self.keys[layer][..end])
341 }
342
343 fn get_values(&self, layer: usize) -> ArchResult<&[f32]> {
344 if layer >= self.num_layers {
345 return Err(oxillama_arch::ArchError::ForwardPassError {
346 layer,
347 message: format!("layer index {layer} out of range (max {})", self.num_layers),
348 });
349 }
350 let end = self.stored_len * self.kv_dim;
351 Ok(&self.values[layer][..end])
352 }
353
354 fn advance(&mut self) {
355 if self.seq_len < self.max_seq_len {
356 self.seq_len += 1;
357 if self.stored_len < self.seq_len {
359 self.stored_len = self.seq_len;
360 }
361 }
362 }
363
364 fn kv_dim(&self) -> usize {
365 self.kv_dim
366 }
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
376 fn test_new_starts_at_zero_seq_len() {
377 let cache = KvCache::new(4, 128, 64);
378 assert_eq!(cache.seq_len(), 0);
379 }
380
381 #[test]
382 fn test_new_stores_dimensions() {
383 let cache = KvCache::new(8, 512, 128);
384 assert_eq!(cache.num_layers(), 8);
385 assert_eq!(cache.max_seq_len(), 512);
386 assert_eq!(cache.kv_dim(), 128);
387 }
388
389 #[test]
392 fn test_advance_increments_seq_len() {
393 let mut cache = KvCache::new(2, 8, 4);
394 assert_eq!(cache.seq_len(), 0);
395 cache.advance();
396 assert_eq!(cache.seq_len(), 1);
397 cache.advance();
398 assert_eq!(cache.seq_len(), 2);
399 }
400
401 #[test]
402 fn test_advance_capped_at_max_seq_len() {
403 let max = 3;
404 let mut cache = KvCache::new(1, max, 4);
405 for _ in 0..max + 5 {
406 cache.advance();
407 }
408 assert_eq!(cache.seq_len(), max, "seq_len must not exceed max_seq_len");
409 }
410
411 #[test]
412 fn test_kvcache_access_advance_also_increments() {
413 let mut cache = KvCache::new(2, 8, 4);
414 <KvCache as KvCacheAccess>::advance(&mut cache);
416 assert_eq!(cache.seq_len(), 1);
417 }
418
419 #[test]
422 fn test_clear_resets_seq_len_to_zero() {
423 let mut cache = KvCache::new(2, 8, 4);
424 cache.advance();
425 cache.advance();
426 assert_eq!(cache.seq_len(), 2);
427 cache.clear();
428 assert_eq!(cache.seq_len(), 0);
429 }
430
431 #[test]
432 fn test_clear_zeros_stored_data() {
433 let kv_dim = 4;
434 let mut cache = KvCache::new(1, 8, kv_dim);
435
436 let key = vec![1.0f32, 2.0, 3.0, 4.0];
438 let val = vec![5.0f32, 6.0, 7.0, 8.0];
439 cache
440 .store_kv(0, &key, &val)
441 .expect("store_kv must succeed");
442 cache.advance();
443
444 cache.clear();
445
446 let keys = cache.get_keys(0).expect("get_keys must succeed");
448 assert!(
449 keys.is_empty(),
450 "after clear, get_keys should return empty slice"
451 );
452 }
453
454 #[test]
457 fn test_store_kv_and_get_keys_round_trip() {
458 let kv_dim = 8;
459 let mut cache = KvCache::new(2, 16, kv_dim);
460
461 let key: Vec<f32> = (0..kv_dim as i32).map(|i| i as f32 * 0.1).collect();
462 let val: Vec<f32> = (0..kv_dim as i32).map(|i| i as f32 * -0.1).collect();
463
464 cache.store_kv(0, &key, &val).expect("store_kv layer 0");
465 cache.advance();
466
467 let stored_keys = cache.get_keys(0).expect("get_keys layer 0");
468 assert_eq!(stored_keys.len(), kv_dim, "should have kv_dim floats");
469 for (i, (&got, &expected)) in stored_keys.iter().zip(key.iter()).enumerate() {
470 assert!(
471 (got - expected).abs() < 1e-7,
472 "key[{i}]: got {got}, expected {expected}"
473 );
474 }
475 }
476
477 #[test]
478 fn test_store_kv_and_get_values_round_trip() {
479 let kv_dim = 4;
480 let mut cache = KvCache::new(1, 8, kv_dim);
481
482 let key = vec![0.0f32; kv_dim];
483 let val = vec![1.1f32, 2.2, 3.3, 4.4];
484
485 cache.store_kv(0, &key, &val).expect("store_kv");
486 cache.advance();
487
488 let stored_vals = cache.get_values(0).expect("get_values");
489 assert_eq!(stored_vals.len(), kv_dim);
490 for (i, (&got, &expected)) in stored_vals.iter().zip(val.iter()).enumerate() {
491 assert!(
492 (got - expected).abs() < 1e-6,
493 "value[{i}]: got {got}, expected {expected}"
494 );
495 }
496 }
497
498 #[test]
499 fn test_store_kv_accumulates_across_tokens() {
500 let kv_dim = 2;
501 let mut cache = KvCache::new(1, 8, kv_dim);
502
503 for t in 0..3u32 {
504 let key = vec![t as f32, t as f32 + 0.5];
505 let val = vec![0.0f32; kv_dim];
506 cache.store_kv(0, &key, &val).expect("store_kv");
507 cache.advance();
508 }
509
510 let keys = cache.get_keys(0).expect("get_keys");
511 assert_eq!(
512 keys.len(),
513 3 * kv_dim,
514 "should have 3 tokens × kv_dim floats"
515 );
516 assert!((keys[0] - 0.0).abs() < 1e-7);
518 assert!((keys[1] - 0.5).abs() < 1e-7);
519 assert!((keys[2] - 1.0).abs() < 1e-7);
521 }
522
523 #[test]
526 fn test_store_kv_out_of_range_layer_returns_error() {
527 let mut cache = KvCache::new(2, 8, 4);
528 let key = vec![0.0f32; 4];
529 let val = vec![0.0f32; 4];
530 let result = cache.store_kv(99, &key, &val);
531 assert!(result.is_err(), "out-of-range layer should return error");
532 }
533
534 #[test]
535 fn test_get_keys_out_of_range_layer_returns_error() {
536 let cache = KvCache::new(2, 8, 4);
537 let result = cache.get_keys(99);
538 assert!(result.is_err(), "out-of-range layer should return error");
539 }
540
541 #[test]
542 fn test_get_values_out_of_range_layer_returns_error() {
543 let cache = KvCache::new(2, 8, 4);
544 let result = cache.get_values(99);
545 assert!(result.is_err(), "out-of-range layer should return error");
546 }
547
548 #[test]
551 fn test_store_kv_different_layers_independent() {
552 let kv_dim = 4;
553 let mut cache = KvCache::new(2, 8, kv_dim);
554
555 let key0 = vec![1.0f32; kv_dim];
556 let key1 = vec![2.0f32; kv_dim];
557 let val0 = vec![3.0f32; kv_dim];
558 let val1 = vec![4.0f32; kv_dim];
559
560 cache.store_kv(0, &key0, &val0).expect("layer 0 store");
561 cache.store_kv(1, &key1, &val1).expect("layer 1 store");
562 cache.advance();
563
564 let stored0 = cache.get_keys(0).expect("layer 0 keys");
565 let stored1 = cache.get_keys(1).expect("layer 1 keys");
566
567 for &v in stored0 {
568 assert!((v - 1.0).abs() < 1e-7, "layer 0 key should be 1.0");
569 }
570 for &v in stored1 {
571 assert!((v - 2.0).abs() < 1e-7, "layer 1 key should be 2.0");
572 }
573 }
574
575 #[test]
578 fn kv_cache_for_each_key_contiguous() {
579 use oxillama_arch::traits::KvCacheAccess;
580
581 let kv_dim = 4usize;
582 let mut cache = KvCache::new(1, 16, kv_dim);
583
584 for t in 0..4u32 {
586 let key: Vec<f32> = (0..kv_dim).map(|d| t as f32 * 10.0 + d as f32).collect();
587 let val: Vec<f32> = (0..kv_dim).map(|d| t as f32 * 100.0 + d as f32).collect();
588 cache.store_kv(0, &key, &val).expect("store_kv");
589 cache.advance();
590 }
591
592 let mut positions_seen: Vec<usize> = Vec::new();
594 let mut keys_seen: Vec<Vec<f32>> = Vec::new();
595 cache
596 .for_each_key(0, &mut |pos, slice| {
597 positions_seen.push(pos);
598 keys_seen.push(slice.to_vec());
599 })
600 .expect("for_each_key must succeed");
601
602 assert_eq!(positions_seen.len(), 4, "must visit all 4 positions");
603 assert_eq!(
604 positions_seen,
605 vec![0, 1, 2, 3],
606 "positions must be in order"
607 );
608
609 for (t, key_row) in keys_seen.iter().enumerate() {
611 assert_eq!(key_row.len(), kv_dim, "key row must have kv_dim elements");
612 for (d, &v) in key_row.iter().enumerate() {
613 let expected = t as f32 * 10.0 + d as f32;
614 assert!(
615 (v - expected).abs() < 1e-6,
616 "token {t} dim {d}: expected {expected}, got {v}"
617 );
618 }
619 }
620 }
621
622 #[test]
623 fn kv_cache_for_each_value_contiguous() {
624 use oxillama_arch::traits::KvCacheAccess;
625
626 let kv_dim = 3usize;
627 let mut cache = KvCache::new(1, 8, kv_dim);
628
629 for t in 0..3u32 {
630 let key = vec![0.0f32; kv_dim];
631 let val: Vec<f32> = (0..kv_dim).map(|d| t as f32 + d as f32 * 0.1).collect();
632 cache.store_kv(0, &key, &val).expect("store_kv");
633 cache.advance();
634 }
635
636 let mut count = 0usize;
637 cache
638 .for_each_value(0, &mut |_pos, slice| {
639 assert_eq!(slice.len(), kv_dim);
640 count += 1;
641 })
642 .expect("for_each_value must succeed");
643 assert_eq!(count, 3, "must visit 3 value rows");
644 }
645}