Skip to main content

pas_external/
session_version.rs

1//! Break-glass `sv` claim validation (#005 spec).
2//!
3//! Consumers who accept Human-entity access tokens should validate the
4//! `sv` claim against the PAS source of truth to enforce break-glass
5//! revocation. Without this check, a token stolen before break-glass
6//! would remain valid until its 1-hour TTL expiry.
7//!
8//! Architecture (cache-then-fetch, matches `paseto-sv-claim.md §R5`):
9//!
10//! 1. Validator gets `token_sv` from `VerifiedClaims::session_version()`.
11//! 2. If `None` (legacy token, AI agent, delegated) → admit (R6 bypass).
12//! 3. Look up `sv:{ppnum_id}` in a pluggable [`SessionVersionCache`] —
13//!    default is the in-memory [`MemorySessionVersionCache`] (60 s TTL);
14//!    consumers that already run KVRocks/Redis can plug in an adapter.
15//! 4. Cache miss → [`SessionVersionFetcher`] does an HTTP GET on PAS
16//!    `/oauth/userinfo` with the caller's own bearer token. Default
17//!    implementation: [`HttpUserInfoFetcher`].
18//! 5. Compare: `token_sv < fresh_sv` → reject with
19//!    [`ValidateSvError::Stale`]; equal or greater → admit.
20//!
21//! Fail-closed on fetch failure: a transient DB / network outage surfaces
22//! as [`ValidateSvError::Transient`] and the caller rejects the request.
23//! Silent admit on transient would defeat break-glass because an attacker
24//! could force cache eviction + DB blip to stall validation.
25
26use std::collections::HashMap;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use async_trait::async_trait;
31use tokio::sync::RwLock;
32
33use crate::oauth::AuthClient;
34
35/// Namespace prefix for cache keys. Matches chat-auth and is_admin caches.
36pub const SV_CACHE_KEY_PREFIX: &str = "sv:";
37
38/// TTL per `paseto-sv-claim.md §R5`. 60 s, non-configurable by design.
39pub const SV_CACHE_TTL: Duration = Duration::from_secs(60);
40
41/// Cache abstraction for `sv:{ppnum_id}` lookups.
42///
43/// Default implementation is [`MemorySessionVersionCache`]. Consumers
44/// that already run KVRocks/Redis can write their own adapter — the
45/// `get` / `set` contract is minimal.
46///
47/// `get` returns `None` on cache miss OR any transient backend error.
48/// `set` is best-effort and swallows failures internally (a failed set
49/// only costs us one extra fetch on the next validate).
50#[async_trait]
51pub trait SessionVersionCache: Send + Sync {
52    async fn get(&self, key: &str) -> Option<i64>;
53    async fn set(&self, key: &str, sv: i64, ttl: Duration);
54}
55
56/// Fresh-read source for the cache-miss path.
57///
58/// The SDK's default implementation ([`HttpUserInfoFetcher`]) calls
59/// `/oauth/userinfo` with the caller's own bearer token. Consumers with
60/// direct DB access (e.g., cross-schema `SELECT` in a monolith) can write
61/// an adapter that bypasses HTTP.
62///
63/// `bearer_token` is the ACCESS TOKEN whose `sv` claim is being
64/// validated — it's always in hand at the validate call site, so
65/// threading it through the trait avoids a second auth layer.
66#[async_trait]
67pub trait SessionVersionFetcher: Send + Sync {
68    async fn fetch(&self, ppnum_id: &str, bearer_token: &str) -> Result<i64, FetchError>;
69}
70
71#[derive(Debug, thiserror::Error)]
72#[error("session_version fetch failed: {0}")]
73pub struct FetchError(pub String);
74
75#[derive(Debug, thiserror::Error)]
76pub enum ValidateSvError {
77    #[error("session_version stale: token_sv={token_sv} < current_sv={current_sv}")]
78    Stale { token_sv: i64, current_sv: i64 },
79
80    #[error("session_version lookup transient failure: {0}")]
81    Transient(FetchError),
82}
83
84/// In-memory [`SessionVersionCache`]. Default choice for SDK consumers.
85///
86/// `tokio::sync::RwLock<HashMap<String, (sv, Instant)>>` with lazy
87/// eviction on read: entries past their TTL are treated as miss.
88/// Production consumers with many pods may want to plug in a shared
89/// cache (Redis, KVRocks) so a break-glass on one pod converges on all
90/// pods within the same 60 s window; the in-memory default is per-pod.
91pub struct MemorySessionVersionCache {
92    inner: Arc<RwLock<HashMap<String, (i64, Instant)>>>,
93}
94
95impl MemorySessionVersionCache {
96    #[must_use]
97    pub fn new() -> Self {
98        Self {
99            inner: Arc::new(RwLock::new(HashMap::new())),
100        }
101    }
102}
103
104impl Default for MemorySessionVersionCache {
105    fn default() -> Self {
106        Self::new()
107    }
108}
109
110#[async_trait]
111impl SessionVersionCache for MemorySessionVersionCache {
112    async fn get(&self, key: &str) -> Option<i64> {
113        let guard = self.inner.read().await;
114        let (sv, written_at) = guard.get(key)?;
115        if written_at.elapsed() >= SV_CACHE_TTL {
116            return None;
117        }
118        Some(*sv)
119    }
120
121    async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
122        // TTL is governed by the SV_CACHE_TTL constant; ignore the param
123        // so callers can't accidentally drift this substrate's TTL away
124        // from the contract.
125        let mut guard = self.inner.write().await;
126        guard.insert(key.to_string(), (sv, Instant::now()));
127    }
128}
129
130/// [`SessionVersionFetcher`] backed by an [`AuthClient`]'s userinfo call.
131///
132/// Uses the caller's bearer token to GET `/oauth/userinfo` and reads
133/// `response.session_version`. Returns [`FetchError`] if the HTTP call
134/// fails, the endpoint returns non-2xx, the response is unparseable, or
135/// the response omits `session_version` (which signals the caller is
136/// not a Human-entity account — but the validator should only ever
137/// call the fetcher for tokens that carried a `sv` claim in the first
138/// place, so `None` here is a protocol violation).
139pub struct HttpUserInfoFetcher {
140    client: Arc<AuthClient>,
141}
142
143impl HttpUserInfoFetcher {
144    #[must_use]
145    pub fn new(client: Arc<AuthClient>) -> Self {
146        Self { client }
147    }
148}
149
150#[async_trait]
151impl SessionVersionFetcher for HttpUserInfoFetcher {
152    async fn fetch(&self, _ppnum_id: &str, bearer_token: &str) -> Result<i64, FetchError> {
153        // `get_user_info` surfaces all failure modes (transport, 4xx/5xx,
154        // parse) via `Error::OAuth { operation, status, detail }`. Any
155        // Err here is treated as transient — fail-closed at the caller.
156        match self.client.get_user_info(bearer_token).await {
157            Ok(info) => info.session_version.ok_or_else(|| {
158                FetchError(
159                    "userinfo response omitted session_version — \
160                     token may not be Human-entity"
161                        .to_string(),
162                )
163            }),
164            Err(e) => Err(FetchError(format!("userinfo call failed: {e}"))),
165        }
166    }
167}
168
169/// Validates a token's `sv` claim against the cached / fresh current
170/// value.
171///
172/// See [module docs] for the full algorithm.
173///
174/// [module docs]: self
175pub async fn validate_sv(
176    token_sv: Option<i64>,
177    ppnum_id: &str,
178    bearer_token: &str,
179    cache: &dyn SessionVersionCache,
180    fetcher: &dyn SessionVersionFetcher,
181) -> Result<(), ValidateSvError> {
182    let Some(token_sv) = token_sv else {
183        // R6 legacy bypass. Bounded by token TTL + refresh cycle.
184        return Ok(());
185    };
186
187    let key = format!("{SV_CACHE_KEY_PREFIX}{ppnum_id}");
188
189    let current_sv = match cache.get(&key).await {
190        Some(v) => v,
191        None => {
192            let fresh = fetcher
193                .fetch(ppnum_id, bearer_token)
194                .await
195                .map_err(ValidateSvError::Transient)?;
196            cache.set(&key, fresh, SV_CACHE_TTL).await;
197            fresh
198        }
199    };
200
201    if token_sv < current_sv {
202        return Err(ValidateSvError::Stale {
203            token_sv,
204            current_sv,
205        });
206    }
207    Ok(())
208}
209
210#[cfg(test)]
211#[allow(clippy::unwrap_used)]
212mod tests {
213    use super::*;
214    use std::sync::Mutex;
215    use std::sync::atomic::{AtomicU64, Ordering};
216
217    const PPNUM_ID: &str = "01HZXY12345678901234567890";
218    const BEARER: &str = "v4.public.placeholder";
219
220    fn cache_key() -> String {
221        format!("{SV_CACHE_KEY_PREFIX}{PPNUM_ID}")
222    }
223
224    struct MockFetcher {
225        sv: i64,
226        calls: AtomicU64,
227    }
228
229    #[async_trait]
230    impl SessionVersionFetcher for MockFetcher {
231        async fn fetch(&self, _: &str, _: &str) -> Result<i64, FetchError> {
232            self.calls.fetch_add(1, Ordering::SeqCst);
233            Ok(self.sv)
234        }
235    }
236
237    struct FailingFetcher;
238
239    #[async_trait]
240    impl SessionVersionFetcher for FailingFetcher {
241        async fn fetch(&self, _: &str, _: &str) -> Result<i64, FetchError> {
242            Err(FetchError("simulated outage".to_string()))
243        }
244    }
245
246    #[derive(Default)]
247    struct RawCache {
248        store: Mutex<HashMap<String, i64>>,
249    }
250
251    #[async_trait]
252    impl SessionVersionCache for RawCache {
253        async fn get(&self, key: &str) -> Option<i64> {
254            self.store.lock().unwrap().get(key).copied()
255        }
256
257        async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
258            self.store.lock().unwrap().insert(key.to_string(), sv);
259        }
260    }
261
262    #[tokio::test]
263    async fn admits_legacy_none() {
264        let cache = RawCache::default();
265        let fetcher = MockFetcher {
266            sv: 999,
267            calls: AtomicU64::new(0),
268        };
269        validate_sv(None, PPNUM_ID, BEARER, &cache, &fetcher)
270            .await
271            .expect("None must admit");
272        assert_eq!(fetcher.calls.load(Ordering::SeqCst), 0);
273    }
274
275    #[tokio::test]
276    async fn admits_current_cache_hit() {
277        let cache = RawCache::default();
278        cache.store.lock().unwrap().insert(cache_key(), 5);
279        let fetcher = MockFetcher {
280            sv: 999,
281            calls: AtomicU64::new(0),
282        };
283        validate_sv(Some(5), PPNUM_ID, BEARER, &cache, &fetcher)
284            .await
285            .unwrap();
286        assert_eq!(fetcher.calls.load(Ordering::SeqCst), 0);
287    }
288
289    #[tokio::test]
290    async fn rejects_stale_cache_hit() {
291        let cache = RawCache::default();
292        cache.store.lock().unwrap().insert(cache_key(), 5);
293        let fetcher = MockFetcher {
294            sv: 999,
295            calls: AtomicU64::new(0),
296        };
297        let err = validate_sv(Some(4), PPNUM_ID, BEARER, &cache, &fetcher)
298            .await
299            .unwrap_err();
300        assert!(matches!(err, ValidateSvError::Stale { token_sv: 4, current_sv: 5 }));
301    }
302
303    #[tokio::test]
304    async fn cache_miss_fetches_and_populates() {
305        let cache = RawCache::default();
306        let fetcher = MockFetcher {
307            sv: 7,
308            calls: AtomicU64::new(0),
309        };
310        validate_sv(Some(7), PPNUM_ID, BEARER, &cache, &fetcher)
311            .await
312            .unwrap();
313        assert_eq!(fetcher.calls.load(Ordering::SeqCst), 1);
314        assert_eq!(cache.store.lock().unwrap().get(&cache_key()), Some(&7));
315
316        validate_sv(Some(7), PPNUM_ID, BEARER, &cache, &fetcher)
317            .await
318            .unwrap();
319        assert_eq!(
320            fetcher.calls.load(Ordering::SeqCst),
321            1,
322            "second call must hit cache"
323        );
324    }
325
326    #[tokio::test]
327    async fn fetch_failure_surfaces_transient() {
328        let cache = RawCache::default();
329        let fetcher = FailingFetcher;
330        let err = validate_sv(Some(1), PPNUM_ID, BEARER, &cache, &fetcher)
331            .await
332            .unwrap_err();
333        assert!(matches!(err, ValidateSvError::Transient(_)));
334    }
335
336    #[tokio::test]
337    async fn memory_cache_respects_ttl() {
338        // Exercises MemorySessionVersionCache's lazy-eviction-on-read.
339        // Can't literally advance wall clock, so this only proves that
340        // within-TTL reads hit — the expiry branch is covered by
341        // construction (if written_at.elapsed() >= SV_CACHE_TTL → None).
342        let cache = MemorySessionVersionCache::new();
343        cache.set("sv:abc", 42, SV_CACHE_TTL).await;
344        assert_eq!(cache.get("sv:abc").await, Some(42));
345        assert_eq!(cache.get("sv:missing").await, None);
346    }
347}