use candle_core::Tensor;
#[derive(Debug, Clone)]
pub struct KvCache {
cache: candle_nn::kv_cache::KvCache,
concat_dim: usize,
max_seq_len: usize,
}
impl KvCache {
pub fn new(concat_dim: usize, max_seq_len: usize) -> Self {
Self {
cache: candle_nn::kv_cache::KvCache::new(concat_dim, 8),
concat_dim,
max_seq_len,
}
}
pub fn cache(&self) -> &candle_nn::kv_cache::KvCache {
&self.cache
}
pub fn cache_mut(&mut self) -> &mut candle_nn::kv_cache::KvCache {
&mut self.cache
}
pub fn reset(&mut self) {
self.cache.reset()
}
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> candle_core::Result<(Tensor, Tensor)> {
let k = k.contiguous()?;
let v = v.contiguous()?;
let seq_len = k.dim(self.concat_dim)?;
debug_assert_eq!(seq_len, v.dim(self.concat_dim)?);
let current_allocated_size = self.cache.k_cache().max_seq_len();
let size_required_for_append = self.cache.current_seq_len() + seq_len;
if size_required_for_append > current_allocated_size {
let next_power_of_two = size_required_for_append.next_power_of_two();
let new_cache_max_seq_len = next_power_of_two.min(self.max_seq_len);
let mut new_cache =
candle_nn::kv_cache::KvCache::new(self.concat_dim, new_cache_max_seq_len);
if let (Ok(Some(k)), Ok(Some(v))) = (self.cache.k(), self.cache.v()) {
new_cache.k_cache_mut().append(&k.contiguous()?)?;
new_cache.v_cache_mut().append(&v.contiguous()?)?;
}
self.cache = new_cache;
}
self.cache.append(&k, &v)
}
}
#[derive(Debug, Clone)]
pub struct TensorCache {
cache: candle_nn::kv_cache::Cache,
concat_dim: usize,
max_seq_len: usize,
}
impl TensorCache {
pub fn new(concat_dim: usize, max_seq_len: usize) -> Self {
Self {
cache: candle_nn::kv_cache::Cache::new(concat_dim, 8),
concat_dim,
max_seq_len,
}
}
pub fn cache(&self) -> &candle_nn::kv_cache::Cache {
&self.cache
}
pub fn cache_mut(&mut self) -> &mut candle_nn::kv_cache::Cache {
&mut self.cache
}
pub fn all_data(&self) -> &Option<Tensor> {
self.cache.all_data()
}
pub fn reset(&mut self) {
self.cache.reset()
}
pub fn append(&mut self, v: &Tensor) -> candle_core::Result<()> {
let v = v.contiguous()?;
let seq_len = v.dim(self.concat_dim)?;
debug_assert_eq!(seq_len, v.dim(self.concat_dim)?);
let current_allocated_size = self.cache.max_seq_len();
let size_required_for_append = self.cache.current_seq_len() + seq_len;
if size_required_for_append > current_allocated_size {
let next_power_of_two = size_required_for_append.next_power_of_two();
let new_cache_max_seq_len = next_power_of_two.min(self.max_seq_len);
let mut new_cache =
candle_nn::kv_cache::Cache::new(self.concat_dim, new_cache_max_seq_len);
if let Some(v) = self.cache.all_data() {
new_cache.append(&v.contiguous()?)?;
}
self.cache = new_cache;
}
self.cache.append(&v)
}
}