use dashmap::DashMap;
use std::{
ops::Deref,
sync::Arc,
time::{Duration, Instant},
};
use trillium_server_common::url::{Origin, Url};
pub const DEFAULT_BROKEN_DURATION: Duration = Duration::from_secs(300);
#[derive(Debug, Clone, Default)]
pub struct AltSvcCache(Arc<DashMap<Origin, AltSvcEntry>>);
impl Deref for AltSvcCache {
type Target = DashMap<Origin, AltSvcEntry>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl AltSvcCache {
pub(crate) fn update(&self, alt_svc: &str, url: &Url) {
let origin = url.origin();
let origin_host = url.host_str().unwrap_or("");
match parse_alt_svc_h3(alt_svc, origin_host) {
None => {
self.remove(&origin);
}
Some(mut entries) => {
if let Some((host, port, max_age)) = entries.next() {
self.insert(origin, AltSvcEntry::new(host, port, max_age));
}
}
}
}
}
#[derive(Debug, Clone)]
pub struct AltSvcEntry {
pub host: String,
pub port: u16,
expires: Instant,
broken_until: Option<Instant>,
}
impl AltSvcEntry {
pub fn new(host: String, port: u16, max_age: Duration) -> Self {
Self {
host,
port,
expires: Instant::now() + max_age,
broken_until: None,
}
}
pub fn is_usable(&self) -> bool {
let now = Instant::now();
now < self.expires && self.broken_until.is_none_or(|t| now >= t)
}
pub fn mark_broken(&mut self, duration: Duration) {
self.broken_until = Some(Instant::now() + duration);
}
}
pub fn parse_alt_svc_h3<'a>(
value: &'a str,
origin_host: &'a str,
) -> Option<impl Iterator<Item = (String, u16, Duration)> + 'a> {
if value.trim().eq_ignore_ascii_case("clear") {
return None;
}
Some(value.split(',').filter_map(move |entry| {
let entry = entry.trim();
let (alpn, rest) = entry.split_once('=')?;
if !alpn.trim().eq_ignore_ascii_case("h3") {
return None;
}
let rest = rest.trim().strip_prefix('"')?;
let (alt_authority, params) = rest.split_once('"')?;
let (host, port_str) = alt_authority.rsplit_once(':')?;
let port = port_str.parse::<u16>().ok()?;
let host = if host.is_empty() {
origin_host.to_string()
} else {
host.to_string()
};
let max_age = parse_max_age(params).unwrap_or(Duration::from_secs(86400));
Some((host, port, max_age))
}))
}
fn parse_max_age(params: &str) -> Option<Duration> {
for param in params.split(';') {
if let Some(val) = param.trim().strip_prefix("ma=")
&& let Ok(secs) = val.trim().parse::<u64>()
{
return Some(Duration::from_secs(secs));
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_same_host() {
let entries: Vec<_> = parse_alt_svc_h3(r#"h3=":443"; ma=86400"#, "example.com")
.unwrap()
.collect();
assert_eq!(
entries,
[("example.com".into(), 443, Duration::from_secs(86400))]
);
}
#[test]
fn parse_different_host() {
let entries: Vec<_> = parse_alt_svc_h3(r#"h3="alt.example.com:8443""#, "example.com")
.unwrap()
.collect();
assert_eq!(
entries,
[("alt.example.com".into(), 8443, Duration::from_secs(86400))]
);
}
#[test]
fn skip_other_alpns() {
let entries: Vec<_> = parse_alt_svc_h3(r#"h2=":443", h3=":443"; ma=3600"#, "example.com")
.unwrap()
.collect();
assert_eq!(
entries,
[("example.com".into(), 443, Duration::from_secs(3600))]
);
}
#[test]
fn clear_returns_none() {
assert!(parse_alt_svc_h3("clear", "example.com").is_none());
assert!(parse_alt_svc_h3(" CLEAR ", "example.com").is_none());
}
#[test]
fn multiple_h3_entries_yields_all() {
let entries: Vec<_> =
parse_alt_svc_h3(r#"h3=":443"; ma=3600, h3=":8443"; ma=600"#, "example.com")
.unwrap()
.collect();
assert_eq!(entries.len(), 2);
assert_eq!(entries[0].1, 443);
assert_eq!(entries[1].1, 8443);
}
}