use super::{simd_dot, simd_softmax};
#[derive(Clone)]
pub struct KVCache {
k_cache: Vec<Vec<f32>>,
v_cache: Vec<Vec<f32>>,
seq_len: usize,
hidden_dim: usize,
}
impl KVCache {
#[must_use]
pub fn new(num_layers: usize, hidden_dim: usize, max_seq_len: usize) -> Self {
let k_cache = vec![vec![0.0; max_seq_len * hidden_dim]; num_layers];
let v_cache = vec![vec![0.0; max_seq_len * hidden_dim]; num_layers];
Self {
k_cache,
v_cache,
seq_len: 0,
hidden_dim,
}
}
pub fn store(&mut self, layer: usize, k: &[f32], v: &[f32]) {
let start = self.seq_len * self.hidden_dim;
let end = start + self.hidden_dim;
if end <= self.k_cache[layer].len() {
self.k_cache[layer][start..end].copy_from_slice(k);
self.v_cache[layer][start..end].copy_from_slice(v);
}
}
pub fn advance(&mut self) {
self.seq_len += 1;
}
#[must_use]
pub fn get_k(&self, layer: usize) -> &[f32] {
&self.k_cache[layer][..self.seq_len * self.hidden_dim]
}
#[must_use]
pub fn get_v(&self, layer: usize) -> &[f32] {
&self.v_cache[layer][..self.seq_len * self.hidden_dim]
}
#[must_use]
pub fn len(&self) -> usize {
self.seq_len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.seq_len == 0
}
pub fn reset(&mut self) {
self.seq_len = 0;
}
}
#[must_use]
pub fn attention_with_cache(
q: &[f32],
k_cache: &[f32],
v_cache: &[f32],
current_k: &[f32],
current_v: &[f32],
num_heads: usize,
) -> Vec<f32> {
let hidden_dim = q.len();
let head_dim = hidden_dim / num_heads;
let cache_len = if hidden_dim > 0 {
k_cache.len() / hidden_dim
} else {
0
};
let total_len = cache_len + 1;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = vec![0.0; hidden_dim];
for h in 0..num_heads {
let head_offset = h * head_dim;
let q_head = &q[head_offset..head_offset + head_dim];
let mut scores = Vec::with_capacity(total_len);
for pos in 0..cache_len {
let k_start = pos * hidden_dim + head_offset;
let k_head = &k_cache[k_start..k_start + head_dim];
let score = simd_dot(q_head, k_head) * scale;
scores.push(score);
}
let current_k_head = ¤t_k[head_offset..head_offset + head_dim];
scores.push(simd_dot(q_head, current_k_head) * scale);
simd_softmax(&mut scores);
let out_head = &mut output[head_offset..head_offset + head_dim];
for (pos, &weight) in scores.iter().enumerate().take(cache_len) {
let v_start = pos * hidden_dim + head_offset;
let v_head = &v_cache[v_start..v_start + head_dim];
for (i, &v) in v_head.iter().enumerate() {
out_head[i] += weight * v;
}
}
let current_v_head = ¤t_v[head_offset..head_offset + head_dim];
let current_weight = scores[cache_len];
for (i, &v) in current_v_head.iter().enumerate() {
out_head[i] += current_weight * v;
}
}
output
}
#[derive(Clone)]
pub struct OptimizedKVCache {
k_cache: Vec<Vec<f32>>,
v_cache: Vec<Vec<f32>>,
seq_len: usize,
hidden_dim: usize,
max_seq_len: usize,
}
impl OptimizedKVCache {
#[must_use]
pub fn new(num_layers: usize, hidden_dim: usize, max_seq_len: usize) -> Self {
let k_cache = vec![vec![0.0; max_seq_len * hidden_dim]; num_layers];
let v_cache = vec![vec![0.0; hidden_dim * max_seq_len]; num_layers];
Self {
k_cache,
v_cache,
seq_len: 0,
hidden_dim,
max_seq_len,
}
}
pub fn store(&mut self, layer: usize, k: &[f32], v: &[f32]) {
if self.seq_len >= self.max_seq_len {
return;
}
let k_start = self.seq_len * self.hidden_dim;
let k_end = k_start + self.hidden_dim;
self.k_cache[layer][k_start..k_end].copy_from_slice(k);
for (i, &val) in v.iter().enumerate() {
self.v_cache[layer][i * self.max_seq_len + self.seq_len] = val;
}
}
pub fn advance(&mut self) {
if self.seq_len < self.max_seq_len {
self.seq_len += 1;
}
}
#[must_use]
pub fn get_k(&self, layer: usize) -> &[f32] {
&self.k_cache[layer][..self.seq_len * self.hidden_dim]
}
#[must_use]
pub fn get_v_transposed(&self, layer: usize) -> &[f32] {
&self.v_cache[layer]
}
#[must_use]
pub fn len(&self) -> usize {
self.seq_len
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.seq_len == 0
}
pub fn reset(&mut self) {
self.seq_len = 0;
}
#[must_use]
pub fn max_len(&self) -> usize {
self.max_seq_len
}
}
#[must_use]
pub fn attention_with_transposed_v(
q: &[f32],
k_cache: &[f32],
v_cache_transposed: &[f32],
current_k: &[f32],
current_v: &[f32],
num_heads: usize,
max_seq_len: usize,
) -> Vec<f32> {
let hidden_dim = q.len();
let head_dim = hidden_dim / num_heads;
let cache_len = if hidden_dim > 0 {
k_cache.len() / hidden_dim
} else {
0
};
let total_len = cache_len + 1;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = vec![0.0; hidden_dim];
for h in 0..num_heads {
let head_offset = h * head_dim;
let q_head = &q[head_offset..head_offset + head_dim];
let mut scores = Vec::with_capacity(total_len);
for pos in 0..cache_len {
let k_start = pos * hidden_dim + head_offset;
let k_head = &k_cache[k_start..k_start + head_dim];
scores.push(simd_dot(q_head, k_head) * scale);
}
let current_k_head = ¤t_k[head_offset..head_offset + head_dim];
scores.push(simd_dot(q_head, current_k_head) * scale);
simd_softmax(&mut scores);
let out_head = &mut output[head_offset..head_offset + head_dim];
for i in 0..head_dim {
let v_idx = (head_offset + i) * max_seq_len;
let mut sum = 0.0;
for (pos, &weight) in scores.iter().enumerate().take(cache_len) {
sum += weight * v_cache_transposed[v_idx + pos];
}
sum += scores[cache_len] * current_v[head_offset + i];
out_head[i] = sum;
}
}
output
}
include!("kv_cache_cache.rs");