gcloud_auth/token_source/
reuse_token_source.rs1use async_trait::async_trait;
2
3use crate::error::Error;
4use crate::token::Token;
5use crate::token_source::TokenSource;
6
7#[derive(Debug)]
8pub struct ReuseTokenSource {
9 target: Box<dyn TokenSource>,
10 current_token: std::sync::RwLock<Token>,
11 guard: tokio::sync::Mutex<()>,
12}
13
14impl ReuseTokenSource {
15 pub(crate) fn new(target: Box<dyn TokenSource>, token: Token) -> ReuseTokenSource {
16 ReuseTokenSource {
17 target,
18 current_token: std::sync::RwLock::new(token),
19 guard: tokio::sync::Mutex::new(()),
20 }
21 }
22}
23
24#[async_trait]
25impl TokenSource for ReuseTokenSource {
26 async fn token(&self) -> Result<Token, Error> {
27 if let Some(token) = self.r_lock_token() {
28 return Ok(token);
29 }
30
31 let _locking = self.guard.lock().await;
33
34 if let Some(token) = self.r_lock_token() {
35 return Ok(token);
36 }
37
38 let token = self.target.token().await?;
39 tracing::debug!("token refresh success : expiry={:?}", token.expiry);
40 *self.current_token.write().unwrap() = token.clone();
41 Ok(token)
42 }
43}
44
45impl ReuseTokenSource {
46 fn r_lock_token(&self) -> Option<Token> {
47 let token = self.current_token.read().unwrap();
48 if token.valid() {
49 Some(token.clone())
50 } else {
51 None
52 }
53 }
54}
55
56#[cfg(test)]
57mod test {
58 use std::fmt::Debug;
59 use std::sync::Arc;
60
61 use async_trait::async_trait;
62 use time::OffsetDateTime;
63 use tracing_subscriber::filter::LevelFilter;
64
65 use crate::error::Error;
66 use crate::token::Token;
67 use crate::token_source::reuse_token_source::ReuseTokenSource;
68 use crate::token_source::TokenSource;
69
70 #[derive(Debug)]
71 struct EmptyTokenSource {
72 pub expiry: OffsetDateTime,
73 }
74 #[async_trait]
75 impl TokenSource for EmptyTokenSource {
76 async fn token(&self) -> Result<Token, Error> {
77 Ok(Token {
78 access_token: "empty".to_string(),
79 token_type: "empty".to_string(),
80 expiry: Some(self.expiry),
81 })
82 }
83 }
84
85 #[ctor::ctor]
86 fn init() {
87 let filter = tracing_subscriber::filter::EnvFilter::from_default_env().add_directive(LevelFilter::DEBUG.into());
88 let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
89 }
90
91 #[tokio::test]
92 async fn test_all_valid() {
93 let ts = Box::new(EmptyTokenSource {
94 expiry: OffsetDateTime::now_utc() + time::Duration::seconds(100),
95 });
96 let token = ts.token().await.unwrap();
97 let results = run_task(ts, token).await;
98 for v in results {
99 assert!(v)
100 }
101 }
102
103 #[tokio::test]
104 async fn test_with_invalid() {
105 let mut ts = Box::new(EmptyTokenSource {
106 expiry: OffsetDateTime::now_utc(),
107 });
108 let token = ts.token().await.unwrap();
109 ts.expiry = OffsetDateTime::now_utc() + time::Duration::seconds(100);
110 let results = run_task(ts, token).await;
111 for v in results {
112 assert!(v)
113 }
114 }
115
116 #[tokio::test]
117 async fn test_all_invalid() {
118 let ts = Box::new(EmptyTokenSource {
119 expiry: OffsetDateTime::now_utc(),
120 });
121 let token = ts.token().await.unwrap();
122 let results = run_task(ts, token).await;
123 for v in results {
124 assert!(!v)
125 }
126 }
127
128 async fn run_task(ts: Box<EmptyTokenSource>, first_token: Token) -> Vec<bool> {
129 let ts = Arc::new(ReuseTokenSource::new(ts, first_token));
130 let mut tasks = Vec::with_capacity(100);
131 for _n in 1..100 {
132 let ts_clone = ts.clone();
133 let task = tokio::spawn(async move {
134 match ts_clone.token().await {
135 Ok(new_token) => new_token.valid(),
136 Err(_e) => false,
137 }
138 });
139 tasks.push(task)
140 }
141 let mut result = Vec::with_capacity(tasks.len());
142 for task in tasks {
143 result.push(task.await.unwrap());
144 }
145 result
146 }
147}