cloud_storage_lite/token_provider/
renewing.rs1use std::sync::Arc;
5
6use anyhow::Context;
7use chrono::{Duration, Utc};
8use tokio::sync::Mutex;
9use tracing::debug;
10
11use super::{Token, TokenProvider};
12
13#[derive(Clone)]
15pub struct RenewingTokenProvider<T> {
16 provider: T,
17
18 token: Arc<Mutex<Arc<Token>>>,
23}
24
25impl<T: TokenProvider> RenewingTokenProvider<T> {
26 pub fn new(provider: T) -> Self {
29 Self {
30 provider,
31 token: Default::default(),
32 }
33 }
34}
35
36#[async_trait::async_trait]
37impl<T: TokenProvider> TokenProvider for RenewingTokenProvider<T> {
38 async fn get_token(&self) -> anyhow::Result<Arc<Token>> {
39 let mut token = self.token.lock().await;
40 if Utc::now() >= token.expiry - Duration::minutes(10) {
41 debug!(expiry = %token.expiry, now = %Utc::now(), "renewing token");
42 let new_token = self
43 .provider
44 .get_token()
45 .await
46 .context("failed to renew token")?;
47 *token = new_token;
48 debug!("successfully renewed token");
49 Ok(Arc::clone(&token))
50 } else {
51 Ok(token.clone())
52 }
53 }
54
55 async fn invalidate_token(&self) {
56 let mut token = self.token.lock().await;
57 *token = Default::default();
58 }
59}
60
61#[cfg(test)]
62mod test {
63 use super::*;
64
65 use crate::token_provider::MockTokenProvider;
66
67 #[tokio::test]
68 async fn renewing_provider_renews_expired() {
69 let new_token = Arc::new(Token {
70 token: "dummy token".into(),
71 expiry: Utc::now() + Duration::minutes(30),
72 });
73 let new_token_return = new_token.clone();
74 let mut mock_provider = MockTokenProvider::new();
75 mock_provider
76 .expect_get_token()
77 .times(1)
78 .returning(move || Ok(new_token_return.clone()));
79 let renewing = RenewingTokenProvider::new(mock_provider);
80 assert_eq!(renewing.get_token().await.unwrap(), new_token);
81 assert_eq!(renewing.get_token().await.unwrap(), new_token);
82 }
83
84 #[tokio::test]
85 async fn renewing_provider_renews_almost_expired() {
86 let shortly_expiring_token = Arc::new(Token {
87 token: "dummy token".into(),
88 expiry: Utc::now() + Duration::minutes(1),
89 });
90 let new_token = Arc::new(Token {
91 token: "dummy token".into(),
92 expiry: Utc::now() + Duration::minutes(30),
93 });
94
95 let mut mock_provider = MockTokenProvider::new();
96 let mut seq = mockall::Sequence::new();
97 mock_provider
98 .expect_get_token()
99 .times(1)
100 .in_sequence(&mut seq)
101 .returning(move || Ok(shortly_expiring_token.clone()));
102 mock_provider
103 .expect_get_token()
104 .times(1)
105 .in_sequence(&mut seq)
106 .returning(move || Ok(new_token.clone()));
107
108 let renewing = RenewingTokenProvider::new(mock_provider);
109 renewing.get_token().await.unwrap();
110 renewing.get_token().await.unwrap();
111 renewing.get_token().await.unwrap();
112 }
113}