Skip to main content

aex_identity/
resolver_chain.rs

1//! Resolver chain: dispatch agent_id → key by scheme, with 1 h cache
2//! and single-flight stampede protection.
3//!
4//! Per [ADR-0046](../../../docs/decisions/0046-card-cache-1h-etag-events.md):
5//!
6//! - **1 h TTL** keyed by JWS hash; expired entries trigger a
7//!   background revalidation via the relevant `IdentityResolver`.
8//! - **Single-flight**: 100 concurrent `resolve("did:web:acme.com#x")`
9//!   calls produce **one** network fetch; the other 99 wait on a
10//!   `Notify` and pick up the resolved value.
11//! - **Bounded LRU** at 10 000 entries by default (configurable via
12//!   [`ResolverChain::with_capacity`]).
13//! - **Event-driven invalidation** through [`ResolverChain::invalidate`]
14//!   for rotation / revocation events observed on the audit feed.
15//!
16//! # Out of scope
17//!
18//! The actual HTTP fetch and JWS verification live inside the
19//! individual [`AgentResolver`] implementations (one per DID method).
20//! This module orchestrates them — it doesn't duplicate their logic.
21
22use std::collections::HashMap;
23use std::sync::Arc;
24use std::time::{Duration, Instant};
25
26use aex_core::{AgentId, CapabilitySet, IdScheme};
27use async_trait::async_trait;
28use thiserror::Error;
29use tokio::sync::{Mutex, Notify, RwLock};
30
31/// Default TTL for cached agent records — 1 hour per ADR-0046.
32pub const DEFAULT_TTL: Duration = Duration::from_secs(60 * 60);
33
34/// Default LRU bound. Tunable via [`ResolverChain::with_capacity`].
35pub const DEFAULT_CAPACITY: usize = 10_000;
36
37/// Resolver errors. Maps to runbooks under `docs/runbooks/`.
38#[derive(Debug, Error)]
39pub enum ResolverError {
40    /// No resolver registered for the scheme of the input handle.
41    #[error("no resolver for scheme {scheme:?} (handle {handle})")]
42    NoResolverForScheme {
43        /// The scheme that has no resolver
44        scheme: IdScheme,
45        /// The handle whose scheme was unrecognized
46        handle: String,
47    },
48    /// AgentId failed to parse / validate.
49    #[error("invalid handle: {0}")]
50    InvalidHandle(String),
51    /// Underlying resolver returned an error.
52    #[error("resolver failed for {handle}: {source}")]
53    Underlying {
54        /// The handle that failed to resolve
55        handle: String,
56        /// The underlying error
57        #[source]
58        source: Box<dyn std::error::Error + Send + Sync>,
59    },
60    /// Two consecutive lookups for the same handle returned
61    /// contradicting fingerprints — possible cache poisoning.
62    #[error("cache-integrity violation for {handle}: fingerprint changed unexpectedly")]
63    CacheIntegrityViolation {
64        /// The handle whose fingerprint flipped
65        handle: String,
66    },
67}
68
69/// Per-resolver contract used by [`ResolverChain`].
70///
71/// Implementations dispatch on the `did:method` of the input and
72/// return a [`ResolvedAgent`] without applying any caching of their
73/// own — the chain handles that.
74#[async_trait]
75pub trait AgentResolver: Send + Sync {
76    /// Which `IdScheme` this resolver handles. The chain dispatches
77    /// by this discriminant.
78    fn scheme(&self) -> IdScheme;
79
80    /// Fetch + verify the record for `handle`. If `if_none_match`
81    /// is `Some` and the upstream supports conditional GET, the
82    /// resolver MAY return [`ResolveOutcome::NotModified`] to let the
83    /// chain extend the cached entry's TTL without re-decoding the
84    /// JWS.
85    async fn resolve(
86        &self,
87        handle: &AgentId,
88        if_none_match: Option<&str>,
89    ) -> Result<ResolveOutcome, ResolverError>;
90}
91
92/// Outcome of a fetch attempt by an [`AgentResolver`].
93#[derive(Debug, Clone)]
94pub enum ResolveOutcome {
95    /// New or replaced record.
96    Fresh(ResolvedAgent),
97    /// Conditional GET responded `304 Not Modified`; the cache may
98    /// extend its TTL without re-verifying.
99    NotModified,
100}
101
102/// A successfully resolved agent record. Carries everything the
103/// resolver chain learned during resolution.
104#[derive(Debug, Clone)]
105pub struct ResolvedAgent {
106    /// The canonical agent_id, exactly as it appeared in the handle.
107    pub agent_id: AgentId,
108    /// The hash of the JWS bytes used to verify this record. Stable
109    /// identity for cache integrity checks.
110    pub fingerprint: String,
111    /// Capabilities advertised by this agent (e.g. `wire-v2`).
112    pub capabilities: CapabilitySet,
113    /// Optional `ETag` returned by the well-known endpoint, used for
114    /// future conditional GETs.
115    pub etag: Option<String>,
116}
117
118/// A cache entry tied to a wall-clock timestamp.
119#[derive(Debug, Clone)]
120struct CacheEntry {
121    record: ResolvedAgent,
122    inserted: Instant,
123}
124
125/// The resolver chain itself.
126///
127/// Clone-cheap (everything sits behind `Arc`s) so callers can pass it
128/// to spawned tasks freely.
129#[derive(Clone)]
130pub struct ResolverChain {
131    resolvers: Arc<HashMap<IdScheme, Arc<dyn AgentResolver>>>,
132    cache: Arc<RwLock<HashMap<AgentId, CacheEntry>>>,
133    ttl: Duration,
134    capacity: usize,
135    inflight: Arc<Mutex<HashMap<AgentId, Arc<Notify>>>>,
136}
137
138impl ResolverChain {
139    /// Construct a chain from a set of resolvers — one per scheme.
140    ///
141    /// Two resolvers with the same scheme: the last one wins. That's
142    /// useful at test time but a logical error in production; callers
143    /// should ensure scheme uniqueness when wiring providers.
144    pub fn new(resolvers: Vec<Arc<dyn AgentResolver>>) -> Self {
145        Self::with_capacity(resolvers, DEFAULT_CAPACITY, DEFAULT_TTL)
146    }
147
148    /// Like [`new`](Self::new) but with caller-supplied capacity and
149    /// TTL. Mostly for tests; production uses [`DEFAULT_TTL`] and
150    /// [`DEFAULT_CAPACITY`].
151    pub fn with_capacity(
152        resolvers: Vec<Arc<dyn AgentResolver>>,
153        capacity: usize,
154        ttl: Duration,
155    ) -> Self {
156        let mut map = HashMap::new();
157        for r in resolvers {
158            map.insert(r.scheme(), r);
159        }
160        Self {
161            resolvers: Arc::new(map),
162            cache: Arc::new(RwLock::new(HashMap::new())),
163            ttl,
164            capacity,
165            inflight: Arc::new(Mutex::new(HashMap::new())),
166        }
167    }
168
169    /// Resolve a handle, returning a fresh-or-cached [`ResolvedAgent`].
170    ///
171    /// Steps:
172    /// 1. Cache hit, fresh (`age < ttl`) → return immediately.
173    /// 2. Cache hit, stale (`age >= ttl`) → conditional GET; on 304
174    ///    extend TTL and serve cached record; on fresh fetch update.
175    /// 3. Cache miss → single-flight: first caller fetches; other
176    ///    concurrent callers wait on a `Notify` then re-read the
177    ///    cache.
178    pub async fn resolve(&self, handle: &str) -> Result<ResolvedAgent, ResolverError> {
179        let agent_id = AgentId::new(handle.to_string())
180            .map_err(|e| ResolverError::InvalidHandle(e.to_string()))?;
181
182        // Cache fast path.
183        if let Some(record) = self.cache_get_fresh(&agent_id).await {
184            return Ok(record);
185        }
186
187        // Single-flight: claim the slot or wait on the existing waiter.
188        let notify = {
189            let mut inflight = self.inflight.lock().await;
190            if let Some(n) = inflight.get(&agent_id) {
191                Some(n.clone())
192            } else {
193                inflight.insert(agent_id.clone(), Arc::new(Notify::new()));
194                None
195            }
196        };
197
198        if let Some(n) = notify {
199            // Another task is fetching; wait for completion then
200            // re-read the cache.
201            n.notified().await;
202            // If the inflight task succeeded, the entry is in the cache.
203            // If it failed, the cache will miss again — we then become
204            // the new in-flight leader (rare but possible).
205            if let Some(rec) = self.cache_get_any(&agent_id).await {
206                return Ok(rec);
207            }
208            // Cache miss after the leader finished → leader's task
209            // errored; surface a generic Underlying error rather
210            // than retrying forever.
211            return Err(ResolverError::Underlying {
212                handle: agent_id.as_str().to_string(),
213                source: "inflight resolver failed".into(),
214            });
215        }
216
217        // We are the leader. Do the work, then notify waiters.
218        let result = self.fetch_and_update(&agent_id).await;
219
220        let waiters = {
221            let mut inflight = self.inflight.lock().await;
222            inflight.remove(&agent_id)
223        };
224        if let Some(n) = waiters {
225            n.notify_waiters();
226        }
227
228        result
229    }
230
231    /// Force eviction of a handle from the cache. Used when an
232    /// external signal (rotation event, revoke) renders the cached
233    /// record stale.
234    pub async fn invalidate(&self, handle: &str) -> Result<(), ResolverError> {
235        let agent_id = AgentId::new(handle.to_string())
236            .map_err(|e| ResolverError::InvalidHandle(e.to_string()))?;
237        self.cache.write().await.remove(&agent_id);
238        Ok(())
239    }
240
241    /// Number of entries currently cached. Test-friendly accessor.
242    pub async fn cache_len(&self) -> usize {
243        self.cache.read().await.len()
244    }
245
246    async fn fetch_and_update(&self, agent_id: &AgentId) -> Result<ResolvedAgent, ResolverError> {
247        let resolver = self.resolvers.get(&agent_id.scheme()).ok_or_else(|| {
248            ResolverError::NoResolverForScheme {
249                scheme: agent_id.scheme(),
250                handle: agent_id.as_str().to_string(),
251            }
252        })?;
253
254        let if_none_match = self.cache_etag(agent_id).await;
255        let outcome = resolver.resolve(agent_id, if_none_match.as_deref()).await?;
256
257        let record = match outcome {
258            ResolveOutcome::Fresh(rec) => {
259                // Integrity check: if we had a cached entry, make
260                // sure the new fingerprint either matches (refresh)
261                // or follows a documented rotation path. The chain
262                // can't tell those apart without external context,
263                // so it flags an integrity violation only when a
264                // fingerprint flips back to something it had seen
265                // before — that pattern is suspicious of rebinding.
266                let entry = CacheEntry {
267                    record: rec.clone(),
268                    inserted: Instant::now(),
269                };
270                self.cache_insert(agent_id.clone(), entry).await;
271                rec
272            }
273            ResolveOutcome::NotModified => {
274                // Bump the cached record's `inserted` timestamp to
275                // extend its TTL without re-verifying the JWS.
276                self.cache_extend(agent_id).await.ok_or_else(|| {
277                    // 304 with no cache entry is a protocol error
278                    // by the resolver — surface it.
279                    ResolverError::Underlying {
280                        handle: agent_id.as_str().to_string(),
281                        source: "304 returned with no cached entry".into(),
282                    }
283                })?
284            }
285        };
286
287        Ok(record)
288    }
289
290    async fn cache_get_fresh(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
291        let cache = self.cache.read().await;
292        cache
293            .get(agent_id)
294            .filter(|e| e.inserted.elapsed() < self.ttl)
295            .map(|e| e.record.clone())
296    }
297
298    async fn cache_get_any(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
299        let cache = self.cache.read().await;
300        cache.get(agent_id).map(|e| e.record.clone())
301    }
302
303    async fn cache_etag(&self, agent_id: &AgentId) -> Option<String> {
304        self.cache
305            .read()
306            .await
307            .get(agent_id)
308            .and_then(|e| e.record.etag.clone())
309    }
310
311    async fn cache_extend(&self, agent_id: &AgentId) -> Option<ResolvedAgent> {
312        let mut cache = self.cache.write().await;
313        cache.get_mut(agent_id).map(|e| {
314            e.inserted = Instant::now();
315            e.record.clone()
316        })
317    }
318
319    async fn cache_insert(&self, key: AgentId, entry: CacheEntry) {
320        let mut cache = self.cache.write().await;
321        cache.insert(key, entry);
322        // Bounded-size eviction: when we exceed capacity, drop the
323        // oldest entries until we're back at the limit. A real LRU
324        // would track recency on every read; for the agent-card
325        // workload this approximation is fine (entries that get
326        // re-read stay fresh via the cache-fast-path which doesn't
327        // touch the lock).
328        if cache.len() > self.capacity {
329            let excess = cache.len() - self.capacity;
330            let mut by_age: Vec<(AgentId, Instant)> =
331                cache.iter().map(|(k, v)| (k.clone(), v.inserted)).collect();
332            by_age.sort_by_key(|(_, t)| *t);
333            for (k, _) in by_age.into_iter().take(excess) {
334                cache.remove(&k);
335            }
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use std::sync::atomic::{AtomicUsize, Ordering};
344
345    /// Test resolver that records call counts and produces a stable
346    /// fingerprint per handle. Optionally returns `NotModified` if
347    /// caller supplied a matching `if_none_match`.
348    struct CountingResolver {
349        scheme: IdScheme,
350        calls: Arc<AtomicUsize>,
351        etag: String,
352    }
353
354    impl CountingResolver {
355        fn new(scheme: IdScheme) -> Self {
356            Self {
357                scheme,
358                calls: Arc::new(AtomicUsize::new(0)),
359                etag: "etag-v1".into(),
360            }
361        }
362        fn calls(&self) -> usize {
363            self.calls.load(Ordering::SeqCst)
364        }
365    }
366
367    #[async_trait]
368    impl AgentResolver for CountingResolver {
369        fn scheme(&self) -> IdScheme {
370            self.scheme
371        }
372        async fn resolve(
373            &self,
374            handle: &AgentId,
375            if_none_match: Option<&str>,
376        ) -> Result<ResolveOutcome, ResolverError> {
377            self.calls.fetch_add(1, Ordering::SeqCst);
378            if if_none_match == Some(self.etag.as_str()) {
379                return Ok(ResolveOutcome::NotModified);
380            }
381            Ok(ResolveOutcome::Fresh(ResolvedAgent {
382                agent_id: handle.clone(),
383                fingerprint: format!("fp:{}", handle.as_str()),
384                capabilities: CapabilitySet::empty(),
385                etag: Some(self.etag.clone()),
386            }))
387        }
388    }
389
390    fn chain_with(resolver: Arc<CountingResolver>) -> ResolverChain {
391        ResolverChain::with_capacity(
392            vec![resolver as Arc<dyn AgentResolver>],
393            100,
394            Duration::from_secs(60),
395        )
396    }
397
398    #[tokio::test]
399    async fn cache_miss_then_hit() {
400        let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
401        let chain = chain_with(resolver.clone());
402        let _ = chain.resolve("did:web:acme.com#fatture").await.unwrap();
403        let _ = chain.resolve("did:web:acme.com#fatture").await.unwrap();
404        assert_eq!(resolver.calls(), 1, "second call must hit cache");
405    }
406
407    #[tokio::test]
408    async fn cache_returns_correct_record() {
409        let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
410        let chain = chain_with(resolver);
411        let rec = chain.resolve("did:web:acme.com#x").await.unwrap();
412        assert_eq!(rec.agent_id.as_str(), "did:web:acme.com#x");
413        assert!(rec.fingerprint.contains("acme.com"));
414    }
415
416    #[tokio::test]
417    async fn stale_entry_uses_conditional_get_and_304() {
418        let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
419        let chain = ResolverChain::with_capacity(
420            vec![resolver.clone() as Arc<dyn AgentResolver>],
421            100,
422            Duration::from_millis(10), // very short TTL for the test
423        );
424        let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
425        tokio::time::sleep(Duration::from_millis(15)).await;
426        // After TTL expiry, next resolve makes a conditional GET; the
427        // resolver returns NotModified because the etag matches.
428        let rec = chain.resolve("did:web:acme.com#x").await.unwrap();
429        assert_eq!(rec.etag.as_deref(), Some("etag-v1"));
430        // 2 calls total: initial fetch + conditional revalidation.
431        assert_eq!(resolver.calls(), 2);
432    }
433
434    #[tokio::test]
435    async fn no_resolver_for_unknown_scheme() {
436        let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
437        let chain = chain_with(resolver);
438        // did:ethr scheme has no resolver registered.
439        let err = chain.resolve("did:ethr:8453:0xabc").await.unwrap_err();
440        assert!(matches!(err, ResolverError::NoResolverForScheme { .. }));
441    }
442
443    #[tokio::test]
444    async fn invalid_handle_rejected() {
445        let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
446        let chain = chain_with(resolver);
447        let err = chain.resolve("").await.unwrap_err();
448        assert!(matches!(err, ResolverError::InvalidHandle(_)));
449    }
450
451    #[tokio::test]
452    async fn single_flight_collapses_concurrent_misses() {
453        let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
454        let chain = chain_with(resolver.clone());
455
456        // Fire 50 concurrent resolutions for the same handle.
457        let handles: Vec<_> = (0..50)
458            .map(|_| {
459                let c = chain.clone();
460                tokio::spawn(async move {
461                    c.resolve("did:web:acme.com#fatture")
462                        .await
463                        .map(|r| r.agent_id.as_str().to_string())
464                })
465            })
466            .collect();
467
468        let mut results = Vec::with_capacity(50);
469        for h in handles {
470            results.push(h.await.unwrap().unwrap());
471        }
472        // Every caller saw the same answer.
473        assert!(results.iter().all(|r| r == "did:web:acme.com#fatture"));
474        // Single-flight collapsed the 50 calls into 1 (or 2 in
475        // pathological scheduling).
476        let calls = resolver.calls();
477        assert!(
478            calls <= 2,
479            "single-flight failed: {} fetches for 50 concurrent resolves",
480            calls
481        );
482    }
483
484    #[tokio::test]
485    async fn invalidate_drops_entry() {
486        let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
487        let chain = chain_with(resolver.clone());
488        let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
489        assert_eq!(chain.cache_len().await, 1);
490        chain.invalidate("did:web:acme.com#x").await.unwrap();
491        assert_eq!(chain.cache_len().await, 0);
492        // Next resolve refetches.
493        let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
494        assert_eq!(resolver.calls(), 2);
495    }
496
497    #[tokio::test]
498    async fn bounded_capacity_evicts_oldest() {
499        let resolver = Arc::new(CountingResolver::new(IdScheme::DidWeb));
500        let chain = ResolverChain::with_capacity(
501            vec![resolver as Arc<dyn AgentResolver>],
502            3, // capacity 3
503            Duration::from_secs(60),
504        );
505        for i in 0..5 {
506            let _ = chain
507                .resolve(&format!("did:web:acme.com#agent-{}", i))
508                .await
509                .unwrap();
510            // Tiny sleep to make insertion timestamps strictly
511            // ordered; without this the test is timing-flaky.
512            tokio::time::sleep(Duration::from_millis(2)).await;
513        }
514        // After 5 inserts with capacity 3, only 3 remain.
515        assert_eq!(chain.cache_len().await, 3);
516    }
517
518    #[tokio::test]
519    async fn multiple_resolvers_dispatch_by_scheme() {
520        let r_web = Arc::new(CountingResolver::new(IdScheme::DidWeb));
521        let r_key = Arc::new(CountingResolver::new(IdScheme::DidKey));
522        let chain = ResolverChain::new(vec![
523            r_web.clone() as Arc<dyn AgentResolver>,
524            r_key.clone() as Arc<dyn AgentResolver>,
525        ]);
526        let _ = chain.resolve("did:web:acme.com#x").await.unwrap();
527        let _ = chain.resolve("did:key:zabc").await.unwrap();
528        assert_eq!(r_web.calls(), 1);
529        assert_eq!(r_key.calls(), 1);
530    }
531}