use std::collections::HashSet;
use std::sync::Arc;
use std::time::{Duration, Instant};
use anyhow::{Context, Result};
use tokio::sync::RwLock;
use tracing::{debug, info};
use crate::http_client::AcceleratedClient;
pub struct PrefetchManager {
warmed: Arc<RwLock<HashSet<String>>>,
client: AcceleratedClient,
}
impl PrefetchManager {
pub fn new() -> Result<Self> {
Ok(Self {
warmed: Arc::new(RwLock::new(HashSet::new())),
client: AcceleratedClient::new()?,
})
}
pub async fn preconnect(&self, host: &str) -> Result<Duration> {
let start = Instant::now();
{
let warmed = self.warmed.read().await;
if warmed.contains(host) {
debug!("Host already warmed: {}", host);
return Ok(Duration::ZERO);
}
}
info!("Preconnecting to {}", host);
let url = if host.starts_with("http") {
host.to_string()
} else {
format!("https://{host}")
};
let response = self
.client
.inner()
.head(&url)
.timeout(Duration::from_secs(5))
.send()
.await
.with_context(|| format!("Preconnect HEAD request failed for {host}"))?;
let elapsed = start.elapsed();
{
let mut warmed = self.warmed.write().await;
warmed.insert(host.to_string());
}
info!(
"Preconnected to {} in {:?} (status: {})",
host,
elapsed,
response.status()
);
Ok(elapsed)
}
pub async fn preconnect_many(&self, hosts: &[&str]) -> Vec<(String, Result<Duration>)> {
let futures: Vec<_> = hosts
.iter()
.map(|host| {
let host = (*host).to_string();
let manager = self;
async move { (host.clone(), manager.preconnect(&host).await) }
})
.collect();
futures::future::join_all(futures).await
}
pub async fn is_warmed(&self, host: &str) -> bool {
self.warmed.read().await.contains(host)
}
pub async fn clear(&self) {
self.warmed.write().await.clear();
}
}
impl Default for PrefetchManager {
fn default() -> Self {
Self::new().expect("Failed to create prefetch manager")
}
}
#[derive(Debug, Clone)]
pub struct EarlyHints {
pub links: Vec<EarlyHintLink>,
}
#[derive(Debug, Clone)]
pub struct EarlyHintLink {
pub url: String,
pub rel: String,
pub as_type: Option<String>,
pub crossorigin: Option<String>,
}
impl EarlyHints {
#[must_use]
pub fn parse(link_headers: &[&str]) -> Self {
let mut links = Vec::new();
for header in link_headers {
if let Some(link) = Self::parse_link(header) {
links.push(link);
}
}
Self { links }
}
fn parse_link(header: &str) -> Option<EarlyHintLink> {
let parts: Vec<&str> = header.split(';').map(str::trim).collect();
if parts.is_empty() {
return None;
}
let url = parts[0].trim_start_matches('<').trim_end_matches('>');
if url.is_empty() {
return None;
}
let mut rel = String::new();
let mut as_type = None;
let mut crossorigin = None;
for part in parts.iter().skip(1) {
let kv: Vec<&str> = part.splitn(2, '=').collect();
if kv.is_empty() {
continue;
}
let key = kv[0].trim().to_lowercase();
let value = kv.get(1).map(|v| v.trim().trim_matches('"').to_string());
match key.as_str() {
"rel" => rel = value.unwrap_or_default(),
"as" => as_type = value,
"crossorigin" => crossorigin = value.or(Some("anonymous".to_string())),
_ => {}
}
}
if rel.is_empty() {
return None;
}
Some(EarlyHintLink {
url: url.to_string(),
rel,
as_type,
crossorigin,
})
}
#[must_use]
pub fn preloads(&self) -> Vec<&EarlyHintLink> {
self.links.iter().filter(|l| l.rel == "preload").collect()
}
#[must_use]
pub fn preconnects(&self) -> Vec<&EarlyHintLink> {
self.links
.iter()
.filter(|l| l.rel == "preconnect")
.collect()
}
#[must_use]
pub fn dns_prefetches(&self) -> Vec<&EarlyHintLink> {
self.links
.iter()
.filter(|l| l.rel == "dns-prefetch")
.collect()
}
}
#[must_use]
pub fn extract_link_hints(html: &str) -> Vec<EarlyHintLink> {
let mut links = Vec::new();
let html_lower = html.to_lowercase();
let mut pos = 0;
while let Some(start) = html_lower[pos..].find("<link") {
let abs_start = pos + start;
if let Some(end) = html_lower[abs_start..].find('>') {
let tag = &html[abs_start..=(abs_start + end)];
let href = extract_attr(tag, "href");
let rel = extract_attr(tag, "rel");
let as_type = extract_attr(tag, "as");
let crossorigin = extract_attr(tag, "crossorigin");
if let (Some(url), Some(rel)) = (href, rel)
&& (rel.contains("preconnect")
|| rel.contains("dns-prefetch")
|| rel.contains("preload")
|| rel.contains("prefetch"))
{
links.push(EarlyHintLink {
url,
rel,
as_type,
crossorigin,
});
}
pos = abs_start + end + 1;
} else {
break;
}
}
links
}
fn extract_attr(tag: &str, attr: &str) -> Option<String> {
let pattern = format!("{attr}=");
if let Some(start) = tag.to_lowercase().find(&pattern) {
let after_eq = &tag[start + pattern.len()..];
let quote = after_eq.chars().next()?;
if quote == '"' || quote == '\'' {
let content = &after_eq[1..];
if let Some(end) = content.find(quote) {
return Some(content[..end].to_string());
}
} else {
let end = after_eq
.find(|c: char| c.is_whitespace() || c == '>')
.unwrap_or(after_eq.len());
return Some(after_eq[..end].to_string());
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_link_header() {
let headers = vec![
"</style.css>; rel=preload; as=style",
"</script.js>; rel=preload; as=script; crossorigin",
"<https://cdn.example.com>; rel=preconnect",
];
let hints = EarlyHints::parse(&headers);
assert_eq!(hints.links.len(), 3);
assert_eq!(hints.preloads().len(), 2);
assert_eq!(hints.preconnects().len(), 1);
}
#[test]
fn test_parse_link_header_empty() {
let headers: Vec<&str> = vec![];
let hints = EarlyHints::parse(&headers);
assert!(hints.links.is_empty());
}
#[test]
fn test_parse_link_header_missing_rel() {
let headers = vec!["</style.css>; as=style"];
let hints = EarlyHints::parse(&headers);
assert!(hints.links.is_empty(), "link without rel should be skipped");
}
#[test]
fn test_parse_link_header_empty_url() {
let headers = vec!["<>; rel=preload"];
let hints = EarlyHints::parse(&headers);
assert!(hints.links.is_empty(), "empty URL should be skipped");
}
#[test]
fn test_parse_link_header_quoted_values() {
let headers =
vec!["</font.woff2>; rel=\"preload\"; as=\"font\"; crossorigin=\"anonymous\""];
let hints = EarlyHints::parse(&headers);
assert_eq!(hints.links.len(), 1);
assert_eq!(hints.links[0].rel, "preload");
assert_eq!(hints.links[0].as_type.as_deref(), Some("font"));
assert_eq!(hints.links[0].crossorigin.as_deref(), Some("anonymous"));
}
#[test]
fn test_parse_link_header_crossorigin_bare() {
let headers = vec!["</script.js>; rel=preload; crossorigin"];
let hints = EarlyHints::parse(&headers);
assert_eq!(hints.links.len(), 1);
assert_eq!(
hints.links[0].crossorigin.as_deref(),
Some("anonymous"),
"bare crossorigin should default to 'anonymous'"
);
}
#[test]
fn test_dns_prefetches() {
let headers = vec![
"<//cdn.example.com>; rel=dns-prefetch",
"<//img.example.com>; rel=dns-prefetch",
"<https://api.example.com>; rel=preconnect",
];
let hints = EarlyHints::parse(&headers);
assert_eq!(hints.dns_prefetches().len(), 2);
assert_eq!(hints.preconnects().len(), 1);
}
#[test]
fn test_extract_link_hints() {
let html = r#"
<head>
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="dns-prefetch" href="//cdn.example.com">
<link rel="preload" href="/main.js" as="script">
<link rel="stylesheet" href="/style.css">
</head>
"#;
let hints = extract_link_hints(html);
assert_eq!(hints.len(), 3); }
#[test]
fn test_extract_link_hints_empty_html() {
let hints = extract_link_hints("");
assert!(hints.is_empty());
}
#[test]
fn test_extract_link_hints_no_link_tags() {
let html = "<html><head><title>Test</title></head><body></body></html>";
let hints = extract_link_hints(html);
assert!(hints.is_empty());
}
#[test]
fn test_extract_link_hints_prefetch_rel() {
let html = r#"<link rel="prefetch" href="/next-page.js">"#;
let hints = extract_link_hints(html);
assert_eq!(hints.len(), 1);
assert_eq!(hints[0].rel, "prefetch");
}
#[test]
fn test_extract_link_hints_with_crossorigin() {
let html = r#"<link rel="preload" href="/font.woff2" as="font" crossorigin="anonymous">"#;
let hints = extract_link_hints(html);
assert_eq!(hints.len(), 1);
assert_eq!(hints[0].crossorigin.as_deref(), Some("anonymous"));
assert_eq!(hints[0].as_type.as_deref(), Some("font"));
}
#[test]
fn test_extract_attr_unquoted() {
let tag = r"<link rel=preconnect href=https://cdn.example.com>";
assert_eq!(extract_attr(tag, "rel"), Some("preconnect".to_string()));
assert_eq!(
extract_attr(tag, "href"),
Some("https://cdn.example.com".to_string())
);
}
#[test]
fn test_extract_attr_missing() {
let tag = r#"<link rel="preconnect" href="https://cdn.example.com">"#;
assert_eq!(extract_attr(tag, "as"), None);
}
#[tokio::test]
async fn test_is_warmed_unknown_host() {
let manager = PrefetchManager::new().unwrap();
assert!(!manager.is_warmed("never-connected.example.com").await);
}
#[tokio::test]
async fn test_clear_removes_warmed_state() {
let manager = PrefetchManager::new().unwrap();
{
let mut warmed = manager.warmed.write().await;
warmed.insert("test-host.example.com".to_string());
}
assert!(manager.is_warmed("test-host.example.com").await);
manager.clear().await;
assert!(
!manager.is_warmed("test-host.example.com").await,
"cleared host should no longer be warmed"
);
}
#[tokio::test]
async fn test_preconnect_already_warmed_returns_zero() {
let manager = PrefetchManager::new().unwrap();
{
let mut warmed = manager.warmed.write().await;
warmed.insert("pre-warmed.example.com".to_string());
}
let result = manager.preconnect("pre-warmed.example.com").await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), Duration::ZERO);
}
}