use crate::auth::authenticator::Authenticator;
use crate::auth::token::AccessToken;
use crate::error::Result;
use std::sync::Arc;
use tokio::sync::{Mutex, RwLock};
#[derive(Debug)]
struct TokenState {
token: Option<Arc<AccessToken>>,
}
#[derive(Debug)]
pub struct TokenManager<A: Authenticator> {
authenticator: A,
state: Arc<RwLock<TokenState>>,
refresh_lock: Mutex<()>,
}
impl<A: Authenticator> TokenManager<A> {
#[must_use]
pub fn new(authenticator: A) -> Self {
Self {
authenticator,
state: Arc::new(RwLock::new(TokenState { token: None })),
refresh_lock: Mutex::new(()),
}
}
async fn update_token_state(
&self,
arc_token: Arc<AccessToken>,
is_initial_auth: bool,
) -> Result<Arc<AccessToken>> {
let mut state = self.state.write().await;
if let Some(current) = &state.token {
if current.issued_at() > arc_token.issued_at() || Arc::ptr_eq(current, &arc_token) {
return Ok(current.clone());
}
state.token = Some(arc_token.clone());
} else if is_initial_auth {
state.token = Some(arc_token.clone());
} else {
return Err(crate::error::ForceError::Authentication(
crate::error::AuthenticationError::InvalidToken,
));
}
Ok(arc_token)
}
async fn latest_token_or(&self, fallback: Arc<AccessToken>) -> Arc<AccessToken> {
let state = self.state.read().await;
match &state.token {
Some(current) if current.issued_at() >= fallback.issued_at() => current.clone(),
_ => fallback,
}
}
pub(crate) async fn get_token_arc(&self) -> Result<Arc<AccessToken>> {
let (is_soft_expired, is_hard_expired_actual, current_token) =
self.evaluate_token_state().await;
if let Some(token) = current_token.as_ref() {
if !is_soft_expired && !is_hard_expired_actual {
return Ok(token.clone());
}
}
if is_hard_expired_actual {
self.handle_hard_refresh().await
} else if let Some(valid_token) = current_token {
self.handle_soft_refresh(valid_token).await
} else {
Err(crate::error::ForceError::Authentication(
crate::error::AuthenticationError::InvalidToken,
))
}
}
async fn evaluate_token_state(&self) -> (bool, bool, Option<Arc<AccessToken>>) {
let state = self.state.read().await;
if let Some(token) = &state.token {
(
token.is_soft_expired(),
token.is_hard_expired(),
Some(token.clone()),
)
} else {
(false, true, None)
}
}
async fn handle_hard_refresh(&self) -> Result<Arc<AccessToken>> {
let _lock = self.refresh_lock.lock().await;
{
let state = self.state.read().await;
if let Some(token) = &state.token {
if !token.is_hard_expired() {
return Ok(token.clone());
}
}
}
let has_token = {
let state = self.state.read().await;
state.token.is_some()
};
let new_token = if has_token {
self.authenticator.refresh().await?
} else {
self.authenticator.authenticate().await?
};
let arc_token = Arc::new(new_token);
self.update_token_state(arc_token, !has_token).await
}
async fn handle_soft_refresh(&self, valid_token: Arc<AccessToken>) -> Result<Arc<AccessToken>> {
let Ok(_lock) = self.refresh_lock.try_lock() else {
return Ok(self.latest_token_or(valid_token).await);
};
{
let state = self.state.read().await;
if let Some(token) = &state.token {
if !token.is_soft_expired() && !token.is_hard_expired() {
return Ok(token.clone());
}
}
}
let refresh_result = self.authenticator.refresh().await;
match refresh_result {
Ok(new_token) => {
let arc_token = Arc::new(new_token);
self.update_token_state(arc_token, false).await
}
Err(_) => Ok(self.latest_token_or(valid_token).await),
}
}
pub async fn token(&self) -> Result<AccessToken> {
let arc_token = self.get_token_arc().await?;
Ok((*arc_token).clone())
}
pub async fn force_refresh(&self) -> Result<AccessToken> {
let current_arc = {
let state = self.state.read().await;
state.token.clone()
};
let _lock = self.refresh_lock.lock().await;
{
let state = self.state.read().await;
if let Some(token) = &state.token {
let is_same = match ¤t_arc {
Some(arc) => Arc::ptr_eq(token, arc),
None => false,
};
if !is_same {
return Ok((*token.clone()).clone());
}
}
}
let has_token = {
let state = self.state.read().await;
state.token.is_some()
};
let new_token = if has_token {
self.authenticator.refresh().await?
} else {
self.authenticator.authenticate().await?
};
let arc_token = Arc::new(new_token);
let final_token = self.update_token_state(arc_token, !has_token).await?;
Ok((*final_token).clone())
}
pub async fn clear(&self) {
let mut state = self.state.write().await;
state.token = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::authenticator::Authenticator;
use crate::test_support::Must;
use async_trait::async_trait;
use chrono::{Duration, Utc};
use std::sync::Arc as StdArc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Debug)]
struct MockAuthenticator {
auth_count: StdArc<AtomicUsize>,
refresh_count: StdArc<AtomicUsize>,
should_fail: bool,
refresh_delay: Option<std::time::Duration>,
}
impl MockAuthenticator {
fn new() -> Self {
Self {
auth_count: StdArc::new(AtomicUsize::new(0)),
refresh_count: StdArc::new(AtomicUsize::new(0)),
should_fail: false,
refresh_delay: None,
}
}
fn with_failure() -> Self {
Self {
auth_count: StdArc::new(AtomicUsize::new(0)),
refresh_count: StdArc::new(AtomicUsize::new(0)),
should_fail: true,
refresh_delay: None,
}
}
fn with_delay(mut self, delay: std::time::Duration) -> Self {
self.refresh_delay = Some(delay);
self
}
fn auth_count(&self) -> usize {
self.auth_count.load(Ordering::SeqCst)
}
fn refresh_count(&self) -> usize {
self.refresh_count.load(Ordering::SeqCst)
}
}
#[async_trait]
impl Authenticator for MockAuthenticator {
async fn authenticate(&self) -> Result<AccessToken> {
if let Some(delay) = self.refresh_delay {
tokio::time::sleep(delay).await;
}
self.auth_count.fetch_add(1, Ordering::SeqCst);
if self.should_fail {
return Err(crate::error::ForceError::Authentication(
crate::error::AuthenticationError::InvalidCredentials(
"mock auth failed".to_string(),
),
));
}
Ok(AccessToken::new(
format!("auth_token_{}", self.auth_count()),
"https://test.salesforce.com".to_string(),
Some(Utc::now() + Duration::hours(2)),
))
}
async fn refresh(&self) -> Result<AccessToken> {
if let Some(delay) = self.refresh_delay {
tokio::time::sleep(delay).await;
}
self.refresh_count.fetch_add(1, Ordering::SeqCst);
if self.should_fail {
return Err(crate::error::ForceError::Authentication(
crate::error::AuthenticationError::TokenRefreshFailed(
"mock refresh failed".to_string(),
),
));
}
Ok(AccessToken::new(
format!("refresh_token_{}", self.refresh_count()),
"https://test.salesforce.com".to_string(),
Some(Utc::now() + Duration::hours(2)),
))
}
}
#[tokio::test]
async fn test_token_manager_initial_auth() {
let auth = MockAuthenticator::new();
let manager = TokenManager::new(auth);
let token = manager.token().await.must();
assert_eq!(token.as_str(), "auth_token_1");
assert_eq!(manager.authenticator.auth_count(), 1);
assert_eq!(manager.authenticator.refresh_count(), 0);
}
#[tokio::test]
async fn test_token_manager_reuses_valid_token() {
let auth = MockAuthenticator::new();
let manager = TokenManager::new(auth);
let token1 = manager.token().await.must();
assert_eq!(manager.authenticator.auth_count(), 1);
let token2 = manager.token().await.must();
assert_eq!(manager.authenticator.auth_count(), 1); assert_eq!(token1.as_str(), token2.as_str());
}
#[tokio::test]
async fn test_token_manager_refreshes_expired_token() {
let auth = MockAuthenticator::new();
let manager = TokenManager::new(auth);
let _token1 = manager.token().await.must();
assert_eq!(manager.authenticator.auth_count(), 1);
{
let mut state = manager.state.write().await;
if let Some(token) = &mut state.token {
*token = Arc::new(AccessToken::new(
"expired_token".to_string(),
"https://test.salesforce.com".to_string(),
Some(Utc::now() - Duration::hours(1)),
));
}
}
let token2 = manager.token().await.must();
assert_eq!(manager.authenticator.auth_count(), 1); assert_eq!(manager.authenticator.refresh_count(), 1); assert_eq!(token2.as_str(), "refresh_token_1");
}
#[tokio::test]
async fn test_token_manager_force_refresh() {
let auth = MockAuthenticator::new();
let manager = TokenManager::new(auth);
let token1 = manager.token().await.must();
assert_eq!(token1.as_str(), "auth_token_1");
let token2 = manager.force_refresh().await.must();
assert_eq!(token2.as_str(), "refresh_token_1");
assert_eq!(manager.authenticator.refresh_count(), 1);
}
#[tokio::test]
async fn test_token_manager_clear() {
let auth = MockAuthenticator::new();
let manager = TokenManager::new(auth);
let _token1 = manager.token().await.must();
assert_eq!(manager.authenticator.auth_count(), 1);
manager.clear().await;
let _token2 = manager.token().await.must();
assert_eq!(manager.authenticator.auth_count(), 2); }
#[tokio::test]
async fn test_token_manager_concurrent_access() {
let auth = MockAuthenticator::new();
let manager = StdArc::new(TokenManager::new(auth));
let mut handles = vec![];
for _ in 0..10 {
let manager_clone = StdArc::clone(&manager);
let handle = tokio::spawn(async move { manager_clone.token().await });
handles.push(handle);
}
for handle in handles {
let result = handle.await.must();
assert!(result.is_ok());
}
assert_eq!(manager.authenticator.auth_count(), 1);
}
#[tokio::test]
async fn test_token_manager_auth_failure() {
let auth = MockAuthenticator::with_failure();
let manager = TokenManager::new(auth);
let result = manager.token().await;
if let Err(crate::error::ForceError::Authentication(
crate::error::AuthenticationError::InvalidCredentials(msg),
)) = result
{
assert_eq!(msg, "mock auth failed");
} else {
panic!("Expected InvalidCredentials error");
}
}
#[tokio::test]
async fn test_token_manager_refresh_failure() {
let auth = MockAuthenticator::with_failure();
let manager = TokenManager::new(auth);
{
let mut state = manager.state.write().await;
state.token = Some(Arc::new(AccessToken::new(
"expired".to_string(),
"https://test.salesforce.com".to_string(),
Some(Utc::now() - Duration::hours(1)),
)));
}
let result = manager.token().await;
if let Err(crate::error::ForceError::Authentication(
crate::error::AuthenticationError::TokenRefreshFailed(msg),
)) = result
{
assert_eq!(msg, "mock refresh failed");
} else {
panic!("Expected TokenRefreshFailed error");
}
}
#[tokio::test]
async fn test_token_manager_concurrent_refresh_only_one_request() {
let auth = MockAuthenticator::new().with_delay(std::time::Duration::from_millis(50));
let manager = StdArc::new(TokenManager::new(auth));
let _ = manager.token().await.must();
assert_eq!(manager.authenticator.auth_count(), 1);
{
let mut state = manager.state.write().await;
if let Some(token) = &mut state.token {
*token = Arc::new(AccessToken::new(
"expired_token".to_string(),
"https://test.salesforce.com".to_string(),
Some(Utc::now() - Duration::hours(1)),
));
}
}
let mut handles = vec![];
for _ in 0..50 {
let manager_clone = StdArc::clone(&manager);
handles.push(tokio::spawn(
async move { manager_clone.token().await.must() },
));
}
for handle in handles {
let token = handle.await.must();
assert_eq!(token.as_str(), "refresh_token_1");
}
assert_eq!(
manager.authenticator.refresh_count(),
1,
"Should have refreshed exactly once despite concurrent load"
);
assert_eq!(manager.authenticator.auth_count(), 1);
}
#[tokio::test]
async fn test_token_manager_force_refresh_protects_against_overwrite() {
let auth = MockAuthenticator::new();
let manager = TokenManager::new(auth);
let future_ts = (Utc::now() + Duration::hours(1)).timestamp_millis();
let response = crate::auth::token::TokenResponse {
access_token: "future_token".to_string(),
instance_url: "https://test.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: future_ts.to_string(),
signature: String::new(),
expires_in: None,
refresh_token: None,
};
let future_token = AccessToken::from_response(response);
{
let mut state = manager.state.write().await;
state.token = Some(StdArc::new(future_token));
}
let result = manager.force_refresh().await.must();
assert_eq!(result.as_str(), "future_token");
let state_token = manager.token().await.must();
assert_eq!(state_token.as_str(), "future_token");
}
#[tokio::test]
async fn test_token_manager_hard_refresh_protects_against_overwrite() {
let auth = MockAuthenticator::new().with_delay(std::time::Duration::from_millis(50));
let manager = StdArc::new(TokenManager::new(auth));
let manager_clone = manager.clone();
let handle = tokio::spawn(async move { manager_clone.token().await.must() });
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let future_ts = (Utc::now() + Duration::hours(1)).timestamp_millis();
let response = crate::auth::token::TokenResponse {
access_token: "future_token".to_string(),
instance_url: "https://test.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: future_ts.to_string(),
signature: String::new(),
expires_in: None,
refresh_token: None,
};
let future_token = AccessToken::from_response(response);
{
let mut state = manager.state.write().await;
state.token = Some(StdArc::new(future_token));
}
let result = handle.await.must();
assert_eq!(result.as_str(), "future_token");
let final_token = manager.token().await.must();
assert_eq!(final_token.as_str(), "future_token");
}
#[tokio::test]
async fn test_token_manager_soft_refresh_protects_against_overwrite() {
let auth = MockAuthenticator::new().with_delay(std::time::Duration::from_millis(50));
let manager = StdArc::new(TokenManager::new(auth));
let soft_token = AccessToken::new(
"soft_token".to_string(),
"https://test.salesforce.com".to_string(),
Some(Utc::now() + Duration::seconds(30)),
);
{
let mut state = manager.state.write().await;
state.token = Some(StdArc::new(soft_token));
}
let manager_clone = manager.clone();
let handle = tokio::spawn(async move { manager_clone.token().await.must() });
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let future_ts = (Utc::now() + Duration::hours(1)).timestamp_millis();
let future_response = crate::auth::token::TokenResponse {
access_token: "future_token".to_string(),
instance_url: "https://test.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: future_ts.to_string(),
signature: String::new(),
expires_in: None,
refresh_token: None,
};
let future_token = AccessToken::from_response(future_response);
{
let mut state = manager.state.write().await;
state.token = Some(StdArc::new(future_token));
}
let result = handle.await.must();
assert_eq!(result.as_str(), "future_token");
let final_token = manager.token().await.must();
assert_eq!(final_token.as_str(), "future_token");
}
#[tokio::test]
async fn test_token_manager_update_token_state_cleared_token_rejects_non_initial() {
let auth = MockAuthenticator::new();
let manager = TokenManager::new(auth);
let _token1 = manager.token().await.must();
manager.clear().await;
let dummy_token = AccessToken::new(
"dummy".to_string(),
"https://test.salesforce.com".to_string(),
Some(Utc::now() + Duration::hours(1)),
);
let result = manager
.update_token_state(StdArc::new(dummy_token), false)
.await;
assert!(
matches!(
result,
Err(crate::error::ForceError::Authentication(
crate::error::AuthenticationError::InvalidToken
))
),
"Expected InvalidToken error after clearing and update_token_state, got: {result:?}"
);
}
#[tokio::test]
async fn test_token_manager_soft_expired_refresh_failure_returns_valid_token() {
let auth = MockAuthenticator::with_failure();
let manager = TokenManager::new(auth);
let soft_token = AccessToken::new(
"still_valid_token".to_string(),
"https://test.salesforce.com".to_string(),
Some(Utc::now() + Duration::seconds(30)), );
{
let mut state = manager.state.write().await;
state.token = Some(StdArc::new(soft_token));
}
let token = manager.token().await.must();
assert_eq!(token.as_str(), "still_valid_token");
}
#[tokio::test]
async fn test_token_manager_soft_expired_concurrent_returns_latest_token() {
let auth = MockAuthenticator::new().with_delay(std::time::Duration::from_millis(200));
let manager = StdArc::new(TokenManager::new(auth));
let soft_token = AccessToken::new(
"soft_valid_token".to_string(),
"https://test.salesforce.com".to_string(),
Some(Utc::now() + Duration::seconds(30)),
);
{
let mut state = manager.state.write().await;
state.token = Some(StdArc::new(soft_token));
}
let manager_clone = manager.clone();
let handle1 = tokio::spawn(async move { manager_clone.token().await.must() });
tokio::time::sleep(std::time::Duration::from_millis(20)).await;
let token2 = manager.token().await.must();
assert_eq!(token2.as_str(), "soft_valid_token");
let _token1 = handle1.await.must();
}
#[tokio::test]
async fn test_token_manager_equality_overwrites() {
#[derive(Debug)]
struct EqAuth(i64);
#[async_trait]
impl Authenticator for EqAuth {
async fn authenticate(&self) -> Result<AccessToken> {
let response = crate::auth::token::TokenResponse {
access_token: "new_token".to_string(),
instance_url: "https://test.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: self.0.to_string(),
signature: String::new(),
expires_in: None,
refresh_token: None,
};
Ok(AccessToken::from_response(response))
}
async fn refresh(&self) -> Result<AccessToken> {
self.authenticate().await
}
}
let fixed_ts = Utc::now().timestamp_millis();
let eq_auth = EqAuth(fixed_ts);
let eq_manager = TokenManager::new(eq_auth);
let response = crate::auth::token::TokenResponse {
access_token: "old_token".to_string(),
instance_url: "https://test.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: fixed_ts.to_string(),
signature: String::new(),
expires_in: None,
refresh_token: None,
};
let old_token = AccessToken::from_response(response);
{
let mut state = eq_manager.state.write().await;
state.token = Some(StdArc::new(old_token));
}
let result = eq_manager.force_refresh().await.must();
assert_eq!(
result.as_str(),
"new_token",
"Equality should trigger an overwrite in force_refresh"
);
let hard_eq_manager = TokenManager::new(EqAuth(fixed_ts));
let response = crate::auth::token::TokenResponse {
access_token: "hard_old_token".to_string(),
instance_url: "https://test.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: fixed_ts.to_string(),
signature: String::new(),
expires_in: Some(0), refresh_token: None,
};
let hard_old_token = AccessToken::from_response(response);
{
let mut state = hard_eq_manager.state.write().await;
state.token = Some(StdArc::new(hard_old_token));
}
let result = hard_eq_manager.token().await.must();
assert_eq!(
result.as_str(),
"new_token",
"Equality should trigger an overwrite in hard refresh"
);
let soft_eq_manager = TokenManager::new(EqAuth(fixed_ts));
let response = crate::auth::token::TokenResponse {
access_token: "soft_old_token".to_string(),
instance_url: "https://test.salesforce.com".to_string(),
token_type: "Bearer".to_string(),
issued_at: fixed_ts.to_string(),
signature: String::new(),
expires_in: Some(30), refresh_token: None,
};
let soft_old_token = AccessToken::from_response(response);
{
let mut state = soft_eq_manager.state.write().await;
state.token = Some(StdArc::new(soft_old_token));
}
let result = soft_eq_manager.token().await.must();
assert_eq!(
result.as_str(),
"new_token",
"Equality should trigger an overwrite in soft refresh"
);
}
#[tokio::test]
async fn test_token_manager_force_refresh_stampede() {
let auth = MockAuthenticator::new().with_delay(std::time::Duration::from_millis(50));
let manager = StdArc::new(TokenManager::new(auth));
let _ = manager.token().await.must();
let mut handles = Vec::new();
for _ in 0..100 {
let manager_clone = manager.clone();
handles.push(tokio::spawn(async move {
manager_clone.force_refresh().await.must()
}));
}
for handle in handles {
let _ = handle.await.must();
}
let refresh_count = manager.authenticator.refresh_count();
assert_eq!(
refresh_count, 1,
"👺 Havoc: force_refresh triggered a stampede!"
);
}
}