gcloud_auth/token_source/
reuse_token_source.rs

1use async_trait::async_trait;
2
3use crate::error::Error;
4use crate::token::Token;
5use crate::token_source::TokenSource;
6
7#[derive(Debug)]
8pub struct ReuseTokenSource {
9    target: Box<dyn TokenSource>,
10    current_token: std::sync::RwLock<Token>,
11    guard: tokio::sync::Mutex<()>,
12}
13
14impl ReuseTokenSource {
15    pub(crate) fn new(target: Box<dyn TokenSource>, token: Token) -> ReuseTokenSource {
16        ReuseTokenSource {
17            target,
18            current_token: std::sync::RwLock::new(token),
19            guard: tokio::sync::Mutex::new(()),
20        }
21    }
22}
23
24#[async_trait]
25impl TokenSource for ReuseTokenSource {
26    async fn token(&self) -> Result<Token, Error> {
27        if let Some(token) = self.r_lock_token() {
28            return Ok(token);
29        }
30
31        // Only single task can refresh token
32        let _locking = self.guard.lock().await;
33
34        if let Some(token) = self.r_lock_token() {
35            return Ok(token);
36        }
37
38        let token = self.target.token().await?;
39        tracing::debug!("token refresh success : expiry={:?}", token.expiry);
40        *self.current_token.write().unwrap() = token.clone();
41        Ok(token)
42    }
43}
44
45impl ReuseTokenSource {
46    fn r_lock_token(&self) -> Option<Token> {
47        let token = self.current_token.read().unwrap();
48        if token.valid() {
49            Some(token.clone())
50        } else {
51            None
52        }
53    }
54}
55
56#[cfg(test)]
57mod test {
58    use std::fmt::Debug;
59    use std::sync::Arc;
60
61    use async_trait::async_trait;
62    use time::OffsetDateTime;
63    use tracing_subscriber::filter::LevelFilter;
64
65    use crate::error::Error;
66    use crate::token::Token;
67    use crate::token_source::reuse_token_source::ReuseTokenSource;
68    use crate::token_source::TokenSource;
69
70    #[derive(Debug)]
71    struct EmptyTokenSource {
72        pub expiry: OffsetDateTime,
73    }
74    #[async_trait]
75    impl TokenSource for EmptyTokenSource {
76        async fn token(&self) -> Result<Token, Error> {
77            Ok(Token {
78                access_token: "empty".to_string(),
79                token_type: "empty".to_string(),
80                expiry: Some(self.expiry),
81            })
82        }
83    }
84
85    #[ctor::ctor]
86    fn init() {
87        let filter = tracing_subscriber::filter::EnvFilter::from_default_env().add_directive(LevelFilter::DEBUG.into());
88        let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
89    }
90
91    #[tokio::test]
92    async fn test_all_valid() {
93        let ts = Box::new(EmptyTokenSource {
94            expiry: OffsetDateTime::now_utc() + time::Duration::seconds(100),
95        });
96        let token = ts.token().await.unwrap();
97        let results = run_task(ts, token).await;
98        for v in results {
99            assert!(v)
100        }
101    }
102
103    #[tokio::test]
104    async fn test_with_invalid() {
105        let mut ts = Box::new(EmptyTokenSource {
106            expiry: OffsetDateTime::now_utc(),
107        });
108        let token = ts.token().await.unwrap();
109        ts.expiry = OffsetDateTime::now_utc() + time::Duration::seconds(100);
110        let results = run_task(ts, token).await;
111        for v in results {
112            assert!(v)
113        }
114    }
115
116    #[tokio::test]
117    async fn test_all_invalid() {
118        let ts = Box::new(EmptyTokenSource {
119            expiry: OffsetDateTime::now_utc(),
120        });
121        let token = ts.token().await.unwrap();
122        let results = run_task(ts, token).await;
123        for v in results {
124            assert!(!v)
125        }
126    }
127
128    async fn run_task(ts: Box<EmptyTokenSource>, first_token: Token) -> Vec<bool> {
129        let ts = Arc::new(ReuseTokenSource::new(ts, first_token));
130        let mut tasks = Vec::with_capacity(100);
131        for _n in 1..100 {
132            let ts_clone = ts.clone();
133            let task = tokio::spawn(async move {
134                match ts_clone.token().await {
135                    Ok(new_token) => new_token.valid(),
136                    Err(_e) => false,
137                }
138            });
139            tasks.push(task)
140        }
141        let mut result = Vec::with_capacity(tasks.len());
142        for task in tasks {
143            result.push(task.await.unwrap());
144        }
145        result
146    }
147}