1use ggen_utils::error::Result;
47use moka::future::Cache;
48use serde::{Deserialize, Serialize};
49use sha2::Sha256;
50use std::sync::Arc;
51use std::time::Duration;
52use tracing::{debug, info};
53
54#[derive(Debug, Clone)]
56pub struct CacheConfig {
57 pub max_capacity: u64,
59 pub ttl: Duration,
61 pub tti: Option<Duration>,
63}
64
65impl Default for CacheConfig {
66 fn default() -> Self {
67 Self {
68 max_capacity: 10_000,
69 ttl: Duration::from_secs(3600), tti: Some(Duration::from_secs(600)), }
72 }
73}
74
75#[derive(Debug, Clone, Serialize, Deserialize)]
77pub struct CachedResponse {
78 pub content: String,
79 pub model: String,
80 pub tokens_used: Option<usize>,
81 pub cached_at: i64, }
83
84pub struct LlmCache {
86 cache: Cache<String, Arc<CachedResponse>>,
87 config: CacheConfig,
88 hits: Arc<std::sync::atomic::AtomicU64>,
89 misses: Arc<std::sync::atomic::AtomicU64>,
90}
91
92impl LlmCache {
93 pub fn new() -> Self {
95 Self::with_config(CacheConfig::default())
96 }
97
98 pub fn with_config(config: CacheConfig) -> Self {
100 let mut builder = Cache::builder()
101 .max_capacity(config.max_capacity)
102 .time_to_live(config.ttl);
103
104 if let Some(tti) = config.tti {
105 builder = builder.time_to_idle(tti);
106 }
107
108 Self {
109 cache: builder.build(),
110 config,
111 hits: Arc::new(std::sync::atomic::AtomicU64::new(0)),
112 misses: Arc::new(std::sync::atomic::AtomicU64::new(0)),
113 }
114 }
115
116 fn cache_key(prompt: &str, model: &str) -> String {
118 use sha2::Digest;
119 let mut hasher = Sha256::new();
120 hasher.update(prompt.as_bytes());
121 hasher.update(model.as_bytes());
122 format!("{:x}", hasher.finalize())
123 }
124
125 pub async fn get_or_generate<F, Fut>(
127 &self, prompt: &str, model: &str, generator: F,
128 ) -> Result<String>
129 where
130 F: FnOnce() -> Fut,
131 Fut: std::future::Future<Output = Result<String>>,
132 {
133 let key = Self::cache_key(prompt, model);
134
135 if let Some(cached) = self.cache.get(&key).await {
137 self.hits.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
138 debug!(
139 "Cache hit for model {} (cached {} seconds ago)",
140 model,
141 chrono::Utc::now().timestamp() - cached.cached_at
142 );
143 return Ok(cached.content.clone());
144 }
145
146 self.misses
148 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
149 debug!("Cache miss for model {}", model);
150
151 let response = generator().await?;
152
153 let cached = Arc::new(CachedResponse {
155 content: response.clone(),
156 model: model.to_string(),
157 tokens_used: None,
158 cached_at: chrono::Utc::now().timestamp(),
159 });
160
161 self.cache.insert(key, cached).await;
162
163 Ok(response)
164 }
165
166 pub async fn insert(&self, prompt: &str, model: &str, response: String, tokens: Option<usize>) {
168 let key = Self::cache_key(prompt, model);
169 let cached = Arc::new(CachedResponse {
170 content: response,
171 model: model.to_string(),
172 tokens_used: tokens,
173 cached_at: chrono::Utc::now().timestamp(),
174 });
175
176 self.cache.insert(key, cached).await;
177 info!("Manually cached response for model {}", model);
178 }
179
180 pub async fn get(&self, prompt: &str, model: &str) -> Option<String> {
182 let key = Self::cache_key(prompt, model);
183 self.cache.get(&key).await.map(|c| c.content.clone())
184 }
185
186 pub async fn clear(&self) {
188 self.cache.invalidate_all();
189 self.cache.run_pending_tasks().await;
190 info!("Cache cleared");
191 }
192
193 pub fn stats(&self) -> CacheStats {
195 let hits = self.hits.load(std::sync::atomic::Ordering::Relaxed);
196 let misses = self.misses.load(std::sync::atomic::Ordering::Relaxed);
197 let total = hits + misses;
198 let hit_rate = if total > 0 {
199 (hits as f64 / total as f64) * 100.0
200 } else {
201 0.0
202 };
203
204 CacheStats {
205 hits,
206 misses,
207 hit_rate,
208 entry_count: self.cache.entry_count(),
209 weighted_size: self.cache.weighted_size(),
210 }
211 }
212
213 pub fn config(&self) -> &CacheConfig {
215 &self.config
216 }
217}
218
219impl Default for LlmCache {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct CacheStats {
228 pub hits: u64,
229 pub misses: u64,
230 pub hit_rate: f64,
231 pub entry_count: u64,
232 pub weighted_size: u64,
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[tokio::test]
240 async fn test_cache_hit() {
241 let cache = LlmCache::new();
242
243 let result1 = cache
245 .get_or_generate("test prompt", "gpt-4", || async {
246 Ok("response".to_string())
247 })
248 .await
249 .unwrap();
250
251 assert_eq!(result1, "response");
252
253 let result2 = cache
255 .get_or_generate("test prompt", "gpt-4", || async {
256 Ok("different response".to_string())
257 })
258 .await
259 .unwrap();
260
261 assert_eq!(result2, "response"); let stats = cache.stats();
264 assert_eq!(stats.hits, 1);
265 assert_eq!(stats.misses, 1);
266 }
267
268 #[tokio::test]
269 async fn test_different_models_different_cache() {
270 let cache = LlmCache::new();
271
272 let result1 = cache
273 .get_or_generate("test prompt", "gpt-4", || async {
274 Ok("gpt-4 response".to_string())
275 })
276 .await
277 .unwrap();
278
279 let result2 = cache
280 .get_or_generate("test prompt", "claude-3", || async {
281 Ok("claude response".to_string())
282 })
283 .await
284 .unwrap();
285
286 assert_eq!(result1, "gpt-4 response");
287 assert_eq!(result2, "claude response");
288 }
289}