pub type PrefixHash = u64;
pub fn compute_prefix_hash(tokens: &[u32]) -> PrefixHash {
let mut hash: u64 = 0xcbf2_9ce4_8422_2325; for &token in tokens {
hash ^= token as u64;
hash = hash.wrapping_mul(0x0100_0000_01b3); }
hash
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedPrefix {
pub hash: PrefixHash,
pub num_tokens: usize,
pub page_ids: Vec<PageId>,
pub ref_count: usize,
pub last_access: u64,
}
impl CachedPrefix {
pub fn new(hash: PrefixHash, num_tokens: usize, page_ids: Vec<PageId>) -> Self {
Self {
hash,
num_tokens,
page_ids,
ref_count: 1,
last_access: 0,
}
}
pub fn add_ref(&mut self) {
self.ref_count += 1;
}
pub fn remove_ref(&mut self) -> bool {
self.ref_count = self.ref_count.saturating_sub(1);
self.ref_count == 0
}
}
pub struct PrefixCache {
cache: HashMap<PrefixHash, CachedPrefix>,
max_entries: usize,
access_counter: u64,
stats: PrefixCacheStats,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct PrefixCacheStats {
pub hits: u64,
pub misses: u64,
pub prefixes_cached: u64,
pub prefixes_evicted: u64,
pub tokens_saved: u64,
}
impl PrefixCacheStats {
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
self.hits as f64 / total as f64
}
}
}
impl PrefixCache {
pub fn new(max_entries: usize) -> Self {
Self {
cache: HashMap::with_capacity(max_entries),
max_entries,
access_counter: 0,
stats: PrefixCacheStats::default(),
}
}
pub fn lookup(&mut self, hash: PrefixHash) -> Option<&CachedPrefix> {
if let Some(entry) = self.cache.get_mut(&hash) {
self.access_counter += 1;
entry.last_access = self.access_counter;
self.stats.hits += 1;
self.cache.get(&hash)
} else {
self.stats.misses += 1;
None
}
}
pub fn lookup_tokens(&mut self, tokens: &[u32]) -> Option<&CachedPrefix> {
let hash = compute_prefix_hash(tokens);
self.lookup(hash)
}
pub fn contains(&self, hash: PrefixHash) -> bool {
self.cache.contains_key(&hash)
}
pub fn insert(&mut self, prefix: CachedPrefix) -> bool {
let hash = prefix.hash;
if self.cache.len() >= self.max_entries && !self.cache.contains_key(&hash) {
self.evict_lru();
}
if self.cache.len() < self.max_entries {
self.stats.prefixes_cached += 1;
self.stats.tokens_saved += prefix.num_tokens as u64;
self.cache.insert(hash, prefix);
true
} else {
false
}
}
pub fn add_ref(&mut self, hash: PrefixHash) -> bool {
if let Some(entry) = self.cache.get_mut(&hash) {
entry.add_ref();
self.access_counter += 1;
entry.last_access = self.access_counter;
true
} else {
false
}
}
pub fn remove_ref(&mut self, hash: PrefixHash) -> bool {
if let Some(entry) = self.cache.get_mut(&hash) {
if entry.remove_ref() {
self.cache.remove(&hash);
return true;
}
}
false
}
fn evict_lru(&mut self) {
if let Some((&hash, _)) = self
.cache
.iter()
.filter(|(_, v)| v.ref_count == 0)
.min_by_key(|(_, v)| v.last_access)
{
self.cache.remove(&hash);
self.stats.prefixes_evicted += 1;
}
}
pub fn len(&self) -> usize {
self.cache.len()
}
pub fn is_empty(&self) -> bool {
self.cache.is_empty()
}
pub fn stats(&self) -> &PrefixCacheStats {
&self.stats
}
pub fn clear(&mut self) {
self.cache.clear();
self.access_counter = 0;
}
pub fn utilization(&self) -> f64 {
if self.max_entries == 0 {
0.0
} else {
self.cache.len() as f64 / self.max_entries as f64
}
}
}
impl Default for PrefixCache {
fn default() -> Self {
Self::new(100)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
pub enum KvQuantType {
#[default]
FP32,
Q8,
Q4,
}
impl KvQuantType {
pub fn bytes_per_value(&self) -> f32 {
match self {
Self::FP32 => 4.0,
Self::Q8 => 1.0, Self::Q4 => 0.5, }
}
pub fn memory_reduction(&self) -> f32 {
4.0 / self.bytes_per_value()
}
}
pub const KV_QUANT_BLOCK_SIZE: usize = 32;
#[derive(Debug, Clone)]
pub struct Q8KvBlock {
pub scale: f32,
pub quants: [i8; KV_QUANT_BLOCK_SIZE],
}
impl Q8KvBlock {
pub fn new() -> Self {
Self {
scale: 0.0,
quants: [0; KV_QUANT_BLOCK_SIZE],
}
}
pub fn quantize(values: &[f32; KV_QUANT_BLOCK_SIZE]) -> Self {
let amax = values.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
if amax < 1e-10 {
return Self::new();
}
let scale = amax / 127.0;
let inv_scale = 1.0 / scale;
let mut quants = [0i8; KV_QUANT_BLOCK_SIZE];
for (i, &v) in values.iter().enumerate() {
let q = (v * inv_scale).round() as i32;
quants[i] = q.clamp(-127, 127) as i8;
}
Self { scale, quants }
}
pub fn dequantize(&self) -> [f32; KV_QUANT_BLOCK_SIZE] {
let mut result = [0.0f32; KV_QUANT_BLOCK_SIZE];
for (i, &q) in self.quants.iter().enumerate() {
result[i] = q as f32 * self.scale;
}
result
}
}
impl Default for Q8KvBlock {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct Q4KvBlock {
pub scale: f32,
pub quants: [u8; KV_QUANT_BLOCK_SIZE / 2],
}
impl Q4KvBlock {
pub fn new() -> Self {
Self {
scale: 0.0,
quants: [0; KV_QUANT_BLOCK_SIZE / 2],
}
}
pub fn quantize(values: &[f32; KV_QUANT_BLOCK_SIZE]) -> Self {
let amax = values.iter().map(|v| v.abs()).fold(0.0f32, f32::max);
if amax < 1e-10 {
return Self::new();
}
let scale = amax / 7.0;
let inv_scale = 1.0 / scale;
let mut quants = [0u8; KV_QUANT_BLOCK_SIZE / 2];
for i in 0..(KV_QUANT_BLOCK_SIZE / 2) {
let v0 = values[i * 2];
let v1 = values[i * 2 + 1];
let q0 = ((v0 * inv_scale).round() as i32).clamp(-8, 7) + 8;
let q1 = ((v1 * inv_scale).round() as i32).clamp(-8, 7) + 8;
quants[i] = ((q1 as u8) << 4) | (q0 as u8);
}
Self { scale, quants }
}
pub fn dequantize(&self) -> [f32; KV_QUANT_BLOCK_SIZE] {
let mut result = [0.0f32; KV_QUANT_BLOCK_SIZE];
for (i, &packed) in self.quants.iter().enumerate() {
let q0 = (packed & 0x0F) as i32 - 8;
let q1 = ((packed >> 4) & 0x0F) as i32 - 8;
result[i * 2] = q0 as f32 * self.scale;
result[i * 2 + 1] = q1 as f32 * self.scale;
}
result
}
}
impl Default for Q4KvBlock {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub enum QuantizedKvData {
FP32 {
keys: Vec<f32>,
values: Vec<f32>,
},
Q8 {
key_blocks: Vec<Q8KvBlock>,
value_blocks: Vec<Q8KvBlock>,
},
Q4 {
key_blocks: Vec<Q4KvBlock>,
value_blocks: Vec<Q4KvBlock>,
},
}