gauth 0.10.1

HTTP Client for Google OAuth2
Documentation
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio::sync::mpsc::Sender;
use tokio::time::{Duration, interval, sleep};

use self::errors::Result;

mod errors;

#[derive(Debug, Clone)]
pub struct AsyncTokenProvider<T> {
    inner: T,
    cached_token: Arc<Mutex<String>>,
    interval: u64,
}

#[async_trait]
pub trait TokenProvider: Send {
    async fn access_token(&mut self) -> Result<String>;
}

#[async_trait]
impl TokenProvider for crate::serv_account::ServiceAccount {
    async fn access_token(&mut self) -> Result<String> {
        Ok(self.access_token().await?)
    }
}

#[async_trait]
impl TokenProvider for crate::app::Auth {
    async fn access_token(&mut self) -> Result<String> {
        Ok(self.access_token().await?)
    }
}

#[async_trait]
pub trait Watcher {
    async fn watch_updates(&mut self, tx: Sender<String>, interval_sec: u64);
}

#[async_trait]
impl<T: TokenProvider> Watcher for T {
    async fn watch_updates(&mut self, tx: Sender<String>, interval_sec: u64) {
        let mut interval = interval(Duration::from_secs(interval_sec));
        let retries = 3;
        let mut attempt = 0;

        loop {
            let res = self.access_token().await;
            match send_token(res, &tx).await {
                Ok(_) => {}
                Err(err) => {
                    if attempt == retries {
                        log::error!("{}", err);
                        break;
                    }

                    attempt += 1;
                    let backoff = 1 << attempt;
                    let delay = Duration::from_secs(backoff);

                    log::error!("{}. retry in: {}s", err, backoff);
                    sleep(delay).await;
                    continue;
                }
            }

            attempt = 0;
            interval.tick().await;
        }
    }
}

async fn send_token(access_token_res: Result<String>, tx: &Sender<String>) -> Result<()> {
    Ok(tx.send(access_token_res?).await?)
}

impl<T> AsyncTokenProvider<T>
where
    T: Watcher + Clone + Send + 'static,
{
    pub fn new(inner: T) -> Self {
        Self {
            inner,
            cached_token: Arc::new(Mutex::new(String::new())),
            interval: 60,
        }
    }

    pub fn with_interval(mut self, interval: u64) -> Self {
        self.interval = interval;
        self
    }

    pub fn access_token(&self) -> Result<String> {
        Ok(self.cached_token.try_lock()?.clone())
    }

    pub async fn watch_updates(&self) {
        let (tx, mut rx) = tokio::sync::mpsc::channel(1);

        let mut watcher = self.inner.clone();
        let cached_token = Arc::clone(&self.cached_token);
        let interval = self.interval;

        tokio::spawn(async move {
            watcher.watch_updates(tx, interval).await;
        });

        tokio::spawn(async move {
            while let Some(token) = rx.recv().await {
                log::debug!("access token refreshed");
                let mut cached_token = cached_token.lock().await;
                *cached_token = token;
            }
        });
    }
}