use crate::providers::Message;
use crate::compress::complexity::ComplexityLevel;
use crate::compress::focus_point::FocusPoint;
use std::collections::HashMap;
use std::time::{Duration, Instant};
use chrono::{DateTime, Utc};
#[derive(Debug, Clone)]
pub struct CacheEntry {
pub compressed: Message,
pub hash: u64,
pub created_at: Instant,
pub hit_count: usize,
}
#[derive(Debug, Clone, Default)]
pub struct CacheStats {
pub hits: usize,
pub misses: usize,
pub entries: usize,
pub total_saved_tokens: u32,
}
impl CacheStats {
pub fn hit_rate(&self) -> f32 {
if self.hits + self.misses == 0 {
0.0
} else {
self.hits as f32 / (self.hits + self.misses) as f32
}
}
}
#[derive(Debug, Clone)]
pub struct CachedPriorityScore {
pub score: f32,
pub calculated_at: DateTime<Utc>,
pub valid_for: Duration,
pub keywords: Vec<String>,
}
impl CachedPriorityScore {
pub fn new(score: f32, keywords: Vec<String>, valid_for: Duration) -> Self {
Self {
score,
calculated_at: Utc::now(),
valid_for,
keywords,
}
}
pub fn is_valid(&self) -> bool {
let now = Utc::now();
now - self.calculated_at < chrono::Duration::from_std(self.valid_for).unwrap()
}
}
#[derive(Debug, Clone)]
pub struct CachedFocusPrediction {
pub focus: FocusPoint,
pub confidence: f32,
pub predicted_at: DateTime<Utc>,
}
impl CachedFocusPrediction {
pub fn new(focus: FocusPoint, confidence: f32) -> Self {
Self {
focus,
confidence,
predicted_at: Utc::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct CachedComplexity {
pub level: ComplexityLevel,
pub analyzed_at: DateTime<Utc>,
}
impl CachedComplexity {
pub fn new(level: ComplexityLevel) -> Self {
Self {
level,
analyzed_at: Utc::now(),
}
}
}
#[derive(Debug, Clone)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl: Duration,
pub min_size_to_cache: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 100,
ttl: Duration::from_secs(300), min_size_to_cache: 100, }
}
}
#[derive(Debug)]
pub struct CompressionCache {
entries: HashMap<u64, CacheEntry>,
config: CacheConfig,
stats: CacheStats,
}
impl Default for CompressionCache {
fn default() -> Self {
Self::new(CacheConfig::default())
}
}
impl CompressionCache {
pub fn new(config: CacheConfig) -> Self {
Self {
entries: HashMap::new(),
config,
stats: CacheStats::default(),
}
}
fn hash_message(message: &Message) -> u64 {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
let role_str = match message.role {
crate::providers::Role::User => "user",
crate::providers::Role::Assistant => "assistant",
crate::providers::Role::System => "system",
crate::providers::Role::Tool => "tool",
};
role_str.hash(&mut hasher);
match &message.content {
crate::providers::MessageContent::Text(text) => {
text.hash(&mut hasher);
}
crate::providers::MessageContent::Blocks(blocks) => {
for block in blocks {
let block_str = format!("{:?}", block);
block_str.hash(&mut hasher);
}
}
}
hasher.finish()
}
pub fn get(&mut self, message: &Message) -> Option<&CacheEntry> {
let hash = Self::hash_message(message);
if let Some(entry) = self.entries.get(&hash) {
if entry.created_at.elapsed() < self.config.ttl {
self.stats.hits += 1;
let entry = self.entries.get_mut(&hash).unwrap();
entry.hit_count += 1;
return Some(entry);
} else {
self.entries.remove(&hash);
}
}
self.stats.misses += 1;
None
}
pub fn put(&mut self, original: &Message, compressed: Message) {
let hash = Self::hash_message(original);
let size = match &original.content {
crate::providers::MessageContent::Text(text) => text.len(),
crate::providers::MessageContent::Blocks(blocks) => {
blocks.iter().map(|b| format!("{:?}", b).len()).sum()
}
};
if size < self.config.min_size_to_cache {
return;
}
if self.entries.len() >= self.config.max_entries {
self.evict_oldest();
}
self.entries.insert(
hash,
CacheEntry {
compressed,
hash,
created_at: Instant::now(),
hit_count: 0,
},
);
self.stats.entries = self.entries.len();
}
fn evict_oldest(&mut self) {
if let Some((&oldest_hash, _)) = self
.entries
.iter()
.min_by_key(|(_, entry)| entry.created_at)
{
self.entries.remove(&oldest_hash);
}
}
pub fn evict_expired(&mut self) {
let now = Instant::now();
self.entries.retain(|_, entry| {
now.duration_since(entry.created_at) < self.config.ttl
});
self.stats.entries = self.entries.len();
}
pub fn clear(&mut self) {
self.entries.clear();
self.stats.entries = 0;
}
pub fn stats(&self) -> &CacheStats {
&self.stats
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn record_token_savings(&mut self, tokens: u32) {
self.stats.total_saved_tokens += tokens;
}
}
#[derive(Debug)]
pub struct ExtendedCompressionCache {
base_cache: CompressionCache,
focus_predictions: HashMap<String, CachedFocusPrediction>,
priority_scores: HashMap<String, CachedPriorityScore>,
complexity_cache: HashMap<String, CachedComplexity>,
config: ExtendedCacheConfig,
}
#[derive(Debug, Clone)]
pub struct ExtendedCacheConfig {
priority_validity: Duration,
max_focus_predictions: usize,
max_complexity_entries: usize,
}
impl Default for ExtendedCacheConfig {
fn default() -> Self {
Self {
priority_validity: Duration::from_secs(600), max_focus_predictions: 50,
max_complexity_entries: 20,
}
}
}
impl Default for ExtendedCompressionCache {
fn default() -> Self {
Self::new(ExtendedCacheConfig::default())
}
}
impl ExtendedCompressionCache {
pub fn new(config: ExtendedCacheConfig) -> Self {
Self {
base_cache: CompressionCache::default(),
focus_predictions: HashMap::new(),
priority_scores: HashMap::new(),
complexity_cache: HashMap::new(),
config,
}
}
pub fn get_priority_score(&self, message_id: &str) -> Option<&CachedPriorityScore> {
self.priority_scores.get(message_id)
.filter(|cached| cached.is_valid())
}
pub fn put_priority_score(&mut self, message_id: String, score: CachedPriorityScore) {
if self.priority_scores.len() >= self.config.max_focus_predictions {
self.evict_oldest_priority();
}
self.priority_scores.insert(message_id, score);
}
pub fn get_focus_prediction(&self, message_id: &str) -> Option<&CachedFocusPrediction> {
self.focus_predictions.get(message_id)
}
pub fn put_focus_prediction(&mut self, message_id: String, prediction: CachedFocusPrediction) {
if self.focus_predictions.len() >= self.config.max_focus_predictions {
self.evict_oldest_focus();
}
self.focus_predictions.insert(message_id, prediction);
}
pub fn get_complexity(&self, conversation_id: &str) -> Option<&CachedComplexity> {
self.complexity_cache.get(conversation_id)
}
pub fn put_complexity(&mut self, conversation_id: String, complexity: CachedComplexity) {
if self.complexity_cache.len() >= self.config.max_complexity_entries {
self.evict_oldest_complexity();
}
self.complexity_cache.insert(conversation_id, complexity);
}
pub fn update_priority_incremental(&mut self, new_keywords: &[String], existing_messages: &[Message]) {
let now = Utc::now();
for (id, cached) in &mut self.priority_scores {
let overlap_count = cached.keywords.iter()
.filter(|kw| new_keywords.contains(kw))
.count();
if overlap_count > 0 {
cached.score += overlap_count as f32 * 0.1;
cached.calculated_at = now;
}
}
}
pub fn cleanup_expired(&mut self) {
self.priority_scores.retain(|_, cached| cached.is_valid());
self.base_cache.evict_expired();
}
fn evict_oldest_priority(&mut self) {
if let Some((oldest_id, _)) = self.priority_scores.iter()
.min_by_key(|(_, cached)| cached.calculated_at)
{
let id = oldest_id.clone();
self.priority_scores.remove(&id);
}
}
fn evict_oldest_focus(&mut self) {
if let Some((oldest_id, _)) = self.focus_predictions.iter()
.min_by_key(|(_, cached)| cached.predicted_at)
{
let id = oldest_id.clone();
self.focus_predictions.remove(&id);
}
}
fn evict_oldest_complexity(&mut self) {
if let Some((oldest_id, _)) = self.complexity_cache.iter()
.min_by_key(|(_, cached)| cached.analyzed_at)
{
let id = oldest_id.clone();
self.complexity_cache.remove(&id);
}
}
pub fn base_cache(&self) -> &CompressionCache {
&self.base_cache
}
pub fn base_cache_mut(&mut self) -> &mut CompressionCache {
&mut self.base_cache
}
pub fn clear_all(&mut self) {
self.base_cache.clear();
self.focus_predictions.clear();
self.priority_scores.clear();
self.complexity_cache.clear();
}
pub fn extended_stats(&self) -> ExtendedCacheStats {
ExtendedCacheStats {
base_stats: self.base_cache.stats().clone(),
focus_prediction_count: self.focus_predictions.len(),
priority_score_count: self.priority_scores.len(),
complexity_cache_count: self.complexity_cache.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct ExtendedCacheStats {
pub base_stats: CacheStats,
pub focus_prediction_count: usize,
pub priority_score_count: usize,
pub complexity_cache_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::{MessageContent, Role};
fn create_test_message(content: &str) -> Message {
Message {
role: Role::User,
content: MessageContent::Text(content.to_string()),
}
}
#[test]
fn test_cache_put_and_get() {
let mut cache = CompressionCache::default();
let original = create_test_message("This is a test message that is long enough to be cached");
let compressed = create_test_message("This is a test message...");
cache.put(&original, compressed.clone());
let entry = cache.get(&original);
assert!(entry.is_some());
assert_eq!(entry.unwrap().hit_count, 1);
}
#[test]
fn test_cache_miss() {
let mut cache = CompressionCache::default();
let msg = create_test_message("Test message");
let entry = cache.get(&msg);
assert!(entry.is_none());
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_cache_hit_increments_counter() {
let mut cache = CompressionCache::default();
let original = create_test_message("This is a longer test message for caching purposes");
let compressed = create_test_message("Longer test message...");
cache.put(&original, compressed);
cache.get(&original);
cache.get(&original);
cache.get(&original);
assert_eq!(cache.stats().hits, 3);
}
#[test]
fn test_cache_minimum_size() {
let config = CacheConfig {
min_size_to_cache: 50,
..Default::default()
};
let mut cache = CompressionCache::new(config);
let small_msg = create_test_message("Short");
let compressed = create_test_message("...");
cache.put(&small_msg, compressed);
assert!(cache.get(&small_msg).is_none());
}
#[test]
fn test_cache_eviction() {
let config = CacheConfig {
max_entries: 2,
..Default::default()
};
let mut cache = CompressionCache::new(config);
let msg1 = create_test_message("Message 1 - long enough for caching");
let msg2 = create_test_message("Message 2 - also long enough");
let msg3 = create_test_message("Message 3 - this one too");
cache.put(&msg1, msg1.clone());
cache.put(&msg2, msg2.clone());
assert_eq!(cache.len(), 2);
cache.put(&msg3, msg3.clone());
assert_eq!(cache.len(), 2);
assert!(cache.get(&msg1).is_none());
assert!(cache.get(&msg2).is_some());
assert!(cache.get(&msg3).is_some());
}
#[test]
fn test_cache_clear() {
let mut cache = CompressionCache::default();
let msg = create_test_message("Long enough message for the cache system");
cache.put(&msg, msg.clone());
assert!(!cache.is_empty());
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn test_cache_stats() {
let mut cache = CompressionCache::default();
let msg = create_test_message("This is a test message for statistics tracking");
cache.get(&msg);
assert_eq!(cache.stats().misses, 1);
assert_eq!(cache.stats().hits, 0);
cache.put(&msg, msg.clone());
cache.get(&msg);
assert_eq!(cache.stats().hits, 1);
assert_eq!(cache.stats().hit_rate(), 0.5);
}
#[test]
fn test_message_hash_consistency() {
let msg1 = create_test_message("Test message");
let msg2 = create_test_message("Test message");
let msg3 = create_test_message("Different message");
let hash1 = CompressionCache::hash_message(&msg1);
let hash2 = CompressionCache::hash_message(&msg2);
let hash3 = CompressionCache::hash_message(&msg3);
assert_eq!(hash1, hash2);
assert_ne!(hash1, hash3);
}
}