1use std::collections::HashMap;
27use std::sync::Arc;
28use std::time::{Duration, Instant};
29
30use async_trait::async_trait;
31use tokio::sync::RwLock;
32
33use crate::oauth::AuthClient;
34
35pub const SV_CACHE_KEY_PREFIX: &str = "sv:";
37
38pub const SV_CACHE_TTL: Duration = Duration::from_secs(60);
40
41#[async_trait]
51pub trait SessionVersionCache: Send + Sync {
52 async fn get(&self, key: &str) -> Option<i64>;
53 async fn set(&self, key: &str, sv: i64, ttl: Duration);
54}
55
56#[async_trait]
67pub trait SessionVersionFetcher: Send + Sync {
68 async fn fetch(&self, ppnum_id: &str, bearer_token: &str) -> Result<i64, FetchError>;
69}
70
71#[derive(Debug, thiserror::Error)]
72#[error("session_version fetch failed: {0}")]
73pub struct FetchError(pub String);
74
75#[derive(Debug, thiserror::Error)]
76pub enum ValidateSvError {
77 #[error("session_version stale: token_sv={token_sv} < current_sv={current_sv}")]
78 Stale { token_sv: i64, current_sv: i64 },
79
80 #[error("session_version lookup transient failure: {0}")]
81 Transient(FetchError),
82}
83
84pub struct MemorySessionVersionCache {
92 inner: Arc<RwLock<HashMap<String, (i64, Instant)>>>,
93}
94
95impl MemorySessionVersionCache {
96 #[must_use]
97 pub fn new() -> Self {
98 Self {
99 inner: Arc::new(RwLock::new(HashMap::new())),
100 }
101 }
102}
103
104impl Default for MemorySessionVersionCache {
105 fn default() -> Self {
106 Self::new()
107 }
108}
109
110#[async_trait]
111impl SessionVersionCache for MemorySessionVersionCache {
112 async fn get(&self, key: &str) -> Option<i64> {
113 let guard = self.inner.read().await;
114 let (sv, written_at) = guard.get(key)?;
115 if written_at.elapsed() >= SV_CACHE_TTL {
116 return None;
117 }
118 Some(*sv)
119 }
120
121 async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
122 let mut guard = self.inner.write().await;
126 guard.insert(key.to_string(), (sv, Instant::now()));
127 }
128}
129
130pub struct HttpUserInfoFetcher {
140 client: Arc<AuthClient>,
141}
142
143impl HttpUserInfoFetcher {
144 #[must_use]
145 pub fn new(client: Arc<AuthClient>) -> Self {
146 Self { client }
147 }
148}
149
150#[async_trait]
151impl SessionVersionFetcher for HttpUserInfoFetcher {
152 async fn fetch(&self, _ppnum_id: &str, bearer_token: &str) -> Result<i64, FetchError> {
153 match self.client.get_user_info(bearer_token).await {
157 Ok(info) => info.session_version.ok_or_else(|| {
158 FetchError(
159 "userinfo response omitted session_version — \
160 token may not be Human-entity"
161 .to_string(),
162 )
163 }),
164 Err(e) => Err(FetchError(format!("userinfo call failed: {e}"))),
165 }
166 }
167}
168
169pub async fn validate_sv(
176 token_sv: Option<i64>,
177 ppnum_id: &str,
178 bearer_token: &str,
179 cache: &dyn SessionVersionCache,
180 fetcher: &dyn SessionVersionFetcher,
181) -> Result<(), ValidateSvError> {
182 let Some(token_sv) = token_sv else {
183 return Ok(());
185 };
186
187 let key = format!("{SV_CACHE_KEY_PREFIX}{ppnum_id}");
188
189 let current_sv = match cache.get(&key).await {
190 Some(v) => v,
191 None => {
192 let fresh = fetcher
193 .fetch(ppnum_id, bearer_token)
194 .await
195 .map_err(ValidateSvError::Transient)?;
196 cache.set(&key, fresh, SV_CACHE_TTL).await;
197 fresh
198 }
199 };
200
201 if token_sv < current_sv {
202 return Err(ValidateSvError::Stale {
203 token_sv,
204 current_sv,
205 });
206 }
207 Ok(())
208}
209
210#[cfg(test)]
211#[allow(clippy::unwrap_used)]
212mod tests {
213 use super::*;
214 use std::sync::Mutex;
215 use std::sync::atomic::{AtomicU64, Ordering};
216
217 const PPNUM_ID: &str = "01HZXY12345678901234567890";
218 const BEARER: &str = "v4.public.placeholder";
219
220 fn cache_key() -> String {
221 format!("{SV_CACHE_KEY_PREFIX}{PPNUM_ID}")
222 }
223
224 struct MockFetcher {
225 sv: i64,
226 calls: AtomicU64,
227 }
228
229 #[async_trait]
230 impl SessionVersionFetcher for MockFetcher {
231 async fn fetch(&self, _: &str, _: &str) -> Result<i64, FetchError> {
232 self.calls.fetch_add(1, Ordering::SeqCst);
233 Ok(self.sv)
234 }
235 }
236
237 struct FailingFetcher;
238
239 #[async_trait]
240 impl SessionVersionFetcher for FailingFetcher {
241 async fn fetch(&self, _: &str, _: &str) -> Result<i64, FetchError> {
242 Err(FetchError("simulated outage".to_string()))
243 }
244 }
245
246 #[derive(Default)]
247 struct RawCache {
248 store: Mutex<HashMap<String, i64>>,
249 }
250
251 #[async_trait]
252 impl SessionVersionCache for RawCache {
253 async fn get(&self, key: &str) -> Option<i64> {
254 self.store.lock().unwrap().get(key).copied()
255 }
256
257 async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
258 self.store.lock().unwrap().insert(key.to_string(), sv);
259 }
260 }
261
262 #[tokio::test]
263 async fn admits_legacy_none() {
264 let cache = RawCache::default();
265 let fetcher = MockFetcher {
266 sv: 999,
267 calls: AtomicU64::new(0),
268 };
269 validate_sv(None, PPNUM_ID, BEARER, &cache, &fetcher)
270 .await
271 .expect("None must admit");
272 assert_eq!(fetcher.calls.load(Ordering::SeqCst), 0);
273 }
274
275 #[tokio::test]
276 async fn admits_current_cache_hit() {
277 let cache = RawCache::default();
278 cache.store.lock().unwrap().insert(cache_key(), 5);
279 let fetcher = MockFetcher {
280 sv: 999,
281 calls: AtomicU64::new(0),
282 };
283 validate_sv(Some(5), PPNUM_ID, BEARER, &cache, &fetcher)
284 .await
285 .unwrap();
286 assert_eq!(fetcher.calls.load(Ordering::SeqCst), 0);
287 }
288
289 #[tokio::test]
290 async fn rejects_stale_cache_hit() {
291 let cache = RawCache::default();
292 cache.store.lock().unwrap().insert(cache_key(), 5);
293 let fetcher = MockFetcher {
294 sv: 999,
295 calls: AtomicU64::new(0),
296 };
297 let err = validate_sv(Some(4), PPNUM_ID, BEARER, &cache, &fetcher)
298 .await
299 .unwrap_err();
300 assert!(matches!(err, ValidateSvError::Stale { token_sv: 4, current_sv: 5 }));
301 }
302
303 #[tokio::test]
304 async fn cache_miss_fetches_and_populates() {
305 let cache = RawCache::default();
306 let fetcher = MockFetcher {
307 sv: 7,
308 calls: AtomicU64::new(0),
309 };
310 validate_sv(Some(7), PPNUM_ID, BEARER, &cache, &fetcher)
311 .await
312 .unwrap();
313 assert_eq!(fetcher.calls.load(Ordering::SeqCst), 1);
314 assert_eq!(cache.store.lock().unwrap().get(&cache_key()), Some(&7));
315
316 validate_sv(Some(7), PPNUM_ID, BEARER, &cache, &fetcher)
317 .await
318 .unwrap();
319 assert_eq!(
320 fetcher.calls.load(Ordering::SeqCst),
321 1,
322 "second call must hit cache"
323 );
324 }
325
326 #[tokio::test]
327 async fn fetch_failure_surfaces_transient() {
328 let cache = RawCache::default();
329 let fetcher = FailingFetcher;
330 let err = validate_sv(Some(1), PPNUM_ID, BEARER, &cache, &fetcher)
331 .await
332 .unwrap_err();
333 assert!(matches!(err, ValidateSvError::Transient(_)));
334 }
335
336 #[tokio::test]
337 async fn memory_cache_respects_ttl() {
338 let cache = MemorySessionVersionCache::new();
343 cache.set("sv:abc", 42, SV_CACHE_TTL).await;
344 assert_eq!(cache.get("sv:abc").await, Some(42));
345 assert_eq!(cache.get("sv:missing").await, None);
346 }
347}