use oxillama_arch::traits::KvCacheAccess;
use oxillama_arch::ArchResult;
const PAGE_SIZE: usize = 16;
struct Page {
data: Vec<f32>,
}
impl Page {
fn new(kv_dim: usize) -> Self {
Self {
data: vec![0.0f32; PAGE_SIZE * kv_dim],
}
}
fn write_token(&mut self, slot: usize, kv_dim: usize, src: &[f32]) {
let offset = slot * kv_dim;
self.data[offset..offset + kv_dim].copy_from_slice(&src[..kv_dim]);
}
fn read_token(&self, slot: usize, kv_dim: usize) -> &[f32] {
let offset = slot * kv_dim;
&self.data[offset..offset + kv_dim]
}
}
struct LayerCache {
key_pages: Vec<Page>,
value_pages: Vec<Page>,
}
impl LayerCache {
fn new() -> Self {
Self {
key_pages: Vec::new(),
value_pages: Vec::new(),
}
}
fn ensure_capacity(&mut self, token_pos: usize, kv_dim: usize) {
let needed_pages = token_pos / PAGE_SIZE + 1;
while self.key_pages.len() < needed_pages {
self.key_pages.push(Page::new(kv_dim));
self.value_pages.push(Page::new(kv_dim));
}
}
fn store(&mut self, token_pos: usize, kv_dim: usize, key: &[f32], value: &[f32]) {
self.ensure_capacity(token_pos, kv_dim);
let page_idx = token_pos / PAGE_SIZE;
let slot = token_pos % PAGE_SIZE;
self.key_pages[page_idx].write_token(slot, kv_dim, key);
self.value_pages[page_idx].write_token(slot, kv_dim, value);
}
fn num_pages(&self) -> usize {
self.key_pages.len()
}
fn shrink_to(&mut self, seq_len: usize) {
let needed = if seq_len == 0 {
0
} else {
seq_len / PAGE_SIZE + 1
};
self.key_pages.truncate(needed);
self.value_pages.truncate(needed);
}
}
pub struct PagedKvCache {
layers: Vec<LayerCache>,
seq_len: usize,
max_seq_len: usize,
kv_dim: usize,
num_layers: usize,
}
impl PagedKvCache {
pub fn new(num_layers: usize, max_seq_len: usize, kv_dim: usize) -> Self {
let layers = (0..num_layers).map(|_| LayerCache::new()).collect();
Self {
layers,
seq_len: 0,
max_seq_len,
kv_dim,
num_layers,
}
}
pub fn page_size(&self) -> usize {
PAGE_SIZE
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn kv_dim(&self) -> usize {
self.kv_dim
}
pub fn num_layers(&self) -> usize {
self.num_layers
}
pub fn total_pages(&self) -> usize {
self.layers.iter().map(|l| l.num_pages()).sum()
}
pub fn memory_bytes(&self) -> usize {
self.total_pages() * PAGE_SIZE * self.kv_dim * 4 * 2 }
pub fn clear(&mut self) {
self.seq_len = 0;
for layer in &mut self.layers {
layer.key_pages.clear();
layer.value_pages.clear();
}
}
pub fn shrink_to_fit(&mut self) {
for layer in &mut self.layers {
layer.shrink_to(self.seq_len);
}
}
fn assemble_keys(&self, layer: usize, buf: &mut Vec<f32>) {
let total = self.seq_len * self.kv_dim;
buf.clear();
buf.reserve(total);
let layer_cache = &self.layers[layer];
for pos in 0..self.seq_len {
let page_idx = pos / PAGE_SIZE;
let slot = pos % PAGE_SIZE;
let token_data = layer_cache.key_pages[page_idx].read_token(slot, self.kv_dim);
buf.extend_from_slice(token_data);
}
}
fn assemble_values(&self, layer: usize, buf: &mut Vec<f32>) {
let total = self.seq_len * self.kv_dim;
buf.clear();
buf.reserve(total);
let layer_cache = &self.layers[layer];
for pos in 0..self.seq_len {
let page_idx = pos / PAGE_SIZE;
let slot = pos % PAGE_SIZE;
let token_data = layer_cache.value_pages[page_idx].read_token(slot, self.kv_dim);
buf.extend_from_slice(token_data);
}
}
}
impl KvCacheAccess for PagedKvCache {
fn seq_len(&self) -> usize {
self.seq_len
}
fn store_kv(&mut self, layer: usize, key: &[f32], value: &[f32]) -> ArchResult<()> {
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
if self.seq_len >= self.max_seq_len {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!(
"sequence length {} exceeds max {}",
self.seq_len, self.max_seq_len
),
});
}
self.layers[layer].store(self.seq_len, self.kv_dim, key, value);
Ok(())
}
fn get_keys(&self, layer: usize) -> ArchResult<&[f32]> {
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
if self.seq_len == 0 {
return Ok(&[]);
}
let pages_used = (self.seq_len - 1) / PAGE_SIZE + 1;
if pages_used == 1 {
let end = self.seq_len * self.kv_dim;
return Ok(&self.layers[layer].key_pages[0].data[..end]);
}
Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!(
"paged KV cache: sequence length {} spans {} pages; \
use get_keys_into() for multi-page access",
self.seq_len, pages_used
),
})
}
fn get_values(&self, layer: usize) -> ArchResult<&[f32]> {
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
if self.seq_len == 0 {
return Ok(&[]);
}
let pages_used = (self.seq_len - 1) / PAGE_SIZE + 1;
if pages_used == 1 {
let end = self.seq_len * self.kv_dim;
return Ok(&self.layers[layer].value_pages[0].data[..end]);
}
Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!(
"paged KV cache: sequence length {} spans {} pages; \
use get_values_into() for multi-page access",
self.seq_len, pages_used
),
})
}
fn advance(&mut self) {
if self.seq_len < self.max_seq_len {
self.seq_len += 1;
}
}
}
impl PagedKvCache {
pub fn get_keys_into(&self, layer: usize, buf: &mut Vec<f32>) -> ArchResult<()> {
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
self.assemble_keys(layer, buf);
Ok(())
}
pub fn get_values_into(&self, layer: usize, buf: &mut Vec<f32>) -> ArchResult<()> {
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
self.assemble_values(layer, buf);
Ok(())
}
pub fn get_key_token(&self, layer: usize, pos: usize) -> ArchResult<&[f32]> {
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
if pos >= self.seq_len {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("position {pos} out of range (seq_len {})", self.seq_len),
});
}
let page_idx = pos / PAGE_SIZE;
let slot = pos % PAGE_SIZE;
Ok(self.layers[layer].key_pages[page_idx].read_token(slot, self.kv_dim))
}
pub fn get_value_token(&self, layer: usize, pos: usize) -> ArchResult<&[f32]> {
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
if pos >= self.seq_len {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("position {pos} out of range (seq_len {})", self.seq_len),
});
}
let page_idx = pos / PAGE_SIZE;
let slot = pos % PAGE_SIZE;
Ok(self.layers[layer].value_pages[page_idx].read_token(slot, self.kv_dim))
}
pub fn iter_keys<F>(&self, layer: usize, mut f: F) -> ArchResult<()>
where
F: FnMut(usize, &[f32]),
{
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
let layer_cache = &self.layers[layer];
for pos in 0..self.seq_len {
let page_idx = pos / PAGE_SIZE;
let slot = pos % PAGE_SIZE;
let data = layer_cache.key_pages[page_idx].read_token(slot, self.kv_dim);
f(pos, data);
}
Ok(())
}
pub fn iter_values<F>(&self, layer: usize, mut f: F) -> ArchResult<()>
where
F: FnMut(usize, &[f32]),
{
if layer >= self.num_layers {
return Err(oxillama_arch::ArchError::ForwardPassError {
layer,
message: format!("layer index {layer} out of range (max {})", self.num_layers),
});
}
let layer_cache = &self.layers[layer];
for pos in 0..self.seq_len {
let page_idx = pos / PAGE_SIZE;
let slot = pos % PAGE_SIZE;
let data = layer_cache.value_pages[page_idx].read_token(slot, self.kv_dim);
f(pos, data);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_paged_basic_store_retrieve() {
let mut cache = PagedKvCache::new(2, 64, 4);
assert_eq!(cache.seq_len(), 0);
assert_eq!(cache.total_pages(), 0);
let key = [1.0, 2.0, 3.0, 4.0];
let val = [5.0, 6.0, 7.0, 8.0];
cache.store_kv(0, &key, &val).unwrap();
cache.advance();
assert_eq!(cache.seq_len(), 1);
assert_eq!(cache.layers[0].num_pages(), 1);
assert_eq!(cache.layers[1].num_pages(), 0);
let keys = cache.get_keys(0).unwrap();
assert_eq!(keys, &[1.0, 2.0, 3.0, 4.0]);
let vals = cache.get_values(0).unwrap();
assert_eq!(vals, &[5.0, 6.0, 7.0, 8.0]);
}
#[test]
fn test_paged_multi_token_single_page() {
let mut cache = PagedKvCache::new(1, 64, 2);
for i in 0..PAGE_SIZE {
let key = [i as f32, (i * 10) as f32];
let val = [(i + 100) as f32, (i + 200) as f32];
cache.store_kv(0, &key, &val).unwrap();
cache.advance();
}
assert_eq!(cache.seq_len(), PAGE_SIZE);
assert_eq!(cache.layers[0].num_pages(), 1);
let keys = cache.get_keys(0).unwrap();
assert_eq!(keys.len(), PAGE_SIZE * 2);
assert_eq!(keys[0], 0.0);
assert_eq!(keys[1], 0.0);
assert_eq!(keys[2], 1.0);
assert_eq!(keys[3], 10.0);
}
#[test]
fn test_paged_multi_page_assembly() {
let mut cache = PagedKvCache::new(1, 64, 2);
for i in 0..=PAGE_SIZE {
let key = [i as f32, (i * 10) as f32];
let val = [(i + 100) as f32, (i + 200) as f32];
cache.store_kv(0, &key, &val).unwrap();
cache.advance();
}
assert_eq!(cache.seq_len(), PAGE_SIZE + 1);
assert_eq!(cache.layers[0].num_pages(), 2);
assert!(cache.get_keys(0).is_err());
let mut buf = Vec::new();
cache.get_keys_into(0, &mut buf).unwrap();
assert_eq!(buf.len(), (PAGE_SIZE + 1) * 2);
assert_eq!(buf[0], 0.0);
assert_eq!(buf[1], 0.0);
let last_off = PAGE_SIZE * 2;
assert_eq!(buf[last_off], PAGE_SIZE as f32);
assert_eq!(buf[last_off + 1], (PAGE_SIZE * 10) as f32);
}
#[test]
fn test_paged_per_token_access() {
let mut cache = PagedKvCache::new(1, 64, 3);
for i in 0..20 {
let key = [i as f32, (i * 2) as f32, (i * 3) as f32];
let val = [(i + 50) as f32, (i + 60) as f32, (i + 70) as f32];
cache.store_kv(0, &key, &val).unwrap();
cache.advance();
}
let k5 = cache.get_key_token(0, 5).unwrap();
assert_eq!(k5, &[5.0, 10.0, 15.0]);
let v17 = cache.get_value_token(0, 17).unwrap();
assert_eq!(v17, &[67.0, 77.0, 87.0]);
assert!(cache.get_key_token(0, 20).is_err());
}
#[test]
fn test_paged_iteration() {
let mut cache = PagedKvCache::new(1, 64, 2);
for i in 0..20 {
let key = [i as f32, (i + 1) as f32];
let val = [(i + 100) as f32, (i + 101) as f32];
cache.store_kv(0, &key, &val).unwrap();
cache.advance();
}
let mut count = 0;
cache
.iter_keys(0, |pos, data| {
assert_eq!(data[0], pos as f32);
assert_eq!(data[1], (pos + 1) as f32);
count += 1;
})
.unwrap();
assert_eq!(count, 20);
}
#[test]
fn test_paged_clear() {
let mut cache = PagedKvCache::new(2, 64, 4);
for i in 0..20 {
let key = [i as f32; 4];
let val = [i as f32; 4];
cache.store_kv(0, &key, &val).unwrap();
cache.store_kv(1, &key, &val).unwrap();
cache.advance();
}
assert!(cache.total_pages() > 0);
cache.clear();
assert_eq!(cache.seq_len(), 0);
assert_eq!(cache.total_pages(), 0);
}
#[test]
fn test_paged_shrink_to_fit() {
let mut cache = PagedKvCache::new(1, 128, 4);
for i in 0..40 {
cache.store_kv(0, &[i as f32; 4], &[i as f32; 4]).unwrap();
cache.advance();
}
assert_eq!(cache.layers[0].num_pages(), 3);
cache.seq_len = 10;
cache.shrink_to_fit();
assert_eq!(cache.layers[0].num_pages(), 1);
}
#[test]
fn test_paged_memory_efficiency() {
let num_layers = 32;
let max_seq = 4096;
let kv_dim = 128;
let contiguous_bytes = num_layers * max_seq * kv_dim * 4 * 2;
let mut cache = PagedKvCache::new(num_layers, max_seq, kv_dim);
for i in 0..10 {
for layer in 0..num_layers {
cache
.store_kv(layer, &vec![i as f32; kv_dim], &vec![i as f32; kv_dim])
.unwrap();
}
cache.advance();
}
let paged_bytes = cache.memory_bytes();
assert!(
paged_bytes < contiguous_bytes / 10,
"paged={paged_bytes} should be << contiguous={contiguous_bytes}"
);
}
#[test]
fn test_paged_max_seq_len_error() {
let mut cache = PagedKvCache::new(1, 2, 2);
cache.store_kv(0, &[1.0, 2.0], &[3.0, 4.0]).unwrap();
cache.advance();
cache.store_kv(0, &[5.0, 6.0], &[7.0, 8.0]).unwrap();
cache.advance();
let result = cache.store_kv(0, &[9.0, 10.0], &[11.0, 12.0]);
assert!(result.is_err());
}
}