1use std::collections::HashMap;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use aex_core::{AgentId, CapabilitySet, IdScheme};
27use async_trait::async_trait;
28use thiserror::Error;
29use tokio::sync::{Mutex, Notify, RwLock};
30
31pub const DEFAULT_TTL: Duration = Duration::from_secs(60 * 60);
33
34pub const DEFAULT_CAPACITY: usize = 10_000;
36
37#[derive(Debug, Error)]
39pub enum ResolverError {
40 #[error("no resolver for scheme {scheme:?} (handle {handle})")]
42 NoResolverForScheme {
43 scheme: IdScheme,
45 handle: String,
47 },
48 #[error("invalid handle: {0}")]
50 InvalidHandle(String),
51 #[error("resolver failed for {handle}: {source}")]
53 Underlying {
54 handle: String,
56 #[source]
58 source: Box<dyn std::error::Error + Send + Sync>,
59 },
60 #[error("cache-integrity violation for {handle}: fingerprint changed unexpectedly")]
63 CacheIntegrityViolation {
64 handle: String,
66 },
67}
68
69#[async_trait]
75pub trait AgentResolver: Send + Sync {
76 fn scheme(&self) -> IdScheme;
79
80 async fn resolve(
86 &self,
87 handle: &AgentId,
88 if_none_match: Option<&str>,
89 ) -> Result<ResolveOutcome, ResolverError>;
90}
91
92#[derive(Debug, Clone)]
94pub enum ResolveOutcome {
95 Fresh(ResolvedAgent),
97 NotModified,
100}
101
102#[derive(Debug, Clone)]
105pub struct ResolvedAgent {
106 pub agent_id: AgentId,
108 pub fingerprint: String,
111 pub capabilities: CapabilitySet,
113 pub etag: Option<String>,
116}
117
118#[derive(Debug, Clone)]
120struct CacheEntry {
121 record: ResolvedAgent,
122 inserted: Instant,
123}
124
125#[derive(Clone)]
130pub struct ResolverChain {
131 resolvers: Arc<HashMap<IdScheme, Arc<dyn AgentResolver>>>,
132 cache: Arc<RwLock<HashMap<AgentId, CacheEntry>>>,
133 ttl: Duration,
134 capacity: usize,
135 inflight: Arc<Mutex<HashMap<AgentId, Arc<Notify>>>>,
136}
137
138impl ResolverChain {
139 pub fn new(resolvers: Vec<Arc<dyn AgentResolver>>) -> Self {
145 Self::with_capacity(resolvers, DEFAULT_CAPACITY, DEFAULT_TTL)
146 }
147
148 pub fn with_capacity(
152 resolvers: Vec<Arc<dyn AgentResolver>>,
153 capacity: usize,
154 ttl: Duration,
155 ) -> Self {
156 let mut map = HashMap::new();
157 for r in resolvers {
158 map.insert(r.scheme(), r);
159 }
160 Self {
161 resolvers: Arc::new(map),
162 cache: Arc::new(RwLock::new(HashMap::new())),
163 ttl,
164 capacity,
165 inflight: Arc::new(Mutex::new(HashMap::new())),
166 }
167 }
168
169 pub async fn resolve(&self, handle: &str) -> Result<ResolvedAgent, ResolverError> {
179 let agent_id = AgentId::new(handle.to_string())
180 .map_err(|e| ResolverError::InvalidHandle(e.to_string()))?;
181
182 if let Some(record) = self.cache_get_fresh(&agent_id).await {
184 return Ok(record);
185 }
186
187 let notify = {
189 let mut inflight = self.inflight.lock().await;
190 if let Some(n) = inflight.get(&agent_id) {
191 Some(n.clone())
192 } else {
193 inflight.insert(agent_id.clone(), Arc::new(Notify::new()));
194 None
195 }
196 };
197
198 if let Some(n) = notify {
199 n.notified().await;
202 if let Some(rec) = self.cache_get_any(&agent_id).await {
206 return Ok(rec);
207 }
208 return Err(ResolverError::Underlying {
212 handle: agent_id.as_str().to_string(),
213 source: "inflight resolver failed".into(),
214 });
215 }
216
217 let result = self.fetch_and_update(&agent_id).await;
219
220 let waiters = {
221 let mut inflight = self.inflight.lock().await;
222 inflight.remove(&agent_id)
223 };
224 if let Some(n) = waiters {
225 n.notify_waiters();
226 }
227
228 result
229 }
230
231 pub async fn invalidate(&self, handle: &str) -> Result<(), ResolverError> {
235 let agent_id = AgentId::new(handle.to_string())
236 .map_err(|e| ResolverError::InvalidHandle(e.to_string()))?;
237 self.cache.write().await.remove(&agent_id);
238 Ok(())
239 }
240
241 pub async fn cache_len(&self) -> usize {
243 self.cache.read().await.len()
244 }
245
246 async fn fetch_and_update(&self, agent_id: &AgentId) -> Result<ResolvedAgent, ResolverError> {
247 let resolver = self.resolvers.get(&agent_id.scheme()).ok_or_else(|| {
248 ResolverError::NoResolverForScheme {
249 scheme: agent_id.scheme(),
250 handle: agent_id.as_str().to_string(),
251 }
252 })?;
253
254 let if_none_match = self.cache_etag(agent_id).await;
255 let outcome = resolver.resolve(agent_id, if_none_match.as_deref()).await?;
256
257 let record = match outcome {
258 ResolveOutcome::Fresh(rec) => {
259 let entry = CacheEntry {
267 record: rec.clone(),
268 inserted: Instant::now(),
269 };
270 self.cache_insert(agent_id.clone(), entry).await;
271 rec
272 }
273 ResolveOutcome::NotModified => {
274 self.cache_extend(agent_id).await.ok_or_else(|| {
277 ResolverError::Underlying {
280 handle: agent_id.as_str().to_string(),
281 source: "304 returned with no cached entry".into(),
282 }
283 })?
284 }
285 };
286
287 Ok(record)
288 }
289
290 async fn cache_get_fresh(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
291 let cache = self.cache.read().await;
292 cache
293 .get(agent_id)
294 .filter(|e| e.inserted.elapsed() < self.ttl)
295 .map(|e| e.record.clone())
296 }
297
298 async fn cache_get_any(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
299 let cache = self.cache.read().await;
300 cache.get(agent_id).map(|e| e.record.clone())
301 }
302
303 async fn cache_etag(&self, agent_id: &AgentId) -> Option<String> {
304 self.cache
305 .read()
306 .await
307 .get(agent_id)
308 .and_then(|e| e.record.etag.clone())
309 }
310
311 async fn cache_extend(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
312 let mut cache = self.cache.write().await;
313 cache.get_mut(agent_id).map(|e| {
314 e.inserted = Instant::now();
315 e.record.clone()
316 })
317 }
318
319 async fn cache_insert(&self, key: AgentId, entry: CacheEntry) {
320 let mut cache = self.cache.write().await;
321 cache.insert(key, entry);
322 if cache.len() > self.capacity {
329 let excess = cache.len() - self.capacity;
330 let mut by_age: Vec<(AgentId, Instant)> =
331 cache.iter().map(|(k, v)| (k.clone(), v.inserted)).collect();
332 by_age.sort_by_key(|(_, t)| *t);
333 for (k, _) in by_age.into_iter().take(excess) {
334 cache.remove(&k);
335 }
336 }
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343 use std::sync::atomic::{AtomicUsize, Ordering};
344
345 struct CountingResolver {
349 scheme: IdScheme,
350 calls: Arc<AtomicUsize>,
351 etag: String,
352 }
353
354 impl CountingResolver {
355 fn new(scheme: IdScheme) -> Self {
356 Self {
357 scheme,
358 calls: Arc::new(AtomicUsize::new(0)),
359 etag: "etag-v1".into(),
360 }
361 }
362 fn calls(&self) -> usize {
363 self.calls.load(Ordering::SeqCst)
364 }
365 }
366
367 #[async_trait]
368 impl AgentResolver for CountingResolver {
369 fn scheme(&self) -> IdScheme {
370 self.scheme
371 }
372 async fn resolve(
373 &self,
374 handle: &AgentId,
375 if_none_match: Option<&str>,
376 ) -> Result<ResolveOutcome, ResolverError> {
377 self.calls.fetch_add(1, Ordering::SeqCst);
378 if if_none_match == Some(self.etag.as_str()) {
379 return Ok(ResolveOutcome::NotModified);
380 }
381 Ok(ResolveOutcome::Fresh(ResolvedAgent {
382 agent_id: handle.clone(),
383 fingerprint: format!("fp:{}", handle.as_str()),
384 capabilities: CapabilitySet::empty(),
385 etag: Some(self.etag.clone()),
386 }))
387 }
388 }
389
390 fn chain_with(resolver: Arc<CountingResolver>) -> ResolverChain {
391 ResolverChain::with_capacity(
392 vec![resolver as Arc<dyn AgentResolver>],
393 100,
394 Duration::from_secs(60),
395 )
396 }
397
398 #[tokio::test]
399 async fn cache_miss_then_hit() {
400 let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
401 let chain = chain_with(resolver.clone());
402 let _ = chain.resolve("did:web:acme.com#fatture").await.unwrap();
403 let _ = chain.resolve("did:web:acme.com#fatture").await.unwrap();
404 assert_eq!(resolver.calls(), 1, "second call must hit cache");
405 }
406
407 #[tokio::test]
408 async fn cache_returns_correct_record() {
409 let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
410 let chain = chain_with(resolver);
411 let rec = chain.resolve("did:web:acme.com#x").await.unwrap();
412 assert_eq!(rec.agent_id.as_str(), "did:web:acme.com#x");
413 assert!(rec.fingerprint.contains("acme.com"));
414 }
415
416 #[tokio::test]
417 async fn stale_entry_uses_conditional_get_and_304() {
418 let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
419 let chain = ResolverChain::with_capacity(
420 vec![resolver.clone() as Arc<dyn AgentResolver>],
421 100,
422 Duration::from_millis(10), );
424 let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
425 tokio::time::sleep(Duration::from_millis(15)).await;
426 let rec = chain.resolve("did:web:acme.com#x").await.unwrap();
429 assert_eq!(rec.etag.as_deref(), Some("etag-v1"));
430 assert_eq!(resolver.calls(), 2);
432 }
433
434 #[tokio::test]
435 async fn no_resolver_for_unknown_scheme() {
436 let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
437 let chain = chain_with(resolver);
438 let err = chain.resolve("did:ethr:8453:0xabc").await.unwrap_err();
440 assert!(matches!(err, ResolverError::NoResolverForScheme { .. }));
441 }
442
443 #[tokio::test]
444 async fn invalid_handle_rejected() {
445 let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
446 let chain = chain_with(resolver);
447 let err = chain.resolve("").await.unwrap_err();
448 assert!(matches!(err, ResolverError::InvalidHandle(_)));
449 }
450
451 #[tokio::test]
452 async fn single_flight_collapses_concurrent_misses() {
453 let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
454 let chain = chain_with(resolver.clone());
455
456 let handles: Vec<_> = (0..50)
458 .map(|_| {
459 let c = chain.clone();
460 tokio::spawn(async move {
461 c.resolve("did:web:acme.com#fatture")
462 .await
463 .map(|r| r.agent_id.as_str().to_string())
464 })
465 })
466 .collect();
467
468 let mut results = Vec::with_capacity(50);
469 for h in handles {
470 results.push(h.await.unwrap().unwrap());
471 }
472 assert!(results.iter().all(|r| r == "did:web:acme.com#fatture"));
474 let calls = resolver.calls();
477 assert!(
478 calls <= 2,
479 "single-flight failed: {} fetches for 50 concurrent resolves",
480 calls
481 );
482 }
483
484 #[tokio::test]
485 async fn invalidate_drops_entry() {
486 let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
487 let chain = chain_with(resolver.clone());
488 let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
489 assert_eq!(chain.cache_len().await, 1);
490 chain.invalidate("did:web:acme.com#x").await.unwrap();
491 assert_eq!(chain.cache_len().await, 0);
492 let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
494 assert_eq!(resolver.calls(), 2);
495 }
496
497 #[tokio::test]
498 async fn bounded_capacity_evicts_oldest() {
499 let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
500 let chain = ResolverChain::with_capacity(
501 vec![resolver as Arc<dyn AgentResolver>],
502 3, Duration::from_secs(60),
504 );
505 for i in 0..5 {
506 let _ = chain
507 .resolve(&format!("did:web:acme.com#agent-{}", i))
508 .await
509 .unwrap();
510 tokio::time::sleep(Duration::from_millis(2)).await;
513 }
514 assert_eq!(chain.cache_len().await, 3);
516 }
517
518 #[tokio::test]
519 async fn multiple_resolvers_dispatch_by_scheme() {
520 let r_web = Arc::new(CountingResolver::new(IdScheme::DidWeb));
521 let r_key = Arc::new(CountingResolver::new(IdScheme::DidKey));
522 let chain = ResolverChain::new(vec![
523 r_web.clone() as Arc<dyn AgentResolver>,
524 r_key.clone() as Arc<dyn AgentResolver>,
525 ]);
526 let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
527 let _ = chain.resolve("did:key:zabc").await.unwrap();
528 assert_eq!(r_web.calls(), 1);
529 assert_eq!(r_key.calls(), 1);
530 }
531}