use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};
pub type CacheKey = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrefixCacheConfig {
pub max_entries: usize,
pub max_memory_bytes: usize,
pub default_ttl_secs: u64,
pub enable_compression: bool,
}
impl Default for PrefixCacheConfig {
fn default() -> Self {
Self {
max_entries: 1000,
max_memory_bytes: 512 * 1024 * 1024, default_ttl_secs: 3600, enable_compression: false,
}
}
}
impl PrefixCacheConfig {
#[must_use]
pub fn new(max_entries: usize, max_memory_bytes: usize) -> Self {
Self {
max_entries,
max_memory_bytes,
..Default::default()
}
}
#[must_use]
pub fn with_default_ttl(mut self, ttl_secs: u64) -> Self {
self.default_ttl_secs = ttl_secs;
self
}
#[must_use]
pub fn with_compression(mut self, enable: bool) -> Self {
self.enable_compression = enable;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub total_bytes: usize,
pub entry_count: usize,
pub expirations: u64,
}
impl CacheStats {
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn hit_rate(&self) -> f64 {
let total = self.hits + self.misses;
if total == 0 {
0.0
} else {
(self.hits as f64 / total as f64) * 100.0
}
}
pub fn reset(&mut self) {
self.hits = 0;
self.misses = 0;
self.evictions = 0;
self.expirations = 0;
}
pub fn record_hit(&mut self) {
self.hits += 1;
}
pub fn record_miss(&mut self) {
self.misses += 1;
}
pub fn record_eviction(&mut self) {
self.evictions += 1;
}
pub fn record_expiration(&mut self) {
self.expirations += 1;
}
pub fn update_memory(&mut self, bytes: usize, count: usize) {
self.total_bytes = bytes;
self.entry_count = count;
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ContextFingerprint {
pub hash: u64,
pub prefix_length: usize,
pub content_summary: String,
}
impl ContextFingerprint {
#[must_use]
pub fn new(hash: u64, prefix_length: usize, content_summary: impl Into<String>) -> Self {
Self {
hash,
prefix_length,
content_summary: content_summary.into(),
}
}
#[must_use]
pub fn is_prefix_of(&self, other: &Self) -> bool {
self.prefix_length <= other.prefix_length
}
}
#[derive(Debug, Clone)]
pub struct KVCacheEntry {
pub key: CacheKey,
pub fingerprint: ContextFingerprint,
pub kv_data: Vec<f32>,
pub sequence_length: usize,
pub created_at: Instant,
pub last_accessed: Instant,
pub access_count: u64,
pub ttl: Option<Duration>,
}
impl KVCacheEntry {
#[must_use]
pub fn new(
key: impl Into<CacheKey>,
fingerprint: ContextFingerprint,
kv_data: Vec<f32>,
sequence_length: usize,
) -> Self {
let now = Instant::now();
Self {
key: key.into(),
fingerprint,
kv_data,
sequence_length,
created_at: now,
last_accessed: now,
access_count: 0,
ttl: None,
}
}
#[must_use]
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = Some(ttl);
self
}
#[must_use]
pub fn with_ttl_secs(mut self, secs: u64) -> Self {
self.ttl = Some(Duration::from_secs(secs));
self
}
#[must_use]
pub fn is_expired(&self) -> bool {
if let Some(ttl) = self.ttl {
self.created_at.elapsed() >= ttl
} else {
false
}
}
#[must_use]
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
#[must_use]
pub fn time_since_access(&self) -> Duration {
self.last_accessed.elapsed()
}
pub fn record_access(&mut self) {
self.last_accessed = Instant::now();
self.access_count += 1;
}
#[must_use]
pub fn estimated_size(&self) -> usize {
let kv_size = self.kv_data.len() * std::mem::size_of::<f32>();
let key_size = self.key.len();
let summary_size = self.fingerprint.content_summary.len();
let overhead = std::mem::size_of::<Self>();
kv_size + key_size + summary_size + overhead
}
}
#[derive(Debug, Clone)]
pub enum CacheLookupResult {
Hit(KVCacheEntry),
Miss,
PartialHit {
entry: KVCacheEntry,
remaining_length: usize,
},
}
impl CacheLookupResult {
#[must_use]
pub fn is_hit(&self) -> bool {
matches!(self, Self::Hit(_) | Self::PartialHit { .. })
}
#[must_use]
pub fn is_miss(&self) -> bool {
matches!(self, Self::Miss)
}
#[must_use]
pub fn entry(&self) -> Option<&KVCacheEntry> {
match self {
Self::Hit(entry) | Self::PartialHit { entry, .. } => Some(entry),
Self::Miss => None,
}
}
}
#[cfg(test)]
#[allow(clippy::float_cmp)]
mod tests {
use super::*;
#[test]
fn test_prefix_cache_config_default() {
let config = PrefixCacheConfig::default();
assert_eq!(config.max_entries, 1000);
assert_eq!(config.max_memory_bytes, 512 * 1024 * 1024);
assert_eq!(config.default_ttl_secs, 3600);
assert!(!config.enable_compression);
}
#[test]
fn test_prefix_cache_config_builder() {
let config = PrefixCacheConfig::new(500, 256 * 1024 * 1024)
.with_default_ttl(1800)
.with_compression(true);
assert_eq!(config.max_entries, 500);
assert_eq!(config.max_memory_bytes, 256 * 1024 * 1024);
assert_eq!(config.default_ttl_secs, 1800);
assert!(config.enable_compression);
}
#[test]
fn test_cache_stats_hit_rate() {
let mut stats = CacheStats::default();
assert!(stats.hit_rate().abs() < f64::EPSILON);
stats.hits = 75;
stats.misses = 25;
assert!((stats.hit_rate() - 75.0).abs() < 0.001);
}
#[test]
fn test_cache_stats_recording() {
let mut stats = CacheStats::default();
stats.record_hit();
stats.record_hit();
stats.record_miss();
stats.record_eviction();
stats.record_expiration();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert_eq!(stats.evictions, 1);
assert_eq!(stats.expirations, 1);
}
#[test]
fn test_cache_stats_reset() {
let mut stats = CacheStats {
hits: 100,
misses: 50,
evictions: 10,
total_bytes: 1000,
entry_count: 5,
expirations: 3,
};
stats.reset();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
assert_eq!(stats.evictions, 0);
assert_eq!(stats.expirations, 0);
assert_eq!(stats.total_bytes, 1000);
}
#[test]
fn test_context_fingerprint_creation() {
let fp = ContextFingerprint::new(12345, 100, "Test context...");
assert_eq!(fp.hash, 12345);
assert_eq!(fp.prefix_length, 100);
assert_eq!(fp.content_summary, "Test context...");
}
#[test]
fn test_context_fingerprint_is_prefix_of() {
let short_fp = ContextFingerprint::new(100, 50, "Short");
let long_fp = ContextFingerprint::new(200, 100, "Long");
assert!(short_fp.is_prefix_of(&long_fp));
assert!(!long_fp.is_prefix_of(&short_fp));
assert!(short_fp.is_prefix_of(&short_fp)); }
#[test]
fn test_kv_cache_entry_creation() {
let fp = ContextFingerprint::new(1, 10, "test");
let entry = KVCacheEntry::new("key1", fp.clone(), vec![1.0, 2.0, 3.0], 10);
assert_eq!(entry.key, "key1");
assert_eq!(entry.fingerprint, fp);
assert_eq!(entry.kv_data.len(), 3);
assert_eq!(entry.sequence_length, 10);
assert_eq!(entry.access_count, 0);
assert!(entry.ttl.is_none());
}
#[test]
fn test_kv_cache_entry_with_ttl() {
let fp = ContextFingerprint::new(1, 10, "test");
let entry = KVCacheEntry::new("key1", fp, vec![], 10).with_ttl_secs(60);
assert!(entry.ttl.is_some());
assert_eq!(entry.ttl.unwrap(), Duration::from_secs(60));
}
#[test]
fn test_kv_cache_entry_expiration() {
let fp = ContextFingerprint::new(1, 10, "test");
let entry = KVCacheEntry::new("key1", fp.clone(), vec![], 10);
assert!(!entry.is_expired());
let entry_long =
KVCacheEntry::new("key2", fp.clone(), vec![], 10).with_ttl(Duration::from_secs(3600));
assert!(!entry_long.is_expired());
let entry_zero = KVCacheEntry::new("key3", fp, vec![], 10).with_ttl(Duration::from_secs(0));
assert!(entry_zero.is_expired());
}
#[test]
fn test_kv_cache_entry_record_access() {
let fp = ContextFingerprint::new(1, 10, "test");
let mut entry = KVCacheEntry::new("key1", fp, vec![], 10);
assert_eq!(entry.access_count, 0);
let initial_access = entry.last_accessed;
std::thread::sleep(std::time::Duration::from_millis(1));
entry.record_access();
assert_eq!(entry.access_count, 1);
assert!(entry.last_accessed > initial_access);
}
#[test]
fn test_kv_cache_entry_estimated_size() {
let fp = ContextFingerprint::new(1, 10, "summary");
let entry = KVCacheEntry::new("testkey", fp, vec![1.0; 100], 10);
let size = entry.estimated_size();
assert!(size >= 400);
}
#[test]
fn test_cache_lookup_result_hit() {
let fp = ContextFingerprint::new(1, 10, "test");
let entry = KVCacheEntry::new("key1", fp, vec![], 10);
let result = CacheLookupResult::Hit(entry.clone());
assert!(result.is_hit());
assert!(!result.is_miss());
assert!(result.entry().is_some());
assert_eq!(result.entry().unwrap().key, "key1");
}
#[test]
fn test_cache_lookup_result_miss() {
let result = CacheLookupResult::Miss;
assert!(!result.is_hit());
assert!(result.is_miss());
assert!(result.entry().is_none());
}
#[test]
fn test_cache_lookup_result_partial_hit() {
let fp = ContextFingerprint::new(1, 10, "test");
let entry = KVCacheEntry::new("key1", fp, vec![], 10);
let result = CacheLookupResult::PartialHit {
entry: entry.clone(),
remaining_length: 5,
};
assert!(result.is_hit());
assert!(!result.is_miss());
assert!(result.entry().is_some());
}
}