use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::mpsc::Sender;
use tokio::time::{Duration, interval, sleep};
use self::errors::Result;
mod errors;
#[derive(Debug, Clone)]
pub struct AsyncTokenProvider<T> {
inner: T,
cached_token: Arc<Mutex<String>>,
interval: u64,
}
#[async_trait]
pub trait TokenProvider: Send {
async fn access_token(&mut self) -> Result<String>;
}
#[async_trait]
impl TokenProvider for crate::serv_account::ServiceAccount {
async fn access_token(&mut self) -> Result<String> {
Ok(self.access_token().await?)
}
}
#[async_trait]
impl TokenProvider for crate::app::Auth {
async fn access_token(&mut self) -> Result<String> {
Ok(self.access_token().await?)
}
}
#[async_trait]
pub trait Watcher {
async fn watch_updates(&mut self, tx: Sender<String>, interval_sec: u64);
}
#[async_trait]
impl<T: TokenProvider> Watcher for T {
async fn watch_updates(&mut self, tx: Sender<String>, interval_sec: u64) {
let mut interval = interval(Duration::from_secs(interval_sec));
let retries = 3;
let mut attempt = 0;
loop {
let res = self.access_token().await;
match send_token(res, &tx).await {
Ok(_) => {}
Err(err) => {
if attempt == retries {
log::error!("{}", err);
break;
}
attempt += 1;
let backoff = 1 << attempt;
let delay = Duration::from_secs(backoff);
log::error!("{}. retry in: {}s", err, backoff);
sleep(delay).await;
continue;
}
}
attempt = 0;
interval.tick().await;
}
}
}
async fn send_token(access_token_res: Result<String>, tx: &Sender<String>) -> Result<()> {
Ok(tx.send(access_token_res?).await?)
}
impl<T> AsyncTokenProvider<T>
where
T: Watcher + Clone + Send + 'static,
{
pub fn new(inner: T) -> Self {
Self {
inner,
cached_token: Arc::new(Mutex::new(String::new())),
interval: 60,
}
}
pub fn with_interval(mut self, interval: u64) -> Self {
self.interval = interval;
self
}
pub fn access_token(&self) -> Result<String> {
Ok(self.cached_token.try_lock()?.clone())
}
pub async fn watch_updates(&self) {
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
let mut watcher = self.inner.clone();
let cached_token = Arc::clone(&self.cached_token);
let interval = self.interval;
tokio::spawn(async move {
watcher.watch_updates(tx, interval).await;
});
tokio::spawn(async move {
while let Some(token) = rx.recv().await {
log::debug!("access token refreshed");
let mut cached_token = cached_token.lock().await;
*cached_token = token;
}
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[derive(Clone)]
struct FakeProvider {
calls: Arc<AtomicU32>,
fail_first_n: u32,
}
impl FakeProvider {
fn new() -> Self {
Self {
calls: Arc::new(AtomicU32::new(0)),
fail_first_n: 0,
}
}
fn with_initial_failures(n: u32) -> Self {
Self {
calls: Arc::new(AtomicU32::new(0)),
fail_first_n: n,
}
}
}
#[async_trait]
impl TokenProvider for FakeProvider {
async fn access_token(&mut self) -> Result<String> {
let call = self.calls.fetch_add(1, Ordering::SeqCst);
if call < self.fail_first_n {
Err(errors::TokenProviderError::SendError(
tokio::sync::mpsc::error::SendError("simulated".to_string()),
))
} else {
Ok(format!("tok-{}", call))
}
}
}
#[test]
fn new_defaults_to_sixty_second_interval() {
let provider = AsyncTokenProvider::new(FakeProvider::new());
assert_eq!(provider.interval, 60);
}
#[test]
fn with_interval_overrides_default() {
let provider = AsyncTokenProvider::new(FakeProvider::new()).with_interval(5);
assert_eq!(provider.interval, 5);
}
#[test]
fn access_token_returns_empty_string_initially() {
let provider = AsyncTokenProvider::new(FakeProvider::new());
assert_eq!(provider.access_token().unwrap(), "");
}
#[tokio::test(start_paused = true)]
async fn watch_updates_populates_cache_from_provider() {
let provider = AsyncTokenProvider::new(FakeProvider::new()).with_interval(60);
provider.watch_updates().await;
for _ in 0..32 {
tokio::task::yield_now().await;
if let Ok(t) = provider.access_token()
&& !t.is_empty()
{
assert!(t.starts_with("tok-"), "unexpected cached value: {t}");
return;
}
}
panic!("cache never populated within 32 yields");
}
#[tokio::test(start_paused = true)]
async fn watch_updates_recovers_after_transient_failures() {
let provider = AsyncTokenProvider::new(FakeProvider::with_initial_failures(2));
provider.watch_updates().await;
for _ in 0..16 {
tokio::time::advance(Duration::from_secs(1)).await;
tokio::task::yield_now().await;
if let Ok(t) = provider.access_token()
&& !t.is_empty()
{
assert!(t.starts_with("tok-"), "unexpected cached value: {t}");
return;
}
}
panic!("cache never populated after 16 simulated seconds");
}
}