1use serde::{Deserialize, Serialize};
34use std::collections::HashMap;
35use std::path::{Path, PathBuf};
36use std::sync::{Mutex, RwLock};
37use std::time::{Duration, SystemTime, UNIX_EPOCH};
38
39pub const MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60);
42
43const FORMAT_MAGIC: &str = "aube-tls-tickets/v1";
44
45#[inline]
47pub fn is_disabled() -> bool {
48 crate::env::embedder_env("DISABLE_TLS_TICKET_CACHE").is_some()
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TicketEntry {
57 pub ticket: Vec<u8>,
59 pub spki_fp: [u8; 32],
61 pub stored_at_unix_secs: u64,
63}
64
65#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
67pub struct HostPort {
68 pub host: String,
69 pub port: u16,
70}
71
72impl HostPort {
73 pub fn new(host: impl Into<String>, port: u16) -> Self {
74 Self {
75 host: host.into().to_ascii_lowercase(),
76 port,
77 }
78 }
79}
80
81#[derive(Debug, Default, Serialize, Deserialize)]
82struct OnDisk {
83 magic: String,
85 entries: Vec<(HostPort, Vec<TicketEntry>)>,
90}
91
92#[derive(Debug)]
96pub struct TicketCache {
97 path: PathBuf,
98 inner: RwLock<HashMap<HostPort, Vec<TicketEntry>>>,
99 file_lock: Mutex<()>,
103}
104
105impl TicketCache {
106 pub fn open(path: impl Into<PathBuf>) -> Self {
110 let path = path.into();
111 let inner = if is_disabled() {
112 HashMap::new()
113 } else {
114 load_from_disk(&path).unwrap_or_default()
115 };
116 Self {
117 path,
118 inner: RwLock::new(inner),
119 file_lock: Mutex::new(()),
120 }
121 }
122
123 pub fn get(&self, host: &str, port: u16) -> Vec<TicketEntry> {
127 if is_disabled() {
128 return Vec::new();
129 }
130 let key = HostPort::new(host, port);
131 let now = unix_now();
132 let inner = self.inner.read().unwrap_or_else(|e| e.into_inner());
133 inner
134 .get(&key)
135 .map(|tickets| {
136 tickets
137 .iter()
138 .filter(|t| now.saturating_sub(t.stored_at_unix_secs) < MAX_AGE.as_secs())
139 .cloned()
140 .collect()
141 })
142 .unwrap_or_default()
143 }
144
145 pub fn put(&self, host: &str, port: u16, entry: TicketEntry) {
149 if is_disabled() {
150 return;
151 }
152 const MAX_PER_HOST: usize = 4;
153 let key = HostPort::new(host, port);
154 let mut inner = self.inner.write().unwrap_or_else(|e| e.into_inner());
155 let bucket = inner.entry(key).or_default();
156 bucket.push(entry);
157 if bucket.len() > MAX_PER_HOST {
158 let drop = bucket.len() - MAX_PER_HOST;
159 bucket.drain(..drop);
160 }
161 }
162
163 pub fn invalidate(&self, host: &str, port: u16) {
168 let key = HostPort::new(host, port);
169 let mut inner = self.inner.write().unwrap_or_else(|e| e.into_inner());
170 inner.remove(&key);
171 }
172
173 pub fn save(&self) -> std::io::Result<()> {
176 if is_disabled() {
177 return Ok(());
178 }
179 let _guard = self.file_lock.lock().unwrap_or_else(|e| e.into_inner());
180 let inner = self.inner.read().unwrap_or_else(|e| e.into_inner());
181 let payload = OnDisk {
182 magic: FORMAT_MAGIC.to_string(),
183 entries: inner.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
184 };
185 let bytes = serde_json::to_vec(&payload).map_err(std::io::Error::other)?;
186 crate::fs_atomic::atomic_write(&self.path, &bytes)?;
187 #[cfg(unix)]
192 {
193 use std::os::unix::fs::PermissionsExt as _;
194 let _ = std::fs::set_permissions(&self.path, std::fs::Permissions::from_mode(0o600));
195 }
196 Ok(())
197 }
198
199 pub fn len(&self) -> usize {
201 let inner = self.inner.read().unwrap_or_else(|e| e.into_inner());
202 inner.values().map(|v| v.len()).sum()
203 }
204
205 pub fn is_empty(&self) -> bool {
206 self.len() == 0
207 }
208}
209
210fn load_from_disk(path: &Path) -> Option<HashMap<HostPort, Vec<TicketEntry>>> {
211 let bytes = std::fs::read(path).ok()?;
212 let payload: OnDisk = serde_json::from_slice(&bytes).ok()?;
213 if payload.magic != FORMAT_MAGIC {
214 return None;
215 }
216 let now = unix_now();
217 let map: HashMap<HostPort, Vec<TicketEntry>> = payload
218 .entries
219 .into_iter()
220 .filter_map(|(k, v)| {
221 let fresh: Vec<TicketEntry> = v
222 .into_iter()
223 .filter(|t| now.saturating_sub(t.stored_at_unix_secs) < MAX_AGE.as_secs())
224 .collect();
225 if fresh.is_empty() {
226 None
227 } else {
228 Some((k, fresh))
229 }
230 })
231 .collect();
232 Some(map)
233}
234
235fn unix_now() -> u64 {
236 SystemTime::now()
237 .duration_since(UNIX_EPOCH)
238 .map(|d| d.as_secs())
239 .unwrap_or(0)
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245 use tempfile::tempdir;
246
247 fn entry(label: u8) -> TicketEntry {
248 TicketEntry {
249 ticket: vec![label, label + 1, label + 2],
250 spki_fp: [label; 32],
251 stored_at_unix_secs: unix_now(),
252 }
253 }
254
255 #[test]
256 fn roundtrip_persists_across_open() {
257 let dir = tempdir().unwrap();
258 let path = dir.path().join("tickets.json");
259 {
260 let cache = TicketCache::open(&path);
261 cache.put("registry.npmjs.org", 443, entry(1));
262 cache.save().unwrap();
263 }
264 let reopened = TicketCache::open(&path);
265 let tickets = reopened.get("registry.npmjs.org", 443);
266 assert_eq!(tickets.len(), 1);
267 assert_eq!(tickets[0].ticket, vec![1, 2, 3]);
268 }
269
270 #[test]
271 fn host_port_lowercases() {
272 let a = HostPort::new("Registry.NPMJS.ORG", 443);
273 let b = HostPort::new("registry.npmjs.org", 443);
274 assert_eq!(a, b);
275 }
276
277 #[test]
278 fn invalidate_removes_all_for_host() {
279 let dir = tempdir().unwrap();
280 let cache = TicketCache::open(dir.path().join("tickets.json"));
281 cache.put("a.example", 443, entry(1));
282 cache.put("a.example", 443, entry(2));
283 assert_eq!(cache.len(), 2);
284 cache.invalidate("a.example", 443);
285 assert!(cache.is_empty());
286 }
287
288 #[test]
289 fn max_per_host_evicts_oldest() {
290 let dir = tempdir().unwrap();
291 let cache = TicketCache::open(dir.path().join("tickets.json"));
292 for i in 0..6u8 {
293 cache.put("a.example", 443, entry(i));
294 }
295 let kept = cache.get("a.example", 443);
296 assert_eq!(kept.len(), 4, "MAX_PER_HOST = 4");
297 assert!(kept.iter().all(|t| t.ticket[0] >= 2));
299 }
300
301 #[test]
302 fn stale_entries_filtered_at_load() {
303 let dir = tempdir().unwrap();
304 let path = dir.path().join("tickets.json");
305 {
306 let cache = TicketCache::open(&path);
307 let mut stale = entry(9);
308 stale.stored_at_unix_secs = 0;
309 cache.put("a.example", 443, stale);
310 cache.save().unwrap();
311 }
312 let reopened = TicketCache::open(&path);
313 assert!(reopened.get("a.example", 443).is_empty());
314 }
315
316 struct EnvVarGuard {
322 key: &'static str,
323 }
324 impl Drop for EnvVarGuard {
325 fn drop(&mut self) {
326 unsafe { std::env::remove_var(self.key) };
329 }
330 }
331
332 #[test]
333 fn killswitch_short_circuits() {
334 unsafe { std::env::set_var("AUBE_DISABLE_TLS_TICKET_CACHE", "1") };
337 let _cleanup = EnvVarGuard {
338 key: "AUBE_DISABLE_TLS_TICKET_CACHE",
339 };
340 let dir = tempdir().unwrap();
341 let cache = TicketCache::open(dir.path().join("tickets.json"));
342 cache.put("a.example", 443, entry(1));
343 assert!(cache.get("a.example", 443).is_empty());
344 }
345
346 #[test]
347 fn missing_file_loads_empty() {
348 let dir = tempdir().unwrap();
349 let cache = TicketCache::open(dir.path().join("nonexistent.json"));
350 assert!(cache.is_empty());
351 }
352
353 #[test]
354 fn corrupt_magic_loads_empty() {
355 let dir = tempdir().unwrap();
356 let path = dir.path().join("tickets.json");
357 std::fs::write(&path, br#"{"magic":"wrong","entries":[]}"#).unwrap();
358 let cache = TicketCache::open(&path);
359 assert!(cache.is_empty());
360 }
361}