#![cfg(feature = "token-based-authentication")]
mod support;
use futures_util::{Stream, StreamExt};
use redis::auth::{BasicAuth, StreamingCredentialsProvider};
use redis::{ErrorKind, RedisError, RedisResult};
use std::pin::Pin;
use std::sync::Once;
use support::*;
static INIT_LOGGER: Once = Once::new();
fn init_logger() {
INIT_LOGGER.call_once(|| {
let mut builder = env_logger::builder();
builder.is_test(true);
if std::env::var("RUST_LOG").is_err() {
builder.filter_level(log::LevelFilter::Debug);
}
builder.init();
});
}
#[derive(Clone)]
struct ImmediatelyFailingCredentialsProvider;
impl StreamingCredentialsProvider for ImmediatelyFailingCredentialsProvider {
fn subscribe(&self) -> Pin<Box<dyn Stream<Item = RedisResult<BasicAuth>> + Send + 'static>> {
futures_util::stream::once(async move {
Err(RedisError::from((
ErrorKind::AuthenticationFailed,
"Unable to fetch token from credentials provider",
)))
})
.boxed()
}
}
#[derive(Clone)]
struct EmptyStreamCredentialsProvider;
impl StreamingCredentialsProvider for EmptyStreamCredentialsProvider {
fn subscribe(&self) -> Pin<Box<dyn Stream<Item = RedisResult<BasicAuth>> + Send + 'static>> {
futures_util::stream::empty().boxed()
}
}
#[derive(Clone)]
struct OneTimeCredentialsProvider;
impl StreamingCredentialsProvider for OneTimeCredentialsProvider {
fn subscribe(&self) -> Pin<Box<dyn Stream<Item = RedisResult<BasicAuth>> + Send + 'static>> {
futures_util::stream::once(async move {
Ok(BasicAuth::new("default".to_string(), "".to_string()))
})
.boxed()
}
}
#[derive(Clone)]
struct DelayedFailureCredentialsProvider;
impl StreamingCredentialsProvider for DelayedFailureCredentialsProvider {
fn subscribe(&self) -> Pin<Box<dyn Stream<Item = RedisResult<BasicAuth>> + Send + 'static>> {
futures_util::stream::iter(vec![
Ok(BasicAuth::new("default".to_string(), "".to_string())),
Err(RedisError::from((
ErrorKind::AuthenticationFailed,
"Token refresh failed after max retries",
))),
])
.boxed()
}
}
#[cfg(test)]
mod credentials_provider_failures_tests {
use super::*;
use futures_time::task::sleep;
use test_macros::async_test;
#[async_test]
async fn test_connection_fails_when_initial_credentials_request_returns_error() {
init_logger();
let ctx = TestContext::new();
let provider = ImmediatelyFailingCredentialsProvider;
let config = redis::AsyncConnectionConfig::new().set_credentials_provider(provider);
let result = ctx
.client
.get_multiplexed_async_connection_with_config(&config)
.await;
assert!(
result.is_err(),
"Connection should fail when the initial credentials request fails."
);
let err = result.unwrap_err();
assert_eq!(err.kind(), ErrorKind::AuthenticationFailed);
}
#[async_test]
async fn test_connection_fails_when_credentials_stream_closes() {
init_logger();
let ctx = TestContext::new();
let provider = EmptyStreamCredentialsProvider;
let config = redis::AsyncConnectionConfig::new().set_credentials_provider(provider);
let result = ctx
.client
.get_multiplexed_async_connection_with_config(&config)
.await;
assert!(
result.is_err(),
"Connection should fail when the credentials stream closes."
);
let err = result.unwrap_err();
assert_eq!(err.kind(), ErrorKind::AuthenticationFailed);
}
#[async_test]
async fn test_connection_renders_unusable_when_the_subscription_stream_closes() {
init_logger();
let ctx = TestContext::new();
let provider = OneTimeCredentialsProvider;
let config = redis::AsyncConnectionConfig::new().set_credentials_provider(provider);
let mut con = ctx
.client
.get_multiplexed_async_connection_with_config(&config)
.await
.expect("Initial connection should succeed.");
let result: RedisResult<String> = redis::cmd("PING").query_async(&mut con).await;
assert!(result.is_ok(), "PING should succeed.");
sleep(std::time::Duration::from_millis(100).into()).await;
let result: RedisResult<String> = redis::cmd("PING").query_async(&mut con).await;
assert!(
result.is_err(),
"Commands should fail after the subscription stream closes unexpectedly."
);
let err = result.unwrap_err();
assert_eq!(err.kind(), ErrorKind::AuthenticationFailed);
assert!(
err.to_string().contains("re-authentication failure"),
"Error message should mention re-authentication failure: {err}"
);
}
#[async_test]
async fn test_connection_renders_unusable_when_the_subscription_stream_closes_after_an_error() {
init_logger();
let ctx = TestContext::new();
let provider = DelayedFailureCredentialsProvider;
let config = redis::AsyncConnectionConfig::new().set_credentials_provider(provider);
let mut con = ctx
.client
.get_multiplexed_async_connection_with_config(&config)
.await
.expect("Initial connection should succeed.");
let result: RedisResult<String> = redis::cmd("PING").query_async(&mut con).await;
assert!(result.is_ok(), "PING should succeed.");
sleep(std::time::Duration::from_millis(100).into()).await;
let result: RedisResult<String> = redis::cmd("PING").query_async(&mut con).await;
assert!(
result.is_err(),
"Commands should fail after the subscription stream returns error."
);
let err = result.unwrap_err();
assert_eq!(err.kind(), ErrorKind::AuthenticationFailed);
assert!(
err.to_string().contains("re-authentication failure"),
"Error message should mention re-authentication failure: {err}"
);
}
#[cfg(feature = "cluster-async")]
mod cluster {
use super::*;
use redis::cluster::ClusterClientBuilder;
#[async_test]
async fn test_cluster_connection_fails_when_credentials_provider_returns_error() {
init_logger();
let cluster = TestClusterContext::new_with_cluster_client_builder(
|builder: ClusterClientBuilder| {
builder.set_credentials_provider(ImmediatelyFailingCredentialsProvider)
},
);
let result = cluster.client.get_async_connection().await;
assert!(
result.is_err(),
"Cluster connection should fail when the credentials provider returns an error."
);
let err = result.err().unwrap();
assert_eq!(err.kind(), ErrorKind::Io);
}
#[async_test]
async fn test_cluster_connection_fails_when_credentials_stream_is_empty() {
init_logger();
let cluster = TestClusterContext::new_with_cluster_client_builder(
|builder: ClusterClientBuilder| {
builder.set_credentials_provider(EmptyStreamCredentialsProvider)
},
);
let result = cluster.client.get_async_connection().await;
assert!(
result.is_err(),
"Cluster connection should fail when the credentials stream closes without yielding."
);
let err = result.err().unwrap();
assert_eq!(err.kind(), ErrorKind::Io);
}
}
}