Skip to main content

microsandbox_network/tls/
cache.rs

1//! In-memory LRU certificate cache with TTL eviction.
2
3use std::{
4    io,
5    num::NonZeroUsize,
6    sync::{Arc, RwLock},
7    time::{Duration, Instant},
8};
9
10use lru::LruCache;
11use rustls::sign::CertifiedKey;
12
13use super::{CaKeyPair, CertCacheConfig};
14
15//--------------------------------------------------------------------------------------------------
16// Types
17//--------------------------------------------------------------------------------------------------
18
19/// Thread-safe LRU cache for per-domain TLS certificates.
20///
21/// On cache miss, generates a new certificate signed by the CA. Expired
22/// entries are evicted on access. The cache is bounded by `max_entries`.
23pub struct CertCache {
24    inner: RwLock<LruCache<String, CacheEntry>>,
25    ca: CaKeyPair,
26    ttl: Duration,
27}
28
29/// A cached certificate with its creation timestamp.
30struct CacheEntry {
31    key: Arc<CertifiedKey>,
32    created: Instant,
33}
34
35//--------------------------------------------------------------------------------------------------
36// Methods
37//--------------------------------------------------------------------------------------------------
38
39impl CertCache {
40    /// Creates a new certificate cache.
41    pub fn new(ca: CaKeyPair, config: &CertCacheConfig) -> Self {
42        let capacity = NonZeroUsize::new(config.max_entries).unwrap_or(NonZeroUsize::MIN);
43        Self {
44            inner: RwLock::new(LruCache::new(capacity)),
45            ca,
46            ttl: Duration::from_secs(config.ttl_secs),
47        }
48    }
49
50    /// Returns a cached `CertifiedKey` for the domain, generating one if
51    /// absent or expired.
52    pub fn get_or_generate(&self, domain: &str) -> io::Result<Arc<CertifiedKey>> {
53        let lower = domain.to_ascii_lowercase();
54
55        // Fast path: read lock. Use unwrap_or_else to recover from poisoned
56        // locks (a previous holder panicked) rather than cascading the panic.
57        {
58            let cache = self.inner.read().unwrap_or_else(|e| e.into_inner());
59            if let Some(entry) = cache.peek(&lower)
60                && entry.created.elapsed() < self.ttl
61            {
62                return Ok(Arc::clone(&entry.key));
63            }
64        }
65
66        // Slow path: write lock, generate.
67        let mut cache = self.inner.write().unwrap_or_else(|e| e.into_inner());
68
69        // Double-check after acquiring write lock.
70        if let Some(entry) = cache.get(&lower)
71            && entry.created.elapsed() < self.ttl
72        {
73            return Ok(Arc::clone(&entry.key));
74        }
75
76        // Generate.
77        let cert = super::certgen::generate_cert(&lower, &self.ca)?;
78        let certified = super::certgen::to_certified_key(&cert)?;
79
80        cache.put(
81            lower,
82            CacheEntry {
83                key: Arc::clone(&certified),
84                created: Instant::now(),
85            },
86        );
87
88        Ok(certified)
89    }
90}
91
92//--------------------------------------------------------------------------------------------------
93// Tests
94//--------------------------------------------------------------------------------------------------
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::tls::ca;
100
101    fn make_cache() -> CertCache {
102        let ca_config = crate::tls::CaConfig::default();
103        let ca_kp = ca::generate_ca(&ca_config).unwrap();
104        let cache_config = CertCacheConfig {
105            max_entries: 10,
106            ttl_secs: 3600,
107        };
108        CertCache::new(ca_kp, &cache_config)
109    }
110
111    #[test]
112    fn test_cache_miss_generates() {
113        let cache = make_cache();
114        let cert = cache.get_or_generate("example.com").unwrap();
115        assert_eq!(cert.cert.len(), 2);
116    }
117
118    #[test]
119    fn test_cache_hit_returns_same() {
120        let cache = make_cache();
121        let cert1 = cache.get_or_generate("example.com").unwrap();
122        let cert2 = cache.get_or_generate("example.com").unwrap();
123        // Same Arc pointer.
124        assert!(Arc::ptr_eq(&cert1, &cert2));
125    }
126
127    #[test]
128    fn test_different_domains() {
129        let cache = make_cache();
130        let cert1 = cache.get_or_generate("a.com").unwrap();
131        let cert2 = cache.get_or_generate("b.com").unwrap();
132        assert!(!Arc::ptr_eq(&cert1, &cert2));
133    }
134}