use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
#[derive(Clone, Debug)]
struct Entry {
#[allow(dead_code)]
alt_authority: String,
expires: Instant,
}
#[derive(Default, Clone, Debug)]
pub(crate) struct AltSvcCache {
inner: Arc<Mutex<HashMap<(String, u16), Entry>>>,
}
impl AltSvcCache {
pub(crate) fn record(&self, host: &str, port: u16, header_value: &str) {
let trimmed = header_value.trim();
if trimmed.eq_ignore_ascii_case("clear") {
self.inner
.lock()
.unwrap_or_else(|err| err.into_inner())
.remove(&(host.to_owned(), port));
return;
}
let mut best: Option<Entry> = None;
for token in trimmed.split(',') {
let token = token.trim();
if token.is_empty() {
continue;
}
let mut parts = token.splitn(2, ';');
let proto_authority = parts.next().unwrap_or("").trim();
let params = parts.next().unwrap_or("").trim();
let (proto_id, alt_authority) = match proto_authority.split_once('=') {
Some((p, a)) => (p.trim(), unquote(a.trim())),
None => continue,
};
if !proto_id.eq_ignore_ascii_case("h3")
&& !proto_id.to_ascii_lowercase().starts_with("h3-")
{
continue;
}
let ma_secs = parse_max_age(params).unwrap_or(86_400);
let expires = Instant::now() + Duration::from_secs(ma_secs);
let candidate = Entry {
alt_authority: alt_authority.to_owned(),
expires,
};
if best
.as_ref()
.map(|b| candidate.expires > b.expires)
.unwrap_or(true)
{
best = Some(candidate);
}
}
if let Some(entry) = best {
self.inner
.lock()
.unwrap_or_else(|err| err.into_inner())
.insert((host.to_owned(), port), entry);
}
}
pub(crate) fn has_h3(&self, host: &str, port: u16) -> bool {
let mut map = self.inner.lock().unwrap_or_else(|err| err.into_inner());
let key = (host.to_owned(), port);
if let Some(entry) = map.get(&key) {
if Instant::now() < entry.expires {
return true;
}
map.remove(&key);
}
false
}
pub(crate) fn clear(&self) {
self.inner
.lock()
.unwrap_or_else(|err| err.into_inner())
.clear();
}
}
fn unquote(s: &str) -> &str {
if s.starts_with('"') && s.ends_with('"') && s.len() >= 2 {
&s[1..s.len() - 1]
} else {
s
}
}
fn parse_max_age(params: &str) -> Option<u64> {
for param in params.split(';') {
let param = param.trim();
if let Some(rest) = param.strip_prefix("ma=") {
return rest.trim().parse().ok();
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn records_h3_entry_and_detects_it() {
let cache = AltSvcCache::default();
cache.record("example.com", 443, r#"h3=":443"; ma=3600"#);
assert!(cache.has_h3("example.com", 443));
}
#[test]
fn ignores_non_h3_alternatives() {
let cache = AltSvcCache::default();
cache.record("example.com", 443, r#"h2=":443"; ma=3600"#);
assert!(!cache.has_h3("example.com", 443));
}
#[test]
fn clear_directive_removes_entry() {
let cache = AltSvcCache::default();
cache.record("example.com", 443, r#"h3=":443"; ma=3600"#);
cache.record("example.com", 443, "clear");
assert!(!cache.has_h3("example.com", 443));
}
#[test]
fn expired_entry_is_not_returned() {
let cache = AltSvcCache::default();
cache.record("example.com", 443, r#"h3=":443"; ma=0"#);
std::thread::sleep(std::time::Duration::from_millis(10));
assert!(!cache.has_h3("example.com", 443));
}
#[test]
fn draft_h3_version_is_accepted() {
let cache = AltSvcCache::default();
cache.record("example.com", 443, r#"h3-29=":443"; ma=86400"#);
assert!(cache.has_h3("example.com", 443));
}
#[test]
fn multiple_alternatives_picks_longest_lived() {
let cache = AltSvcCache::default();
cache.record(
"example.com",
443,
r#"h3-29=":443"; ma=60, h3=":443"; ma=86400"#,
);
assert!(cache.has_h3("example.com", 443));
}
#[test]
fn different_ports_are_independent() {
let cache = AltSvcCache::default();
cache.record("example.com", 443, r#"h3=":443"; ma=3600"#);
assert!(cache.has_h3("example.com", 443));
assert!(!cache.has_h3("example.com", 8443));
}
#[test]
fn unquote_strips_double_quotes() {
assert_eq!(unquote(r#"":443""#), ":443");
assert_eq!(unquote(":443"), ":443");
}
#[test]
fn parse_max_age_extracts_value() {
assert_eq!(parse_max_age("ma=3600"), Some(3600));
assert_eq!(parse_max_age("ma=0"), Some(0));
assert_eq!(parse_max_age(""), None);
}
}