1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
//! A `TokenProvider` that renews its token using a backing `TokenProvider` when the current
//! token is expired or about to expire.

use std::sync::Arc;

use chrono::{Duration, Utc};
use tokio::sync::Mutex;

use super::{Token, TokenProvider};

/// A `TokenProvider` that renews its token whenever the current token expires.
#[derive(Clone)]
pub struct RenewingTokenProvider<T> {
    provider: T,

    // This is a mutex instead of anything fancier like an ArcSwap because no other authenticated
    // requests can reasonably continue (without becoming a thundering herd) while the thread that
    // discovered the expired token is fetching a replacement.
    // This is an async mutex instead of a regular mutex because fetching the new token is async.
    token: Arc<Mutex<Arc<Token>>>,
}

impl<T: TokenProvider> RenewingTokenProvider<T> {
    /// Returns new `RenewingTokenProvider` that fetches a new token using the provided `provider`
    /// when the current one expires.
    pub fn new(provider: T) -> Self {
        Self {
            provider,
            token: Default::default(),
        }
    }
}

#[async_trait::async_trait]
impl<T: TokenProvider> TokenProvider for RenewingTokenProvider<T> {
    async fn get_token(&self) -> anyhow::Result<Arc<Token>> {
        let mut token = self.token.lock().await;
        if token.expiry <= Utc::now() - Duration::minutes(5) {
            let new_token = self.provider.get_token().await?;
            *token = new_token;
            Ok(Arc::clone(&token))
        } else {
            Ok(token.clone())
        }
    }
}

#[cfg(test)]
mod test {
    use super::*;

    use crate::token_provider::MockTokenProvider;

    #[tokio::test]
    async fn renewing_provider_renews_expired() {
        let new_token = Arc::new(Token {
            token: "dummy token".into(),
            expiry: Utc::now(),
        });
        let new_token_return = new_token.clone();
        let mut mock_provider = MockTokenProvider::new();
        mock_provider
            .expect_get_token()
            .times(1)
            .returning(move || Ok(new_token_return.clone()));
        let renewing = RenewingTokenProvider::new(mock_provider);
        assert_eq!(renewing.get_token().await.unwrap(), new_token);
        assert_eq!(renewing.get_token().await.unwrap(), new_token);
    }
}