use anyhow::{Context, Result};
use burn::tensor::{Int, Tensor, TensorData, backend::Backend, module::embedding};
use crate::model::{Whisper, qkv_attention};
pub struct KvCache<B: Backend> {
cross_kv: Vec<(Tensor<B, 3>, Tensor<B, 3>)>,
self_kv: Vec<Option<(Tensor<B, 3>, Tensor<B, 3>)>>,
pub step: usize,
}
impl<B: Backend> KvCache<B> {
pub fn new(model: &Whisper<B>, encoder_output: Tensor<B, 3>) -> Self {
let cross_kv = model
.decoder
.blocks
.iter()
.map(|block| {
let ca = &block.cross_attn;
let k = ca.key.forward(encoder_output.clone());
let v = ca.value.forward(encoder_output.clone());
(k, v)
})
.collect::<Vec<_>>();
let n_layers = cross_kv.len();
Self {
cross_kv,
self_kv: vec![None; n_layers],
step: 0,
}
}
}
pub fn forward_decoder_cached<B: Backend>(
model: &Whisper<B>,
token: u32,
cache: &mut KvCache<B>,
device: &B::Device,
) -> Result<Vec<f32>> {
let decoder = &model.decoder;
let step = cache.step;
let token_tensor: Tensor<B, 2, Int> =
Tensor::from_data(TensorData::new(vec![token as i32], [1, 1]), device);
let mut x = embedding(decoder.token_embedding.val(), token_tensor)
+ decoder
.positional_embedding
.val()
.slice([step..(step + 1)])
.unsqueeze::<3>();
for (layer_idx, block) in decoder.blocks.iter().enumerate() {
let x_norm = block.attn_ln.forward(x.clone());
let sa = &block.attn;
let q = sa.query.forward(x_norm.clone());
let k_new = sa.key.forward(x_norm.clone());
let v_new = sa.value.forward(x_norm);
let (k_full, v_full) = match cache.self_kv[layer_idx].take() {
Some((k_prev, v_prev)) => (
Tensor::cat(vec![k_prev, k_new], 1),
Tensor::cat(vec![v_prev, v_new], 1),
),
None => (k_new, v_new),
};
cache.self_kv[layer_idx] = Some((k_full.clone(), v_full.clone()));
let sa_out = qkv_attention(q, k_full, v_full, None, sa.n_head);
x = x + sa.out.forward(sa_out);
let x_norm = block.cross_attn_ln.forward(x.clone());
let ca = &block.cross_attn;
let q = ca.query.forward(x_norm);
let (k_cross, v_cross) = &cache.cross_kv[layer_idx];
let ca_out = qkv_attention(q, k_cross.clone(), v_cross.clone(), None, ca.n_head);
x = x + ca.out.forward(ca_out);
x = x.clone() + block.mlp.forward(block.mlp_ln.forward(x));
}
cache.step += 1;
let x = decoder.ln.forward(x);
let logits = x.matmul(decoder.token_embedding.val().transpose().unsqueeze::<3>());
let [_, _, vocab_size] = logits.dims();
logits
.squeeze::<1>()
.into_data()
.to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("logit extraction failed: {:?}", e))
.with_context(|| format!("forward_decoder_cached step {step}, vocab_size={vocab_size}"))
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::{Distribution, Int, TensorData};
use burn_flex::Flex;
use burn_flex::FlexDevice;
fn tiny_en_random() -> (crate::model::Whisper<Flex<f32>>, FlexDevice) {
let device = FlexDevice::default();
let config = crate::model::WhisperConfig::tiny_en();
let model = config.init::<Flex<f32>>(&device);
(model, device)
}
#[test]
fn test_kv_cache_step_counter() {
let (model, device) = tiny_en_random();
let encoder_out = Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
let mut cache = KvCache::new(&model, encoder_out);
assert_eq!(cache.step, 0);
forward_decoder_cached(&model, 50258u32, &mut cache, &device).unwrap();
assert_eq!(cache.step, 1);
forward_decoder_cached(&model, 50259u32, &mut cache, &device).unwrap();
assert_eq!(cache.step, 2);
}
#[test]
fn test_kv_cache_logit_shape() {
let (model, device) = tiny_en_random();
let encoder_out = Tensor::<Flex<f32>, 3>::zeros([1, 1500, 384], &device);
let mut cache = KvCache::new(&model, encoder_out);
let logits = forward_decoder_cached(&model, 50258u32, &mut cache, &device).unwrap();
assert_eq!(logits.len(), 51864);
}
#[test]
fn test_kv_cache_matches_forward_decoder() {
let (model, device) = tiny_en_random();
let encoder_out =
Tensor::<Flex<f32>, 3>::random([1, 1500, 384], Distribution::Normal(0.0, 0.1), &device);
let init: [u32; 4] = [50258, 50259, 50359, 50363];
let token_tensor: Tensor<Flex<f32>, 2, Int> = Tensor::from_data(
TensorData::new(init.iter().map(|&t| t as i32).collect::<Vec<_>>(), [1, 4]),
&device,
);
let logits_full = model.forward_decoder(token_tensor, encoder_out.clone());
let [b, seq, vocab] = logits_full.dims();
let orig: Vec<f32> = logits_full
.slice([0..b, (seq - 1)..seq, 0..vocab])
.squeeze::<1>()
.into_data()
.to_vec::<f32>()
.unwrap();
let mut cache = KvCache::new(&model, encoder_out);
let mut cached = Vec::new();
for &tok in &init {
cached = forward_decoder_cached(&model, tok, &mut cache, &device).unwrap();
}
assert_eq!(orig.len(), cached.len());
let max_diff = orig
.iter()
.zip(cached.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_diff < 1e-4,
"KV-cached logits diverge from forward_decoder by {max_diff:.2e} (expected < 1e-4)"
);
}
}