use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone)]
pub struct AccessToken {
pub token: String,
pub expires_at: u64,
}
impl AccessToken {
pub fn new(token: impl Into<String>, expires_at: u64) -> Self {
Self {
token: token.into(),
expires_at,
}
}
pub fn seconds_remaining(&self) -> u64 {
let now = now_secs();
self.expires_at.saturating_sub(now)
}
pub fn expires_soon(&self, margin_secs: u64) -> bool {
let now = now_secs();
now + margin_secs >= self.expires_at
}
}
#[derive(Debug, Clone)]
pub struct CachedToken {
inner: Arc<RwLock<Option<AccessToken>>>,
}
const MIN_REMAINING_SECS: u64 = 300;
impl Default for CachedToken {
fn default() -> Self {
Self::new()
}
}
impl CachedToken {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(None)),
}
}
pub async fn get(&self) -> Option<AccessToken> {
let guard = self.inner.read().await;
if let Some(ref tok) = *guard
&& !tok.expires_soon(MIN_REMAINING_SECS)
{
return Some(tok.clone());
}
None
}
pub async fn set(&self, token: AccessToken) {
let mut guard = self.inner.write().await;
*guard = Some(token);
}
pub async fn clear(&self) {
let mut guard = self.inner.write().await;
*guard = None;
}
}
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn token_seconds_remaining_future() {
let now = now_secs();
let tok = AccessToken::new("t", now + 3600);
assert!(tok.seconds_remaining() > 3500);
}
#[test]
fn token_seconds_remaining_past() {
let tok = AccessToken::new("t", 0);
assert_eq!(tok.seconds_remaining(), 0);
}
#[test]
fn token_expires_soon_short_margin() {
let now = now_secs();
let tok = AccessToken::new("t", now + 100);
assert!(tok.expires_soon(200));
assert!(!tok.expires_soon(50));
}
#[tokio::test]
async fn cached_token_empty() {
let cache = CachedToken::new();
assert!(cache.get().await.is_none());
}
#[tokio::test]
async fn cached_token_valid_token() {
let now = now_secs();
let cache = CachedToken::new();
cache.set(AccessToken::new("abc", now + 3600)).await;
let tok = cache.get().await.expect("should have token");
assert_eq!(tok.token, "abc");
}
#[tokio::test]
async fn cached_token_expired_token_not_returned() {
let cache = CachedToken::new();
let now = now_secs();
cache.set(AccessToken::new("stale", now + 1)).await;
assert!(cache.get().await.is_none());
}
#[tokio::test]
async fn cached_token_clear() {
let now = now_secs();
let cache = CachedToken::new();
cache.set(AccessToken::new("x", now + 3600)).await;
cache.clear().await;
assert!(cache.get().await.is_none());
}
}