use crate::error::AcdpError;
#[cfg(feature = "client")]
use {
super::document::DidDocument,
crate::limits::{CONNECT_TIMEOUT, MAX_METADATA_BYTES, MAX_REDIRECTS, REQUEST_TIMEOUT},
crate::safe_http::SsrfPolicy,
lru::LruCache,
reqwest::redirect,
std::num::NonZeroUsize,
std::sync::{Arc, Mutex},
std::time::{Duration, Instant},
};
#[cfg(feature = "client")]
const CACHE_MAX: Duration = Duration::from_secs(24 * 3600); #[cfg(feature = "client")]
const DEFAULT_CACHE_CAPACITY: usize = 1000;
#[cfg(feature = "client")]
struct CacheEntry {
doc: DidDocument,
cached_at: Instant,
}
#[cfg(feature = "client")]
pub struct WebResolver {
http: reqwest::Client,
cache: Arc<Mutex<LruCache<String, CacheEntry>>>,
ssrf_policy: SsrfPolicy,
root_cert_pem: Option<Vec<u8>>,
}
#[cfg(feature = "client")]
impl WebResolver {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_CACHE_CAPACITY)
}
pub fn with_capacity(capacity: usize) -> Self {
Self::from_parts(capacity, SsrfPolicy::default(), None)
.expect("failed to build HTTP client for DID resolver")
}
pub fn with_root_cert_pem(pem: &[u8]) -> Result<Self, AcdpError> {
Self::from_parts(
DEFAULT_CACHE_CAPACITY,
SsrfPolicy::default(),
Some(pem.to_vec()),
)
}
pub fn with_capacity_and_root_cert_pem(capacity: usize, pem: &[u8]) -> Result<Self, AcdpError> {
Self::from_parts(capacity, SsrfPolicy::default(), Some(pem.to_vec()))
}
#[doc(hidden)]
pub fn with_test_endpoint(
pem: &[u8],
host: &str,
target: std::net::SocketAddr,
) -> Result<Self, AcdpError> {
let cap = NonZeroUsize::new(DEFAULT_CACHE_CAPACITY).expect("capacity > 0");
let policy = SsrfPolicy::allow_test_loopback();
let http = build_http_client_pinned(Some(pem), &policy, Some((host, target)))?;
Ok(Self {
http,
cache: Arc::new(Mutex::new(LruCache::new(cap))),
ssrf_policy: policy,
root_cert_pem: Some(pem.to_vec()),
})
}
fn from_parts(
capacity: usize,
ssrf_policy: SsrfPolicy,
root_cert_pem: Option<Vec<u8>>,
) -> Result<Self, AcdpError> {
let cap = NonZeroUsize::new(capacity).expect("WebResolver capacity must be > 0");
let http = build_http_client(root_cert_pem.as_deref(), &ssrf_policy)?;
Ok(Self {
http,
cache: Arc::new(Mutex::new(LruCache::new(cap))),
ssrf_policy,
root_cert_pem,
})
}
pub fn with_ssrf_policy(mut self, policy: SsrfPolicy) -> Self {
let http = build_http_client(self.root_cert_pem.as_deref(), &policy)
.expect("rebuild HTTP client for DID resolver");
self.http = http;
self.ssrf_policy = policy;
self
}
#[cfg_attr(feature = "tracing", tracing::instrument(skip(self), fields(did = did)))]
pub async fn resolve(&self, did: &str) -> Result<DidDocument, AcdpError> {
{
let mut cache = self.cache.lock().unwrap();
if let Some(entry) = cache.get(did) {
if entry.cached_at.elapsed() < CACHE_MAX {
return Ok(entry.doc.clone());
}
}
}
let url = did_web_to_url(did)?;
self.ssrf_policy.check_url(&url).map_err(|e| {
AcdpError::KeyResolution(format!("SSRF policy blocked did:web resolution: {e}"))
})?;
let mut resp = self
.http
.get(&url)
.header("Accept", "application/did+json, application/json")
.send()
.await
.map_err(|e| classify_reqwest_error(&e))?;
if !resp.status().is_success() {
return Err(AcdpError::KeyResolution(format!(
"DID document fetch returned HTTP {}",
resp.status()
)));
}
if let Some(len) = resp.content_length() {
if len as usize > MAX_METADATA_BYTES {
return Err(AcdpError::KeyResolution(format!(
"DID document Content-Length {len} exceeds {MAX_METADATA_BYTES}-byte cap"
)));
}
}
let mut buf: Vec<u8> = Vec::with_capacity(8 * 1024);
while let Some(chunk) = resp
.chunk()
.await
.map_err(|e| AcdpError::KeyResolutionUnreachable(e.to_string()))?
{
if buf.len() + chunk.len() > MAX_METADATA_BYTES {
return Err(AcdpError::KeyResolution(format!(
"DID document body exceeded {MAX_METADATA_BYTES}-byte cap"
)));
}
buf.extend_from_slice(&chunk);
}
let doc: DidDocument = serde_json::from_slice(&buf)
.map_err(|e| AcdpError::KeyResolution(format!("DID document parse: {e}")))?;
{
let mut cache = self.cache.lock().unwrap();
cache.put(
did.to_string(),
CacheEntry {
doc: doc.clone(),
cached_at: Instant::now(),
},
);
}
Ok(doc)
}
pub fn invalidate(&self, did: &str) {
self.cache.lock().unwrap().pop(did);
}
}
#[cfg(feature = "client")]
impl Default for WebResolver {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "client")]
fn build_http_client(
extra_root_pem: Option<&[u8]>,
ssrf_policy: &SsrfPolicy,
) -> Result<reqwest::Client, AcdpError> {
build_http_client_pinned(extra_root_pem, ssrf_policy, None)
}
#[cfg(feature = "client")]
fn build_http_client_pinned(
extra_root_pem: Option<&[u8]>,
ssrf_policy: &SsrfPolicy,
pin: Option<(&str, std::net::SocketAddr)>,
) -> Result<reqwest::Client, AcdpError> {
let policy = redirect::Policy::custom(|attempt| {
if attempt.previous().len() >= MAX_REDIRECTS {
return attempt.error(format!("DID resolver: exceeded {MAX_REDIRECTS} redirects"));
}
let cross = attempt
.previous()
.first()
.filter(|orig| !crate::safe_http::same_fetch_authority(orig, attempt.url()))
.map(|orig| (orig.to_string(), attempt.url().to_string()));
if let Some((from, to)) = cross {
return attempt.error(format!(
"DID resolver: cross-authority redirect rejected ({from} -> {to})"
));
}
attempt.follow()
});
let mut builder = reqwest::Client::builder()
.use_rustls_tls()
.connect_timeout(CONNECT_TIMEOUT)
.timeout(REQUEST_TIMEOUT)
.redirect(policy)
.dns_resolver(crate::safe_http::SafeDnsResolver::arc(ssrf_policy.clone()));
if let Some(pem) = extra_root_pem {
let cert = reqwest::Certificate::from_pem(pem)
.map_err(|e| AcdpError::Http(format!("invalid root cert PEM: {e}")))?;
builder = builder.add_root_certificate(cert);
}
if let Some((host, target)) = pin {
builder = builder.resolve(host, target);
}
builder
.build()
.map_err(|e| AcdpError::Http(format!("DID resolver client build: {e}")))
}
#[cfg(feature = "client")]
fn classify_reqwest_error(e: &reqwest::Error) -> AcdpError {
let mut chain = e.to_string();
let mut src: Option<&dyn std::error::Error> = std::error::Error::source(e);
while let Some(s) = src {
chain = format!("{chain}: {s}");
src = s.source();
}
if chain.contains("SSRF policy") {
return AcdpError::KeyResolution(chain);
}
if e.is_timeout() || e.is_connect() {
AcdpError::KeyResolutionUnreachable(chain)
} else {
AcdpError::KeyResolution(chain)
}
}
pub fn did_web_to_url(did: &str) -> Result<String, AcdpError> {
let rest = did
.strip_prefix("did:web:")
.ok_or_else(|| AcdpError::KeyResolution(format!("not a did:web DID: {did}")))?;
let parts: Vec<&str> = rest.split(':').collect();
let authority = urlencoding::decode(parts[0])
.map_err(|e| AcdpError::KeyResolution(format!("authority decode: {e}")))?;
if parts.len() == 1 {
Ok(format!("https://{}/.well-known/did.json", authority))
} else {
let path = parts[1..].join("/");
Ok(format!("https://{}/{}/did.json", authority, path))
}
}
pub fn authority_to_did_web(authority: &str) -> String {
let encoded = authority.replace(':', "%3A");
format!("did:web:{encoded}")
}
pub fn did_web_to_authority(did: &str) -> Option<String> {
let rest = did.strip_prefix("did:web:")?;
let mut parts = rest.splitn(2, ':');
let authority = parts.next()?;
Some(authority.replace("%3A", ":"))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bare_authority() {
let url = did_web_to_url("did:web:example.com").unwrap();
assert_eq!(url, "https://example.com/.well-known/did.json");
}
#[test]
fn path_authority() {
let url = did_web_to_url("did:web:example.com:users:alice").unwrap();
assert_eq!(url, "https://example.com/users/alice/did.json");
}
#[test]
fn authority_to_did_web_bare_hostname() {
assert_eq!(
authority_to_did_web("registry.example.com"),
"did:web:registry.example.com"
);
}
#[test]
fn authority_to_did_web_with_port() {
assert_eq!(
authority_to_did_web("localhost:8443"),
"did:web:localhost%3A8443"
);
}
#[test]
fn did_web_to_authority_round_trips() {
for authority in ["registry.example.com", "localhost:8443", "127.0.0.1:9000"] {
let did = authority_to_did_web(authority);
let back = did_web_to_authority(&did)
.unwrap_or_else(|| panic!("did_web_to_authority returned None for '{did}'"));
assert_eq!(back, authority, "round-trip for '{authority}' failed");
}
}
#[test]
fn authority_to_did_web_then_to_url_keeps_port() {
let did = authority_to_did_web("localhost:8443");
let url = did_web_to_url(&did).unwrap();
assert_eq!(url, "https://localhost:8443/.well-known/did.json");
}
#[cfg(feature = "client")]
#[tokio::test]
async fn did_resolver_rejects_loopback_did() {
let resolver = WebResolver::new();
let err = resolver.resolve("did:web:127.0.0.1").await.unwrap_err();
assert!(
matches!(err, AcdpError::KeyResolution(_)),
"did-ssrf-001: loopback did:web MUST be blocked by SSRF policy, got {err:?}"
);
}
#[cfg(feature = "client")]
#[tokio::test]
async fn did_resolver_rejects_imds_did() {
let resolver = WebResolver::new();
let err = resolver
.resolve("did:web:169.254.169.254")
.await
.unwrap_err();
assert!(
matches!(err, AcdpError::KeyResolution(_)),
"did-ssrf-002: IMDS did:web MUST be blocked by SSRF policy, got {err:?}"
);
}
#[cfg(feature = "client")]
#[tokio::test]
async fn did_resolver_rejects_private_range_did() {
let resolver = WebResolver::new();
for did in [
"did:web:192.168.1.1",
"did:web:10.0.0.1",
"did:web:172.16.0.1",
] {
let err = resolver.resolve(did).await.unwrap_err();
assert!(
matches!(err, AcdpError::KeyResolution(_)),
"did-ssrf-003: private-range did:web '{did}' MUST be blocked, got {err:?}"
);
}
}
#[cfg(feature = "client")]
#[tokio::test]
async fn did_resolver_rejects_hostname_resolving_to_loopback() {
let resolver = WebResolver::new();
let err = resolver
.resolve("did:web:localhost%3A12345")
.await
.expect_err("DNS-rebinding protection MUST refuse localhost under default policy");
let msg = format!("{err}");
assert!(
matches!(err, AcdpError::KeyResolution(_)),
"DNS-rebinding refusal MUST be permanent KeyResolution, got {err:?}"
);
assert!(
msg.contains("SSRF policy"),
"DNS-rebinding refusal MUST identify the SSRF policy in its message; got: {msg}"
);
assert!(!err.is_transient(), "SSRF refusal MUST NOT be transient");
}
}