use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use aex_core::{AgentId, CapabilitySet, IdScheme};
use async_trait::async_trait;
use thiserror::Error;
use tokio::sync::{Mutex, Notify, RwLock};
pub const DEFAULT_TTL: Duration = Duration::from_secs(60 * 60);
pub const DEFAULT_CAPACITY: usize = 10_000;
#[derive(Debug, Error)]
pub enum ResolverError {
#[error("no resolver for scheme {scheme:?} (handle {handle})")]
NoResolverForScheme {
scheme: IdScheme,
handle: String,
},
#[error("invalid handle: {0}")]
InvalidHandle(String),
#[error("resolver failed for {handle}: {source}")]
Underlying {
handle: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
#[error("cache-integrity violation for {handle}: fingerprint changed unexpectedly")]
CacheIntegrityViolation {
handle: String,
},
}
#[async_trait]
pub trait AgentResolver: Send + Sync {
fn scheme(&self) -> IdScheme;
async fn resolve(
&self,
handle: &AgentId,
if_none_match: Option<&str>,
) -> Result<ResolveOutcome, ResolverError>;
}
#[derive(Debug, Clone)]
pub enum ResolveOutcome {
Fresh(ResolvedAgent),
NotModified,
}
#[derive(Debug, Clone)]
pub struct ResolvedAgent {
pub agent_id: AgentId,
pub fingerprint: String,
pub capabilities: CapabilitySet,
pub etag: Option<String>,
}
#[derive(Debug, Clone)]
struct CacheEntry {
record: ResolvedAgent,
inserted: Instant,
}
#[derive(Clone)]
pub struct ResolverChain {
resolvers: Arc<HashMap<IdScheme, Arc<dyn AgentResolver>>>,
cache: Arc<RwLock<HashMap<AgentId, CacheEntry>>>,
ttl: Duration,
capacity: usize,
inflight: Arc<Mutex<HashMap<AgentId, Arc<Notify>>>>,
}
impl ResolverChain {
pub fn new(resolvers: Vec<Arc<dyn AgentResolver>>) -> Self {
Self::with_capacity(resolvers, DEFAULT_CAPACITY, DEFAULT_TTL)
}
pub fn with_capacity(
resolvers: Vec<Arc<dyn AgentResolver>>,
capacity: usize,
ttl: Duration,
) -> Self {
let mut map = HashMap::new();
for r in resolvers {
map.insert(r.scheme(), r);
}
Self {
resolvers: Arc::new(map),
cache: Arc::new(RwLock::new(HashMap::new())),
ttl,
capacity,
inflight: Arc::new(Mutex::new(HashMap::new())),
}
}
pub async fn resolve(&self, handle: &str) -> Result<ResolvedAgent, ResolverError> {
let agent_id = AgentId::new(handle.to_string())
.map_err(|e| ResolverError::InvalidHandle(e.to_string()))?;
if let Some(record) = self.cache_get_fresh(&agent_id).await {
return Ok(record);
}
let notify = {
let mut inflight = self.inflight.lock().await;
if let Some(n) = inflight.get(&agent_id) {
Some(n.clone())
} else {
inflight.insert(agent_id.clone(), Arc::new(Notify::new()));
None
}
};
if let Some(n) = notify {
n.notified().await;
if let Some(rec) = self.cache_get_any(&agent_id).await {
return Ok(rec);
}
return Err(ResolverError::Underlying {
handle: agent_id.as_str().to_string(),
source: "inflight resolver failed".into(),
});
}
let result = self.fetch_and_update(&agent_id).await;
let waiters = {
let mut inflight = self.inflight.lock().await;
inflight.remove(&agent_id)
};
if let Some(n) = waiters {
n.notify_waiters();
}
result
}
pub async fn invalidate(&self, handle: &str) -> Result<(), ResolverError> {
let agent_id = AgentId::new(handle.to_string())
.map_err(|e| ResolverError::InvalidHandle(e.to_string()))?;
self.cache.write().await.remove(&agent_id);
Ok(())
}
pub async fn cache_len(&self) -> usize {
self.cache.read().await.len()
}
async fn fetch_and_update(&self, agent_id: &AgentId) -> Result<ResolvedAgent, ResolverError> {
let resolver = self.resolvers.get(&agent_id.scheme()).ok_or_else(|| {
ResolverError::NoResolverForScheme {
scheme: agent_id.scheme(),
handle: agent_id.as_str().to_string(),
}
})?;
let if_none_match = self.cache_etag(agent_id).await;
let outcome = resolver.resolve(agent_id, if_none_match.as_deref()).await?;
let record = match outcome {
ResolveOutcome::Fresh(rec) => {
let entry = CacheEntry {
record: rec.clone(),
inserted: Instant::now(),
};
self.cache_insert(agent_id.clone(), entry).await;
rec
}
ResolveOutcome::NotModified => {
self.cache_extend(agent_id).await.ok_or_else(|| {
ResolverError::Underlying {
handle: agent_id.as_str().to_string(),
source: "304 returned with no cached entry".into(),
}
})?
}
};
Ok(record)
}
async fn cache_get_fresh(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
let cache = self.cache.read().await;
cache
.get(agent_id)
.filter(|e| e.inserted.elapsed() < self.ttl)
.map(|e| e.record.clone())
}
async fn cache_get_any(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
let cache = self.cache.read().await;
cache.get(agent_id).map(|e| e.record.clone())
}
async fn cache_etag(&self, agent_id: &AgentId) -> Option<String> {
self.cache
.read()
.await
.get(agent_id)
.and_then(|e| e.record.etag.clone())
}
async fn cache_extend(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
let mut cache = self.cache.write().await;
cache.get_mut(agent_id).map(|e| {
e.inserted = Instant::now();
e.record.clone()
})
}
async fn cache_insert(&self, key: AgentId, entry: CacheEntry) {
let mut cache = self.cache.write().await;
cache.insert(key, entry);
if cache.len() > self.capacity {
let excess = cache.len() - self.capacity;
let mut by_age: Vec<(AgentId, Instant)> =
cache.iter().map(|(k, v)| (k.clone(), v.inserted)).collect();
by_age.sort_by_key(|(_, t)| *t);
for (k, _) in by_age.into_iter().take(excess) {
cache.remove(&k);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
struct CountingResolver {
scheme: IdScheme,
calls: Arc<AtomicUsize>,
etag: String,
}
impl CountingResolver {
fn new(scheme: IdScheme) -> Self {
Self {
scheme,
calls: Arc::new(AtomicUsize::new(0)),
etag: "etag-v1".into(),
}
}
fn calls(&self) -> usize {
self.calls.load(Ordering::SeqCst)
}
}
#[async_trait]
impl AgentResolver for CountingResolver {
fn scheme(&self) -> IdScheme {
self.scheme
}
async fn resolve(
&self,
handle: &AgentId,
if_none_match: Option<&str>,
) -> Result<ResolveOutcome, ResolverError> {
self.calls.fetch_add(1, Ordering::SeqCst);
if if_none_match == Some(self.etag.as_str()) {
return Ok(ResolveOutcome::NotModified);
}
Ok(ResolveOutcome::Fresh(ResolvedAgent {
agent_id: handle.clone(),
fingerprint: format!("fp:{}", handle.as_str()),
capabilities: CapabilitySet::empty(),
etag: Some(self.etag.clone()),
}))
}
}
fn chain_with(resolver: Arc<CountingResolver>) -> ResolverChain {
ResolverChain::with_capacity(
vec![resolver as Arc<dyn AgentResolver>],
100,
Duration::from_secs(60),
)
}
#[tokio::test]
async fn cache_miss_then_hit() {
let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let chain = chain_with(resolver.clone());
let _ = chain.resolve("did:web:acme.com#fatture").await.unwrap();
let _ = chain.resolve("did:web:acme.com#fatture").await.unwrap();
assert_eq!(resolver.calls(), 1, "second call must hit cache");
}
#[tokio::test]
async fn cache_returns_correct_record() {
let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let chain = chain_with(resolver);
let rec = chain.resolve("did:web:acme.com#x").await.unwrap();
assert_eq!(rec.agent_id.as_str(), "did:web:acme.com#x");
assert!(rec.fingerprint.contains("acme.com"));
}
#[tokio::test]
async fn stale_entry_uses_conditional_get_and_304() {
let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let chain = ResolverChain::with_capacity(
vec![resolver.clone() as Arc<dyn AgentResolver>],
100,
Duration::from_millis(10), );
let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
tokio::time::sleep(Duration::from_millis(15)).await;
let rec = chain.resolve("did:web:acme.com#x").await.unwrap();
assert_eq!(rec.etag.as_deref(), Some("etag-v1"));
assert_eq!(resolver.calls(), 2);
}
#[tokio::test]
async fn no_resolver_for_unknown_scheme() {
let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let chain = chain_with(resolver);
let err = chain.resolve("did:ethr:8453:0xabc").await.unwrap_err();
assert!(matches!(err, ResolverError::NoResolverForScheme { .. }));
}
#[tokio::test]
async fn invalid_handle_rejected() {
let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let chain = chain_with(resolver);
let err = chain.resolve("").await.unwrap_err();
assert!(matches!(err, ResolverError::InvalidHandle(_)));
}
#[tokio::test]
async fn single_flight_collapses_concurrent_misses() {
let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let chain = chain_with(resolver.clone());
let handles: Vec<_> = (0..50)
.map(|_| {
let c = chain.clone();
tokio::spawn(async move {
c.resolve("did:web:acme.com#fatture")
.await
.map(|r| r.agent_id.as_str().to_string())
})
})
.collect();
let mut results = Vec::with_capacity(50);
for h in handles {
results.push(h.await.unwrap().unwrap());
}
assert!(results.iter().all(|r| r == "did:web:acme.com#fatture"));
let calls = resolver.calls();
assert!(
calls <= 2,
"single-flight failed: {} fetches for 50 concurrent resolves",
calls
);
}
#[tokio::test]
async fn invalidate_drops_entry() {
let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let chain = chain_with(resolver.clone());
let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
assert_eq!(chain.cache_len().await, 1);
chain.invalidate("did:web:acme.com#x").await.unwrap();
assert_eq!(chain.cache_len().await, 0);
let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
assert_eq!(resolver.calls(), 2);
}
#[tokio::test]
async fn bounded_capacity_evicts_oldest() {
let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let chain = ResolverChain::with_capacity(
vec![resolver as Arc<dyn AgentResolver>],
3, Duration::from_secs(60),
);
for i in 0..5 {
let _ = chain
.resolve(&format!("did:web:acme.com#agent-{}", i))
.await
.unwrap();
tokio::time::sleep(Duration::from_millis(2)).await;
}
assert_eq!(chain.cache_len().await, 3);
}
#[tokio::test]
async fn multiple_resolvers_dispatch_by_scheme() {
let r_web = Arc::new(CountingResolver::new(IdScheme::DidWeb));
let r_key = Arc::new(CountingResolver::new(IdScheme::DidKey));
let chain = ResolverChain::new(vec![
r_web.clone() as Arc<dyn AgentResolver>,
r_key.clone() as Arc<dyn AgentResolver>,
]);
let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
let _ = chain.resolve("did:key:zabc").await.unwrap();
assert_eq!(r_web.calls(), 1);
assert_eq!(r_key.calls(), 1);
}
}