use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct CacheKey {
pub name: String,
pub profile: String,
pub content_hash: Option<u64>,
}
impl CacheKey {
pub fn new(name: impl Into<String>, profile: impl Into<String>) -> Self {
Self {
name: name.into(),
profile: profile.into(),
content_hash: None,
}
}
pub fn with_hash(self, hash: u64) -> Self {
Self { content_hash: Some(hash), ..self }
}
}
#[derive(Debug, Clone)]
pub struct CachedEntry {
pub content: String,
pub cached_at: Instant,
pub token_count: usize,
pub use_count: u64,
}
impl CachedEntry {
pub fn new(content: String) -> Self {
let token_count = estimate_tokens(&content);
Self {
content,
cached_at: Instant::now(),
token_count,
use_count: 0,
}
}
pub fn is_expired(&self, max_age: Duration) -> bool {
self.cached_at.elapsed() > max_age
}
pub fn mark_used(&mut self) {
self.use_count += 1;
}
}
pub struct SectionCache {
entries: RwLock<HashMap<CacheKey, CachedEntry>>,
max_age: Duration,
stats: RwLock<CacheStats>,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub total_entries: usize,
pub total_hits: u64,
pub total_misses: u64,
pub total_evictions: u64,
pub tokens_saved: u64,
}
impl CacheStats {
pub fn hit_rate(&self) -> f64 {
if self.total_hits + self.total_misses == 0 {
0.0
} else {
self.total_hits as f64 / (self.total_hits + self.total_misses) as f64
}
}
}
impl SectionCache {
pub fn new() -> Self {
Self {
entries: RwLock::new(HashMap::new()),
max_age: Duration::from_secs(3600), stats: RwLock::new(CacheStats::default()),
}
}
pub fn with_max_age(max_age: Duration) -> Self {
Self {
entries: RwLock::new(HashMap::new()),
max_age,
stats: RwLock::new(CacheStats::default()),
}
}
pub fn get(&self, key: &CacheKey) -> Option<String> {
let mut entries = self.entries.write().unwrap();
let mut stats = self.stats.write().unwrap();
if let Some(entry) = entries.get_mut(key) {
if entry.is_expired(self.max_age) {
entries.remove(key);
stats.total_misses += 1;
stats.total_evictions += 1;
None
} else {
entry.mark_used();
stats.total_hits += 1;
stats.tokens_saved += entry.token_count as u64;
Some(entry.content.clone())
}
} else {
stats.total_misses += 1;
None
}
}
pub fn set(&self, key: CacheKey, content: String) {
let mut entries = self.entries.write().unwrap();
let mut stats = self.stats.write().unwrap();
let entry = CachedEntry::new(content);
entries.insert(key, entry);
stats.total_entries = entries.len();
}
pub fn get_or_compute<F>(&self, key: &CacheKey, compute: F) -> String
where
F: FnOnce() -> String,
{
if let Some(cached) = self.get(key) {
cached
} else {
let content = compute();
self.set(key.clone(), content.clone());
content
}
}
pub fn clear(&self) {
let mut entries = self.entries.write().unwrap();
let mut stats = self.stats.write().unwrap();
let evicted = entries.len();
entries.clear();
stats.total_entries = 0;
stats.total_evictions += evicted as u64;
}
pub fn clear_profile(&self, profile: &str) {
let mut entries = self.entries.write().unwrap();
let mut stats = self.stats.write().unwrap();
entries.retain(|k, _| k.profile != profile);
stats.total_entries = entries.len();
}
pub fn stats(&self) -> CacheStats {
self.stats.read().unwrap().clone()
}
pub fn cached_tokens(&self) -> usize {
let entries = self.entries.read().unwrap();
entries.values().map(|e| e.token_count).sum()
}
pub fn is_empty(&self) -> bool {
self.entries.read().unwrap().is_empty()
}
pub fn size(&self) -> usize {
self.entries.read().unwrap().len()
}
}
impl Default for SectionCache {
fn default() -> Self {
Self::new()
}
}
pub fn estimate_tokens(content: &str) -> usize {
let chinese_chars = content.chars().filter(|c| c.is_alphabetic() && c.len_utf8() > 1).count();
let english_words = content.split_whitespace().count();
let non_whitespace: usize = content.chars().filter(|c| !c.is_whitespace()).count();
let fallback_estimate = if english_words == 0 && non_whitespace > 0 {
non_whitespace / 4
} else {
0
};
chinese_chars / 3 + english_words + fallback_estimate
}
static GLOBAL_CACHE: std::sync::OnceLock<Arc<SectionCache>> = std::sync::OnceLock::new();
pub fn global_cache() -> Arc<SectionCache> {
GLOBAL_CACHE.get_or_init(|| Arc::new(SectionCache::new())).clone()
}
pub fn clear_global_cache() {
global_cache().clear();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_basic() {
let cache = SectionCache::new();
let key = CacheKey::new("test", "default");
assert!(cache.get(&key).is_none());
cache.set(key.clone(), "test content".to_string());
assert_eq!(cache.get(&key), Some("test content".to_string()));
let stats = cache.stats();
assert_eq!(stats.total_hits, 1);
assert_eq!(stats.total_misses, 1);
}
#[test]
fn test_cache_expiry() {
let cache = SectionCache::with_max_age(Duration::from_millis(10));
let key = CacheKey::new("test", "default");
cache.set(key.clone(), "test".to_string());
std::thread::sleep(Duration::from_millis(20));
assert!(cache.get(&key).is_none());
let stats = cache.stats();
assert_eq!(stats.total_evictions, 1);
}
#[test]
fn test_get_or_compute() {
let cache = SectionCache::new();
let key = CacheKey::new("compute", "default");
let result = cache.get_or_compute(&key, || "computed".to_string());
assert_eq!(result, "computed");
let result2 = cache.get_or_compute(&key, || "different".to_string());
assert_eq!(result2, "computed"); }
#[test]
fn test_clear_profile() {
let cache = SectionCache::new();
cache.set(CacheKey::new("a", "default"), "a".to_string());
cache.set(CacheKey::new("b", "safe"), "b".to_string());
cache.clear_profile("default");
assert!(cache.get(&CacheKey::new("a", "default")).is_none());
assert_eq!(cache.get(&CacheKey::new("b", "safe")), Some("b".to_string()));
}
#[test]
fn test_estimate_tokens() {
let english = "Hello world this is a test";
let chinese = "你好世界这是一个测试";
let eng_tokens = estimate_tokens(english);
assert!(eng_tokens >= 5 && eng_tokens <= 10, "English tokens: {}", eng_tokens);
let ch_tokens = estimate_tokens(chinese);
assert!(ch_tokens >= 2 && ch_tokens <= 10, "Chinese tokens: {}", ch_tokens);
}
#[test]
fn test_global_cache() {
clear_global_cache();
let cache = global_cache();
let key = CacheKey::new("global_test", "default");
cache.set(key.clone(), "global content".to_string());
let cache2 = global_cache();
assert_eq!(cache2.get(&key), Some("global content".to_string()));
clear_global_cache();
assert!(cache2.get(&key).is_none());
}
}