use super::auth_state::AuthStateManager;
use super::errors::TokenRefreshError;
use super::provider::AuthProvider;
use super::types::CachedToken;
use crate::service_bus_manager::ServiceBusError;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tokio::time::{interval, sleep};
pub type RefreshFailureCallback = Arc<dyn Fn(TokenRefreshError) + Send + Sync>;
pub struct TokenRefreshService {
auth_state: Arc<AuthStateManager>,
check_interval: Duration,
shutdown_signal: Arc<RwLock<bool>>,
failure_callback: Option<RefreshFailureCallback>,
}
impl TokenRefreshService {
pub fn new(auth_state: Arc<AuthStateManager>) -> Self {
Self {
auth_state,
check_interval: Duration::from_secs(120), shutdown_signal: Arc::new(RwLock::new(false)),
failure_callback: None,
}
}
pub fn with_failure_callback(mut self, callback: RefreshFailureCallback) -> Self {
self.failure_callback = Some(callback);
self
}
pub fn start(self: Arc<Self>) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
self.run().await;
})
}
pub async fn shutdown(&self) {
let mut shutdown = self.shutdown_signal.write().await;
*shutdown = true;
}
async fn run(&self) {
let mut check_interval = interval(self.check_interval);
check_interval.tick().await;
loop {
if *self.shutdown_signal.read().await {
log::info!("Token refresh service shutting down");
break;
}
check_interval.tick().await;
if let Err(e) = self.check_and_refresh_tokens().await {
log::error!("Error during token refresh check: {e}");
}
}
}
async fn check_and_refresh_tokens(&self) -> Result<(), ServiceBusError> {
log::debug!("Checking tokens for refresh...");
if let Some(provider) = self.auth_state.get_service_bus_provider().await {
self.refresh_if_needed("service_bus", provider).await?;
}
if let Some(provider) = self.auth_state.get_management_provider().await {
self.refresh_if_needed("management_api", provider).await?;
}
Ok(())
}
async fn refresh_if_needed(
&self,
cache_key: &str,
provider: Arc<dyn AuthProvider>,
) -> Result<(), ServiceBusError> {
let token_cache = self.auth_state.get_token_cache();
if token_cache.needs_refresh(cache_key).await {
log::info!("Token for '{cache_key}' needs refresh, attempting refresh...");
match self.refresh_with_retry(provider, 3).await {
Ok(auth_token) => {
let cached_token = CachedToken::new(
auth_token.token,
Duration::from_secs(auth_token.expires_in_secs.unwrap_or(3600)),
auth_token.token_type,
);
token_cache.set(cache_key.to_string(), cached_token).await;
log::info!("Successfully refreshed token for '{cache_key}'");
}
Err(e) => {
log::error!("Failed to refresh token for '{cache_key}': {e}");
token_cache.invalidate(cache_key).await;
if let Some(callback) = &self.failure_callback {
callback(e.clone());
}
return Err(e.into());
}
}
}
Ok(())
}
async fn refresh_with_retry(
&self,
provider: Arc<dyn AuthProvider>,
max_attempts: u32,
) -> Result<super::provider::AuthToken, TokenRefreshError> {
let mut last_error = None;
for attempt in 1..=max_attempts {
match provider.refresh().await {
Ok(token) => return Ok(token),
Err(e) => {
let refresh_error = match &e {
ServiceBusError::AuthenticationFailed(_) => {
TokenRefreshError::InvalidRefreshToken
}
ServiceBusError::AuthenticationError(msg) if msg.contains("expired") => {
TokenRefreshError::RefreshTokenExpired
}
ServiceBusError::ConnectionFailed(reason) => {
TokenRefreshError::NetworkError {
reason: reason.clone(),
}
}
ServiceBusError::OperationTimeout(msg) => {
if msg.contains("rate") {
TokenRefreshError::RateLimited {
retry_after_seconds: None,
}
} else {
TokenRefreshError::ServiceUnavailable {
reason: msg.clone(),
}
}
}
_ => TokenRefreshError::Internal(e.to_string()),
};
last_error = Some(refresh_error);
if attempt < max_attempts {
let delay = Duration::from_secs(2u64.pow(attempt - 1)); log::warn!(
"Token refresh attempt {attempt} failed, retrying in {delay:?}..."
);
sleep(delay).await;
}
}
}
}
Err(last_error.unwrap_or(TokenRefreshError::MaxRetriesExceeded {
attempts: max_attempts,
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::auth_state::AuthStateManager;
use crate::auth::provider::{AuthProvider, AuthToken};
use crate::auth::types::AuthType;
use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
struct MockAuthProvider {
refresh_count: Arc<AtomicU32>,
should_fail: bool,
}
#[async_trait]
impl AuthProvider for MockAuthProvider {
async fn authenticate(&self) -> Result<AuthToken, ServiceBusError> {
Ok(AuthToken {
token: "test_token".to_string(),
token_type: "Bearer".to_string(),
expires_in_secs: Some(3600),
})
}
async fn refresh(&self) -> Result<AuthToken, ServiceBusError> {
self.refresh_count.fetch_add(1, Ordering::SeqCst);
if self.should_fail {
Err(ServiceBusError::AuthenticationError(
"Mock refresh failure".to_string(),
))
} else {
self.authenticate().await
}
}
fn auth_type(&self) -> AuthType {
AuthType::ConnectionString
}
}
#[tokio::test]
async fn test_refresh_with_retry_success() {
let auth_state = Arc::new(AuthStateManager::new());
let service = TokenRefreshService::new(auth_state);
let provider = Arc::new(MockAuthProvider {
refresh_count: Arc::new(AtomicU32::new(0)),
should_fail: false,
});
let result = service.refresh_with_retry(provider.clone(), 3).await;
assert!(result.is_ok());
assert_eq!(provider.refresh_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_refresh_with_retry_failure() {
let auth_state = Arc::new(AuthStateManager::new());
let service = TokenRefreshService::new(auth_state);
let provider = Arc::new(MockAuthProvider {
refresh_count: Arc::new(AtomicU32::new(0)),
should_fail: true,
});
let result = service.refresh_with_retry(provider.clone(), 3).await;
assert!(result.is_err());
assert_eq!(provider.refresh_count.load(Ordering::SeqCst), 3); }
}