use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::RwLock;
use crate::oauth::AuthClient;
pub const SV_CACHE_KEY_PREFIX: &str = "sv:";
pub const SV_CACHE_TTL: Duration = Duration::from_secs(60);
#[async_trait]
pub trait SessionVersionCache: Send + Sync {
async fn get(&self, key: &str) -> Option<i64>;
async fn set(&self, key: &str, sv: i64, ttl: Duration);
}
#[async_trait]
pub trait SessionVersionFetcher: Send + Sync {
async fn fetch(&self, ppnum_id: &str, bearer_token: &str) -> Result<i64, FetchError>;
}
#[derive(Debug, thiserror::Error)]
#[error("session_version fetch failed: {0}")]
pub struct FetchError(pub String);
#[derive(Debug, thiserror::Error)]
pub enum ValidateSvError {
#[error("session_version stale: token_sv={token_sv} < current_sv={current_sv}")]
Stale { token_sv: i64, current_sv: i64 },
#[error("session_version lookup transient failure: {0}")]
Transient(FetchError),
}
pub struct MemorySessionVersionCache {
inner: Arc<RwLock<HashMap<String, (i64, Instant)>>>,
}
impl MemorySessionVersionCache {
#[must_use]
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for MemorySessionVersionCache {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl SessionVersionCache for MemorySessionVersionCache {
async fn get(&self, key: &str) -> Option<i64> {
let guard = self.inner.read().await;
let (sv, written_at) = guard.get(key)?;
if written_at.elapsed() >= SV_CACHE_TTL {
return None;
}
Some(*sv)
}
async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
let mut guard = self.inner.write().await;
guard.insert(key.to_string(), (sv, Instant::now()));
}
}
pub struct HttpUserInfoFetcher {
client: Arc<AuthClient>,
}
impl HttpUserInfoFetcher {
#[must_use]
pub fn new(client: Arc<AuthClient>) -> Self {
Self { client }
}
}
#[async_trait]
impl SessionVersionFetcher for HttpUserInfoFetcher {
async fn fetch(&self, _ppnum_id: &str, bearer_token: &str) -> Result<i64, FetchError> {
match self.client.get_user_info(bearer_token).await {
Ok(info) => info.session_version.ok_or_else(|| {
FetchError(
"userinfo response omitted session_version — \
token may not be Human-entity"
.to_string(),
)
}),
Err(e) => Err(FetchError(format!("userinfo call failed: {e}"))),
}
}
}
pub async fn validate_sv(
token_sv: Option<i64>,
ppnum_id: &str,
bearer_token: &str,
cache: &dyn SessionVersionCache,
fetcher: &dyn SessionVersionFetcher,
) -> Result<(), ValidateSvError> {
let Some(token_sv) = token_sv else {
return Ok(());
};
let key = format!("{SV_CACHE_KEY_PREFIX}{ppnum_id}");
let current_sv = match cache.get(&key).await {
Some(v) => v,
None => {
let fresh = fetcher
.fetch(ppnum_id, bearer_token)
.await
.map_err(ValidateSvError::Transient)?;
cache.set(&key, fresh, SV_CACHE_TTL).await;
fresh
}
};
if token_sv < current_sv {
return Err(ValidateSvError::Stale {
token_sv,
current_sv,
});
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::sync::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
const PPNUM_ID: &str = "01HZXY12345678901234567890";
const BEARER: &str = "v4.public.placeholder";
fn cache_key() -> String {
format!("{SV_CACHE_KEY_PREFIX}{PPNUM_ID}")
}
struct MockFetcher {
sv: i64,
calls: AtomicU64,
}
#[async_trait]
impl SessionVersionFetcher for MockFetcher {
async fn fetch(&self, _: &str, _: &str) -> Result<i64, FetchError> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok(self.sv)
}
}
struct FailingFetcher;
#[async_trait]
impl SessionVersionFetcher for FailingFetcher {
async fn fetch(&self, _: &str, _: &str) -> Result<i64, FetchError> {
Err(FetchError("simulated outage".to_string()))
}
}
#[derive(Default)]
struct RawCache {
store: Mutex<HashMap<String, i64>>,
}
#[async_trait]
impl SessionVersionCache for RawCache {
async fn get(&self, key: &str) -> Option<i64> {
self.store.lock().unwrap().get(key).copied()
}
async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
self.store.lock().unwrap().insert(key.to_string(), sv);
}
}
#[tokio::test]
async fn admits_legacy_none() {
let cache = RawCache::default();
let fetcher = MockFetcher {
sv: 999,
calls: AtomicU64::new(0),
};
validate_sv(None, PPNUM_ID, BEARER, &cache, &fetcher)
.await
.expect("None must admit");
assert_eq!(fetcher.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn admits_current_cache_hit() {
let cache = RawCache::default();
cache.store.lock().unwrap().insert(cache_key(), 5);
let fetcher = MockFetcher {
sv: 999,
calls: AtomicU64::new(0),
};
validate_sv(Some(5), PPNUM_ID, BEARER, &cache, &fetcher)
.await
.unwrap();
assert_eq!(fetcher.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn rejects_stale_cache_hit() {
let cache = RawCache::default();
cache.store.lock().unwrap().insert(cache_key(), 5);
let fetcher = MockFetcher {
sv: 999,
calls: AtomicU64::new(0),
};
let err = validate_sv(Some(4), PPNUM_ID, BEARER, &cache, &fetcher)
.await
.unwrap_err();
assert!(matches!(err, ValidateSvError::Stale { token_sv: 4, current_sv: 5 }));
}
#[tokio::test]
async fn cache_miss_fetches_and_populates() {
let cache = RawCache::default();
let fetcher = MockFetcher {
sv: 7,
calls: AtomicU64::new(0),
};
validate_sv(Some(7), PPNUM_ID, BEARER, &cache, &fetcher)
.await
.unwrap();
assert_eq!(fetcher.calls.load(Ordering::SeqCst), 1);
assert_eq!(cache.store.lock().unwrap().get(&cache_key()), Some(&7));
validate_sv(Some(7), PPNUM_ID, BEARER, &cache, &fetcher)
.await
.unwrap();
assert_eq!(
fetcher.calls.load(Ordering::SeqCst),
1,
"second call must hit cache"
);
}
#[tokio::test]
async fn fetch_failure_surfaces_transient() {
let cache = RawCache::default();
let fetcher = FailingFetcher;
let err = validate_sv(Some(1), PPNUM_ID, BEARER, &cache, &fetcher)
.await
.unwrap_err();
assert!(matches!(err, ValidateSvError::Transient(_)));
}
#[tokio::test]
async fn memory_cache_respects_ttl() {
let cache = MemorySessionVersionCache::new();
cache.set("sv:abc", 42, SV_CACHE_TTL).await;
assert_eq!(cache.get("sv:abc").await, Some(42));
assert_eq!(cache.get("sv:missing").await, None);
}
}