pas-external 3.1.0

Ppoppo Accounts System (PAS) external SDK -- OAuth2 PKCE, PASETO verification, Axum middleware, session liveness
Documentation
//! Break-glass `sv` claim validation (#005 spec).
//!
//! Consumers who accept Human-entity access tokens should validate the
//! `sv` claim against the PAS source of truth to enforce break-glass
//! revocation. Without this check, a token stolen before break-glass
//! would remain valid until its 1-hour TTL expiry.
//!
//! Architecture (cache-then-fetch, matches `paseto-sv-claim.md §R5`):
//!
//! 1. Validator gets `token_sv` from `VerifiedClaims::session_version()`.
//! 2. If `None` (legacy token, AI agent, delegated) → admit (R6 bypass).
//! 3. Look up `sv:{ppnum_id}` in a pluggable [`SessionVersionCache`] —
//!    default is the in-memory [`MemorySessionVersionCache`] (60 s TTL);
//!    consumers that already run KVRocks/Redis can plug in an adapter.
//! 4. Cache miss → [`SessionVersionFetcher`] does an HTTP GET on PAS
//!    `/oauth/userinfo` with the caller's own bearer token. Default
//!    implementation: [`HttpUserInfoFetcher`].
//! 5. Compare: `token_sv < fresh_sv` → reject with
//!    [`ValidateSvError::Stale`]; equal or greater → admit.
//!
//! Fail-closed on fetch failure: a transient DB / network outage surfaces
//! as [`ValidateSvError::Transient`] and the caller rejects the request.
//! Silent admit on transient would defeat break-glass because an attacker
//! could force cache eviction + DB blip to stall validation.

use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};

use async_trait::async_trait;
use tokio::sync::RwLock;

use crate::oauth::AuthClient;

/// Namespace prefix for cache keys. Matches chat-auth and is_admin caches.
pub const SV_CACHE_KEY_PREFIX: &str = "sv:";

/// TTL per `paseto-sv-claim.md §R5`. 60 s, non-configurable by design.
pub const SV_CACHE_TTL: Duration = Duration::from_secs(60);

/// Cache abstraction for `sv:{ppnum_id}` lookups.
///
/// Default implementation is [`MemorySessionVersionCache`]. Consumers
/// that already run KVRocks/Redis can write their own adapter — the
/// `get` / `set` contract is minimal.
///
/// `get` returns `None` on cache miss OR any transient backend error.
/// `set` is best-effort and swallows failures internally (a failed set
/// only costs us one extra fetch on the next validate).
#[async_trait]
pub trait SessionVersionCache: Send + Sync {
    async fn get(&self, key: &str) -> Option<i64>;
    async fn set(&self, key: &str, sv: i64, ttl: Duration);
}

/// Fresh-read source for the cache-miss path.
///
/// The SDK's default implementation ([`HttpUserInfoFetcher`]) calls
/// `/oauth/userinfo` with the caller's own bearer token. Consumers with
/// direct DB access (e.g., cross-schema `SELECT` in a monolith) can write
/// an adapter that bypasses HTTP.
///
/// `bearer_token` is the ACCESS TOKEN whose `sv` claim is being
/// validated — it's always in hand at the validate call site, so
/// threading it through the trait avoids a second auth layer.
#[async_trait]
pub trait SessionVersionFetcher: Send + Sync {
    async fn fetch(&self, ppnum_id: &str, bearer_token: &str) -> Result<i64, FetchError>;
}

#[derive(Debug, thiserror::Error)]
#[error("session_version fetch failed: {0}")]
pub struct FetchError(pub String);

#[derive(Debug, thiserror::Error)]
pub enum ValidateSvError {
    #[error("session_version stale: token_sv={token_sv} < current_sv={current_sv}")]
    Stale { token_sv: i64, current_sv: i64 },

    #[error("session_version lookup transient failure: {0}")]
    Transient(FetchError),
}

/// In-memory [`SessionVersionCache`]. Default choice for SDK consumers.
///
/// `tokio::sync::RwLock<HashMap<String, (sv, Instant)>>` with lazy
/// eviction on read: entries past their TTL are treated as miss.
/// Production consumers with many pods may want to plug in a shared
/// cache (Redis, KVRocks) so a break-glass on one pod converges on all
/// pods within the same 60 s window; the in-memory default is per-pod.
pub struct MemorySessionVersionCache {
    inner: Arc<RwLock<HashMap<String, (i64, Instant)>>>,
}

impl MemorySessionVersionCache {
    #[must_use]
    pub fn new() -> Self {
        Self {
            inner: Arc::new(RwLock::new(HashMap::new())),
        }
    }
}

impl Default for MemorySessionVersionCache {
    fn default() -> Self {
        Self::new()
    }
}

#[async_trait]
impl SessionVersionCache for MemorySessionVersionCache {
    async fn get(&self, key: &str) -> Option<i64> {
        let guard = self.inner.read().await;
        let (sv, written_at) = guard.get(key)?;
        if written_at.elapsed() >= SV_CACHE_TTL {
            return None;
        }
        Some(*sv)
    }

    async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
        // TTL is governed by the SV_CACHE_TTL constant; ignore the param
        // so callers can't accidentally drift this substrate's TTL away
        // from the contract.
        let mut guard = self.inner.write().await;
        guard.insert(key.to_string(), (sv, Instant::now()));
    }
}

/// [`SessionVersionFetcher`] backed by an [`AuthClient`]'s userinfo call.
///
/// Uses the caller's bearer token to GET `/oauth/userinfo` and reads
/// `response.session_version`. Returns [`FetchError`] if the HTTP call
/// fails, the endpoint returns non-2xx, the response is unparseable, or
/// the response omits `session_version` (which signals the caller is
/// not a Human-entity account — but the validator should only ever
/// call the fetcher for tokens that carried a `sv` claim in the first
/// place, so `None` here is a protocol violation).
pub struct HttpUserInfoFetcher {
    client: Arc<AuthClient>,
}

impl HttpUserInfoFetcher {
    #[must_use]
    pub fn new(client: Arc<AuthClient>) -> Self {
        Self { client }
    }
}

#[async_trait]
impl SessionVersionFetcher for HttpUserInfoFetcher {
    async fn fetch(&self, _ppnum_id: &str, bearer_token: &str) -> Result<i64, FetchError> {
        // `get_user_info` surfaces all failure modes (transport, 4xx/5xx,
        // parse) via `Error::OAuth { operation, status, detail }`. Any
        // Err here is treated as transient — fail-closed at the caller.
        match self.client.get_user_info(bearer_token).await {
            Ok(info) => info.session_version.ok_or_else(|| {
                FetchError(
                    "userinfo response omitted session_version — \
                     token may not be Human-entity"
                        .to_string(),
                )
            }),
            Err(e) => Err(FetchError(format!("userinfo call failed: {e}"))),
        }
    }
}

/// Validates a token's `sv` claim against the cached / fresh current
/// value.
///
/// See [module docs] for the full algorithm.
///
/// [module docs]: self
pub async fn validate_sv(
    token_sv: Option<i64>,
    ppnum_id: &str,
    bearer_token: &str,
    cache: &dyn SessionVersionCache,
    fetcher: &dyn SessionVersionFetcher,
) -> Result<(), ValidateSvError> {
    let Some(token_sv) = token_sv else {
        // R6 legacy bypass. Bounded by token TTL + refresh cycle.
        return Ok(());
    };

    let key = format!("{SV_CACHE_KEY_PREFIX}{ppnum_id}");

    let current_sv = match cache.get(&key).await {
        Some(v) => v,
        None => {
            let fresh = fetcher
                .fetch(ppnum_id, bearer_token)
                .await
                .map_err(ValidateSvError::Transient)?;
            cache.set(&key, fresh, SV_CACHE_TTL).await;
            fresh
        }
    };

    if token_sv < current_sv {
        return Err(ValidateSvError::Stale {
            token_sv,
            current_sv,
        });
    }
    Ok(())
}

#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
    use super::*;
    use std::sync::Mutex;
    use std::sync::atomic::{AtomicU64, Ordering};

    const PPNUM_ID: &str = "01HZXY12345678901234567890";
    const BEARER: &str = "v4.public.placeholder";

    fn cache_key() -> String {
        format!("{SV_CACHE_KEY_PREFIX}{PPNUM_ID}")
    }

    struct MockFetcher {
        sv: i64,
        calls: AtomicU64,
    }

    #[async_trait]
    impl SessionVersionFetcher for MockFetcher {
        async fn fetch(&self, _: &str, _: &str) -> Result<i64, FetchError> {
            self.calls.fetch_add(1, Ordering::SeqCst);
            Ok(self.sv)
        }
    }

    struct FailingFetcher;

    #[async_trait]
    impl SessionVersionFetcher for FailingFetcher {
        async fn fetch(&self, _: &str, _: &str) -> Result<i64, FetchError> {
            Err(FetchError("simulated outage".to_string()))
        }
    }

    #[derive(Default)]
    struct RawCache {
        store: Mutex<HashMap<String, i64>>,
    }

    #[async_trait]
    impl SessionVersionCache for RawCache {
        async fn get(&self, key: &str) -> Option<i64> {
            self.store.lock().unwrap().get(key).copied()
        }

        async fn set(&self, key: &str, sv: i64, _ttl: Duration) {
            self.store.lock().unwrap().insert(key.to_string(), sv);
        }
    }

    #[tokio::test]
    async fn admits_legacy_none() {
        let cache = RawCache::default();
        let fetcher = MockFetcher {
            sv: 999,
            calls: AtomicU64::new(0),
        };
        validate_sv(None, PPNUM_ID, BEARER, &cache, &fetcher)
            .await
            .expect("None must admit");
        assert_eq!(fetcher.calls.load(Ordering::SeqCst), 0);
    }

    #[tokio::test]
    async fn admits_current_cache_hit() {
        let cache = RawCache::default();
        cache.store.lock().unwrap().insert(cache_key(), 5);
        let fetcher = MockFetcher {
            sv: 999,
            calls: AtomicU64::new(0),
        };
        validate_sv(Some(5), PPNUM_ID, BEARER, &cache, &fetcher)
            .await
            .unwrap();
        assert_eq!(fetcher.calls.load(Ordering::SeqCst), 0);
    }

    #[tokio::test]
    async fn rejects_stale_cache_hit() {
        let cache = RawCache::default();
        cache.store.lock().unwrap().insert(cache_key(), 5);
        let fetcher = MockFetcher {
            sv: 999,
            calls: AtomicU64::new(0),
        };
        let err = validate_sv(Some(4), PPNUM_ID, BEARER, &cache, &fetcher)
            .await
            .unwrap_err();
        assert!(matches!(err, ValidateSvError::Stale { token_sv: 4, current_sv: 5 }));
    }

    #[tokio::test]
    async fn cache_miss_fetches_and_populates() {
        let cache = RawCache::default();
        let fetcher = MockFetcher {
            sv: 7,
            calls: AtomicU64::new(0),
        };
        validate_sv(Some(7), PPNUM_ID, BEARER, &cache, &fetcher)
            .await
            .unwrap();
        assert_eq!(fetcher.calls.load(Ordering::SeqCst), 1);
        assert_eq!(cache.store.lock().unwrap().get(&cache_key()), Some(&7));

        validate_sv(Some(7), PPNUM_ID, BEARER, &cache, &fetcher)
            .await
            .unwrap();
        assert_eq!(
            fetcher.calls.load(Ordering::SeqCst),
            1,
            "second call must hit cache"
        );
    }

    #[tokio::test]
    async fn fetch_failure_surfaces_transient() {
        let cache = RawCache::default();
        let fetcher = FailingFetcher;
        let err = validate_sv(Some(1), PPNUM_ID, BEARER, &cache, &fetcher)
            .await
            .unwrap_err();
        assert!(matches!(err, ValidateSvError::Transient(_)));
    }

    #[tokio::test]
    async fn memory_cache_respects_ttl() {
        // Exercises MemorySessionVersionCache's lazy-eviction-on-read.
        // Can't literally advance wall clock, so this only proves that
        // within-TTL reads hit — the expiry branch is covered by
        // construction (if written_at.elapsed() >= SV_CACHE_TTL → None).
        let cache = MemorySessionVersionCache::new();
        cache.set("sv:abc", 42, SV_CACHE_TTL).await;
        assert_eq!(cache.get("sv:abc").await, Some(42));
        assert_eq!(cache.get("sv:missing").await, None);
    }
}