use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use ppoppo_clock::ArcClock;
use ppoppo_clock::native::WallClock;
use tokio::sync::Mutex;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum TokenCacheError {
#[error("token fetch failed: {0}")]
Fetch(String),
#[error("token response is malformed: {0}")]
Malformed(String),
}
#[async_trait]
pub trait TokenSource: Send + Sync {
async fn fetch_token(&self) -> Result<(String, Duration), TokenCacheError>;
}
#[derive(Debug, Clone)]
pub struct TokenCacheConfig {
pub refresh_skew: Duration,
}
impl Default for TokenCacheConfig {
fn default() -> Self {
Self { refresh_skew: Duration::from_secs(60) }
}
}
struct CacheInner {
cached: Option<(String, i64)>,
}
pub struct TokenCache {
data: Mutex<CacheInner>,
refresh: Mutex<()>,
config: TokenCacheConfig,
source: Box<dyn TokenSource>,
clock: ArcClock,
}
impl TokenCache {
pub fn new(source: Box<dyn TokenSource>, config: TokenCacheConfig) -> Self {
Self {
data: Mutex::new(CacheInner { cached: None }),
refresh: Mutex::new(()),
config,
source,
clock: Arc::new(WallClock),
}
}
#[must_use]
pub fn with_clock(mut self, clock: ArcClock) -> Self {
self.clock = clock;
self
}
pub async fn get(&self) -> Result<String, TokenCacheError> {
{
let inner = self.data.lock().await;
if let Some((token, exp)) = &inner.cached {
if self.clock.now_unix_millis() < *exp {
return Ok(token.clone());
}
}
}
let _refresh_guard = self.refresh.lock().await;
{
let inner = self.data.lock().await;
if let Some((token, exp)) = &inner.cached {
if self.clock.now_unix_millis() < *exp {
return Ok(token.clone());
}
}
}
let (token, ttl) = self.source.fetch_token().await?;
let skewed_ttl = ttl.saturating_sub(self.config.refresh_skew);
let exp = self.clock.now_unix_millis() + skewed_ttl.as_millis() as i64;
self.data.lock().await.cached = Some((token.clone(), exp));
Ok(token)
}
}
#[cfg(feature = "client-credentials")]
mod credentials;
#[cfg(feature = "client-credentials")]
pub use credentials::ClientCredentialsSource;
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
struct CountingSource {
count: Arc<AtomicUsize>,
ttl: Duration,
token: &'static str,
}
#[async_trait]
impl TokenSource for CountingSource {
async fn fetch_token(&self) -> Result<(String, Duration), TokenCacheError> {
self.count.fetch_add(1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(10)).await;
Ok((self.token.to_string(), self.ttl))
}
}
#[tokio::test]
async fn token_cache_returns_cached_until_near_expiry() {
let count = Arc::new(AtomicUsize::new(0));
let source = CountingSource {
count: Arc::clone(&count),
ttl: Duration::from_secs(3600),
token: "tok-abc",
};
let cache = TokenCache::new(Box::new(source), TokenCacheConfig::default());
let t1 = cache.get().await.unwrap();
let t2 = cache.get().await.unwrap();
assert_eq!(t1, "tok-abc");
assert_eq!(t2, "tok-abc");
assert_eq!(count.load(Ordering::SeqCst), 1, "source called more than once for valid cache");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn token_cache_single_flights_concurrent_refresh() {
let count = Arc::new(AtomicUsize::new(0));
let source = CountingSource {
count: Arc::clone(&count),
ttl: Duration::from_secs(3600),
token: "tok-xyz",
};
let cache = Arc::new(TokenCache::new(Box::new(source), TokenCacheConfig::default()));
let mut handles = Vec::new();
for _ in 0..8 {
let cache = Arc::clone(&cache);
handles.push(tokio::spawn(async move { cache.get().await }));
}
for h in handles {
assert_eq!(h.await.unwrap().unwrap(), "tok-xyz");
}
assert_eq!(count.load(Ordering::SeqCst), 1, "source called more than once (single-flight broken)");
}
}