use std::num::NonZeroUsize;
use std::sync::Arc;
use lru::LruCache;
use serde::Deserialize;
use crate::prelude::*;
const SAFETY_MARGIN_SECS: i64 = 60;
const DEFAULT_CAPACITY: NonZeroUsize = match NonZeroUsize::new(256) {
Some(n) => n,
None => NonZeroUsize::MIN,
};
#[derive(Debug, Clone)]
struct CachedAccessToken {
token: Box<str>,
valid_until: Timestamp,
}
type TokenCacheKey = (TnId, Box<str>);
type TokenCacheInner = LruCache<TokenCacheKey, CachedAccessToken>;
#[derive(Debug)]
pub struct ProxyTokenCache {
entries: Arc<parking_lot::Mutex<TokenCacheInner>>,
}
impl ProxyTokenCache {
pub fn new() -> Self {
Self::with_capacity(DEFAULT_CAPACITY)
}
pub fn with_capacity(capacity: NonZeroUsize) -> Self {
Self { entries: Arc::new(parking_lot::Mutex::new(LruCache::new(capacity))) }
}
pub fn get(&self, tn_id: TnId, id_tag: &str) -> Option<Box<str>> {
let mut cache = self.entries.lock();
let now = Timestamp::now();
cache
.get(&(tn_id, Box::<str>::from(id_tag)))
.filter(|e| e.valid_until.0 > now.0)
.map(|e| e.token.clone())
}
pub fn insert(&self, tn_id: TnId, id_tag: &str, token: Box<str>) {
let valid_until = match read_jwt_exp(&token) {
Ok(exp) => Timestamp(exp.0 - SAFETY_MARGIN_SECS),
Err(e) => {
warn!(id_tag = %id_tag, error = %e,
"failed to read access-token exp; using minimal cache TTL");
Timestamp::from_now(60)
}
};
let mut cache = self.entries.lock();
cache.put((tn_id, Box::<str>::from(id_tag)), CachedAccessToken { token, valid_until });
}
pub fn invalidate(&self, tn_id: TnId, id_tag: &str) {
let mut cache = self.entries.lock();
cache.pop(&(tn_id, Box::<str>::from(id_tag)));
}
}
impl Default for ProxyTokenCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Deserialize)]
struct AccessTokenExp {
exp: i64,
}
fn read_jwt_exp(jwt: &str) -> ClResult<Timestamp> {
let claim: AccessTokenExp = cloudillo_types::utils::decode_jwt_no_verify(jwt)?;
Ok(Timestamp(claim.exp))
}