use serde::{Deserialize, Serialize};
use std::hash::Hash;
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "caching")]
pub use moka;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub max_capacity: u64,
pub ttl_secs: Option<u64>,
pub tti_secs: Option<u64>,
pub enable_stats: bool,
pub name: String,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_capacity: 10_000,
ttl_secs: Some(3600), tti_secs: Some(1800), enable_stats: true,
name: "reasonkit_cache".to_string(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub evictions: u64,
pub size: u64,
pub hit_rate: f64,
}
pub struct ReasonKitCache<K, V>
where
K: Hash + Eq + Send + Sync + Clone + 'static,
V: Clone + Send + Sync + 'static,
{
inner: moka::future::Cache<K, V>,
config: CacheConfig,
}
impl<K, V> ReasonKitCache<K, V>
where
K: Hash + Eq + Send + Sync + Clone + 'static,
V: Clone + Send + Sync + 'static,
{
pub fn new(config: CacheConfig) -> Self {
let mut builder = moka::future::Cache::builder().max_capacity(config.max_capacity);
if let Some(ttl) = config.ttl_secs {
builder = builder.time_to_live(Duration::from_secs(ttl));
}
if let Some(tti) = config.tti_secs {
builder = builder.time_to_idle(Duration::from_secs(tti));
}
let inner = builder.build();
Self { inner, config }
}
pub fn default_cache() -> Self {
Self::new(CacheConfig::default())
}
pub async fn get(&self, key: &K) -> Option<V> {
self.inner.get(key).await
}
pub async fn insert(&self, key: K, value: V) {
self.inner.insert(key, value).await;
}
pub async fn get_or_insert_with<F>(&self, key: K, init: F) -> V
where
F: std::future::Future<Output = V>,
{
self.inner.get_with(key, init).await
}
pub async fn invalidate(&self, key: &K) {
self.inner.invalidate(key).await;
}
pub fn invalidate_all(&self) {
self.inner.invalidate_all();
}
pub fn entry_count(&self) -> u64 {
self.inner.entry_count()
}
pub async fn run_pending_tasks(&self) {
self.inner.run_pending_tasks().await;
}
pub fn config(&self) -> &CacheConfig {
&self.config
}
}
pub type LlmResponseCache = ReasonKitCache<String, CachedLlmResponse>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedLlmResponse {
pub content: String,
pub model: String,
pub tokens: usize,
pub created_at: chrono::DateTime<chrono::Utc>,
pub cache_key: String,
}
pub type EmbeddingCache = ReasonKitCache<String, CachedEmbedding>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedEmbedding {
pub vector: Vec<f32>,
pub model: String,
pub dimensions: usize,
pub created_at: chrono::DateTime<chrono::Utc>,
}
pub type ReasoningCache = ReasonKitCache<String, CachedReasoning>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedReasoning {
pub result: String,
pub confidence: f32,
pub thinktool: String,
pub steps: Vec<String>,
pub created_at: chrono::DateTime<chrono::Utc>,
}
pub struct CacheManager {
pub llm_responses: Arc<LlmResponseCache>,
pub embeddings: Arc<EmbeddingCache>,
pub reasoning: Arc<ReasoningCache>,
}
impl CacheManager {
pub fn new() -> Self {
Self {
llm_responses: Arc::new(ReasonKitCache::new(CacheConfig {
max_capacity: 5_000,
ttl_secs: Some(7200), name: "llm_responses".to_string(),
..Default::default()
})),
embeddings: Arc::new(ReasonKitCache::new(CacheConfig {
max_capacity: 50_000,
ttl_secs: Some(86400), name: "embeddings".to_string(),
..Default::default()
})),
reasoning: Arc::new(ReasonKitCache::new(CacheConfig {
max_capacity: 10_000,
ttl_secs: Some(3600), name: "reasoning".to_string(),
..Default::default()
})),
}
}
pub fn stats(&self) -> CacheManagerStats {
CacheManagerStats {
llm_response_count: self.llm_responses.entry_count(),
embedding_count: self.embeddings.entry_count(),
reasoning_count: self.reasoning.entry_count(),
}
}
pub fn invalidate_all(&self) {
self.llm_responses.invalidate_all();
self.embeddings.invalidate_all();
self.reasoning.invalidate_all();
}
}
impl Default for CacheManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheManagerStats {
pub llm_response_count: u64,
pub embedding_count: u64,
pub reasoning_count: u64,
}
pub fn generate_cache_key(content: &str) -> String {
#[cfg(feature = "fast-hash")]
{
let hash = blake3::hash(content.as_bytes());
hash.to_hex().to_string()
}
#[cfg(not(feature = "fast-hash"))]
{
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(content.as_bytes());
hex::encode(hasher.finalize())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_cache_insert_get() {
let cache: ReasonKitCache<String, String> = ReasonKitCache::default_cache();
cache.insert("key1".to_string(), "value1".to_string()).await;
let result = cache.get(&"key1".to_string()).await;
assert_eq!(result, Some("value1".to_string()));
let missing = cache.get(&"missing".to_string()).await;
assert_eq!(missing, None);
}
#[tokio::test]
async fn test_cache_manager() {
let manager = CacheManager::new();
manager
.llm_responses
.insert(
"test_key".to_string(),
CachedLlmResponse {
content: "Hello".to_string(),
model: "gpt-4".to_string(),
tokens: 10,
created_at: chrono::Utc::now(),
cache_key: "test_key".to_string(),
},
)
.await;
manager.llm_responses.run_pending_tasks().await;
let stats = manager.stats();
assert_eq!(stats.llm_response_count, 1);
}
#[test]
fn test_cache_key_generation() {
let key1 = generate_cache_key("test content");
let key2 = generate_cache_key("test content");
let key3 = generate_cache_key("different content");
assert_eq!(key1, key2);
assert_ne!(key1, key3);
}
}