use crate::RetryConfig;
use crate::auth::BasicAuth;
use crate::auth::StreamingCredentialsProvider;
use crate::auth_management::credentials_management_utils;
use crate::errors::{ErrorKind, RedisError};
use crate::types::RedisResult;
use azure_core::credentials::{AccessToken, Secret, TokenCredential};
use azure_core::time::OffsetDateTime;
use azure_identity::{
ClientCertificateCredential, ClientCertificateCredentialOptions, ClientSecretCredential,
ClientSecretCredentialOptions, DeveloperToolsCredential, DeveloperToolsCredentialOptions,
ManagedIdentityCredential, ManagedIdentityCredentialOptions,
};
use backon::{ExponentialBuilder, Retryable};
use futures_util::{Stream, StreamExt};
use log::{debug, error, warn};
use std::pin::Pin;
use std::sync::{Arc, Mutex, RwLock};
use tokio::sync::mpsc::Sender;
pub const REDIS_SCOPE_DEFAULT: &str = "https://redis.azure.com/.default";
const TOKEN_REFRESH_BUFFER_SECS: u64 = 240;
#[derive(Debug, Clone)]
pub struct ClientCertificate {
pub base64_pkcs12: String,
pub password: Option<String>,
}
type Subscriptions = Vec<Sender<RedisResult<BasicAuth>>>;
type SharedSubscriptions = Arc<Mutex<Subscriptions>>;
#[derive(Clone)]
pub struct EntraIdCredentialsProvider {
credential_provider: Arc<dyn TokenCredential + Send + Sync>,
scopes: Vec<String>,
background_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
subscribers: SharedSubscriptions,
current_credentials: Arc<RwLock<Option<BasicAuth>>>,
}
impl EntraIdCredentialsProvider {
fn validate_scopes(scopes: &[String]) -> RedisResult<()> {
if scopes.is_empty() {
return Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"Scopes cannot be empty for Entra ID authentication",
)));
}
for scope in scopes {
if scope.trim().is_empty() {
return Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"Scope cannot be empty or whitespace-only",
)));
}
}
Ok(())
}
fn convert_credentials(username: String, access_token: &AccessToken) -> BasicAuth {
BasicAuth {
username,
password: access_token.token.secret().to_string(),
}
}
fn convert_error(err: azure_core::Error) -> RedisError {
RedisError::from((
ErrorKind::AuthenticationFailed,
"Entra ID authentication failed",
format!("{err}"),
))
}
fn convert_error_ref(err: &azure_core::Error) -> RedisError {
RedisError::from((
ErrorKind::AuthenticationFailed,
"Entra ID authentication failed",
format!("{err}"),
))
}
fn unwrap_credential<T: TokenCredential>(credential: std::sync::Arc<T>) -> RedisResult<T> {
std::sync::Arc::try_unwrap(credential).map_err(|_| {
RedisError::from((
ErrorKind::AuthenticationFailed,
"[azure_identity]: Unexpected shared ownership of credentials provider.",
))
})
}
async fn notify_subscribers(
subscribers_arc: &SharedSubscriptions,
username: &str,
token_response: Result<AccessToken, azure_core::Error>,
) {
let subscribers = {
let mut guard = subscribers_arc
.lock()
.expect("could not acquire lock for subscribers");
guard.retain(|sender| !sender.is_closed());
guard.clone()
};
futures_util::future::join_all(subscribers.iter().map(|sender| {
let token_response = token_response.as_ref();
let response = match token_response {
Ok(access_token) => {
Ok(Self::convert_credentials(username.to_owned(), access_token))
}
Err(error) => Err(Self::convert_error_ref(error)),
};
sender.send(response)
}))
.await;
}
fn start_refresh_service<F>(
&mut self,
retry_config: RetryConfig,
compute_sleep_duration_on_success: F,
) where
F: Fn(&AccessToken) -> std::time::Duration + Send + Sync + 'static,
{
if self.background_handle.lock().unwrap().is_some() {
return;
}
let subscribers_arc = Arc::clone(&self.subscribers);
let current_credentials_arc = Arc::clone(&self.current_credentials);
let credential_provider_arc = Arc::clone(&self.credential_provider);
let scopes = self.scopes.clone();
*self.background_handle.lock().unwrap() = Some(tokio::spawn(async move {
let scopes: Vec<&str> = scopes.iter().map(|s| s.as_str()).collect();
let mut next_sleep_duration;
let mut username = "default".to_string();
let RetryConfig {
exponent_base,
min_delay,
max_delay,
number_of_retries,
} = retry_config;
let mut strategy = ExponentialBuilder::default()
.with_factor(exponent_base)
.with_min_delay(min_delay)
.with_max_times(number_of_retries)
.with_jitter();
if let Some(mex_delay) = max_delay {
strategy = strategy.with_max_delay(mex_delay);
}
loop {
debug!("Refreshing token.");
let get_auth = || async { credential_provider_arc.get_token(&scopes, None).await };
let token_response = get_auth
.retry(strategy)
.sleep(|duration| async move { tokio::time::sleep(duration).await })
.notify(|err, duration| warn!("An error `{err}` occurred while refreshing the token. Sleeping for {duration:?}"))
.await;
if let Ok(ref access_token) = token_response {
username = match credentials_management_utils::extract_oid_from_jwt(
access_token.token.secret(),
) {
Ok(object_id) => object_id,
Err(error) => {
warn!("Failed to extract OID: {error}");
"default".to_string()
}
};
*current_credentials_arc.write().unwrap() =
Some(Self::convert_credentials(username.clone(), access_token));
next_sleep_duration = compute_sleep_duration_on_success(access_token);
} else {
error!("Maximum token refresh attempts reached. Stopping token refresh.");
Self::notify_subscribers(&subscribers_arc, &username, token_response).await;
break;
}
Self::notify_subscribers(&subscribers_arc, &username, token_response).await;
tokio::time::sleep(std::time::Duration::from_millis(
next_sleep_duration.as_millis() as u64,
))
.await;
}
}));
}
fn stop(&mut self) {
if let Some(handle) = self.background_handle.lock().unwrap().take() {
handle.abort();
}
}
pub fn new_developer_tools() -> RedisResult<Self> {
Self::new_developer_tools_with_scopes(vec![REDIS_SCOPE_DEFAULT.to_string()], None)
}
pub fn new_developer_tools_with_scopes(
scopes: Vec<String>,
options: Option<DeveloperToolsCredentialOptions>,
) -> RedisResult<Self> {
Self::validate_scopes(&scopes)?;
let credential_provider =
DeveloperToolsCredential::new(options).map_err(Self::convert_error)?;
Ok(Self {
credential_provider: Arc::new(Self::unwrap_credential(credential_provider)?),
scopes,
background_handle: Default::default(),
subscribers: Default::default(),
current_credentials: Default::default(),
})
}
pub fn new_client_secret(
tenant_id: String,
client_id: String,
client_secret: String,
) -> RedisResult<Self> {
Self::new_client_secret_with_scopes(
tenant_id,
client_id,
client_secret,
vec![REDIS_SCOPE_DEFAULT.to_string()],
None,
)
}
pub fn new_client_secret_with_scopes(
tenant_id: String,
client_id: String,
client_secret: String,
scopes: Vec<String>,
options: Option<ClientSecretCredentialOptions>,
) -> RedisResult<Self> {
Self::validate_scopes(&scopes)?;
let credential_provider =
ClientSecretCredential::new(&tenant_id, client_id, client_secret.into(), options)
.map_err(Self::convert_error)?;
Ok(Self {
credential_provider: Arc::new(Self::unwrap_credential(credential_provider)?),
scopes,
background_handle: Default::default(),
subscribers: Default::default(),
current_credentials: Default::default(),
})
}
pub fn new_client_certificate(
tenant_id: String,
client_id: String,
client_certificate: ClientCertificate,
) -> RedisResult<Self> {
Self::new_client_certificate_with_scopes(
tenant_id,
client_id,
client_certificate,
vec![REDIS_SCOPE_DEFAULT.to_string()],
None,
)
}
pub fn new_client_certificate_with_scopes(
tenant_id: String,
client_id: String,
client_certificate: ClientCertificate,
scopes: Vec<String>,
mut options: Option<ClientCertificateCredentialOptions>,
) -> RedisResult<Self> {
Self::validate_scopes(&scopes)?;
if let Some(password) = client_certificate.password {
if let Some(ref mut opts) = options {
opts.password = Some(Secret::new(password));
} else {
options = Some(ClientCertificateCredentialOptions {
password: Some(Secret::new(password)),
..Default::default()
});
}
}
let credential_provider = ClientCertificateCredential::new(
tenant_id,
client_id,
client_certificate.base64_pkcs12.into(),
options,
)
.map_err(Self::convert_error)?;
Ok(Self {
credential_provider: Arc::new(Self::unwrap_credential(credential_provider)?),
scopes,
background_handle: Default::default(),
subscribers: Default::default(),
current_credentials: Default::default(),
})
}
pub fn new_system_assigned_managed_identity() -> RedisResult<Self> {
Self::new_system_assigned_managed_identity_with_scopes(
vec![REDIS_SCOPE_DEFAULT.to_string()],
None,
)
}
pub fn new_system_assigned_managed_identity_with_scopes(
scopes: Vec<String>,
options: Option<ManagedIdentityCredentialOptions>,
) -> RedisResult<Self> {
Self::validate_scopes(&scopes)?;
let credential_provider =
ManagedIdentityCredential::new(options).map_err(Self::convert_error)?;
Ok(Self {
credential_provider: Arc::new(Self::unwrap_credential(credential_provider)?),
scopes,
background_handle: Default::default(),
subscribers: Default::default(),
current_credentials: Default::default(),
})
}
pub fn new_user_assigned_managed_identity() -> RedisResult<Self> {
Self::new_user_assigned_managed_identity_with_scopes(
vec![REDIS_SCOPE_DEFAULT.to_string()],
None,
)
}
pub fn new_user_assigned_managed_identity_with_scopes(
scopes: Vec<String>,
options: Option<ManagedIdentityCredentialOptions>,
) -> RedisResult<Self> {
Self::validate_scopes(&scopes)?;
let credential_provider =
ManagedIdentityCredential::new(options).map_err(Self::convert_error)?;
Ok(Self {
credential_provider: Arc::new(Self::unwrap_credential(credential_provider)?),
scopes,
background_handle: Default::default(),
subscribers: Default::default(),
current_credentials: Default::default(),
})
}
pub fn new_with_credential(
credential_provider: Arc<dyn TokenCredential + Send + Sync>,
scopes: Vec<String>,
) -> RedisResult<Self> {
Self::validate_scopes(&scopes)?;
Ok(Self {
credential_provider,
scopes,
background_handle: Default::default(),
subscribers: Default::default(),
current_credentials: Default::default(),
})
}
pub fn start(&mut self, retry_config: RetryConfig) {
self.start_refresh_service(retry_config, |access_token| {
let remaining = access_token.expires_on - OffsetDateTime::now_utc();
let remaining_duration = match remaining.try_into() {
Ok(duration) => duration,
Err(_) => std::time::Duration::from_secs(0),
};
remaining_duration
.checked_sub(std::time::Duration::from_secs(TOKEN_REFRESH_BUFFER_SECS))
.unwrap_or_else(|| {
warn!("The token is about to expire! Refreshing...");
std::time::Duration::from_secs(0)
})
});
}
}
impl StreamingCredentialsProvider for EntraIdCredentialsProvider {
fn subscribe(&self) -> Pin<Box<dyn Stream<Item = RedisResult<BasicAuth>> + Send + 'static>> {
let (tx, rx) = tokio::sync::mpsc::channel::<RedisResult<BasicAuth>>(1);
self.subscribers
.lock()
.expect("could not acquire lock for subscribers")
.push(tx);
let stream = futures_util::stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(|item| (item, rx))
});
if let Some(creds) = self.current_credentials.read().unwrap().clone() {
futures_util::stream::once(async move { Ok(creds) })
.chain(stream)
.boxed()
} else {
stream.boxed()
}
}
}
impl std::fmt::Debug for EntraIdCredentialsProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EntraIdCredentialsProvider")
.field("scopes", &self.scopes)
.field("credential", &"<TokenCredential>")
.finish()
}
}
impl Drop for EntraIdCredentialsProvider {
fn drop(&mut self) {
self.stop();
}
}
#[cfg(all(feature = "entra-id", test))]
mod tests {
use super::*;
#[test]
fn test_entra_id_provider_creation() {
let _default_provider = EntraIdCredentialsProvider::new_developer_tools();
assert!(_default_provider.is_ok());
let _client_secret_provider = EntraIdCredentialsProvider::new_client_secret(
"tenant".to_string(),
"client".to_string(),
"secret".to_string(),
);
assert!(_client_secret_provider.is_ok());
let _managed_identity_provider =
EntraIdCredentialsProvider::new_system_assigned_managed_identity();
assert!(_managed_identity_provider.is_ok());
}
#[test]
fn test_scope_validation() {
let result = EntraIdCredentialsProvider::new_developer_tools_with_scopes(vec![], None);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Scopes cannot be empty")
);
let result =
EntraIdCredentialsProvider::new_developer_tools_with_scopes(vec!["".to_string()], None);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Scope cannot be empty")
);
let result = EntraIdCredentialsProvider::new_developer_tools_with_scopes(
vec![" ".to_string()],
None,
);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Scope cannot be empty")
);
}
#[test]
fn test_custom_scopes() {
let custom_scopes = vec!["https://custom.scope/.default".to_string()];
let provider = EntraIdCredentialsProvider::new_developer_tools_with_scopes(
custom_scopes.clone(),
None,
)
.unwrap();
assert_eq!(provider.scopes, custom_scopes);
}
}
#[cfg(all(feature = "entra-id", test))]
mod entra_id_mock_tests {
use crate::{
EntraIdCredentialsProvider, REDIS_SCOPE_DEFAULT, RetryConfig, StreamingCredentialsProvider,
};
use azure_core::Error as AzureError;
use azure_core::credentials::{AccessToken, Secret, TokenCredential};
use azure_core::time::{Duration, OffsetDateTime};
use futures_util::StreamExt;
use std::collections::VecDeque;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock, Once};
use tokio::sync::Mutex;
static INIT_LOGGER: Once = Once::new();
fn init_logger() {
INIT_LOGGER.call_once(|| {
let mut builder = env_logger::builder();
builder.is_test(true);
if std::env::var("RUST_LOG").is_err() {
builder.filter_level(log::LevelFilter::Debug);
}
builder.init();
});
}
const TOKEN_PAYLOAD: &str = "eyJvaWQiOiIxMjM0NTY3OC05YWJjLWRlZi0xMjM0LTU2Nzg5YWJjZGVmMCJ9"; const TOKEN_SIGNATURE: &str = "signature";
const OID_CLAIM_VALUE: &str = "12345678-9abc-def-1234-56789abcdef0";
static MOCKED_TOKEN: LazyLock<String> =
LazyLock::new(|| format!("mock_jwt_token.{}.{}", TOKEN_PAYLOAD, TOKEN_SIGNATURE));
static MOCKED_TOKEN_1: LazyLock<String> =
LazyLock::new(|| format!("mock_jwt_token1.{}.{}", TOKEN_PAYLOAD, TOKEN_SIGNATURE));
static MOCKED_TOKEN_2: LazyLock<String> =
LazyLock::new(|| format!("mock_jwt_token2.{}.{}", TOKEN_PAYLOAD, TOKEN_SIGNATURE));
static MOCKED_TOKEN_3: LazyLock<String> =
LazyLock::new(|| format!("mock_jwt_token3.{}.{}", TOKEN_PAYLOAD, TOKEN_SIGNATURE));
#[derive(Debug)]
struct MockTokenCredential {
call_count: Arc<AtomicUsize>,
responses: Arc<Mutex<VecDeque<Result<AccessToken, AzureError>>>>,
}
impl MockTokenCredential {
fn success() -> Self {
let token = AccessToken {
token: Secret::new(MOCKED_TOKEN.as_str()),
expires_on: OffsetDateTime::now_utc() + Duration::hours(1),
};
Self {
call_count: Arc::new(AtomicUsize::new(0)),
responses: Arc::new(Mutex::new(VecDeque::from(vec![Ok(token)]))),
}
}
fn failure() -> Self {
let error = AzureError::new(
azure_core::error::ErrorKind::Credential,
"Authentication failed due to invalid credentials".to_string(),
);
Self {
call_count: Arc::new(AtomicUsize::new(0)),
responses: Arc::new(Mutex::new(VecDeque::from(vec![Err(error)]))),
}
}
fn alternating_fail_success() -> Self {
let error = AzureError::new(
azure_core::error::ErrorKind::Credential,
"Temporary failure".to_string(),
);
let token = AccessToken {
token: Secret::new(MOCKED_TOKEN.as_str()),
expires_on: OffsetDateTime::now_utc() + Duration::hours(1),
};
Self {
call_count: Arc::new(AtomicUsize::new(0)),
responses: Arc::new(Mutex::new(VecDeque::from(vec![Err(error), Ok(token)]))),
}
}
fn multiple_tokens() -> Self {
let time_now = OffsetDateTime::now_utc();
let tokens = vec![
Ok(AccessToken {
token: Secret::new(MOCKED_TOKEN_1.as_str()),
expires_on: time_now + Duration::seconds(1),
}),
Ok(AccessToken {
token: Secret::new(MOCKED_TOKEN_2.as_str()),
expires_on: time_now + Duration::seconds(2),
}),
Ok(AccessToken {
token: Secret::new(MOCKED_TOKEN_3.as_str()),
expires_on: time_now + Duration::seconds(3),
}),
];
Self {
call_count: Arc::new(AtomicUsize::new(0)),
responses: Arc::new(Mutex::new(VecDeque::from(tokens))),
}
}
}
#[async_trait::async_trait]
impl TokenCredential for MockTokenCredential {
async fn get_token(
&self,
_scopes: &[&str],
_options: Option<azure_core::credentials::TokenRequestOptions<'_>>,
) -> azure_core::Result<AccessToken> {
self.call_count.fetch_add(1, Ordering::SeqCst);
let mut responses = self.responses.lock().await;
match responses.pop_front() {
Some(Ok(token)) => {
responses.push_back(Ok(token.clone()));
Ok(token)
}
Some(Err(err)) => {
responses.push_back(Err(AzureError::new(
azure_core::error::ErrorKind::Credential,
"Mock authentication failed",
)));
Err(err)
}
None => {
unreachable!("No more responses");
}
}
}
}
fn create_mock_entra_id_credentials_provider(
mock_credential: MockTokenCredential,
scopes: Vec<String>,
) -> EntraIdCredentialsProvider {
EntraIdCredentialsProvider::new_with_credential(Arc::new(mock_credential), scopes).unwrap()
}
#[tokio::test]
async fn test_mock_successful_authentication() {
init_logger();
let mock_credential = MockTokenCredential::success();
let call_count_ref = mock_credential.call_count.clone();
let mut provider = create_mock_entra_id_credentials_provider(
mock_credential,
vec![REDIS_SCOPE_DEFAULT.to_string()],
);
provider.start(RetryConfig::default());
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
assert!(call_count_ref.load(Ordering::SeqCst) > 0);
let mut stream = provider.subscribe();
let credentials = stream.next().await.unwrap().unwrap();
assert!(!credentials.username.is_empty());
assert!(!credentials.password.is_empty());
assert_eq!(credentials.username, OID_CLAIM_VALUE);
assert_eq!(credentials.password, MOCKED_TOKEN.as_str());
}
#[tokio::test]
async fn test_mock_authentication_failure() {
init_logger();
let mock_credential = MockTokenCredential::failure();
let call_count_ref = mock_credential.call_count.clone();
let mut provider = create_mock_entra_id_credentials_provider(
mock_credential,
vec![REDIS_SCOPE_DEFAULT.to_string()],
);
provider.start(
RetryConfig::default()
.set_number_of_retries(1) .set_min_delay(std::time::Duration::from_millis(10))
.set_max_delay(std::time::Duration::from_millis(100))
.set_exponent_base(2.0),
);
let mut stream = provider.subscribe();
if let Some(result) = stream.next().await {
assert!(call_count_ref.load(Ordering::SeqCst) > 0);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("authentication failed")
);
}
}
#[tokio::test]
async fn test_mock_retry_mechanism() {
init_logger();
let mock_credential = MockTokenCredential::alternating_fail_success();
let call_count_ref = mock_credential.call_count.clone();
let mut provider = create_mock_entra_id_credentials_provider(
mock_credential,
vec![REDIS_SCOPE_DEFAULT.to_string()],
);
provider.start(RetryConfig::default());
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
assert!(call_count_ref.load(Ordering::SeqCst) >= 2);
let mut stream = provider.subscribe();
let credentials = stream.next().await.unwrap().unwrap();
assert_eq!(credentials.username, OID_CLAIM_VALUE);
assert_eq!(credentials.password, MOCKED_TOKEN.as_str());
}
#[tokio::test]
async fn test_mock_multiple_subscribers() {
init_logger();
let mock_credential = MockTokenCredential::multiple_tokens();
let call_count_ref = mock_credential.call_count.clone();
let mut provider = create_mock_entra_id_credentials_provider(
mock_credential,
vec![REDIS_SCOPE_DEFAULT.to_string()],
);
provider.start(RetryConfig::default());
let mut stream1 = provider.subscribe();
let mut stream2 = provider.subscribe();
let mut stream3 = provider.subscribe();
let credentials1 = stream1.next().await.unwrap().unwrap();
let credentials2 = stream2.next().await.unwrap().unwrap();
let credentials3 = stream3.next().await.unwrap().unwrap();
assert_eq!(credentials1.password, credentials2.password);
assert_eq!(credentials2.password, credentials3.password);
assert_eq!(credentials1.password, MOCKED_TOKEN_1.as_str());
assert_eq!(call_count_ref.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_mock_multiple_tokens_over_time() {
init_logger();
let mock_credential = MockTokenCredential::multiple_tokens();
let call_count_ref = mock_credential.call_count.clone();
let mut provider = create_mock_entra_id_credentials_provider(
mock_credential,
vec![REDIS_SCOPE_DEFAULT.to_string()],
);
provider.start(RetryConfig::default());
let mut stream = provider.subscribe();
let credentials = stream.next().await.unwrap().unwrap();
assert!(call_count_ref.load(Ordering::SeqCst) >= 1);
assert_eq!(credentials.username, OID_CLAIM_VALUE);
assert_eq!(credentials.password, MOCKED_TOKEN_1.as_str());
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let credentials = stream.next().await.unwrap().unwrap();
assert!(call_count_ref.load(Ordering::SeqCst) >= 2);
assert_eq!(credentials.username, OID_CLAIM_VALUE);
assert_eq!(credentials.password, MOCKED_TOKEN_2.as_str());
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
let credentials = stream.next().await.unwrap().unwrap();
assert!(call_count_ref.load(Ordering::SeqCst) >= 3);
assert_eq!(credentials.username, OID_CLAIM_VALUE);
assert_eq!(credentials.password, MOCKED_TOKEN_3.as_str());
}
#[test]
fn test_mock_scope_validation() {
use std::panic;
let result = panic::catch_unwind(|| {
create_mock_entra_id_credentials_provider(MockTokenCredential::success(), Vec::new());
});
assert!(
result.is_err(),
"Expected `create_mock_entra_id_credentials_provider` to panic, but it did not."
);
assert!(
result
.unwrap_err()
.downcast_ref::<String>()
.unwrap()
.contains("Scopes cannot be empty")
);
}
#[tokio::test]
async fn test_mock_subscriber_cleanup() {
init_logger();
let mock_credential = MockTokenCredential::multiple_tokens();
let mut provider = create_mock_entra_id_credentials_provider(
mock_credential,
vec![REDIS_SCOPE_DEFAULT.to_string()],
);
provider.start(RetryConfig::default());
let mut stream1 = provider.subscribe();
let mut stream2 = provider.subscribe();
let mut stream3 = provider.subscribe();
assert_eq!(provider.subscribers.lock().unwrap().len(), 3);
let credentials1 = stream1.next().await.unwrap().unwrap();
let credentials2 = stream2.next().await.unwrap().unwrap();
let credentials3 = stream3.next().await.unwrap().unwrap();
assert_eq!(credentials1.password, MOCKED_TOKEN_1.as_str());
assert_eq!(credentials1.password, credentials2.password);
assert_eq!(credentials2.password, credentials3.password);
drop(stream3);
let credentials1 = stream1.next().await.unwrap().unwrap();
let credentials2 = stream2.next().await.unwrap().unwrap();
assert_eq!(credentials1.password, MOCKED_TOKEN_2.as_str());
assert_eq!(credentials1.password, credentials2.password);
assert_eq!(
provider.subscribers.lock().unwrap().len(),
2,
"Dropped subscriber should be removed"
);
drop(stream2);
assert_eq!(
stream1.next().await.unwrap().unwrap().password,
MOCKED_TOKEN_3.as_str()
);
assert_eq!(
provider.subscribers.lock().unwrap().len(),
1,
"Only one subscriber should remain"
);
}
#[tokio::test]
async fn test_mock_provider_cleanup() {
init_logger();
let mock_credential = MockTokenCredential::success();
let mut provider = create_mock_entra_id_credentials_provider(
mock_credential,
vec![REDIS_SCOPE_DEFAULT.to_string()],
);
provider.start(RetryConfig::default());
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
drop(provider);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
}
}