Skip to main content

cloud_storage_lite/token_provider/
renewing.rs

1//! A `TokenProvider` that renews its token using a backing `TokenProvider` when the current
2//! token is expired or about to expire.
3
4use std::sync::Arc;
5
6use anyhow::Context;
7use chrono::{Duration, Utc};
8use tokio::sync::Mutex;
9use tracing::debug;
10
11use super::{Token, TokenProvider};
12
13/// A `TokenProvider` that renews its token whenever the current token expires.
14#[derive(Clone)]
15pub struct RenewingTokenProvider<T> {
16    provider: T,
17
18    // This is a mutex instead of anything fancier like an ArcSwap because no other authenticated
19    // requests can reasonably continue (without becoming a thundering herd) while the thread that
20    // discovered the expired token is fetching a replacement.
21    // This is an async mutex instead of a regular mutex because fetching the new token is async.
22    token: Arc<Mutex<Arc<Token>>>,
23}
24
25impl<T: TokenProvider> RenewingTokenProvider<T> {
26    /// Returns new `RenewingTokenProvider` that fetches a new token using the provided `provider`
27    /// when the current one expires.
28    pub fn new(provider: T) -> Self {
29        Self {
30            provider,
31            token: Default::default(),
32        }
33    }
34}
35
36#[async_trait::async_trait]
37impl<T: TokenProvider> TokenProvider for RenewingTokenProvider<T> {
38    async fn get_token(&self) -> anyhow::Result<Arc<Token>> {
39        let mut token = self.token.lock().await;
40        if Utc::now() >= token.expiry - Duration::minutes(10) {
41            debug!(expiry = %token.expiry, now = %Utc::now(), "renewing token");
42            let new_token = self
43                .provider
44                .get_token()
45                .await
46                .context("failed to renew token")?;
47            *token = new_token;
48            debug!("successfully renewed token");
49            Ok(Arc::clone(&token))
50        } else {
51            Ok(token.clone())
52        }
53    }
54
55    async fn invalidate_token(&self) {
56        let mut token = self.token.lock().await;
57        *token = Default::default();
58    }
59}
60
61#[cfg(test)]
62mod test {
63    use super::*;
64
65    use crate::token_provider::MockTokenProvider;
66
67    #[tokio::test]
68    async fn renewing_provider_renews_expired() {
69        let new_token = Arc::new(Token {
70            token: "dummy token".into(),
71            expiry: Utc::now() + Duration::minutes(30),
72        });
73        let new_token_return = new_token.clone();
74        let mut mock_provider = MockTokenProvider::new();
75        mock_provider
76            .expect_get_token()
77            .times(1)
78            .returning(move || Ok(new_token_return.clone()));
79        let renewing = RenewingTokenProvider::new(mock_provider);
80        assert_eq!(renewing.get_token().await.unwrap(), new_token);
81        assert_eq!(renewing.get_token().await.unwrap(), new_token);
82    }
83
84    #[tokio::test]
85    async fn renewing_provider_renews_almost_expired() {
86        let shortly_expiring_token = Arc::new(Token {
87            token: "dummy token".into(),
88            expiry: Utc::now() + Duration::minutes(1),
89        });
90        let new_token = Arc::new(Token {
91            token: "dummy token".into(),
92            expiry: Utc::now() + Duration::minutes(30),
93        });
94
95        let mut mock_provider = MockTokenProvider::new();
96        let mut seq = mockall::Sequence::new();
97        mock_provider
98            .expect_get_token()
99            .times(1)
100            .in_sequence(&mut seq)
101            .returning(move || Ok(shortly_expiring_token.clone()));
102        mock_provider
103            .expect_get_token()
104            .times(1)
105            .in_sequence(&mut seq)
106            .returning(move || Ok(new_token.clone()));
107
108        let renewing = RenewingTokenProvider::new(mock_provider);
109        renewing.get_token().await.unwrap();
110        renewing.get_token().await.unwrap();
111        renewing.get_token().await.unwrap();
112    }
113}