pub struct StreamingKVCache {
num_layers: usize,
max_positions: usize,
num_heads: usize,
head_dim: usize,
keys: Vec<Vec<f32>>,
values: Vec<Vec<f32>>,
position: usize,
valid_positions: usize,
}
impl StreamingKVCache {
#[must_use]
pub fn new(num_layers: usize, max_positions: usize, num_heads: usize, head_dim: usize) -> Self {
let kv_size = max_positions * num_heads * head_dim;
Self {
num_layers,
max_positions,
num_heads,
head_dim,
keys: vec![vec![0.0f32; kv_size]; num_layers],
values: vec![vec![0.0f32; kv_size]; num_layers],
position: 0,
valid_positions: 0,
}
}
pub fn append(&mut self, layer: usize, key: &[f32], value: &[f32]) {
let kv_dim = self.num_heads * self.head_dim;
assert!(layer < self.num_layers, "Layer index out of bounds");
assert_eq!(key.len(), kv_dim, "Key dimension mismatch");
assert_eq!(value.len(), kv_dim, "Value dimension mismatch");
let offset = self.position * kv_dim;
self.keys[layer][offset..offset + kv_dim].copy_from_slice(key);
self.values[layer][offset..offset + kv_dim].copy_from_slice(value);
if layer == self.num_layers - 1 {
self.position = (self.position + 1) % self.max_positions;
self.valid_positions = (self.valid_positions + 1).min(self.max_positions);
}
}
#[must_use]
pub fn get_range(&self, layer: usize, start: usize, end: usize) -> (&[f32], &[f32]) {
let kv_dim = self.num_heads * self.head_dim;
let start_offset = start * kv_dim;
let end_offset = end * kv_dim;
(
&self.keys[layer][start_offset..end_offset],
&self.values[layer][start_offset..end_offset],
)
}
#[must_use]
pub fn get_valid(&self, layer: usize) -> (&[f32], &[f32]) {
self.get_range(layer, 0, self.valid_positions)
}
#[must_use]
pub fn len(&self) -> usize {
self.valid_positions
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.valid_positions == 0
}
#[must_use]
pub fn max_positions(&self) -> usize {
self.max_positions
}
pub fn clear(&mut self) {
self.position = 0;
self.valid_positions = 0;
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
let kv_size = self.max_positions * self.num_heads * self.head_dim;
self.num_layers * kv_size * 2 * 4
}
#[must_use]
pub fn memory_mb(&self) -> f64 {
self.memory_bytes() as f64 / (1024.0 * 1024.0)
}
}
pub struct StreamingKVCacheFp16 {
num_layers: usize,
max_positions: usize,
num_heads: usize,
head_dim: usize,
keys: Vec<Vec<u16>>,
values: Vec<Vec<u16>>,
position: usize,
valid_positions: usize,
}
impl StreamingKVCacheFp16 {
#[must_use]
pub fn new(num_layers: usize, max_positions: usize, num_heads: usize, head_dim: usize) -> Self {
let kv_size = max_positions * num_heads * head_dim;
Self {
num_layers,
max_positions,
num_heads,
head_dim,
keys: vec![vec![0u16; kv_size]; num_layers],
values: vec![vec![0u16; kv_size]; num_layers],
position: 0,
valid_positions: 0,
}
}
#[inline]
pub(crate) fn f32_to_f16(value: f32) -> u16 {
half::f16::from_f32(value).to_bits()
}
#[inline]
pub(crate) fn f16_to_f32(bits: u16) -> f32 {
half::f16::from_bits(bits).to_f32()
}
pub fn append(&mut self, layer: usize, key: &[f32], value: &[f32]) {
let kv_dim = self.num_heads * self.head_dim;
assert!(layer < self.num_layers, "Layer index out of bounds");
assert_eq!(key.len(), kv_dim, "Key dimension mismatch");
assert_eq!(value.len(), kv_dim, "Value dimension mismatch");
let offset = self.position * kv_dim;
for (i, &k) in key.iter().enumerate() {
self.keys[layer][offset + i] = Self::f32_to_f16(k);
}
for (i, &v) in value.iter().enumerate() {
self.values[layer][offset + i] = Self::f32_to_f16(v);
}
if layer == self.num_layers - 1 {
self.position = (self.position + 1) % self.max_positions;
self.valid_positions = (self.valid_positions + 1).min(self.max_positions);
}
}
#[must_use]
pub fn get_range_f32(&self, layer: usize, start: usize, end: usize) -> (Vec<f32>, Vec<f32>) {
let kv_dim = self.num_heads * self.head_dim;
let start_offset = start * kv_dim;
let end_offset = end * kv_dim;
let keys: Vec<f32> = self.keys[layer][start_offset..end_offset]
.iter()
.map(|&bits| Self::f16_to_f32(bits))
.collect();
let values: Vec<f32> = self.values[layer][start_offset..end_offset]
.iter()
.map(|&bits| Self::f16_to_f32(bits))
.collect();
(keys, values)
}
#[must_use]
pub fn get_range_raw(&self, layer: usize, start: usize, end: usize) -> (&[u16], &[u16]) {
let kv_dim = self.num_heads * self.head_dim;
let start_offset = start * kv_dim;
let end_offset = end * kv_dim;
(
&self.keys[layer][start_offset..end_offset],
&self.values[layer][start_offset..end_offset],
)
}
#[must_use]
pub fn get_valid_f32(&self, layer: usize) -> (Vec<f32>, Vec<f32>) {
self.get_range_f32(layer, 0, self.valid_positions)
}
#[must_use]
pub fn len(&self) -> usize {
self.valid_positions
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.valid_positions == 0
}
#[must_use]
pub fn max_positions(&self) -> usize {
self.max_positions
}
pub fn clear(&mut self) {
self.position = 0;
self.valid_positions = 0;
}
#[must_use]
pub fn memory_bytes(&self) -> usize {
let kv_size = self.max_positions * self.num_heads * self.head_dim;
self.num_layers * kv_size * 2 * 2
}
#[must_use]
pub fn memory_mb(&self) -> f64 {
self.memory_bytes() as f64 / (1024.0 * 1024.0)
}
}
include!("streaming_kv_streaming.rs");
#[cfg(test)]
mod kv_contract_tests {
use super::*;
#[test]
fn falsify_kv_001_memory_formula() {
let test_cases = vec![
(1, 1, 1, 1),
(32, 2048, 32, 128), (40, 4096, 40, 128), (1, 512, 8, 64),
];
for (nl, mp, nh, hd) in test_cases {
let cache = StreamingKVCache::new(nl, mp, nh, hd);
let expected = nl * mp * nh * hd * 2 * 4;
assert_eq!(
cache.memory_bytes(),
expected,
"FALSIFIED KV-001: memory_bytes({nl}, {mp}, {nh}, {hd}) = {}, expected {expected}",
cache.memory_bytes()
);
}
}
#[test]
fn falsify_kv_002_monotonic_sequence_length() {
let seq_lengths = [128, 256, 512, 1024, 2048, 4096];
let mut prev_bytes = 0;
for &sl in &seq_lengths {
let cache = StreamingKVCache::new(32, sl, 32, 128);
let bytes = cache.memory_bytes();
assert!(
bytes > prev_bytes,
"FALSIFIED KV-002: memory({sl}) = {bytes} not > memory(prev) = {prev_bytes}"
);
prev_bytes = bytes;
}
}
#[test]
fn falsify_kv_002b_monotonic_layers() {
let layer_counts = [1, 8, 16, 32, 40, 64];
let mut prev_bytes = 0;
for &nl in &layer_counts {
let cache = StreamingKVCache::new(nl, 2048, 32, 128);
let bytes = cache.memory_bytes();
assert!(
bytes > prev_bytes,
"FALSIFIED KV-002b: memory(layers={nl}) = {bytes} not > {prev_bytes}"
);
prev_bytes = bytes;
}
}
#[test]
fn falsify_kv_001b_fp16_half_memory() {
let nl = 32;
let mp = 2048;
let nh = 32;
let hd = 128;
let f32_cache = StreamingKVCache::new(nl, mp, nh, hd);
let f16_cache = StreamingKVCacheFp16::new(nl, mp, nh, hd);
let f32_bytes = f32_cache.memory_bytes();
let f16_bytes = f16_cache.memory_bytes();
assert_eq!(
f16_bytes * 2,
f32_bytes,
"FALSIFIED KV-001b: FP16 ({f16_bytes}) * 2 != FP32 ({f32_bytes})"
);
}
}