microsandbox_network/tls/
cache.rs1use std::{
4 io,
5 num::NonZeroUsize,
6 sync::{Arc, RwLock},
7 time::{Duration, Instant},
8};
9
10use lru::LruCache;
11use rustls::sign::CertifiedKey;
12
13use super::{CaKeyPair, CertCacheConfig};
14
15pub struct CertCache {
24 inner: RwLock<LruCache<String, CacheEntry>>,
25 ca: CaKeyPair,
26 ttl: Duration,
27}
28
29struct CacheEntry {
31 key: Arc<CertifiedKey>,
32 created: Instant,
33}
34
35impl CertCache {
40 pub fn new(ca: CaKeyPair, config: &CertCacheConfig) -> Self {
42 let capacity = NonZeroUsize::new(config.max_entries).unwrap_or(NonZeroUsize::MIN);
43 Self {
44 inner: RwLock::new(LruCache::new(capacity)),
45 ca,
46 ttl: Duration::from_secs(config.ttl_secs),
47 }
48 }
49
50 pub fn get_or_generate(&self, domain: &str) -> io::Result<Arc<CertifiedKey>> {
53 let lower = domain.to_ascii_lowercase();
54
55 {
58 let cache = self.inner.read().unwrap_or_else(|e| e.into_inner());
59 if let Some(entry) = cache.peek(&lower)
60 && entry.created.elapsed() < self.ttl
61 {
62 return Ok(Arc::clone(&entry.key));
63 }
64 }
65
66 let mut cache = self.inner.write().unwrap_or_else(|e| e.into_inner());
68
69 if let Some(entry) = cache.get(&lower)
71 && entry.created.elapsed() < self.ttl
72 {
73 return Ok(Arc::clone(&entry.key));
74 }
75
76 let cert = super::certgen::generate_cert(&lower, &self.ca)?;
78 let certified = super::certgen::to_certified_key(&cert)?;
79
80 cache.put(
81 lower,
82 CacheEntry {
83 key: Arc::clone(&certified),
84 created: Instant::now(),
85 },
86 );
87
88 Ok(certified)
89 }
90}
91
92#[cfg(test)]
97mod tests {
98 use super::*;
99 use crate::tls::ca;
100
101 fn make_cache() -> CertCache {
102 let ca_config = crate::tls::CaConfig::default();
103 let ca_kp = ca::generate_ca(&ca_config).unwrap();
104 let cache_config = CertCacheConfig {
105 max_entries: 10,
106 ttl_secs: 3600,
107 };
108 CertCache::new(ca_kp, &cache_config)
109 }
110
111 #[test]
112 fn test_cache_miss_generates() {
113 let cache = make_cache();
114 let cert = cache.get_or_generate("example.com").unwrap();
115 assert_eq!(cert.cert.len(), 2);
116 }
117
118 #[test]
119 fn test_cache_hit_returns_same() {
120 let cache = make_cache();
121 let cert1 = cache.get_or_generate("example.com").unwrap();
122 let cert2 = cache.get_or_generate("example.com").unwrap();
123 assert!(Arc::ptr_eq(&cert1, &cert2));
125 }
126
127 #[test]
128 fn test_different_domains() {
129 let cache = make_cache();
130 let cert1 = cache.get_or_generate("a.com").unwrap();
131 let cert2 = cache.get_or_generate("b.com").unwrap();
132 assert!(!Arc::ptr_eq(&cert1, &cert2));
133 }
134}