llm_orchestrator_secrets/
cache.rs

1// Copyright (c) 2025 LLM DevOps
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! In-memory secret cache with TTL.
5//!
6//! Provides a caching layer for secret stores to reduce backend calls
7//! and improve performance.
8
9use crate::models::{Secret, SecretMetadata, SecretVersion};
10use crate::traits::{Result, SecretStore};
11use async_trait::async_trait;
12use chrono::{DateTime, Duration, Utc};
13use parking_lot::RwLock;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tracing::{debug, trace};
17
18/// Cached secret with expiration.
19#[derive(Debug, Clone)]
20struct CachedSecret {
21    /// The cached secret.
22    secret: Secret,
23    /// When this cache entry expires.
24    expires_at: DateTime<Utc>,
25}
26
27impl CachedSecret {
28    /// Check if this cache entry is expired.
29    fn is_expired(&self) -> bool {
30        Utc::now() >= self.expires_at
31    }
32}
33
34/// Secret cache wrapper that adds TTL-based caching to any SecretStore.
35///
36/// # Features
37///
38/// - Configurable TTL (default: 300 seconds / 5 minutes)
39/// - Thread-safe caching with read-write locks
40/// - Automatic expiration checking
41/// - Manual cache invalidation
42/// - Cache statistics tracking
43///
44/// # Example
45///
46/// ```no_run
47/// use llm_orchestrator_secrets::{SecretCache, EnvSecretStore};
48/// use std::sync::Arc;
49/// use chrono::Duration;
50///
51/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
52/// let backend = Arc::new(EnvSecretStore::new());
53/// let cache = SecretCache::new(backend, Duration::minutes(5));
54///
55/// // First call hits the backend
56/// let secret1 = cache.get("api_key").await?;
57///
58/// // Second call within TTL uses cache
59/// let secret2 = cache.get("api_key").await?;
60/// # Ok(())
61/// # }
62/// ```
63pub struct SecretCache<S: SecretStore + ?Sized> {
64    /// The underlying secret store backend.
65    backend: Arc<S>,
66    /// Cache storage.
67    cache: Arc<RwLock<HashMap<String, CachedSecret>>>,
68    /// Time-to-live for cached secrets.
69    ttl: Duration,
70    /// Cache statistics.
71    stats: Arc<RwLock<CacheStats>>,
72}
73
74/// Cache statistics for monitoring.
75#[derive(Debug, Clone, Default)]
76pub struct CacheStats {
77    /// Total number of cache hits.
78    pub hits: u64,
79    /// Total number of cache misses.
80    pub misses: u64,
81    /// Total number of expired entries encountered.
82    pub expirations: u64,
83    /// Total number of manual invalidations.
84    pub invalidations: u64,
85}
86
87impl CacheStats {
88    /// Calculate the cache hit rate as a percentage.
89    pub fn hit_rate(&self) -> f64 {
90        let total = self.hits + self.misses;
91        if total == 0 {
92            0.0
93        } else {
94            (self.hits as f64 / total as f64) * 100.0
95        }
96    }
97
98    /// Get total number of cache accesses.
99    pub fn total_accesses(&self) -> u64 {
100        self.hits + self.misses
101    }
102}
103
104impl<S: SecretStore + ?Sized> SecretCache<S> {
105    /// Create a new secret cache.
106    ///
107    /// # Arguments
108    ///
109    /// * `backend` - The underlying secret store
110    /// * `ttl` - Time-to-live for cached entries
111    pub fn new(backend: Arc<S>, ttl: Duration) -> Self {
112        debug!("Creating secret cache with TTL of {} seconds", ttl.num_seconds());
113        Self {
114            backend,
115            cache: Arc::new(RwLock::new(HashMap::new())),
116            ttl,
117            stats: Arc::new(RwLock::new(CacheStats::default())),
118        }
119    }
120
121    /// Create a new secret cache with default TTL (5 minutes).
122    pub fn with_default_ttl(backend: Arc<S>) -> Self {
123        Self::new(backend, Duration::minutes(5))
124    }
125
126    /// Get a secret, using cache if available and not expired.
127    pub async fn get(&self, key: &str) -> Result<Secret> {
128        trace!("Cache lookup for key: {}", key);
129
130        // Try to get from cache first
131        {
132            let cache_guard = self.cache.read();
133            if let Some(cached) = cache_guard.get(key) {
134                if !cached.is_expired() {
135                    debug!("Cache hit for key: {}", key);
136                    self.stats.write().hits += 1;
137                    return Ok(cached.secret.clone());
138                } else {
139                    debug!("Cache entry expired for key: {}", key);
140                    self.stats.write().expirations += 1;
141                    // Entry is expired, fall through to fetch from backend
142                }
143            } else {
144                debug!("Cache miss for key: {}", key);
145                self.stats.write().misses += 1;
146            }
147        }
148
149        // Not in cache or expired, fetch from backend
150        let secret = self.backend.get_secret(key).await?;
151
152        // Store in cache
153        {
154            let mut cache_guard = self.cache.write();
155            let expires_at = Utc::now() + self.ttl;
156            cache_guard.insert(
157                key.to_string(),
158                CachedSecret {
159                    secret: secret.clone(),
160                    expires_at,
161                },
162            );
163            debug!("Cached secret {} until {}", key, expires_at);
164        }
165
166        Ok(secret)
167    }
168
169    /// Invalidate a specific cache entry.
170    ///
171    /// # Arguments
172    ///
173    /// * `key` - The secret key to invalidate
174    pub fn invalidate(&self, key: &str) {
175        let mut cache_guard = self.cache.write();
176        if cache_guard.remove(key).is_some() {
177            debug!("Invalidated cache entry for key: {}", key);
178            self.stats.write().invalidations += 1;
179        }
180    }
181
182    /// Clear all cached entries.
183    pub fn clear(&self) {
184        let mut cache_guard = self.cache.write();
185        let count = cache_guard.len();
186        cache_guard.clear();
187        debug!("Cleared {} cache entries", count);
188        self.stats.write().invalidations += count as u64;
189    }
190
191    /// Remove expired entries from the cache.
192    ///
193    /// This is useful for periodic cleanup to prevent memory growth.
194    pub fn cleanup_expired(&self) {
195        let mut cache_guard = self.cache.write();
196        let before_count = cache_guard.len();
197        cache_guard.retain(|key, cached| {
198            let is_valid = !cached.is_expired();
199            if !is_valid {
200                trace!("Removing expired cache entry: {}", key);
201            }
202            is_valid
203        });
204        let removed = before_count - cache_guard.len();
205        if removed > 0 {
206            debug!("Cleaned up {} expired cache entries", removed);
207            self.stats.write().expirations += removed as u64;
208        }
209    }
210
211    /// Get cache statistics.
212    pub fn stats(&self) -> CacheStats {
213        self.stats.read().clone()
214    }
215
216    /// Get the number of entries currently in the cache.
217    pub fn size(&self) -> usize {
218        self.cache.read().len()
219    }
220
221    /// Get the TTL duration.
222    pub fn ttl(&self) -> Duration {
223        self.ttl
224    }
225}
226
227#[async_trait]
228impl<S: SecretStore + ?Sized> SecretStore for SecretCache<S> {
229    async fn get_secret(&self, key: &str) -> Result<Secret> {
230        self.get(key).await
231    }
232
233    async fn put_secret(
234        &self,
235        key: &str,
236        value: &str,
237        metadata: Option<SecretMetadata>,
238    ) -> Result<()> {
239        // Invalidate cache for this key
240        self.invalidate(key);
241
242        // Forward to backend
243        self.backend.put_secret(key, value, metadata).await
244    }
245
246    async fn delete_secret(&self, key: &str) -> Result<()> {
247        // Invalidate cache for this key
248        self.invalidate(key);
249
250        // Forward to backend
251        self.backend.delete_secret(key).await
252    }
253
254    async fn list_secrets(&self, prefix: &str) -> Result<Vec<String>> {
255        // List operations are not cached
256        self.backend.list_secrets(prefix).await
257    }
258
259    async fn rotate_secret(&self, key: &str) -> Result<Secret> {
260        // Invalidate cache for this key
261        self.invalidate(key);
262
263        // Forward to backend
264        self.backend.rotate_secret(key).await
265    }
266
267    async fn health_check(&self) -> Result<()> {
268        self.backend.health_check().await
269    }
270
271    async fn get_secret_versions(&self, key: &str) -> Result<Vec<SecretVersion>> {
272        // Version listing is not cached
273        self.backend.get_secret_versions(key).await
274    }
275
276    async fn get_secret_version(&self, key: &str, version: &str) -> Result<Secret> {
277        // Versioned secrets are not cached (they are immutable)
278        self.backend.get_secret_version(key, version).await
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::env::EnvSecretStore;
286    use std::env;
287
288    #[tokio::test]
289    async fn test_cache_hit() {
290        env::set_var("TEST_CACHE_KEY", "test_value");
291
292        let backend = Arc::new(EnvSecretStore::new());
293        let cache = SecretCache::new(backend, Duration::minutes(5));
294
295        // First access - should be a miss
296        let secret1 = cache.get("test/cache/key").await.unwrap();
297        assert_eq!(secret1.value, "test_value");
298
299        let stats1 = cache.stats();
300        assert_eq!(stats1.misses, 1);
301        assert_eq!(stats1.hits, 0);
302
303        // Second access - should be a hit
304        let secret2 = cache.get("test/cache/key").await.unwrap();
305        assert_eq!(secret2.value, "test_value");
306
307        let stats2 = cache.stats();
308        assert_eq!(stats2.misses, 1);
309        assert_eq!(stats2.hits, 1);
310
311        env::remove_var("TEST_CACHE_KEY");
312    }
313
314    #[tokio::test]
315    async fn test_cache_expiration() {
316        env::set_var("TEST_EXPIRE_KEY", "expire_value");
317
318        let backend = Arc::new(EnvSecretStore::new());
319        let cache = SecretCache::new(backend, Duration::milliseconds(100));
320
321        // First access
322        let _ = cache.get("test/expire/key").await.unwrap();
323
324        // Wait for expiration
325        tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
326
327        // Second access after expiration - should be treated as miss
328        let _ = cache.get("test/expire/key").await.unwrap();
329
330        let stats = cache.stats();
331        assert_eq!(stats.misses, 2); // Both should be misses
332        assert_eq!(stats.expirations, 1);
333
334        env::remove_var("TEST_EXPIRE_KEY");
335    }
336
337    #[tokio::test]
338    async fn test_cache_invalidation() {
339        env::set_var("TEST_INVALIDATE_KEY", "invalidate_value");
340
341        let backend = Arc::new(EnvSecretStore::new());
342        let cache = SecretCache::new(backend, Duration::minutes(5));
343
344        // First access
345        let _ = cache.get("test/invalidate/key").await.unwrap();
346
347        // Invalidate
348        cache.invalidate("test/invalidate/key");
349
350        // Second access - should be a miss due to invalidation
351        let _ = cache.get("test/invalidate/key").await.unwrap();
352
353        let stats = cache.stats();
354        assert_eq!(stats.misses, 2);
355        assert_eq!(stats.invalidations, 1);
356
357        env::remove_var("TEST_INVALIDATE_KEY");
358    }
359
360    #[tokio::test]
361    async fn test_cache_clear() {
362        env::set_var("TEST_CLEAR_KEY1", "value1");
363        env::set_var("TEST_CLEAR_KEY2", "value2");
364
365        let backend = Arc::new(EnvSecretStore::new());
366        let cache = SecretCache::new(backend, Duration::minutes(5));
367
368        // Access multiple keys
369        let _ = cache.get("test/clear/key1").await.unwrap();
370        let _ = cache.get("test/clear/key2").await.unwrap();
371
372        assert_eq!(cache.size(), 2);
373
374        // Clear cache
375        cache.clear();
376
377        assert_eq!(cache.size(), 0);
378
379        let stats = cache.stats();
380        assert_eq!(stats.invalidations, 2);
381
382        env::remove_var("TEST_CLEAR_KEY1");
383        env::remove_var("TEST_CLEAR_KEY2");
384    }
385
386    #[tokio::test]
387    async fn test_cache_stats_hit_rate() {
388        env::set_var("TEST_STATS_KEY", "stats_value");
389
390        let backend = Arc::new(EnvSecretStore::new());
391        let cache = SecretCache::new(backend, Duration::minutes(5));
392
393        // 1 miss
394        let _ = cache.get("test/stats/key").await.unwrap();
395        // 3 hits
396        let _ = cache.get("test/stats/key").await.unwrap();
397        let _ = cache.get("test/stats/key").await.unwrap();
398        let _ = cache.get("test/stats/key").await.unwrap();
399
400        let stats = cache.stats();
401        assert_eq!(stats.total_accesses(), 4);
402        assert_eq!(stats.hit_rate(), 75.0); // 3 hits out of 4 total
403
404        env::remove_var("TEST_STATS_KEY");
405    }
406
407    #[tokio::test]
408    async fn test_cleanup_expired() {
409        env::set_var("TEST_CLEANUP_KEY1", "value1");
410        env::set_var("TEST_CLEANUP_KEY2", "value2");
411
412        let backend = Arc::new(EnvSecretStore::new());
413        let cache = SecretCache::new(backend, Duration::milliseconds(100));
414
415        // Access keys
416        let _ = cache.get("test/cleanup/key1").await.unwrap();
417        let _ = cache.get("test/cleanup/key2").await.unwrap();
418
419        assert_eq!(cache.size(), 2);
420
421        // Wait for expiration
422        tokio::time::sleep(tokio::time::Duration::from_millis(150)).await;
423
424        // Cleanup expired entries
425        cache.cleanup_expired();
426
427        assert_eq!(cache.size(), 0);
428
429        env::remove_var("TEST_CLEANUP_KEY1");
430        env::remove_var("TEST_CLEANUP_KEY2");
431    }
432}