use ndarray::{ArrayD, IxDyn};
use std::collections::HashMap;
use super::engine::ModelVariant;
use super::model::MoonshineError;
pub struct KVCache {
cache: HashMap<String, ArrayD<f32>>,
num_layers: usize,
}
impl KVCache {
pub fn new(variant: &ModelVariant) -> Self {
let num_layers = variant.num_layers();
let num_heads = variant.num_key_value_heads();
let head_dim = variant.head_dim();
let mut cache = HashMap::new();
for i in 0..num_layers {
for attention_type in &["decoder", "encoder"] {
for kv_type in &["key", "value"] {
let key = format!("past_key_values.{}.{}.{}", i, attention_type, kv_type);
let empty_tensor = ArrayD::<f32>::zeros(IxDyn(&[0, num_heads, 1, head_dim]));
cache.insert(key, empty_tensor);
}
}
}
Self { cache, num_layers }
}
pub fn get_inputs(&self) -> Vec<(String, ArrayD<f32>)> {
let mut inputs = Vec::new();
for i in 0..self.num_layers {
for attention_type in &["decoder", "encoder"] {
for kv_type in &["key", "value"] {
let key = format!("past_key_values.{}.{}.{}", i, attention_type, kv_type);
if let Some(tensor) = self.cache.get(&key) {
inputs.push((key, tensor.clone()));
}
}
}
}
inputs
}
pub fn update_from_outputs(
&mut self,
outputs: &ort::session::SessionOutputs,
use_cache_branch: bool,
) -> Result<(), MoonshineError> {
for i in 0..self.num_layers {
for attention_type in &["decoder", "encoder"] {
if use_cache_branch && *attention_type == "encoder" {
continue;
}
for kv_type in &["key", "value"] {
let output_key = format!("present.{}.{}.{}", i, attention_type, kv_type);
let cache_key = format!("past_key_values.{}.{}.{}", i, attention_type, kv_type);
if let Some(output) = outputs.get(&output_key) {
let tensor = output
.try_extract_array::<f32>()
.map_err(|e| MoonshineError::Ort(e))?;
self.cache.insert(cache_key, tensor.to_owned());
}
}
}
}
Ok(())
}
}