use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use crate::client::{RegistryClient, VerifiedContext};
use crate::did::WebResolver;
use crate::error::AcdpError;
use crate::safe_http::SsrfPolicy;
use crate::types::body::Body;
use crate::types::primitives::CtxId;
use crate::types::CapabilitiesDocument;
#[derive(Debug, Clone)]
pub struct ResolverOptions {
pub max_depth: usize,
pub max_nodes: usize,
pub max_fanout: usize,
pub total_timeout: Duration,
pub capabilities_ttl: Duration,
}
impl Default for ResolverOptions {
fn default() -> Self {
Self {
max_depth: 10,
max_nodes: 100,
max_fanout: 32,
total_timeout: Duration::from_secs(30),
capabilities_ttl: Duration::from_secs(300),
}
}
}
pub struct CrossRegistryResolver {
did_resolver: WebResolver,
options: ResolverOptions,
allowlist: Option<HashSet<String>>,
ssrf_policy: SsrfPolicy,
client_cache: Mutex<HashMap<String, RegistryClient>>,
caps_cache: Mutex<HashMap<String, (CapabilitiesDocument, Instant, Duration)>>,
}
impl Default for CrossRegistryResolver {
fn default() -> Self {
Self::new()
}
}
impl CrossRegistryResolver {
pub fn new() -> Self {
Self {
did_resolver: WebResolver::new(),
options: ResolverOptions::default(),
allowlist: None,
ssrf_policy: SsrfPolicy::default(),
client_cache: Mutex::new(HashMap::new()),
caps_cache: Mutex::new(HashMap::new()),
}
}
pub fn with_ssrf_policy(mut self, policy: SsrfPolicy) -> Self {
self.ssrf_policy = policy;
self
}
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.options.max_depth = depth;
self
}
pub fn with_options(mut self, options: ResolverOptions) -> Self {
self.options = options;
self
}
pub fn options(&self) -> &ResolverOptions {
&self.options
}
pub fn with_did_resolver(mut self, resolver: WebResolver) -> Self {
self.did_resolver = resolver;
self
}
pub fn seed_client(&self, authority: impl Into<String>, client: RegistryClient) {
self.client_cache
.lock()
.unwrap()
.insert(authority.into(), client);
}
pub fn with_allowlist<I, S>(mut self, authorities: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowlist = Some(authorities.into_iter().map(Into::into).collect());
self
}
pub async fn resolve(&self, ctx_id: &CtxId) -> Result<VerifiedContext, AcdpError> {
let parsed = CtxId::parse(ctx_id.as_str())?;
let authority = parsed.authority().to_string();
self.check_allowlist(&authority)?;
let base = format!("https://{authority}");
self.ssrf_policy
.check_url(&base)
.map_err(|e| AcdpError::CrossRegistryResolutionFailed(format!("SSRF policy: {e}")))?;
let registry = self.client_for(&authority, &base).await?;
let caps = self.capabilities_for(&authority, ®istry).await?;
let expected_did = crate::did::authority_to_did_web(&authority);
if caps.registry_did != expected_did {
return Err(AcdpError::CrossRegistryResolutionFailed(format!(
"registry DID '{}' does not match expected '{expected_did}'",
caps.registry_did
)));
}
let registry_doc = self
.did_resolver
.resolve(&caps.registry_did)
.await
.map_err(|e| {
AcdpError::CrossRegistryResolutionFailed(format!(
"could not resolve registry DID document for '{}': {e}",
caps.registry_did
))
})?;
if registry_doc.id != caps.registry_did {
return Err(AcdpError::CrossRegistryResolutionFailed(format!(
"registry DID document `id` '{}' does not match capabilities.registry_did '{}'",
registry_doc.id, caps.registry_did
)));
}
VerifiedContext::fetch(®istry, &self.did_resolver, &parsed).await
}
pub async fn walk_derived_from(&self, body: &Body) -> Result<Vec<VerifiedContext>, AcdpError> {
let total_timeout = self.options.total_timeout;
let fut = self.walk_derived_from_inner(body);
match tokio::time::timeout(total_timeout, fut).await {
Ok(res) => res,
Err(_) => Err(AcdpError::CrossRegistryResolutionFailed(format!(
"derived_from walk exceeded total_timeout={:?}",
total_timeout
))),
}
}
async fn walk_derived_from_inner(
&self,
body: &Body,
) -> Result<Vec<VerifiedContext>, AcdpError> {
let mut seen: HashSet<String> = HashSet::new();
seen.insert(body.ctx_id.0.clone());
if body.derived_from.len() > self.options.max_fanout {
return Err(AcdpError::CrossRegistryResolutionFailed(format!(
"root context {} has derived_from fanout {} > max_fanout={}",
body.ctx_id.0,
body.derived_from.len(),
self.options.max_fanout
)));
}
let mut results: Vec<VerifiedContext> = Vec::new();
let mut frontier: VecDeque<(CtxId, usize)> = body
.derived_from
.iter()
.map(|c| (c.clone(), 1usize))
.collect();
while let Some((next, depth)) = frontier.pop_front() {
if !seen.insert(next.0.clone()) {
continue; }
if depth > self.options.max_depth {
return Err(AcdpError::CrossRegistryResolutionFailed(format!(
"derived_from walk exceeded max_depth={} at {}",
self.options.max_depth, next.0
)));
}
if results.len() >= self.options.max_nodes {
return Err(AcdpError::CrossRegistryResolutionFailed(format!(
"derived_from walk exceeded max_nodes={} (last attempted: {})",
self.options.max_nodes, next.0
)));
}
let verified = self.resolve(&next).await?;
let parents = &verified.body().derived_from;
if parents.len() > self.options.max_fanout {
return Err(AcdpError::CrossRegistryResolutionFailed(format!(
"context {} has derived_from fanout {} > max_fanout={}",
next.0,
parents.len(),
self.options.max_fanout
)));
}
for parent in parents {
if !seen.contains(parent.as_str()) {
frontier.push_back((parent.clone(), depth + 1));
}
}
results.push(verified);
}
Ok(results)
}
fn check_allowlist(&self, authority: &str) -> Result<(), AcdpError> {
if let Some(list) = &self.allowlist {
if !list.contains(authority) {
return Err(AcdpError::CrossRegistryResolutionFailed(format!(
"authority '{authority}' is not on the resolver allowlist"
)));
}
}
Ok(())
}
async fn client_for(&self, authority: &str, base: &str) -> Result<RegistryClient, AcdpError> {
{
let cache = self.client_cache.lock().unwrap();
if let Some(c) = cache.get(authority) {
return Ok(c.clone());
}
}
let client = RegistryClient::new_pinned(base, &self.ssrf_policy).await?;
let mut cache = self.client_cache.lock().unwrap();
Ok(cache.entry(authority.to_string()).or_insert(client).clone())
}
async fn capabilities_for(
&self,
authority: &str,
registry: &RegistryClient,
) -> Result<CapabilitiesDocument, AcdpError> {
{
let cache = self.caps_cache.lock().unwrap();
if let Some((caps, fetched_at, ttl)) = cache.get(authority) {
if fetched_at.elapsed() < *ttl {
return Ok(caps.clone());
}
}
}
let (caps, response_ttl) = registry
.capabilities_with_ttl()
.await
.map_err(|e| match e {
AcdpError::Http(_) | AcdpError::KeyResolutionUnreachable(_) => {
AcdpError::CrossRegistryResolutionFailed(format!(
"could not reach registry '{authority}': {e}"
))
}
other => other,
})?;
let ttl = response_ttl.min(self.options.capabilities_ttl);
let mut cache = self.caps_cache.lock().unwrap();
cache.insert(authority.to_string(), (caps.clone(), Instant::now(), ttl));
Ok(caps)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn allowlist_rejects_outside_authorities() {
let resolver =
CrossRegistryResolver::new().with_allowlist(["registry.example.com".to_string()]);
let err = resolver.check_allowlist("evil.com").unwrap_err();
assert!(matches!(err, AcdpError::CrossRegistryResolutionFailed(_)));
resolver.check_allowlist("registry.example.com").unwrap();
}
#[test]
fn options_default_values_match_doc() {
let o = ResolverOptions::default();
assert_eq!(o.max_depth, 10);
assert_eq!(o.max_nodes, 100);
assert_eq!(o.max_fanout, 32);
assert_eq!(o.total_timeout, Duration::from_secs(30));
assert_eq!(o.capabilities_ttl, Duration::from_secs(300));
}
#[test]
fn with_options_replaces_full_struct() {
let r = CrossRegistryResolver::new().with_options(ResolverOptions {
max_depth: 3,
max_nodes: 7,
max_fanout: 2,
total_timeout: Duration::from_secs(5),
capabilities_ttl: Duration::from_secs(60),
});
assert_eq!(r.options().max_depth, 3);
assert_eq!(r.options().max_nodes, 7);
assert_eq!(r.options().max_fanout, 2);
}
#[test]
fn cycle_detection_short_circuits() {
let _resolver = CrossRegistryResolver::new();
let mut seen: HashSet<String> = HashSet::new();
let id = "acdp://r/12345678-1234-4321-8123-123456781234".to_string();
assert!(seen.insert(id.clone()));
assert!(!seen.insert(id));
}
}