use super::{AutoRefreshable, ClearTokenError, Credentials, GetTokenError};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::task::JoinHandle;
pub struct AutoRefresh<C: AutoRefreshable> {
refreshable: Arc<C>,
job: JoinHandle<()>,
}
impl<C: AutoRefreshable> AutoRefresh<C> {
pub fn new(credentials: C) -> Self {
let refreshable = Arc::new(credentials);
let refreshable_clone = refreshable.clone();
let job = tokio::spawn(async move {
let started_at = tokio::time::Instant::now();
let mut attempt: u64 = 0;
loop {
attempt += 1;
let elapsed = started_at.elapsed();
let refresh_interval = {
tracing::debug!(
target: "auto_refresh",
attempt,
elapsed_secs = elapsed.as_secs(),
"Refreshing credentials"
);
refreshable_clone.refresh().await
};
tracing::debug!(
target: "auto_refresh",
attempt,
elapsed_secs = started_at.elapsed().as_secs(),
next_refresh_secs = refresh_interval.as_secs_f64(),
"Refresh complete, sleeping until next refresh"
);
tokio::time::sleep(refresh_interval).await;
}
});
Self { refreshable, job }
}
}
#[async_trait]
impl<C: AutoRefreshable> Credentials for AutoRefresh<C> {
type Token = C::Token;
async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
self.refreshable.get_token().await
}
async fn clear_token(&self) -> Result<(), ClearTokenError> {
self.refreshable.clear_token().await
}
}
impl<Credentials: AutoRefreshable> Drop for AutoRefresh<Credentials> {
fn drop(&mut self) {
tracing::debug!(target: "auto_refresh", "Aborting refresh job");
self.job.abort();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
struct ConcurrencyTracker {
current: Arc<AtomicUsize>,
peak: Arc<AtomicUsize>,
}
impl ConcurrencyTracker {
fn new() -> (Self, Arc<AtomicUsize>) {
let peak = Arc::new(AtomicUsize::new(0));
(
Self {
current: Arc::new(AtomicUsize::new(0)),
peak: peak.clone(),
},
peak,
)
}
}
#[async_trait]
impl Credentials for ConcurrencyTracker {
type Token = String;
async fn get_token(&self) -> Result<String, GetTokenError> {
let prev = self.current.fetch_add(1, Ordering::SeqCst);
self.peak.fetch_max(prev + 1, Ordering::SeqCst);
tokio::time::sleep(Duration::from_millis(50)).await;
self.current.fetch_sub(1, Ordering::SeqCst);
Ok("token".to_string())
}
async fn clear_token(&self) -> Result<(), ClearTokenError> {
Ok(())
}
}
#[async_trait]
impl AutoRefreshable for ConcurrencyTracker {
async fn refresh(&self) -> Duration {
Duration::from_secs(300)
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
async fn test_auto_refresh_does_not_serialize_concurrent_calls() {
let (tracker, peak) = ConcurrencyTracker::new();
let creds = Arc::new(AutoRefresh::new(tracker));
let mut handles = vec![];
for _ in 0..10 {
let creds = creds.clone();
handles.push(tokio::spawn(
async move { creds.get_token().await.unwrap() },
));
}
for h in handles {
h.await.unwrap();
}
let peak = peak.load(Ordering::SeqCst);
assert!(
peak > 1,
"Expected concurrent get_token() calls but peak concurrency was {peak}. \
AutoRefresh is serializing calls through its outer Mutex.",
);
}
}