use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Mutex, RwLock};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub const MAX_AGE: Duration = Duration::from_secs(24 * 60 * 60);
const FORMAT_MAGIC: &str = "aube-tls-tickets/v1";
#[inline]
pub fn is_disabled() -> bool {
std::env::var_os("AUBE_DISABLE_TLS_TICKET_CACHE").is_some()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TicketEntry {
pub ticket: Vec<u8>,
pub spki_fp: [u8; 32],
pub stored_at_unix_secs: u64,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct HostPort {
pub host: String,
pub port: u16,
}
impl HostPort {
pub fn new(host: impl Into<String>, port: u16) -> Self {
Self {
host: host.into().to_ascii_lowercase(),
port,
}
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
struct OnDisk {
magic: String,
entries: Vec<(HostPort, Vec<TicketEntry>)>,
}
#[derive(Debug)]
pub struct TicketCache {
path: PathBuf,
inner: RwLock<HashMap<HostPort, Vec<TicketEntry>>>,
file_lock: Mutex<()>,
}
impl TicketCache {
pub fn open(path: impl Into<PathBuf>) -> Self {
let path = path.into();
let inner = if is_disabled() {
HashMap::new()
} else {
load_from_disk(&path).unwrap_or_default()
};
Self {
path,
inner: RwLock::new(inner),
file_lock: Mutex::new(()),
}
}
pub fn get(&self, host: &str, port: u16) -> Vec<TicketEntry> {
if is_disabled() {
return Vec::new();
}
let key = HostPort::new(host, port);
let now = unix_now();
let inner = self.inner.read().unwrap_or_else(|e| e.into_inner());
inner
.get(&key)
.map(|tickets| {
tickets
.iter()
.filter(|t| now.saturating_sub(t.stored_at_unix_secs) < MAX_AGE.as_secs())
.cloned()
.collect()
})
.unwrap_or_default()
}
pub fn put(&self, host: &str, port: u16, entry: TicketEntry) {
if is_disabled() {
return;
}
const MAX_PER_HOST: usize = 4;
let key = HostPort::new(host, port);
let mut inner = self.inner.write().unwrap_or_else(|e| e.into_inner());
let bucket = inner.entry(key).or_default();
bucket.push(entry);
if bucket.len() > MAX_PER_HOST {
let drop = bucket.len() - MAX_PER_HOST;
bucket.drain(..drop);
}
}
pub fn invalidate(&self, host: &str, port: u16) {
let key = HostPort::new(host, port);
let mut inner = self.inner.write().unwrap_or_else(|e| e.into_inner());
inner.remove(&key);
}
pub fn save(&self) -> std::io::Result<()> {
if is_disabled() {
return Ok(());
}
let _guard = self.file_lock.lock().unwrap_or_else(|e| e.into_inner());
let inner = self.inner.read().unwrap_or_else(|e| e.into_inner());
let payload = OnDisk {
magic: FORMAT_MAGIC.to_string(),
entries: inner.iter().map(|(k, v)| (k.clone(), v.clone())).collect(),
};
let bytes = serde_json::to_vec(&payload).map_err(std::io::Error::other)?;
crate::fs_atomic::atomic_write(&self.path, &bytes)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt as _;
let _ = std::fs::set_permissions(&self.path, std::fs::Permissions::from_mode(0o600));
}
Ok(())
}
pub fn len(&self) -> usize {
let inner = self.inner.read().unwrap_or_else(|e| e.into_inner());
inner.values().map(|v| v.len()).sum()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
fn load_from_disk(path: &Path) -> Option<HashMap<HostPort, Vec<TicketEntry>>> {
let bytes = std::fs::read(path).ok()?;
let payload: OnDisk = serde_json::from_slice(&bytes).ok()?;
if payload.magic != FORMAT_MAGIC {
return None;
}
let now = unix_now();
let map: HashMap<HostPort, Vec<TicketEntry>> = payload
.entries
.into_iter()
.filter_map(|(k, v)| {
let fresh: Vec<TicketEntry> = v
.into_iter()
.filter(|t| now.saturating_sub(t.stored_at_unix_secs) < MAX_AGE.as_secs())
.collect();
if fresh.is_empty() {
None
} else {
Some((k, fresh))
}
})
.collect();
Some(map)
}
fn unix_now() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn entry(label: u8) -> TicketEntry {
TicketEntry {
ticket: vec![label, label + 1, label + 2],
spki_fp: [label; 32],
stored_at_unix_secs: unix_now(),
}
}
#[test]
fn roundtrip_persists_across_open() {
let dir = tempdir().unwrap();
let path = dir.path().join("tickets.json");
{
let cache = TicketCache::open(&path);
cache.put("registry.npmjs.org", 443, entry(1));
cache.save().unwrap();
}
let reopened = TicketCache::open(&path);
let tickets = reopened.get("registry.npmjs.org", 443);
assert_eq!(tickets.len(), 1);
assert_eq!(tickets[0].ticket, vec![1, 2, 3]);
}
#[test]
fn host_port_lowercases() {
let a = HostPort::new("Registry.NPMJS.ORG", 443);
let b = HostPort::new("registry.npmjs.org", 443);
assert_eq!(a, b);
}
#[test]
fn invalidate_removes_all_for_host() {
let dir = tempdir().unwrap();
let cache = TicketCache::open(dir.path().join("tickets.json"));
cache.put("a.example", 443, entry(1));
cache.put("a.example", 443, entry(2));
assert_eq!(cache.len(), 2);
cache.invalidate("a.example", 443);
assert!(cache.is_empty());
}
#[test]
fn max_per_host_evicts_oldest() {
let dir = tempdir().unwrap();
let cache = TicketCache::open(dir.path().join("tickets.json"));
for i in 0..6u8 {
cache.put("a.example", 443, entry(i));
}
let kept = cache.get("a.example", 443);
assert_eq!(kept.len(), 4, "MAX_PER_HOST = 4");
assert!(kept.iter().all(|t| t.ticket[0] >= 2));
}
#[test]
fn stale_entries_filtered_at_load() {
let dir = tempdir().unwrap();
let path = dir.path().join("tickets.json");
{
let cache = TicketCache::open(&path);
let mut stale = entry(9);
stale.stored_at_unix_secs = 0;
cache.put("a.example", 443, stale);
cache.save().unwrap();
}
let reopened = TicketCache::open(&path);
assert!(reopened.get("a.example", 443).is_empty());
}
struct EnvVarGuard {
key: &'static str,
}
impl Drop for EnvVarGuard {
fn drop(&mut self) {
unsafe { std::env::remove_var(self.key) };
}
}
#[test]
fn killswitch_short_circuits() {
unsafe { std::env::set_var("AUBE_DISABLE_TLS_TICKET_CACHE", "1") };
let _cleanup = EnvVarGuard {
key: "AUBE_DISABLE_TLS_TICKET_CACHE",
};
let dir = tempdir().unwrap();
let cache = TicketCache::open(dir.path().join("tickets.json"));
cache.put("a.example", 443, entry(1));
assert!(cache.get("a.example", 443).is_empty());
}
#[test]
fn missing_file_loads_empty() {
let dir = tempdir().unwrap();
let cache = TicketCache::open(dir.path().join("nonexistent.json"));
assert!(cache.is_empty());
}
#[test]
fn corrupt_magic_loads_empty() {
let dir = tempdir().unwrap();
let path = dir.path().join("tickets.json");
std::fs::write(&path, br#"{"magic":"wrong","entries":[]}"#).unwrap();
let cache = TicketCache::open(&path);
assert!(cache.is_empty());
}
}