use super::stats::SharedCacheStatistics;
use super::{
CacheConfig, CacheEntry, CacheError, CacheHealth, CacheKey, CacheKeyGenerator, CacheMetrics,
CacheResult, CacheStatistics, CacheWarmer, WarmingConfig,
};
use crate::core::traits::{GenerationParams, LanguageModel, ModelInfo};
use crate::core::Result;
use moka::future::Cache;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
pub struct CachedLLMClient<T: LanguageModel> {
inner: Arc<T>,
cache: Cache<String, CacheEntry>,
config: CacheConfig,
key_generator: Arc<RwLock<CacheKeyGenerator>>,
stats: SharedCacheStatistics,
warming_config: Option<WarmingConfig>,
}
impl<T: LanguageModel + Send + Sync> CachedLLMClient<T> {
pub async fn new(inner: T, config: CacheConfig) -> CacheResult<Self> {
config.validate()?;
let cache = Self::build_cache(&config).await?;
let key_generator = Arc::new(RwLock::new(CacheKeyGenerator::new()));
let stats = Arc::new(CacheStatistics::new());
Ok(Self {
inner: Arc::new(inner),
cache,
config,
key_generator,
stats,
warming_config: None,
})
}
pub async fn with_warming(
inner: T,
config: CacheConfig,
warming_config: WarmingConfig,
) -> CacheResult<Self> {
let mut client = Self::new(inner, config).await?;
client.warming_config = Some(warming_config);
Ok(client)
}
async fn build_cache(config: &CacheConfig) -> CacheResult<Cache<String, CacheEntry>> {
let mut cache_builder = Cache::builder();
cache_builder = cache_builder.max_capacity(config.max_capacity);
if let Some(initial) = config.initial_capacity {
cache_builder = cache_builder.initial_capacity(initial as usize);
}
cache_builder = cache_builder.time_to_live(config.ttl_duration());
cache_builder = cache_builder.time_to_idle(config.cleanup_interval());
let cache = cache_builder.build();
Ok(cache)
}
async fn execute_with_cache<F, Fut>(&self, cache_key: CacheKey, operation: F) -> Result<String>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<String>>,
{
let start_time = Instant::now();
if let Some(mut entry) = self.cache.get(&cache_key.key_hash).await {
entry.access();
self.cache
.insert(cache_key.key_hash.clone(), entry.clone())
.await;
let time_saved = start_time.elapsed();
self.stats.record_hit(time_saved);
return Ok(entry.response);
}
self.stats.record_miss();
let response = operation().await?;
if response.len() > self.config.max_entry_size {
self.stats.record_error();
return Err(crate::core::GraphRAGError::Generation {
message: format!(
"Response size ({} bytes) exceeds maximum cache entry size ({} bytes)",
response.len(),
self.config.max_entry_size
),
});
}
let entry = CacheEntry::new(response.clone());
let entry_size = self.estimate_entry_size(&entry);
self.cache.insert(cache_key.key_hash, entry).await;
self.stats.record_insertion(entry_size);
Ok(response)
}
fn estimate_entry_size(&self, entry: &CacheEntry) -> usize {
entry.response.len()
+ entry
.metadata
.iter()
.map(|(k, v)| k.len() + v.len())
.sum::<usize>()
+ 100 }
async fn generate_cache_key(
&self,
prompt: &str,
params: Option<&GenerationParams>,
) -> CacheResult<CacheKey> {
let key_gen = self.key_generator.read().await;
let model_info = self.inner.model_info();
key_gen.generate_key_with_params(prompt, params, Some(&model_info.name))
}
pub async fn is_cached(&self, prompt: &str, params: Option<&GenerationParams>) -> bool {
if let Ok(cache_key) = self.generate_cache_key(prompt, params).await {
self.cache.get(&cache_key.key_hash).await.is_some()
} else {
false
}
}
pub fn cache_statistics(&self) -> CacheMetrics {
self.stats.snapshot()
}
pub fn cache_health(&self) -> CacheHealth {
let metrics = self.cache_statistics();
CacheHealth::evaluate(metrics, self.config.max_capacity)
}
pub async fn clear_cache(&self) {
self.cache.invalidate_all();
}
pub async fn invalidate(
&self,
prompt: &str,
params: Option<&GenerationParams>,
) -> CacheResult<bool> {
let cache_key = self.generate_cache_key(prompt, params).await?;
if let Some(entry) = self.cache.get(&cache_key.key_hash).await {
let entry_size = self.estimate_entry_size(&entry);
self.cache.invalidate(&cache_key.key_hash).await;
self.stats.record_eviction(entry_size);
Ok(true)
} else {
Ok(false)
}
}
pub fn cache_utilization(&self) -> f64 {
let current_size = self.cache.entry_count();
if self.config.max_capacity == 0 {
0.0
} else {
(current_size as f64 / self.config.max_capacity as f64).min(1.0)
}
}
pub async fn warm_cache(&self) -> CacheResult<()> {
if let Some(warming_config) = &self.warming_config {
let warmer = CacheWarmer::new(warming_config.clone())?;
let results = warmer.warm_cache(self).await?;
results.print();
Ok(())
} else {
Err(CacheError::Configuration(
"Cache warming not configured".to_string(),
))
}
}
pub async fn update_key_strategy(&self, new_generator: CacheKeyGenerator) {
let mut key_gen = self.key_generator.write().await;
*key_gen = new_generator;
}
pub fn inner(&self) -> &T {
&self.inner
}
pub fn config(&self) -> &CacheConfig {
&self.config
}
pub async fn export_cache(&self) -> CacheResult<Vec<(String, CacheEntry)>> {
let entries = Vec::new();
Ok(entries)
}
pub fn print_statistics(&self) {
self.stats.print();
}
pub fn print_health(&self) {
let health = self.cache_health();
health.print();
}
pub fn cache_info(&self) -> CacheInfo {
let metrics = self.cache_statistics();
let health = self.cache_health();
CacheInfo {
config: self.config.clone(),
metrics,
health,
entry_count: self.cache.entry_count(),
weighted_size: self.cache.weighted_size(),
}
}
}
#[derive(Debug)]
pub struct CacheInfo {
pub config: CacheConfig,
pub metrics: CacheMetrics,
pub health: CacheHealth,
pub entry_count: u64,
pub weighted_size: u64,
}
impl CacheInfo {
pub fn print(&self) {
tracing::info!(
max_capacity = self.config.max_capacity,
ttl_seconds = self.config.ttl_seconds,
eviction_policy = ?self.config.eviction_policy,
statistics_enabled = self.config.enable_statistics,
entry_count = self.entry_count,
weighted_size = self.weighted_size,
utilization = format!("{:.1}%", (self.entry_count as f64 / self.config.max_capacity as f64 * 100.0).min(100.0)),
"Cache information"
);
self.health.print();
}
}
impl<T: LanguageModel + Send + Sync> LanguageModel for CachedLLMClient<T> {
type Error = crate::core::GraphRAGError;
fn complete(&self, prompt: &str) -> Result<String> {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let cache_key = self.generate_cache_key(prompt, None).await.map_err(|e| {
crate::core::GraphRAGError::Generation {
message: format!("Cache key generation failed: {e}"),
}
})?;
self.execute_with_cache(cache_key, || async { self.inner.complete(prompt) })
.await
})
})
}
fn complete_with_params(&self, prompt: &str, params: GenerationParams) -> Result<String> {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::current().block_on(async {
let cache_key = self
.generate_cache_key(prompt, Some(¶ms))
.await
.map_err(|e| crate::core::GraphRAGError::Generation {
message: format!("Cache key generation failed: {e}"),
})?;
self.execute_with_cache(cache_key, || async {
self.inner.complete_with_params(prompt, params.clone())
})
.await
})
})
}
fn is_available(&self) -> bool {
self.inner.is_available()
}
fn model_info(&self) -> ModelInfo {
let mut info = self.inner.model_info();
info.name = format!("Cached({})", info.name);
info
}
}
impl<T: LanguageModel + Send + Sync> CachedLLMClient<T> {
pub async fn complete_async(&self, prompt: &str) -> Result<String> {
let cache_key = self.generate_cache_key(prompt, None).await.map_err(|e| {
crate::core::GraphRAGError::Generation {
message: format!("Cache key generation failed: {e}"),
}
})?;
self.execute_with_cache(cache_key, || async { self.inner.complete(prompt) })
.await
}
pub async fn complete_with_params_async(
&self,
prompt: &str,
params: GenerationParams,
) -> Result<String> {
let cache_key = self
.generate_cache_key(prompt, Some(¶ms))
.await
.map_err(|e| crate::core::GraphRAGError::Generation {
message: format!("Cache key generation failed: {e}"),
})?;
self.execute_with_cache(cache_key, || async {
self.inner.complete_with_params(prompt, params.clone())
})
.await
}
}
impl<T: LanguageModel> Clone for CachedLLMClient<T> {
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
cache: self.cache.clone(),
config: self.config.clone(),
key_generator: Arc::clone(&self.key_generator),
stats: Arc::clone(&self.stats),
warming_config: self.warming_config.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::traits::GenerationParams;
use crate::generation::MockLLM;
#[tokio::test]
async fn test_cache_hit_miss() {
let mock_llm = MockLLM::new().unwrap();
let config = CacheConfig::development();
let client = CachedLLMClient::new(mock_llm, config).await.unwrap();
let prompt = "What is AI?";
let response1 = client.complete_async(prompt).await.unwrap();
let stats1 = client.cache_statistics();
assert_eq!(stats1.cache_misses, 1);
assert_eq!(stats1.cache_hits, 0);
let response2 = client.complete_async(prompt).await.unwrap();
let stats2 = client.cache_statistics();
assert_eq!(stats2.cache_misses, 1);
assert_eq!(stats2.cache_hits, 1);
assert_eq!(response1, response2);
}
#[tokio::test]
async fn test_cache_with_params() {
let mock_llm = MockLLM::new().unwrap();
let config = CacheConfig::development();
let client = CachedLLMClient::new(mock_llm, config).await.unwrap();
let prompt = "What is AI?";
let params1 = GenerationParams {
temperature: Some(0.7),
..Default::default()
};
let params2 = GenerationParams {
temperature: Some(0.9),
..Default::default()
};
let _response1 = client
.complete_with_params_async(prompt, params1)
.await
.unwrap();
let _response2 = client
.complete_with_params_async(prompt, params2)
.await
.unwrap();
let stats = client.cache_statistics();
assert_eq!(stats.cache_misses, 2); assert_eq!(stats.cache_hits, 0);
}
#[tokio::test]
async fn test_cache_invalidation() {
let mock_llm = MockLLM::new().unwrap();
let config = CacheConfig::development();
let client = CachedLLMClient::new(mock_llm, config).await.unwrap();
let prompt = "What is AI?";
let _response1 = client.complete_async(prompt).await.unwrap();
assert!(client.is_cached(prompt, None).await);
let was_cached = client.invalidate(prompt, None).await.unwrap();
assert!(was_cached);
assert!(!client.is_cached(prompt, None).await);
}
#[tokio::test]
async fn test_cache_utilization() {
let mock_llm = MockLLM::new().unwrap();
let mut config = CacheConfig::development();
config.max_capacity = 10; let client = CachedLLMClient::new(mock_llm, config).await.unwrap();
assert_eq!(client.cache_utilization(), 0.0);
for i in 0..5 {
let prompt = format!("Query {i}");
let _ = client.complete_async(&prompt).await.unwrap();
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let utilization = client.cache_utilization();
assert!(
utilization >= 0.0,
"Utilization should be >= 0.0, got: {}",
utilization
);
assert!(
utilization <= 1.0,
"Utilization should be <= 1.0, got: {}",
utilization
);
}
#[tokio::test]
async fn test_cache_health() {
let mock_llm = MockLLM::new().unwrap();
let config = CacheConfig::development();
let client = CachedLLMClient::new(mock_llm, config).await.unwrap();
let queries = vec!["Query A", "Query B", "Query C"];
for _ in 0..3 {
for query in &queries {
let _ = client.complete_async(query).await.unwrap();
}
}
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let health = client.cache_health();
let metrics = client.cache_statistics();
assert!(
metrics.hit_rate >= 0.5,
"Hit rate should be >= 50%, got: {}",
metrics.hit_rate
);
assert!(
matches!(
health.status,
super::super::stats::HealthStatus::Healthy
| super::super::stats::HealthStatus::Warning
),
"Expected Healthy/Warning but got: {:?}",
health.status
);
}
#[test]
fn test_language_model_trait() {
let mock_llm = MockLLM::new().unwrap();
let config = CacheConfig::development();
tokio::runtime::Runtime::new().unwrap().block_on(async {
let client = CachedLLMClient::new(mock_llm, config).await.unwrap();
assert!(client.is_available());
let model_info = client.model_info();
assert!(model_info.name.contains("Cached"));
let response = client.complete("Test prompt").unwrap();
assert!(!response.is_empty());
});
}
}