zeta_kv_cache/
lib.rs

1// Copyright 2025 ZETA RETICULA INC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Unified KV Cache Implementation for Zeta Reticula
16//! 
17//! This module consolidates all KV cache functionality from:
18//! - kvquant_rs/src/block.rs (LogStructuredKVCache)
19//! - llm-rs/src/kv_cache.rs
20//! - llm-rs/src/kv_cache_manager.rs
21//! - zeta-vault-synergy implementations
22
23use std::sync::{Arc, Mutex};
24use std::collections::HashMap;
25use dashmap::DashMap;
26use serde::{Serialize, Deserialize};
27use tokio::sync::RwLock;
28use anyhow::Result;
29use thiserror::Error;
30
31#[derive(Error, Debug)]
32pub enum KVCacheError {
33    #[error("Cache capacity exceeded")]
34    CapacityExceeded,
35    #[error("Invalid cache key: {0}")]
36    InvalidKey(String),
37    #[error("Cache miss for key: {0}")]
38    CacheMiss(String),
39    #[error("Serialization error: {0}")]
40    Serialization(#[from] serde_json::Error),
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct KVCacheConfig {
45    pub precision: PrecisionLevel,
46    pub block_size: usize,
47    pub spot_capacity: usize,
48    pub max_cache_items: usize,
49    pub salience_threshold: f32,
50    pub enable_debug_logging: bool,
51    pub eviction_policy: EvictionPolicy,
52    pub compression_enabled: bool,
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub enum PrecisionLevel {
57    Int1,
58    Int2,
59    Int4,
60    Int8,
61    FP16,
62    FP32,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub enum EvictionPolicy {
67    LRU,
68    LFU,
69    SalienceBased,
70    Adaptive,
71}
72
73impl Default for KVCacheConfig {
74    fn default() -> Self {
75        Self {
76            precision: PrecisionLevel::Int4,
77            block_size: 1024,
78            spot_capacity: 10000,
79            max_cache_items: 50000,
80            salience_threshold: 0.7,
81            enable_debug_logging: false,
82            eviction_policy: EvictionPolicy::SalienceBased,
83            compression_enabled: true,
84        }
85    }
86}
87
88#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
89pub enum BlockState {
90    Free,
91    Valid,
92    Obsolete,
93    Invalid,
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize)]
97pub struct DataBlock {
98    pub id: usize,
99    pub state: BlockState,
100    pub data: HashMap<u32, f32>,
101    pub pointers: Vec<usize>,
102    pub biases: Vec<f32>,
103    pub vector_ids: Vec<u32>,
104    pub navigation_graph: HashMap<usize, Vec<usize>>,
105    pub size: usize,
106    pub capacity: usize,
107    pub salience_scores: HashMap<u32, f32>,
108    pub access_count: u64,
109    pub last_accessed: u64,
110}
111
112impl DataBlock {
113    pub fn new(id: usize, capacity: usize) -> Self {
114        Self {
115            id,
116            state: BlockState::Free,
117            data: HashMap::new(),
118            pointers: Vec::with_capacity(capacity),
119            biases: Vec::with_capacity(capacity),
120            vector_ids: Vec::with_capacity(capacity),
121            navigation_graph: HashMap::new(),
122            salience_scores: HashMap::new(),
123            size: 0,
124            capacity,
125            access_count: 0,
126            last_accessed: 0,
127        }
128    }
129
130    pub fn write(&mut self, token_id: u32, value: f32, pointer: usize, bias: f32, vector_id: u32, graph_entry: (usize, Vec<usize>)) {
131        if self.state == BlockState::Free || self.state == BlockState::Valid {
132            self.data.insert(token_id, value);
133            self.pointers.push(pointer);
134            self.biases.push(bias);
135            self.vector_ids.push(vector_id);
136            self.navigation_graph.insert(graph_entry.0, graph_entry.1);
137            self.size += 1;
138            self.state = BlockState::Valid;
139            self.access_count += 1;
140            self.last_accessed = std::time::SystemTime::now()
141                .duration_since(std::time::UNIX_EPOCH)
142                .unwrap_or_default()
143                .as_secs();
144        }
145    }
146
147    pub fn update_salience(&mut self, token_id: u32, salience_score: f32) {
148        self.salience_scores.insert(token_id, salience_score);
149    }
150
151    pub fn get_salience(&self, token_id: u32) -> Option<f32> {
152        self.salience_scores.get(&token_id).copied()
153    }
154
155    pub fn invalidate(&mut self) {
156        self.state = BlockState::Invalid;
157    }
158
159    pub fn erase(&mut self) {
160        self.data.clear();
161        self.pointers.clear();
162        self.biases.clear();
163        self.vector_ids.clear();
164        self.navigation_graph.clear();
165        self.salience_scores.clear();
166        self.size = 0;
167        self.state = BlockState::Free;
168        self.access_count = 0;
169    }
170}
171
172/// Unified KV Cache that consolidates all previous implementations
173pub struct UnifiedKVCache {
174    config: KVCacheConfig,
175    blocks: DashMap<usize, DataBlock>,
176    valid_bitmap: DashMap<(usize, usize), bool>,
177    lock: Arc<Mutex<()>>,
178    access_order: Arc<RwLock<Vec<usize>>>, // For LRU
179    access_frequency: Arc<RwLock<HashMap<usize, u64>>>, // For LFU
180}
181
182impl UnifiedKVCache {
183    pub fn new(config: KVCacheConfig) -> Self {
184        Self {
185            config,
186            blocks: DashMap::new(),
187            valid_bitmap: DashMap::new(),
188            lock: Arc::new(Mutex::new(())),
189            access_order: Arc::new(RwLock::new(Vec::new())),
190            access_frequency: Arc::new(RwLock::new(HashMap::new())),
191        }
192    }
193
194    pub async fn store(&self, key: u32, value: f32, salience_score: f32) -> Result<(), KVCacheError> {
195        if salience_score < self.config.salience_threshold {
196            return Ok(()); // Skip low salience items
197        }
198
199        let block_id = (key as usize) % self.config.block_size;
200        
201        {
202            let _guard = self.lock.lock().unwrap();
203            let mut block = self.blocks.entry(block_id).or_insert_with(|| {
204                DataBlock::new(block_id, self.config.block_size)
205            });
206
207            block.data.insert(key, value);
208            block.update_salience(key, salience_score);
209            block.access_count += 1;
210            block.last_accessed = std::time::SystemTime::now()
211                .duration_since(std::time::UNIX_EPOCH)
212                .unwrap_or_default()
213                .as_secs();
214        } // Guard is dropped here
215
216        // Update access tracking for eviction policies
217        self.update_access_tracking(block_id).await;
218
219        // Check if eviction is needed
220        if self.blocks.len() > self.config.max_cache_items {
221            self.evict_blocks().await?;
222        }
223
224        Ok(())
225    }
226
227    pub async fn retrieve(&self, key: u32) -> Result<Option<f32>, KVCacheError> {
228        let block_id = (key as usize) % self.config.block_size;
229        
230        if let Some(mut block) = self.blocks.get_mut(&block_id) {
231            block.access_count += 1;
232            block.last_accessed = std::time::SystemTime::now()
233                .duration_since(std::time::UNIX_EPOCH)
234                .unwrap_or_default()
235                .as_secs();
236
237            self.update_access_tracking(block_id).await;
238            Ok(block.data.get(&key).copied())
239        } else {
240            Ok(None)
241        }
242    }
243
244    pub async fn get_salience(&self, key: u32) -> Option<f32> {
245        let block_id = (key as usize) % self.config.block_size;
246        self.blocks.get(&block_id)?.get_salience(key)
247    }
248
249    async fn update_access_tracking(&self, block_id: usize) {
250        match self.config.eviction_policy {
251            EvictionPolicy::LRU => {
252                let mut access_order = self.access_order.write().await;
253                access_order.retain(|&id| id != block_id);
254                access_order.push(block_id);
255            }
256            EvictionPolicy::LFU => {
257                let mut access_frequency = self.access_frequency.write().await;
258                *access_frequency.entry(block_id).or_insert(0) += 1;
259            }
260            _ => {} // Other policies handled elsewhere
261        }
262    }
263
264    async fn evict_blocks(&self) -> Result<(), KVCacheError> {
265        let blocks_to_evict = match self.config.eviction_policy {
266            EvictionPolicy::LRU => self.select_lru_blocks().await,
267            EvictionPolicy::LFU => self.select_lfu_blocks().await,
268            EvictionPolicy::SalienceBased => self.select_low_salience_blocks().await,
269            EvictionPolicy::Adaptive => self.select_adaptive_blocks().await,
270        };
271
272        for block_id in blocks_to_evict {
273            if let Some(mut block) = self.blocks.get_mut(&block_id) {
274                block.erase();
275            }
276            self.blocks.remove(&block_id);
277        }
278
279        Ok(())
280    }
281
282    async fn select_lru_blocks(&self) -> Vec<usize> {
283        let access_order = self.access_order.read().await;
284        let evict_count = (self.blocks.len() / 4).max(1); // Evict 25%
285        access_order.iter().take(evict_count).copied().collect()
286    }
287
288    async fn select_lfu_blocks(&self) -> Vec<usize> {
289        let access_frequency = self.access_frequency.read().await;
290        let mut freq_blocks: Vec<_> = access_frequency.iter().collect();
291        freq_blocks.sort_by_key(|(_, &freq)| freq);
292        
293        let evict_count = (self.blocks.len() / 4).max(1);
294        freq_blocks.iter().take(evict_count).map(|(&id, _)| id).collect()
295    }
296
297    async fn select_low_salience_blocks(&self) -> Vec<usize> {
298        let mut salience_blocks = Vec::new();
299        
300        for entry in self.blocks.iter() {
301            let (block_id, block) = entry.pair();
302            let avg_salience: f32 = block.salience_scores.values().sum::<f32>() / block.salience_scores.len().max(1) as f32;
303            salience_blocks.push((*block_id, avg_salience));
304        }
305
306        salience_blocks.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
307        let evict_count = (self.blocks.len() / 4).max(1);
308        salience_blocks.iter().take(evict_count).map(|(id, _)| *id).collect()
309    }
310
311    async fn select_adaptive_blocks(&self) -> Vec<usize> {
312        // Adaptive policy combines salience and access patterns
313        let mut adaptive_scores = Vec::new();
314        
315        for entry in self.blocks.iter() {
316            let (block_id, block) = entry.pair();
317            let avg_salience: f32 = block.salience_scores.values().sum::<f32>() / block.salience_scores.len().max(1) as f32;
318            let recency_score = 1.0 / (block.access_count as f32 + 1.0);
319            let adaptive_score = avg_salience * 0.7 + recency_score * 0.3;
320            adaptive_scores.push((*block_id, adaptive_score));
321        }
322
323        adaptive_scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
324        let evict_count = (self.blocks.len() / 4).max(1);
325        adaptive_scores.iter().take(evict_count).map(|(id, _)| *id).collect()
326    }
327
328    pub fn get_stats(&self) -> KVCacheStats {
329        let total_blocks = self.blocks.len();
330        let valid_blocks = self.blocks.iter().filter(|entry| entry.value().state == BlockState::Valid).count();
331        let total_items: usize = self.blocks.iter().map(|entry| entry.value().size).sum();
332        let memory_usage = total_blocks * self.config.block_size * std::mem::size_of::<f32>();
333
334        KVCacheStats {
335            total_blocks,
336            valid_blocks,
337            total_items,
338            memory_usage_bytes: memory_usage,
339            hit_rate: 0.0, // Would need to track hits/misses
340            eviction_count: 0, // Would need to track evictions
341        }
342    }
343}
344
345#[derive(Debug, Serialize, Deserialize)]
346pub struct KVCacheStats {
347    pub total_blocks: usize,
348    pub valid_blocks: usize,
349    pub total_items: usize,
350    pub memory_usage_bytes: usize,
351    pub hit_rate: f32,
352    pub eviction_count: u64,
353}
354
355/// Factory function to create KV cache instances
356pub fn create_kv_cache(config: KVCacheConfig) -> UnifiedKVCache {
357    UnifiedKVCache::new(config)
358}
359
360/// Async trait for KV cache operations (for compatibility with existing code)
361#[async_trait::async_trait]
362pub trait KVCacheManager: Send + Sync {
363    async fn store(&self, key: String, value: Vec<u8>) -> Result<()>;
364    async fn retrieve(&self, key: &str) -> Result<Option<Vec<u8>>>;
365    async fn delete(&self, key: &str) -> Result<bool>;
366    async fn clear(&self) -> Result<()>;
367}
368
369/// Compatibility wrapper for existing KVCacheManager implementations
370pub struct KVCacheManagerAdapter {
371    cache: Arc<UnifiedKVCache>,
372}
373
374impl KVCacheManagerAdapter {
375    pub fn new(cache: UnifiedKVCache) -> Self {
376        Self {
377            cache: Arc::new(cache),
378        }
379    }
380}
381
382#[async_trait::async_trait]
383impl KVCacheManager for KVCacheManagerAdapter {
384    async fn store(&self, key: String, value: Vec<u8>) -> Result<()> {
385        // Convert string key to u32 hash
386        let key_hash = key.chars().map(|c| c as u32).sum::<u32>();
387        
388        // Store as f32 (simplified for this trait implementation)
389        let value_f32 = value.len() as f32;
390        self.cache.store(key_hash, value_f32, 1.0).await
391            .map_err(|e| anyhow::anyhow!("Store failed: {}", e))
392    }
393
394    async fn retrieve(&self, key: &str) -> Result<Option<Vec<u8>>> {
395        let key_hash = key.chars().map(|c| c as u32).sum::<u32>();
396        match self.cache.retrieve(key_hash).await? {
397            Some(value) => {
398                // Convert f32 back to bytes (simplified)
399                let bytes = (value as u32).to_le_bytes().to_vec();
400                Ok(Some(bytes))
401            }
402            None => Ok(None),
403        }
404    }
405
406    async fn delete(&self, _key: &str) -> Result<bool> {
407        // Implementation would require adding delete method to UnifiedKVCache
408        Ok(true) // Placeholder
409    }
410
411    async fn clear(&self) -> Result<()> {
412        // Implementation would require adding clear method to UnifiedKVCache
413        Ok(()) // Placeholder
414    }
415}