server/auth/
token_refresh_service.rs1use super::auth_state::AuthStateManager;
2use super::errors::TokenRefreshError;
3use super::provider::AuthProvider;
4use super::types::CachedToken;
5use crate::service_bus_manager::ServiceBusError;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::RwLock;
9use tokio::time::{interval, sleep};
10
11pub type RefreshFailureCallback = Arc<dyn Fn(TokenRefreshError) + Send + Sync>;
13
14pub struct TokenRefreshService {
16 auth_state: Arc<AuthStateManager>,
17 check_interval: Duration,
18 shutdown_signal: Arc<RwLock<bool>>,
19 failure_callback: Option<RefreshFailureCallback>,
20}
21
22impl TokenRefreshService {
23 pub fn new(auth_state: Arc<AuthStateManager>) -> Self {
25 Self {
26 auth_state,
27 check_interval: Duration::from_secs(120), shutdown_signal: Arc::new(RwLock::new(false)),
29 failure_callback: None,
30 }
31 }
32
33 pub fn with_failure_callback(mut self, callback: RefreshFailureCallback) -> Self {
35 self.failure_callback = Some(callback);
36 self
37 }
38
39 pub fn start(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
41 tokio::spawn(async move {
42 self.run().await;
43 })
44 }
45
46 pub async fn shutdown(&self) {
48 let mut shutdown = self.shutdown_signal.write().await;
49 *shutdown = true;
50 }
51
52 async fn run(&self) {
54 let mut check_interval = interval(self.check_interval);
55 check_interval.tick().await; loop {
58 if *self.shutdown_signal.read().await {
60 log::info!("Token refresh service shutting down");
61 break;
62 }
63
64 check_interval.tick().await;
65
66 if let Err(e) = self.check_and_refresh_tokens().await {
68 log::error!("Error during token refresh check: {e}");
69 }
70 }
71 }
72
73 async fn check_and_refresh_tokens(&self) -> Result<(), ServiceBusError> {
75 log::debug!("Checking tokens for refresh...");
76
77 if let Some(provider) = self.auth_state.get_service_bus_provider().await {
79 self.refresh_if_needed("service_bus", provider).await?;
80 }
81
82 if let Some(provider) = self.auth_state.get_management_provider().await {
84 self.refresh_if_needed("management_api", provider).await?;
85 }
86
87 Ok(())
88 }
89
90 async fn refresh_if_needed(
92 &self,
93 cache_key: &str,
94 provider: Arc<dyn AuthProvider>,
95 ) -> Result<(), ServiceBusError> {
96 let token_cache = self.auth_state.get_token_cache();
97
98 if token_cache.needs_refresh(cache_key).await {
99 log::info!("Token for '{cache_key}' needs refresh, attempting refresh...");
100
101 match self.refresh_with_retry(provider, 3).await {
102 Ok(auth_token) => {
103 let cached_token = CachedToken::new(
105 auth_token.token,
106 Duration::from_secs(auth_token.expires_in_secs.unwrap_or(3600)),
107 auth_token.token_type,
108 );
109
110 token_cache.set(cache_key.to_string(), cached_token).await;
111 log::info!("Successfully refreshed token for '{cache_key}'");
112 }
113 Err(e) => {
114 log::error!("Failed to refresh token for '{cache_key}': {e}");
115
116 token_cache.invalidate(cache_key).await;
118
119 if let Some(callback) = &self.failure_callback {
121 callback(e.clone());
122 }
123
124 return Err(e.into());
126 }
127 }
128 }
129
130 Ok(())
131 }
132
133 async fn refresh_with_retry(
135 &self,
136 provider: Arc<dyn AuthProvider>,
137 max_attempts: u32,
138 ) -> Result<super::provider::AuthToken, TokenRefreshError> {
139 let mut last_error = None;
140
141 for attempt in 1..=max_attempts {
142 match provider.refresh().await {
143 Ok(token) => return Ok(token),
144 Err(e) => {
145 let refresh_error = match &e {
147 ServiceBusError::AuthenticationFailed(_) => {
148 TokenRefreshError::InvalidRefreshToken
149 }
150 ServiceBusError::AuthenticationError(msg) if msg.contains("expired") => {
151 TokenRefreshError::RefreshTokenExpired
152 }
153 ServiceBusError::ConnectionFailed(reason) => {
154 TokenRefreshError::NetworkError {
155 reason: reason.clone(),
156 }
157 }
158 ServiceBusError::OperationTimeout(msg) => {
159 if msg.contains("rate") {
160 TokenRefreshError::RateLimited {
161 retry_after_seconds: None,
162 }
163 } else {
164 TokenRefreshError::ServiceUnavailable {
165 reason: msg.clone(),
166 }
167 }
168 }
169 _ => TokenRefreshError::Internal(e.to_string()),
170 };
171
172 last_error = Some(refresh_error);
173
174 if attempt < max_attempts {
175 let delay = Duration::from_secs(2u64.pow(attempt - 1)); log::warn!(
177 "Token refresh attempt {attempt} failed, retrying in {delay:?}..."
178 );
179 sleep(delay).await;
180 }
181 }
182 }
183 }
184
185 Err(last_error.unwrap_or(TokenRefreshError::MaxRetriesExceeded {
186 attempts: max_attempts,
187 }))
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use crate::auth::auth_state::AuthStateManager;
195 use crate::auth::provider::{AuthProvider, AuthToken};
196 use crate::auth::types::AuthType;
197 use async_trait::async_trait;
198 use std::sync::atomic::{AtomicU32, Ordering};
199
200 struct MockAuthProvider {
202 refresh_count: Arc<AtomicU32>,
203 should_fail: bool,
204 }
205
206 #[async_trait]
207 impl AuthProvider for MockAuthProvider {
208 async fn authenticate(&self) -> Result<AuthToken, ServiceBusError> {
209 Ok(AuthToken {
210 token: "test_token".to_string(),
211 token_type: "Bearer".to_string(),
212 expires_in_secs: Some(3600),
213 })
214 }
215
216 async fn refresh(&self) -> Result<AuthToken, ServiceBusError> {
217 self.refresh_count.fetch_add(1, Ordering::SeqCst);
218
219 if self.should_fail {
220 Err(ServiceBusError::AuthenticationError(
221 "Mock refresh failure".to_string(),
222 ))
223 } else {
224 self.authenticate().await
225 }
226 }
227
228 fn auth_type(&self) -> AuthType {
229 AuthType::ConnectionString
230 }
231 }
232
233 #[tokio::test]
234 async fn test_refresh_with_retry_success() {
235 let auth_state = Arc::new(AuthStateManager::new());
236 let service = TokenRefreshService::new(auth_state);
237
238 let provider = Arc::new(MockAuthProvider {
239 refresh_count: Arc::new(AtomicU32::new(0)),
240 should_fail: false,
241 });
242
243 let result = service.refresh_with_retry(provider.clone(), 3).await;
244 assert!(result.is_ok());
245 assert_eq!(provider.refresh_count.load(Ordering::SeqCst), 1);
246 }
247
248 #[tokio::test]
249 async fn test_refresh_with_retry_failure() {
250 let auth_state = Arc::new(AuthStateManager::new());
251 let service = TokenRefreshService::new(auth_state);
252
253 let provider = Arc::new(MockAuthProvider {
254 refresh_count: Arc::new(AtomicU32::new(0)),
255 should_fail: true,
256 });
257
258 let result = service.refresh_with_retry(provider.clone(), 3).await;
259 assert!(result.is_err());
260 assert_eq!(provider.refresh_count.load(Ordering::SeqCst), 3); }
262}