use std::fmt;
use crate::error::RuntimeError;
use crate::tensor::Tensor;
const BLOCK_TOKEN_COUNT: usize = 16;
#[derive(Debug, Clone)]
pub struct KvBlock {
data: Vec<f64>,
dim: usize,
used: usize,
}
impl KvBlock {
fn new(dim: usize) -> Self {
KvBlock {
data: vec![0.0; BLOCK_TOKEN_COUNT * dim],
dim,
used: 0,
}
}
fn is_full(&self) -> bool {
self.used >= BLOCK_TOKEN_COUNT
}
#[allow(dead_code)]
fn remaining(&self) -> usize {
BLOCK_TOKEN_COUNT - self.used
}
fn write_token(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
if token.len() != self.dim {
return Err(RuntimeError::ShapeMismatch {
expected: self.dim,
got: token.len(),
});
}
if self.is_full() {
return Err(RuntimeError::InvalidOperation(
"KvBlock is full".to_string(),
));
}
let base = self.used * self.dim;
self.data[base..base + self.dim].copy_from_slice(token);
self.used += 1;
Ok(())
}
fn read_token(&self, idx: usize) -> &[f64] {
let base = idx * self.dim;
&self.data[base..base + self.dim]
}
}
#[derive(Debug, Clone)]
pub struct PagedKvCache {
blocks: Vec<KvBlock>,
block_table: Vec<usize>,
dim: usize,
max_tokens: usize,
current_len: usize,
}
impl PagedKvCache {
pub fn new(max_tokens: usize, dim: usize) -> Self {
let num_blocks = (max_tokens + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT;
let mut blocks = Vec::with_capacity(num_blocks);
let mut block_table = Vec::with_capacity(num_blocks);
for i in 0..num_blocks {
blocks.push(KvBlock::new(dim));
block_table.push(i); }
PagedKvCache {
blocks,
block_table,
dim,
max_tokens,
current_len: 0,
}
}
pub fn len(&self) -> usize {
self.current_len
}
pub fn is_empty(&self) -> bool {
self.current_len == 0
}
pub fn max_tokens(&self) -> usize {
self.max_tokens
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn num_blocks(&self) -> usize {
self.blocks.len()
}
pub fn blocks_in_use(&self) -> usize {
if self.current_len == 0 { return 0; }
(self.current_len + BLOCK_TOKEN_COUNT - 1) / BLOCK_TOKEN_COUNT
}
pub fn append(&mut self, token: &[f64]) -> Result<(), RuntimeError> {
if token.len() != self.dim {
return Err(RuntimeError::ShapeMismatch {
expected: self.dim,
got: token.len(),
});
}
if self.current_len >= self.max_tokens {
return Err(RuntimeError::InvalidOperation(
format!(
"PagedKvCache full: {} / {} tokens",
self.current_len, self.max_tokens
),
));
}
let logical_block = self.current_len / BLOCK_TOKEN_COUNT;
let physical_block = self.block_table[logical_block];
self.blocks[physical_block].write_token(token)?;
self.current_len += 1;
Ok(())
}
pub fn append_tensor(&mut self, t: &Tensor) -> Result<(), RuntimeError> {
if t.ndim() != 2 || t.shape()[1] != self.dim {
return Err(RuntimeError::InvalidOperation(
format!(
"PagedKvCache.append_tensor: expected [n, {}], got {:?}",
self.dim, t.shape()
),
));
}
let n = t.shape()[0];
if self.current_len + n > self.max_tokens {
return Err(RuntimeError::InvalidOperation(
format!(
"PagedKvCache overflow: {} + {} > {}",
self.current_len, n, self.max_tokens
),
));
}
let data = t.to_vec();
for i in 0..n {
let start = i * self.dim;
self.append(&data[start..start + self.dim])?;
}
Ok(())
}
pub fn as_tensor(&self) -> Tensor {
if self.current_len == 0 {
return Tensor::from_vec(vec![], &[0, self.dim])
.unwrap_or_else(|_| Tensor::zeros(&[0]));
}
let mut data = Vec::with_capacity(self.current_len * self.dim);
let mut remaining = self.current_len;
for &phys_idx in &self.block_table {
if remaining == 0 { break; }
let block = &self.blocks[phys_idx];
let tokens_in_block = remaining.min(block.used);
for t in 0..tokens_in_block {
data.extend_from_slice(block.read_token(t));
}
remaining -= tokens_in_block;
}
Tensor::from_vec(data, &[self.current_len, self.dim])
.expect("PagedKvCache::as_tensor shape mismatch")
}
pub fn clear(&mut self) {
for block in &mut self.blocks {
block.used = 0;
}
self.current_len = 0;
}
pub fn get_token(&self, idx: usize) -> Result<Vec<f64>, RuntimeError> {
if idx >= self.current_len {
return Err(RuntimeError::IndexOutOfBounds {
index: idx,
length: self.current_len,
});
}
let logical_block = idx / BLOCK_TOKEN_COUNT;
let offset_in_block = idx % BLOCK_TOKEN_COUNT;
let physical_block = self.block_table[logical_block];
Ok(self.blocks[physical_block].read_token(offset_in_block).to_vec())
}
}
impl fmt::Display for PagedKvCache {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"PagedKvCache(len={}, max={}, dim={}, blocks={}/{})",
self.current_len,
self.max_tokens,
self.dim,
self.blocks_in_use(),
self.blocks.len()
)
}
}
impl fmt::Display for Tensor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Tensor(shape={:?}, data={:?})", self.shape, self.to_vec())
}
}