use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AltSvc {
pub protocol_id: String,
pub host: String,
pub port: u16,
pub max_age: Duration,
}
const DEFAULT_MAX_AGE_SECS: u64 = 86400;
#[must_use]
pub fn parse_alt_svc(value: &str) -> Vec<AltSvc> {
let trimmed = value.trim();
if trimmed == "clear" {
return Vec::new();
}
let mut results = Vec::new();
for entry in split_entries(trimmed) {
if let Some(alt) = parse_single_entry(entry.trim()) {
results.push(alt);
}
}
results
}
fn split_entries(value: &str) -> Vec<&str> {
let mut entries = Vec::new();
let mut start = 0;
let mut in_quotes = false;
for (i, ch) in value.char_indices() {
match ch {
'"' => in_quotes = !in_quotes,
',' if !in_quotes => {
entries.push(&value[start..i]);
start = i + 1;
}
_ => {}
}
}
if start < value.len() {
entries.push(&value[start..]);
}
entries
}
fn parse_single_entry(entry: &str) -> Option<AltSvc> {
let (proto_id, rest) = entry.split_once('=')?;
let protocol_id = proto_id.trim().to_string();
let mut parts = rest.splitn(2, ';');
let authority_str = parts.next()?.trim().trim_matches('"');
let params_str = parts.next().unwrap_or("");
let (host, port) = parse_authority(authority_str)?;
let max_age = parse_max_age(params_str);
Some(AltSvc { protocol_id, host, port, max_age })
}
fn parse_authority(authority: &str) -> Option<(String, u16)> {
if let Some(colon_pos) = authority.rfind(':') {
let host = authority[..colon_pos].to_string();
let port: u16 = authority[colon_pos + 1..].parse().ok()?;
Some((host, port))
} else {
None
}
}
fn parse_max_age(params: &str) -> Duration {
for param in params.split(';') {
let param = param.trim();
if let Some(val) = param.strip_prefix("ma=") {
if let Ok(secs) = val.trim().parse::<u64>() {
return Duration::from_secs(secs);
}
}
}
Duration::from_secs(DEFAULT_MAX_AGE_SECS)
}
#[must_use]
pub fn parse_retry_after(value: &str) -> Option<Duration> {
if let Ok(secs) = value.trim().parse::<u64>() {
return Some(Duration::from_secs(secs));
}
None
}
#[derive(Debug, Clone)]
struct AltSvcEntry {
alt_svc: AltSvc,
expires_at: Instant,
}
#[derive(Debug)]
pub struct AltSvcCache {
entries: HashMap<String, Vec<AltSvcEntry>>,
}
impl AltSvcCache {
#[must_use]
pub fn new() -> Self {
Self { entries: HashMap::new() }
}
pub fn store(&mut self, origin: &str, services: &[AltSvc]) {
let now = Instant::now();
let entries: Vec<AltSvcEntry> = services
.iter()
.map(|svc| AltSvcEntry { alt_svc: svc.clone(), expires_at: now + svc.max_age })
.collect();
let _ = self.entries.insert(origin.to_string(), entries);
}
pub fn clear_origin(&mut self, origin: &str) {
let _ = self.entries.remove(origin);
}
#[must_use]
pub fn get(&self, origin: &str) -> Vec<&AltSvc> {
let now = Instant::now();
self.entries
.get(origin)
.map(|entries| {
entries.iter().filter(|e| now < e.expires_at).map(|e| &e.alt_svc).collect()
})
.unwrap_or_default()
}
#[must_use]
pub fn get_protocol(&self, origin: &str, protocol_id: &str) -> Option<&AltSvc> {
let now = Instant::now();
self.entries.get(origin).and_then(|entries| {
entries
.iter()
.find(|e| now < e.expires_at && e.alt_svc.protocol_id == protocol_id)
.map(|e| &e.alt_svc)
})
}
pub fn purge_expired(&mut self) {
let now = Instant::now();
self.entries.retain(|_, entries| {
entries.retain(|e| now < e.expires_at);
!entries.is_empty()
});
}
pub fn clear(&mut self) {
self.entries.clear();
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
impl Default for AltSvcCache {
fn default() -> Self {
Self::new()
}
}
impl Clone for AltSvcCache {
fn clone(&self) -> Self {
Self { entries: self.entries.clone() }
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn parse_clear() {
let result = parse_alt_svc("clear");
assert!(result.is_empty());
}
#[test]
fn parse_single_h3() {
let result = parse_alt_svc(r#"h3=":443""#);
assert_eq!(result.len(), 1);
assert_eq!(result[0].protocol_id, "h3");
assert_eq!(result[0].host, "");
assert_eq!(result[0].port, 443);
assert_eq!(result[0].max_age, Duration::from_secs(DEFAULT_MAX_AGE_SECS));
}
#[test]
fn parse_with_max_age() {
let result = parse_alt_svc(r#"h3=":443"; ma=2592000"#);
assert_eq!(result.len(), 1);
assert_eq!(result[0].protocol_id, "h3");
assert_eq!(result[0].port, 443);
assert_eq!(result[0].max_age, Duration::from_secs(2_592_000));
}
#[test]
fn parse_multiple_entries() {
let result = parse_alt_svc(r#"h3=":443"; ma=2592000, h2=":443""#);
assert_eq!(result.len(), 2);
assert_eq!(result[0].protocol_id, "h3");
assert_eq!(result[1].protocol_id, "h2");
}
#[test]
fn parse_with_host() {
let result = parse_alt_svc(r#"h3="alt.example.com:8443""#);
assert_eq!(result.len(), 1);
assert_eq!(result[0].host, "alt.example.com");
assert_eq!(result[0].port, 8443);
}
#[test]
fn parse_versioned_protocol() {
let result = parse_alt_svc(r#"h3-29=":443""#);
assert_eq!(result.len(), 1);
assert_eq!(result[0].protocol_id, "h3-29");
}
#[test]
fn parse_empty_string() {
let result = parse_alt_svc("");
assert!(result.is_empty());
}
#[test]
fn parse_whitespace() {
let result = parse_alt_svc(" ");
assert!(result.is_empty());
}
#[test]
fn parse_invalid_port() {
let result = parse_alt_svc(r#"h3=":notaport""#);
assert!(result.is_empty());
}
#[test]
fn parse_no_port() {
let result = parse_alt_svc(r#"h3="noport""#);
assert!(result.is_empty());
}
#[test]
fn retry_after_seconds() {
assert_eq!(parse_retry_after("120"), Some(Duration::from_secs(120)));
}
#[test]
fn retry_after_zero() {
assert_eq!(parse_retry_after("0"), Some(Duration::from_secs(0)));
}
#[test]
fn retry_after_with_whitespace() {
assert_eq!(parse_retry_after(" 60 "), Some(Duration::from_secs(60)));
}
#[test]
fn retry_after_invalid() {
assert_eq!(parse_retry_after("not a number"), None);
}
#[test]
fn retry_after_http_date_not_supported() {
assert_eq!(parse_retry_after("Fri, 31 Dec 1999 23:59:59 GMT"), None);
}
#[test]
fn split_entries_basic() {
let entries = split_entries(r#"h3=":443", h2=":443""#);
assert_eq!(entries.len(), 2);
}
#[test]
fn split_entries_with_quoted_comma() {
let entries = split_entries(r#"h3="host,name:443""#);
assert_eq!(entries.len(), 1);
}
#[test]
fn parse_authority_basic() {
let (host, port) = parse_authority(":443").unwrap();
assert_eq!(host, "");
assert_eq!(port, 443);
}
#[test]
fn parse_authority_with_host() {
let (host, port) = parse_authority("example.com:8080").unwrap();
assert_eq!(host, "example.com");
assert_eq!(port, 8080);
}
#[test]
fn cache_new_is_empty() {
let cache = AltSvcCache::new();
assert!(cache.is_empty());
assert_eq!(cache.len(), 0);
}
#[test]
fn cache_store_and_get() {
let mut cache = AltSvcCache::new();
let services = vec![AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::from_secs(3600),
}];
cache.store("https://example.com:443", &services);
let result = cache.get("https://example.com:443");
assert_eq!(result.len(), 1);
assert_eq!(result[0].protocol_id, "h3");
assert_eq!(result[0].port, 443);
}
#[test]
fn cache_get_missing_returns_empty() {
let cache = AltSvcCache::new();
assert!(cache.get("https://example.com:443").is_empty());
}
#[test]
fn cache_get_protocol() {
let mut cache = AltSvcCache::new();
let services = vec![
AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::from_secs(3600),
},
AltSvc {
protocol_id: "h2".to_string(),
host: String::new(),
port: 443,
max_age: Duration::from_secs(3600),
},
];
cache.store("https://example.com:443", &services);
assert!(cache.get_protocol("https://example.com:443", "h3").is_some());
assert!(cache.get_protocol("https://example.com:443", "h2").is_some());
assert!(cache.get_protocol("https://example.com:443", "h1").is_none());
}
#[test]
fn cache_expired_entries_not_returned() {
let mut cache = AltSvcCache::new();
let services = vec![AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::ZERO, }];
cache.store("https://example.com:443", &services);
assert!(cache.get("https://example.com:443").is_empty());
assert!(cache.get_protocol("https://example.com:443", "h3").is_none());
}
#[test]
fn cache_clear_origin() {
let mut cache = AltSvcCache::new();
let services = vec![AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::from_secs(3600),
}];
cache.store("https://example.com:443", &services);
assert_eq!(cache.len(), 1);
cache.clear_origin("https://example.com:443");
assert!(cache.is_empty());
}
#[test]
fn cache_clear_all() {
let mut cache = AltSvcCache::new();
cache.store(
"https://a.com:443",
&[AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::from_secs(3600),
}],
);
cache.store(
"https://b.com:443",
&[AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::from_secs(3600),
}],
);
assert_eq!(cache.len(), 2);
cache.clear();
assert!(cache.is_empty());
}
#[test]
fn cache_purge_expired() {
let mut cache = AltSvcCache::new();
cache.store(
"https://expired.com:443",
&[AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::ZERO,
}],
);
cache.store(
"https://valid.com:443",
&[AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::from_secs(3600),
}],
);
assert_eq!(cache.len(), 2);
cache.purge_expired();
assert_eq!(cache.len(), 1);
assert!(!cache.get("https://valid.com:443").is_empty());
}
#[test]
fn cache_default_is_empty() {
let cache = AltSvcCache::default();
assert!(cache.is_empty());
}
#[test]
fn cache_clone() {
let mut cache = AltSvcCache::new();
cache.store(
"https://example.com:443",
&[AltSvc {
protocol_id: "h3".to_string(),
host: String::new(),
port: 443,
max_age: Duration::from_secs(3600),
}],
);
let cloned = cache.clone();
assert_eq!(cloned.len(), 1);
}
#[test]
fn cache_h3_upgrade_from_header() {
let entries = parse_alt_svc(r#"h3=":443"; ma=86400"#);
assert_eq!(entries.len(), 1);
let mut cache = AltSvcCache::new();
cache.store("https://example.com:443", &entries);
let h3 = cache.get_protocol("https://example.com:443", "h3");
assert!(h3.is_some());
assert_eq!(h3.unwrap().port, 443);
}
#[test]
fn cache_h3_upgrade_different_port() {
let entries = parse_alt_svc(r#"h3=":8443"; ma=86400"#);
let mut cache = AltSvcCache::new();
cache.store("https://example.com:443", &entries);
let h3 = cache.get_protocol("https://example.com:443", "h3");
assert!(h3.is_some());
assert_eq!(h3.unwrap().port, 8443);
}
#[test]
fn cache_h3_upgrade_not_available_for_http() {
let entries = parse_alt_svc(r#"h3=":443"; ma=86400"#);
let mut cache = AltSvcCache::new();
cache.store("http://example.com:80", &entries);
assert!(cache.get_protocol("https://example.com:443", "h3").is_none());
}
#[test]
fn cache_clear_disables_h3_upgrade() {
let entries = parse_alt_svc(r#"h3=":443"; ma=86400"#);
let mut cache = AltSvcCache::new();
cache.store("https://example.com:443", &entries);
assert!(cache.get_protocol("https://example.com:443", "h3").is_some());
let clear_entries = parse_alt_svc("clear");
assert!(clear_entries.is_empty());
cache.clear_origin("https://example.com:443");
assert!(cache.get_protocol("https://example.com:443", "h3").is_none());
}
#[test]
fn cache_h3_upgrade_with_h2_fallback() {
let entries = parse_alt_svc(r#"h3=":443"; ma=86400, h2=":443"; ma=86400"#);
let mut cache = AltSvcCache::new();
cache.store("https://example.com:443", &entries);
assert!(cache.get_protocol("https://example.com:443", "h3").is_some());
assert!(cache.get_protocol("https://example.com:443", "h2").is_some());
}
}