use candle_core::{Result, Tensor};
use super::NormalCache;
#[derive(Debug, Clone)]
pub struct SingleCache {
pub all_data: Option<Tensor>,
pub dim: usize,
pub current_seq_len: usize,
pub capacity_seq_len: usize,
pub max_seq_len: usize,
}
impl SingleCache {
pub fn new(dim: usize, max_seq_len: usize, capacity_seq_len: usize) -> Self {
Self {
all_data: None,
dim,
current_seq_len: 0,
max_seq_len,
capacity_seq_len,
}
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn current_seq_len(&self) -> usize {
self.current_seq_len
}
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
pub fn all_data(&self) -> Option<&Tensor> {
self.all_data.as_ref()
}
pub fn current_data(&self) -> Result<Option<Tensor>> {
let data = match self.all_data.as_ref() {
None => None,
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
};
Ok(data)
}
pub fn reset(&mut self) {
self.current_seq_len = 0;
self.all_data = None;
}
pub fn try_set_len(&self, len: usize) -> candle_core::Result<()> {
if len > self.capacity_seq_len {
candle_core::bail!(
"kv-cache: requested length ({}) exceeds current capacity ({})",
len,
self.capacity_seq_len
);
}
Ok(())
}
pub fn set_len(&mut self, len: usize) -> candle_core::Result<()> {
self.try_set_len(len)?;
self.current_seq_len = len;
Ok(())
}
pub fn append(&mut self, src: &Tensor) -> Result<()> {
let seq_len = src.dim(self.dim)?;
if self.all_data.is_none() {
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
self.all_data = Some(ad);
};
if self.current_seq_len + seq_len > self.capacity_seq_len {
let diff = self.current_seq_len + seq_len - self.capacity_seq_len;
let n_blocks_needed = diff.div_ceil(NormalCache::CACHE_GROW_SIZE);
self.capacity_seq_len += n_blocks_needed * NormalCache::CACHE_GROW_SIZE;
if self.capacity_seq_len > self.max_seq_len {
candle_core::bail!(
"kv-cache: requested capacity ({}) above max seq len ({})",
self.capacity_seq_len,
self.max_seq_len
)
}
let mut shape = src.dims().to_vec();
shape[self.dim] = self.capacity_seq_len;
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
ad.slice_set(self.all_data.as_ref().unwrap(), self.dim, 0)?;
self.all_data = Some(ad);
}
let ad = self.all_data.as_mut().unwrap();
ad.slice_set(src, self.dim, self.current_seq_len)?;
self.current_seq_len += seq_len;
Ok(())
}
}