oxidized_transformers/
kv_cache.rsuse std::collections::HashMap;
use std::ops::{Index, IndexMut};
use candle_core::Tensor;
use snafu::{ResultExt, Snafu};
#[derive(Debug, Snafu)]
pub enum LayerKeyValueCacheError {
#[snafu(display("Failed to extend key"))]
ExtendKey { source: candle_core::Error },
#[snafu(display("Failed to extend value"))]
ExtendValue { source: candle_core::Error },
}
enum LayerKeyValueCacheEnum {
Empty,
Cache { key: Tensor, value: Tensor },
NoCache,
}
pub struct LayerKeyValueCache(LayerKeyValueCacheEnum);
impl LayerKeyValueCache {
pub fn empty() -> LayerKeyValueCache {
LayerKeyValueCache(LayerKeyValueCacheEnum::Empty)
}
pub fn no_cache() -> Self {
Self(LayerKeyValueCacheEnum::NoCache)
}
pub fn key(&self) -> Option<&Tensor> {
use LayerKeyValueCacheEnum::*;
match &self.0 {
Cache { key, .. } => Some(key),
Empty => None,
NoCache => None,
}
}
pub fn value(&self) -> Option<&Tensor> {
use LayerKeyValueCacheEnum::*;
match &self.0 {
Cache { value, .. } => Some(value),
Empty => None,
NoCache => None,
}
}
pub fn update(
&mut self,
new_key: &Tensor,
new_value: &Tensor,
) -> Result<(), LayerKeyValueCacheError> {
use LayerKeyValueCacheEnum::*;
match &mut self.0 {
Cache { key, value } => {
*key = Tensor::cat(&[&*key, new_key], 2).context(ExtendKeySnafu)?;
*value = Tensor::cat(&[&*value, new_value], 2).context(ExtendKeySnafu)?;
}
Empty => {
self.0 = Cache {
key: new_key.clone(),
value: new_value.clone(),
}
}
NoCache => (),
}
Ok(())
}
}
enum KeyValueCacheEnum {
#[allow(private_interfaces)]
Cache {
layer_caches: HashMap<usize, LayerKeyValueCache>,
},
#[allow(private_interfaces)]
NoCache { stub: LayerKeyValueCache },
}
pub struct KeyValueCache(KeyValueCacheEnum);
impl KeyValueCache {
pub fn cache() -> KeyValueCache {
Self(KeyValueCacheEnum::Cache {
layer_caches: HashMap::new(),
})
}
pub fn no_cache() -> Self {
Self(KeyValueCacheEnum::NoCache {
stub: LayerKeyValueCache::no_cache(),
})
}
}
impl Index<usize> for KeyValueCache {
type Output = LayerKeyValueCache;
fn index(&self, index: usize) -> &Self::Output {
match &self.0 {
KeyValueCacheEnum::Cache { layer_caches } => &layer_caches[&index],
KeyValueCacheEnum::NoCache { stub, .. } => stub,
}
}
}
impl IndexMut<usize> for KeyValueCache {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
match &mut self.0 {
KeyValueCacheEnum::Cache { layer_caches } => layer_caches
.entry(index)
.or_insert(LayerKeyValueCache::empty()),
KeyValueCacheEnum::NoCache { stub, .. } => stub,
}
}
}