use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use async_trait::async_trait;
use parking_lot::RwLock;
use tokio::sync::Notify;
use crate::CloudError;
pub const DEFAULT_REFRESH_BUFFER: Duration = Duration::from_mins(5);
#[async_trait]
pub trait TokenRefresher<T>: Send + Sync
where
T: Clone + Send + Sync + 'static,
{
async fn refresh(&self) -> Result<TokenSnapshot<T>, CloudError>;
}
#[derive(Clone, Debug)]
pub struct TokenSnapshot<T> {
pub value: T,
pub expires_at: Instant,
}
#[derive(Clone)]
struct TokenState<T> {
value: T,
expires_at: Instant,
}
pub struct CachedTokenProvider<T>
where
T: Clone + Send + Sync + 'static,
{
cached: RwLock<Option<TokenState<T>>>,
refresh_in_progress: AtomicBool,
refresh_done: Notify,
refresher: Arc<dyn TokenRefresher<T>>,
refresh_buffer: Duration,
}
impl<T> CachedTokenProvider<T>
where
T: Clone + Send + Sync + 'static,
{
pub fn new(refresher: Arc<dyn TokenRefresher<T>>) -> Self {
Self::with_refresh_buffer(refresher, DEFAULT_REFRESH_BUFFER)
}
pub fn with_refresh_buffer(
refresher: Arc<dyn TokenRefresher<T>>,
refresh_buffer: Duration,
) -> Self {
Self {
cached: RwLock::new(None),
refresh_in_progress: AtomicBool::new(false),
refresh_done: Notify::new(),
refresher,
refresh_buffer,
}
}
pub async fn current(&self) -> Result<T, CloudError> {
loop {
if let Some(state) = self.read_fresh() {
return Ok(state);
}
if self
.refresh_in_progress
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
let result = self.refresher.refresh().await;
if let Ok(snap) = &result {
*self.cached.write() = Some(TokenState {
value: snap.value.clone(),
expires_at: snap.expires_at,
});
}
self.refresh_in_progress.store(false, Ordering::Release);
self.refresh_done.notify_waiters();
return result.map(|s| s.value);
}
let waiter = self.refresh_done.notified();
tokio::pin!(waiter);
waiter.as_mut().enable();
if let Some(state) = self.read_fresh() {
return Ok(state);
}
waiter.await;
}
}
pub fn invalidate(&self) {
*self.cached.write() = None;
}
fn read_fresh(&self) -> Option<T> {
let snapshot = {
let guard = self.cached.read();
guard.as_ref().cloned()
};
let state = snapshot?;
if Instant::now() + self.refresh_buffer < state.expires_at {
Some(state.value)
} else {
None
}
}
}