use crate::{
error::{RealizarError, Result},
generate::{sample_token, GenerationConfig},
tensor::Tensor,
};
use super::{FeedForward, LayerNorm, Linear, MultiHeadAttention};
#[derive(Debug, Clone)]
pub struct KVCache {
num_layers: usize,
max_seq_len: usize,
head_dim: usize,
current_pos: usize,
keys: Vec<Vec<f32>>,
values: Vec<Vec<f32>>,
}
impl KVCache {
pub fn new(num_layers: usize, max_seq_len: usize, head_dim: usize) -> Result<Self> {
if num_layers == 0 {
return Err(RealizarError::InvalidShape {
reason: "num_layers must be > 0".to_string(),
});
}
if max_seq_len == 0 {
return Err(RealizarError::InvalidShape {
reason: "max_seq_len must be > 0".to_string(),
});
}
if head_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "head_dim must be > 0".to_string(),
});
}
let cache_size = max_seq_len * head_dim;
let keys = vec![vec![0.0; cache_size]; num_layers];
let values = vec![vec![0.0; cache_size]; num_layers];
Ok(Self {
num_layers,
max_seq_len,
head_dim,
current_pos: 0,
keys,
values,
})
}
pub fn update(&mut self, layer: usize, key: &Tensor<f32>, value: &Tensor<f32>) -> Result<()> {
if layer >= self.num_layers {
return Err(RealizarError::InvalidShape {
reason: format!(
"Layer {} out of bounds (max {})",
layer,
self.num_layers - 1
),
});
}
if self.current_pos >= self.max_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!(
"Cache full at position {} (max {})",
self.current_pos, self.max_seq_len
),
});
}
let k_data = key.data();
let v_data = value.data();
if k_data.len() != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!("Key size {} != head_dim {}", k_data.len(), self.head_dim),
});
}
if v_data.len() != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!("Value size {} != head_dim {}", v_data.len(), self.head_dim),
});
}
let offset = self.current_pos * self.head_dim;
self.keys[layer][offset..offset + self.head_dim].copy_from_slice(k_data);
self.values[layer][offset..offset + self.head_dim].copy_from_slice(v_data);
Ok(())
}
pub fn advance(&mut self) {
if self.current_pos < self.max_seq_len {
self.current_pos += 1;
}
}
pub fn get_key(&self, layer: usize) -> Result<Tensor<f32>> {
if layer >= self.num_layers {
return Err(RealizarError::InvalidShape {
reason: format!(
"Layer {} out of bounds (max {})",
layer,
self.num_layers - 1
),
});
}
if self.current_pos == 0 {
return Tensor::from_vec(vec![1, self.head_dim], vec![0.0; self.head_dim]);
}
let size = self.current_pos * self.head_dim;
let data = self.keys[layer][..size].to_vec();
Tensor::from_vec(vec![self.current_pos, self.head_dim], data)
}
pub fn get_value(&self, layer: usize) -> Result<Tensor<f32>> {
if layer >= self.num_layers {
return Err(RealizarError::InvalidShape {
reason: format!(
"Layer {} out of bounds (max {})",
layer,
self.num_layers - 1
),
});
}
if self.current_pos == 0 {
return Tensor::from_vec(vec![1, self.head_dim], vec![0.0; self.head_dim]);
}
let size = self.current_pos * self.head_dim;
let data = self.values[layer][..size].to_vec();
Tensor::from_vec(vec![self.current_pos, self.head_dim], data)
}
pub fn clear(&mut self) {
self.current_pos = 0;
for layer in 0..self.num_layers {
self.keys[layer].fill(0.0);
self.values[layer].fill(0.0);
}
}
#[must_use]
pub fn current_pos(&self) -> usize {
self.current_pos
}
#[must_use]
pub fn num_layers(&self) -> usize {
self.num_layers
}
#[must_use]
pub fn max_seq_len(&self) -> usize {
self.max_seq_len
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn is_full(&self) -> bool {
self.current_pos >= self.max_seq_len
}
}
#[derive(Debug, Clone)]
pub struct TransformerBlock {
attn_norm: LayerNorm,
attention: MultiHeadAttention,
ffn_norm: LayerNorm,
ffn: FeedForward,
hidden_dim: usize,
num_heads: usize,
}
include!("model_transformer_block.rs");
include!("model_model.rs");
include!("model_cache.rs");