use std::sync::Arc;
use fred::clients::{Client, Pool};
use fred::interfaces::{ClientLike, EventInterface, PubsubInterface, StreamsInterface};
#[cfg(feature = "credential-provider")]
use fred::types::config::CredentialProvider;
#[cfg(any(
feature = "tls-rustls",
feature = "tls-rustls-ring",
feature = "tls-native-tls"
))]
use fred::types::config::TlsConfig;
use fred::types::config::{Config, ServerConfig};
use ruststream::{Broker, DescribeServer, ServerSpec, Subscribe};
use tokio::sync::OnceCell;
use crate::{
error::RedisError,
list::{RedisList, RedisListPublisher, RedisListSubscriber},
publisher::RedisPublisher,
pubsub::{PubSubMode, RedisPubSub, RedisPubSubPublisher, RedisPubSubSubscriber},
stream::RedisStream,
subscriber::RedisSubscriber,
};
const DEFAULT_POOL_SIZE: usize = 4;
#[derive(Debug, Clone)]
enum Topology {
Standalone(String),
Cluster(Vec<String>),
Sentinel { service: String, hosts: Vec<String> },
Preconnected,
}
fn parse_server(addr: &str, default_port: u16) -> Result<(String, u16), RedisError> {
let trimmed = addr
.trim()
.trim_start_matches("rediss://")
.trim_start_matches("redis://");
let (host, port) = match trimmed.rsplit_once(':') {
Some((host, port)) => {
let port = port.parse::<u16>().map_err(|_| {
RedisError::Connect(format!("invalid port in redis address `{addr}`").into())
})?;
(host, port)
}
None => (trimmed, default_port),
};
if host.is_empty() {
return Err(RedisError::Connect(
format!("missing host in redis address `{addr}`").into(),
));
}
Ok((host.to_owned(), port))
}
fn parse_servers(addrs: &[String], default_port: u16) -> Result<Vec<(String, u16)>, RedisError> {
if addrs.is_empty() {
return Err(RedisError::Connect("no redis addresses provided".into()));
}
addrs
.iter()
.map(|addr| parse_server(addr, default_port))
.collect()
}
#[derive(Clone, Default)]
struct AuthConfig {
username: Option<String>,
password: Option<String>,
#[cfg(feature = "sentinel-auth")]
sentinel_username: Option<String>,
#[cfg(feature = "sentinel-auth")]
sentinel_password: Option<String>,
#[cfg(any(
feature = "tls-rustls",
feature = "tls-rustls-ring",
feature = "tls-native-tls"
))]
tls: Option<TlsConfig>,
#[cfg(feature = "credential-provider")]
credential_provider: Option<Arc<dyn CredentialProvider>>,
}
impl std::fmt::Debug for AuthConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut s = f.debug_struct("AuthConfig");
s.field("username", &self.username);
s.field("password", &self.password.as_ref().map(|_| "<redacted>"));
#[cfg(feature = "sentinel-auth")]
{
s.field("sentinel_username", &self.sentinel_username);
s.field(
"sentinel_password",
&self.sentinel_password.as_ref().map(|_| "<redacted>"),
);
}
#[cfg(any(
feature = "tls-rustls",
feature = "tls-rustls-ring",
feature = "tls-native-tls"
))]
s.field("tls", &self.tls.as_ref().map(|_| "<configured>"));
#[cfg(feature = "credential-provider")]
s.field(
"credential_provider",
&self.credential_provider.as_ref().map(|_| "<configured>"),
);
s.finish()
}
}
#[derive(Clone)]
pub struct RedisBroker {
pool: Arc<OnceCell<Pool>>,
topology: Topology,
pool_size: usize,
default_group: Option<String>,
auth: AuthConfig,
}
impl std::fmt::Debug for RedisBroker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisBroker")
.field("topology", &self.topology)
.field("pool_size", &self.pool_size)
.field("default_group", &self.default_group)
.field("auth", &self.auth)
.finish_non_exhaustive()
}
}
impl RedisBroker {
#[must_use]
pub fn standalone(url: impl Into<String>) -> Self {
Self::with_topology(Topology::Standalone(url.into()))
}
#[must_use]
pub fn cluster(nodes: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::with_topology(Topology::Cluster(
nodes.into_iter().map(Into::into).collect(),
))
}
#[must_use]
pub fn sentinel(
service: impl Into<String>,
sentinels: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
Self::with_topology(Topology::Sentinel {
service: service.into(),
hosts: sentinels.into_iter().map(Into::into).collect(),
})
}
fn with_topology(topology: Topology) -> Self {
Self {
pool: Arc::new(OnceCell::new()),
topology,
pool_size: DEFAULT_POOL_SIZE,
default_group: None,
auth: AuthConfig::default(),
}
}
#[must_use]
pub const fn pool(mut self, size: usize) -> Self {
self.pool_size = size;
self
}
#[must_use]
pub fn default_group(mut self, group: impl Into<String>) -> Self {
self.default_group = Some(group.into());
self
}
#[must_use]
pub fn credentials(mut self, username: impl Into<String>, password: impl Into<String>) -> Self {
self.auth.username = Some(username.into());
self.auth.password = Some(password.into());
self
}
#[must_use]
pub fn password(mut self, password: impl Into<String>) -> Self {
self.auth.password = Some(password.into());
self
}
#[cfg(any(
feature = "tls-rustls",
feature = "tls-rustls-ring",
feature = "tls-native-tls"
))]
#[must_use]
pub fn tls(mut self, tls: impl Into<TlsConfig>) -> Self {
self.auth.tls = Some(tls.into());
self
}
#[cfg(feature = "sentinel-auth")]
#[must_use]
pub fn sentinel_credentials(
mut self,
username: impl Into<String>,
password: impl Into<String>,
) -> Self {
self.auth.sentinel_username = Some(username.into());
self.auth.sentinel_password = Some(password.into());
self
}
#[cfg(feature = "sentinel-auth")]
#[must_use]
pub fn sentinel_password(mut self, password: impl Into<String>) -> Self {
self.auth.sentinel_password = Some(password.into());
self
}
#[cfg(feature = "credential-provider")]
#[must_use]
pub fn credential_provider(mut self, provider: Arc<dyn CredentialProvider>) -> Self {
self.auth.credential_provider = Some(provider);
self
}
pub async fn connect(url: impl Into<String>) -> Result<Self, RedisError> {
let broker = Self::standalone(url);
Broker::connect(&broker).await?;
Ok(broker)
}
#[must_use]
pub fn from_pool(pool: Pool) -> Self {
Self {
pool: Arc::new(OnceCell::new_with(Some(pool))),
topology: Topology::Preconnected,
pool_size: DEFAULT_POOL_SIZE,
default_group: None,
auth: AuthConfig::default(),
}
}
fn build_config(&self) -> Result<Config, RedisError> {
let mut config = match &self.topology {
Topology::Standalone(url) => {
Config::from_url(url).map_err(|err| RedisError::Connect(Box::new(err)))?
}
Topology::Cluster(nodes) => {
let hosts = parse_servers(nodes, 6379)?;
Config {
server: ServerConfig::new_clustered(hosts),
..Config::default()
}
}
Topology::Sentinel { service, hosts } => {
let hosts = parse_servers(hosts, 26379)?;
Config {
server: ServerConfig::new_sentinel(hosts, service.clone()),
..Config::default()
}
}
Topology::Preconnected => return Err(RedisError::NotConnected),
};
self.apply_auth(&mut config);
Ok(config)
}
fn apply_auth(&self, config: &mut Config) {
if self.auth.username.is_some() {
config.username.clone_from(&self.auth.username);
}
if self.auth.password.is_some() {
config.password.clone_from(&self.auth.password);
}
#[cfg(any(
feature = "tls-rustls",
feature = "tls-rustls-ring",
feature = "tls-native-tls"
))]
if self.auth.tls.is_some() {
config.tls.clone_from(&self.auth.tls);
}
#[cfg(feature = "credential-provider")]
if self.auth.credential_provider.is_some() {
config
.credential_provider
.clone_from(&self.auth.credential_provider);
}
#[cfg(feature = "sentinel-auth")]
if let ServerConfig::Sentinel {
username, password, ..
} = &mut config.server
{
if self.auth.sentinel_username.is_some() {
username.clone_from(&self.auth.sentinel_username);
}
if self.auth.sentinel_password.is_some() {
password.clone_from(&self.auth.sentinel_password);
}
}
}
fn connected(&self) -> Result<Pool, RedisError> {
self.pool.get().cloned().ok_or(RedisError::NotConnected)
}
#[must_use]
pub fn pool_handle(&self) -> Pool {
self.pool
.get()
.cloned()
.expect("RedisBroker::pool_handle() called before connect()")
}
pub async fn subscribe(&self, def: RedisStream) -> Result<RedisSubscriber, RedisError> {
let pool = self.connected()?;
let group = def.group_or_err()?.to_owned();
let consumer = def.consumer_or_auto();
ensure_group(&pool, def.key(), &group, def.start().as_id()).await?;
Ok(RedisSubscriber::new(
pool,
def.key().to_owned(),
group,
consumer,
def.count_or_default(),
def.block_or_default(),
def.mode(),
def.poison_policy(),
def.delay_config(),
))
}
#[must_use]
pub fn publisher(&self) -> RedisPublisher {
RedisPublisher::new(Arc::clone(&self.pool), self.supports_transactions())
}
const fn supports_transactions(&self) -> bool {
!matches!(self.topology, Topology::Cluster(_))
}
async fn new_client(&self) -> Result<Client, RedisError> {
let config = self.build_config()?;
let client = Client::new(config, None, None, None);
client
.init()
.await
.map_err(|err| RedisError::Connect(Box::new(err)))?;
Ok(client)
}
pub async fn subscribe_pubsub(
&self,
def: RedisPubSub,
) -> Result<RedisPubSubSubscriber, RedisError> {
def.validate()?;
let codec = def.codec_handle();
let client = self.new_client().await?;
let channel = def.channel().to_owned();
let result = match (def.delivery_mode(), def.is_pattern()) {
(PubSubMode::Classic, true) => client.psubscribe(channel).await,
(PubSubMode::Classic, false) => client.subscribe(channel).await,
(PubSubMode::Sharded, _) => client.ssubscribe(channel).await,
};
result.map_err(RedisError::subscribe)?;
let rx = client.message_rx();
Ok(RedisPubSubSubscriber::new(client, rx, codec))
}
#[allow(
clippy::unused_async,
reason = "async for parity with the other subscribe methods and the SubscriptionSource shape"
)]
pub async fn subscribe_list(&self, def: RedisList) -> Result<RedisListSubscriber, RedisError> {
let pool = self.connected()?;
let recovery = def.recovery_config()?;
Ok(RedisListSubscriber::new(
pool,
def.key().to_owned(),
def.is_reliable(),
def.processing_or_default(),
def.block_or_default(),
def.codec_handle(),
def.poison_policy(),
recovery,
))
}
#[must_use]
pub fn pubsub_publisher(&self) -> RedisPubSubPublisher {
RedisPubSubPublisher::new(Arc::clone(&self.pool), PubSubMode::Classic)
}
#[must_use]
pub fn list_publisher(&self) -> RedisListPublisher {
RedisListPublisher::new(Arc::clone(&self.pool))
}
pub async fn shutdown_pool(&self) {
if let Some(pool) = self.pool.get() {
let _ = pool.quit().await;
}
}
}
async fn ensure_group(
pool: &Pool,
key: &str,
group: &str,
start_id: &str,
) -> Result<(), RedisError> {
let result: Result<String, fred::error::Error> =
pool.xgroup_create(key, group, start_id, true).await;
match result {
Ok(_) => Ok(()),
Err(err) if err.details().contains("BUSYGROUP") => Ok(()),
Err(err) => Err(RedisError::subscribe(err)),
}
}
impl Broker for RedisBroker {
type Error = RedisError;
async fn connect(&self) -> Result<(), Self::Error> {
self.pool
.get_or_try_init(|| async {
let config = self.build_config()?;
let pool = Pool::new(config, None, None, None, self.pool_size)
.map_err(|err| RedisError::Connect(Box::new(err)))?;
pool.init()
.await
.map_err(|err| RedisError::Connect(Box::new(err)))?;
Ok(pool)
})
.await?;
Ok(())
}
async fn shutdown(&self) -> Result<(), Self::Error> {
self.shutdown_pool().await;
Ok(())
}
}
#[allow(clippy::use_self)]
impl Subscribe for RedisBroker {
type Subscriber = RedisSubscriber;
async fn subscribe(&self, name: &str) -> Result<Self::Subscriber, Self::Error> {
let group = self.default_group.clone().ok_or_else(|| {
RedisError::InvalidOptions(format!(
"bare-string subscription on `{name}` needs a broker-wide default group: \
call RedisBroker::default_group(name), or subscribe with \
RedisStream::new(name).group(group)"
))
})?;
RedisBroker::subscribe(self, RedisStream::new(name).group(group)).await
}
}
impl DescribeServer for RedisBroker {
fn describe_server(&self) -> ServerSpec {
let host = match &self.topology {
Topology::Standalone(url) => url
.trim_start_matches("rediss://")
.trim_start_matches("redis://")
.to_owned(),
Topology::Cluster(nodes) => nodes.first().cloned().unwrap_or_default(),
Topology::Sentinel { hosts, .. } => hosts.first().cloned().unwrap_or_default(),
Topology::Preconnected => String::new(),
};
ServerSpec::new(host, "redis")
}
}
#[cfg(test)]
mod tests {
use ruststream::{OutgoingMessage, Publisher};
use super::*;
#[tokio::test]
async fn standalone_does_not_connect() {
let broker = RedisBroker::standalone("redis://127.0.0.1:6379");
let publish_err = broker
.publisher()
.publish(OutgoingMessage::new("orders", b"{}".as_slice()))
.await
.unwrap_err();
assert!(matches!(publish_err, RedisError::NotConnected));
let subscribe_err = broker
.subscribe(RedisStream::new("orders").group("g"))
.await
.unwrap_err();
assert!(matches!(subscribe_err, RedisError::NotConnected));
}
#[tokio::test]
async fn bare_string_subscription_needs_default_group() {
let broker = RedisBroker::standalone("redis://127.0.0.1:6379");
let err = Subscribe::subscribe(&broker, "orders").await.unwrap_err();
assert!(matches!(err, RedisError::InvalidOptions(msg) if msg.contains("default group")));
}
#[test]
fn describe_server_reports_redis() {
let broker = RedisBroker::standalone("redis://localhost:6379");
let spec = broker.describe_server();
assert_eq!(spec.protocol, "redis");
assert_eq!(spec.host, "localhost:6379");
}
#[test]
fn credentials_apply_to_all_topologies() {
let brokers = [
RedisBroker::standalone("redis://localhost:6379").credentials("alice", "s3cr3t"),
RedisBroker::cluster(["127.0.0.1:7000"]).credentials("alice", "s3cr3t"),
RedisBroker::sentinel("mymaster", ["127.0.0.1:26379"]).credentials("alice", "s3cr3t"),
];
for broker in brokers {
let config = broker.build_config().expect("config builds");
assert_eq!(config.username.as_deref(), Some("alice"));
assert_eq!(config.password.as_deref(), Some("s3cr3t"));
}
}
#[test]
fn password_only_sets_password_without_username() {
let config = RedisBroker::cluster(["127.0.0.1:7000"])
.password("requirepass")
.build_config()
.expect("config builds");
assert_eq!(config.username, None);
assert_eq!(config.password.as_deref(), Some("requirepass"));
}
#[test]
fn programmatic_credentials_override_standalone_url() {
let config = RedisBroker::standalone("redis://urluser:urlpass@localhost:6379")
.credentials("acluser", "aclpass")
.build_config()
.expect("config builds");
assert_eq!(config.username.as_deref(), Some("acluser"));
assert_eq!(config.password.as_deref(), Some("aclpass"));
}
#[test]
fn url_credentials_preserved_without_override() {
let config = RedisBroker::standalone("redis://urluser:urlpass@localhost:6379")
.build_config()
.expect("config builds");
assert_eq!(config.username.as_deref(), Some("urluser"));
assert_eq!(config.password.as_deref(), Some("urlpass"));
}
#[test]
fn debug_redacts_password() {
let broker =
RedisBroker::standalone("redis://localhost:6379").credentials("alice", "s3cr3t");
let rendered = format!("{broker:?}");
assert!(
!rendered.contains("s3cr3t"),
"password must not appear in Debug output: {rendered}"
);
assert!(
rendered.contains("alice"),
"expected username in: {rendered}"
);
}
#[cfg(feature = "sentinel-auth")]
#[test]
fn sentinel_credentials_apply_to_sentinel_server() {
let config = RedisBroker::sentinel("mymaster", ["127.0.0.1:26379"])
.credentials("datauser", "datapass")
.sentinel_credentials("sentineluser", "sentinelpass")
.build_config()
.expect("config builds");
assert_eq!(config.username.as_deref(), Some("datauser"));
let ServerConfig::Sentinel {
username, password, ..
} = &config.server
else {
panic!("expected a sentinel server config");
};
assert_eq!(username.as_deref(), Some("sentineluser"));
assert_eq!(password.as_deref(), Some("sentinelpass"));
}
#[cfg(feature = "credential-provider")]
#[derive(Debug)]
struct StaticCredentials;
#[cfg(feature = "credential-provider")]
#[async_trait::async_trait]
impl CredentialProvider for StaticCredentials {
async fn fetch(
&self,
_server: Option<&fred::types::config::Server>,
) -> Result<(Option<String>, Option<String>), fred::error::Error> {
Ok((Some("rotating".into()), Some("token".into())))
}
}
#[cfg(feature = "credential-provider")]
#[test]
fn credential_provider_is_applied() {
let provider: Arc<dyn CredentialProvider> = Arc::new(StaticCredentials);
let config = RedisBroker::cluster(["127.0.0.1:7000"])
.credential_provider(provider)
.build_config()
.expect("config builds");
assert!(config.credential_provider.is_some());
}
}