ppoppo-sdk-core 0.2.0

Internal shared primitives for the Ppoppo SDK family (pas-external, pas-plims, pcs-external) — verifier port, audit trait, session liveness port, OIDC discovery, perimeter Bearer-auth Layer kit, identity types. Not a stable public API; do not depend on this crate directly. Consume the SDK crates that re-export from it (e.g. `pas-external`).
Documentation
//! JWT/token retention layer for Ppoppo client SDKs.
//!
//! [`TokenCache`] wraps a [`TokenSource`] and adds near-expiry refresh with
//! single-flight semantics: when multiple tasks call [`TokenCache::get`]
//! concurrently during an expiry window, only one performs the fetch; the
//! rest queue on the refresh lock and then reuse the result.
//!
//! ## Typical wiring
//!
//! ```ignore
//! use ppoppo_sdk_core::token_cache::{
//!     ClientCredentialsSource, TokenCache, TokenCacheConfig,
//! };
//!
//! let source = ClientCredentialsSource::new(token_url, client_id, client_secret);
//! let cache = Arc::new(TokenCache::new(Box::new(source), TokenCacheConfig::default()));
//! let token: String = cache.get().await?;
//! ```

use std::sync::Arc;
use std::time::Duration;

use async_trait::async_trait;
use ppoppo_clock::ArcClock;
use ppoppo_clock::native::WallClock;
use tokio::sync::Mutex;

/// Errors returned by [`TokenCache`] and [`TokenSource`] implementations.
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TokenCacheError {
    #[error("token fetch failed: {0}")]
    Fetch(String),
    #[error("token response is malformed: {0}")]
    Malformed(String),
}

/// Object-safe async port for acquiring a fresh `(token, ttl)` pair.
///
/// Implement this for custom token acquisition strategies. The built-in
/// implementation is [`ClientCredentialsSource`].
#[async_trait]
pub trait TokenSource: Send + Sync {
    /// Fetch a fresh token. Returns the raw JWT string and its lifetime.
    /// The cache subtracts [`TokenCacheConfig::refresh_skew`] from the TTL
    /// before storing the expiry, so the token is treated as stale slightly
    /// before the server considers it expired.
    async fn fetch_token(&self) -> Result<(String, Duration), TokenCacheError>;
}

/// Configuration for [`TokenCache`].
#[derive(Debug, Clone)]
pub struct TokenCacheConfig {
    /// How far before a token's actual expiry to treat it as stale.
    /// Defaults to 60 s — guards against clock skew and network latency on
    /// the refresh call.
    pub refresh_skew: Duration,
}

impl Default for TokenCacheConfig {
    fn default() -> Self {
        Self { refresh_skew: Duration::from_secs(60) }
    }
}

struct CacheInner {
    cached: Option<(String, i64)>,
}

/// Thread-safe JWT cache with near-expiry refresh and single-flight semantics.
///
/// Wrap in `Arc<TokenCache>` and share across tasks.
pub struct TokenCache {
    data: Mutex<CacheInner>,
    /// Serialises concurrent refreshes (single-flight pattern).
    /// Only the task that acquires this lock performs the HTTP round-trip;
    /// the rest wait, then double-check the cache before deciding to fetch.
    refresh: Mutex<()>,
    config: TokenCacheConfig,
    source: Box<dyn TokenSource>,
    clock: ArcClock,
}

impl TokenCache {
    /// Create a new cache backed by `source`.
    pub fn new(source: Box<dyn TokenSource>, config: TokenCacheConfig) -> Self {
        Self {
            data: Mutex::new(CacheInner { cached: None }),
            refresh: Mutex::new(()),
            config,
            source,
            clock: Arc::new(WallClock),
        }
    }

    #[must_use]
    pub fn with_clock(mut self, clock: ArcClock) -> Self {
        self.clock = clock;
        self
    }

    /// Return the current token, refreshing via the source if near-expired.
    pub async fn get(&self) -> Result<String, TokenCacheError> {
        // Fast path: cache is fresh.
        {
            let inner = self.data.lock().await;
            if let Some((token, exp)) = &inner.cached {
                if self.clock.now_unix_millis() < *exp {
                    return Ok(token.clone());
                }
            }
        }

        // Slow path: acquire the refresh lock (serialises concurrent refreshes).
        let _refresh_guard = self.refresh.lock().await;

        // Double-check: a previous holder may have already refreshed.
        {
            let inner = self.data.lock().await;
            if let Some((token, exp)) = &inner.cached {
                if self.clock.now_unix_millis() < *exp {
                    return Ok(token.clone());
                }
            }
        }

        // We are the refresh leader. Fetch a fresh token.
        let (token, ttl) = self.source.fetch_token().await?;
        let skewed_ttl = ttl.saturating_sub(self.config.refresh_skew);
        let exp = self.clock.now_unix_millis() + skewed_ttl.as_millis() as i64;
        self.data.lock().await.cached = Some((token.clone(), exp));

        Ok(token)
    }
}

#[cfg(feature = "client-credentials")]
mod credentials;

#[cfg(feature = "client-credentials")]
pub use credentials::ClientCredentialsSource;

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
    use super::*;
    use std::sync::{
        Arc,
        atomic::{AtomicUsize, Ordering},
    };

    struct CountingSource {
        count: Arc<AtomicUsize>,
        ttl: Duration,
        token: &'static str,
    }

    #[async_trait]
    impl TokenSource for CountingSource {
        async fn fetch_token(&self) -> Result<(String, Duration), TokenCacheError> {
            self.count.fetch_add(1, Ordering::SeqCst);
            // Simulate a brief network hop so concurrent callers actually
            // overlap on the refresh lock.
            tokio::time::sleep(Duration::from_millis(10)).await;
            Ok((self.token.to_string(), self.ttl))
        }
    }

    #[tokio::test]
    async fn token_cache_returns_cached_until_near_expiry() {
        let count = Arc::new(AtomicUsize::new(0));
        let source = CountingSource {
            count: Arc::clone(&count),
            ttl: Duration::from_secs(3600),
            token: "tok-abc",
        };
        let cache = TokenCache::new(Box::new(source), TokenCacheConfig::default());

        let t1 = cache.get().await.unwrap();
        let t2 = cache.get().await.unwrap();

        assert_eq!(t1, "tok-abc");
        assert_eq!(t2, "tok-abc");
        assert_eq!(count.load(Ordering::SeqCst), 1, "source called more than once for valid cache");
    }

    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn token_cache_single_flights_concurrent_refresh() {
        let count = Arc::new(AtomicUsize::new(0));
        let source = CountingSource {
            count: Arc::clone(&count),
            ttl: Duration::from_secs(3600),
            token: "tok-xyz",
        };
        let cache = Arc::new(TokenCache::new(Box::new(source), TokenCacheConfig::default()));

        let mut handles = Vec::new();
        for _ in 0..8 {
            let cache = Arc::clone(&cache);
            handles.push(tokio::spawn(async move { cache.get().await }));
        }
        for h in handles {
            assert_eq!(h.await.unwrap().unwrap(), "tok-xyz");
        }

        // All 8 concurrent callers must share one fetch (single-flight).
        assert_eq!(count.load(Ordering::SeqCst), 1, "source called more than once (single-flight broken)");
    }
}