use std::sync::Arc;
use fred::clients::{Client, Pool};
use fred::interfaces::{ClientLike, EventInterface, PubsubInterface, StreamsInterface};
use fred::types::config::{Config, ServerConfig};
use ruststream::{Broker, DescribeServer, ServerSpec, Subscribe};
use tokio::sync::OnceCell;
use crate::{
error::RedisError,
list::{RedisList, RedisListPublisher, RedisListSubscriber},
publisher::RedisPublisher,
pubsub::{PubSubMode, RedisPubSub, RedisPubSubPublisher, RedisPubSubSubscriber},
stream::RedisStream,
subscriber::RedisSubscriber,
};
const DEFAULT_POOL_SIZE: usize = 4;
#[derive(Debug, Clone)]
enum Topology {
Standalone(String),
Cluster(Vec<String>),
Sentinel { service: String, hosts: Vec<String> },
Preconnected,
}
fn parse_server(addr: &str, default_port: u16) -> Result<(String, u16), RedisError> {
let trimmed = addr
.trim()
.trim_start_matches("rediss://")
.trim_start_matches("redis://");
let (host, port) = match trimmed.rsplit_once(':') {
Some((host, port)) => {
let port = port.parse::<u16>().map_err(|_| {
RedisError::Connect(format!("invalid port in redis address `{addr}`").into())
})?;
(host, port)
}
None => (trimmed, default_port),
};
if host.is_empty() {
return Err(RedisError::Connect(
format!("missing host in redis address `{addr}`").into(),
));
}
Ok((host.to_owned(), port))
}
fn parse_servers(addrs: &[String], default_port: u16) -> Result<Vec<(String, u16)>, RedisError> {
if addrs.is_empty() {
return Err(RedisError::Connect("no redis addresses provided".into()));
}
addrs
.iter()
.map(|addr| parse_server(addr, default_port))
.collect()
}
#[derive(Clone)]
pub struct RedisBroker {
pool: Arc<OnceCell<Pool>>,
topology: Topology,
pool_size: usize,
default_group: Option<String>,
}
impl std::fmt::Debug for RedisBroker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisBroker")
.field("topology", &self.topology)
.field("pool_size", &self.pool_size)
.field("default_group", &self.default_group)
.finish_non_exhaustive()
}
}
impl RedisBroker {
#[must_use]
pub fn standalone(url: impl Into<String>) -> Self {
Self::with_topology(Topology::Standalone(url.into()))
}
#[must_use]
pub fn cluster(nodes: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::with_topology(Topology::Cluster(
nodes.into_iter().map(Into::into).collect(),
))
}
#[must_use]
pub fn sentinel(
service: impl Into<String>,
sentinels: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
Self::with_topology(Topology::Sentinel {
service: service.into(),
hosts: sentinels.into_iter().map(Into::into).collect(),
})
}
fn with_topology(topology: Topology) -> Self {
Self {
pool: Arc::new(OnceCell::new()),
topology,
pool_size: DEFAULT_POOL_SIZE,
default_group: None,
}
}
#[must_use]
pub const fn pool(mut self, size: usize) -> Self {
self.pool_size = size;
self
}
#[must_use]
pub fn default_group(mut self, group: impl Into<String>) -> Self {
self.default_group = Some(group.into());
self
}
pub async fn connect(url: impl Into<String>) -> Result<Self, RedisError> {
let broker = Self::standalone(url);
Broker::connect(&broker).await?;
Ok(broker)
}
#[must_use]
pub fn from_pool(pool: Pool) -> Self {
Self {
pool: Arc::new(OnceCell::new_with(Some(pool))),
topology: Topology::Preconnected,
pool_size: DEFAULT_POOL_SIZE,
default_group: None,
}
}
fn build_config(&self) -> Result<Config, RedisError> {
match &self.topology {
Topology::Standalone(url) => {
Config::from_url(url).map_err(|err| RedisError::Connect(Box::new(err)))
}
Topology::Cluster(nodes) => {
let hosts = parse_servers(nodes, 6379)?;
Ok(Config {
server: ServerConfig::new_clustered(hosts),
..Config::default()
})
}
Topology::Sentinel { service, hosts } => {
let hosts = parse_servers(hosts, 26379)?;
Ok(Config {
server: ServerConfig::new_sentinel(hosts, service.clone()),
..Config::default()
})
}
Topology::Preconnected => Err(RedisError::NotConnected),
}
}
fn connected(&self) -> Result<Pool, RedisError> {
self.pool.get().cloned().ok_or(RedisError::NotConnected)
}
#[must_use]
pub fn pool_handle(&self) -> Pool {
self.pool
.get()
.cloned()
.expect("RedisBroker::pool_handle() called before connect()")
}
pub async fn subscribe(&self, def: RedisStream) -> Result<RedisSubscriber, RedisError> {
let pool = self.connected()?;
let group = def.group_or_err()?.to_owned();
let consumer = def.consumer_or_auto();
ensure_group(&pool, def.key(), &group, def.start().as_id()).await?;
Ok(RedisSubscriber::new(
pool,
def.key().to_owned(),
group,
consumer,
def.count_or_default(),
def.block_or_default(),
def.mode(),
))
}
#[must_use]
pub fn publisher(&self) -> RedisPublisher {
RedisPublisher::new(Arc::clone(&self.pool), self.supports_transactions())
}
const fn supports_transactions(&self) -> bool {
!matches!(self.topology, Topology::Cluster(_))
}
async fn new_client(&self) -> Result<Client, RedisError> {
let config = self.build_config()?;
let client = Client::new(config, None, None, None);
client
.init()
.await
.map_err(|err| RedisError::Connect(Box::new(err)))?;
Ok(client)
}
pub async fn subscribe_pubsub(
&self,
def: RedisPubSub,
) -> Result<RedisPubSubSubscriber, RedisError> {
def.validate()?;
let codec = def.codec_handle();
let client = self.new_client().await?;
let channel = def.channel().to_owned();
let result = match (def.delivery_mode(), def.is_pattern()) {
(PubSubMode::Classic, true) => client.psubscribe(channel).await,
(PubSubMode::Classic, false) => client.subscribe(channel).await,
(PubSubMode::Sharded, _) => client.ssubscribe(channel).await,
};
result.map_err(RedisError::subscribe)?;
let rx = client.message_rx();
Ok(RedisPubSubSubscriber::new(client, rx, codec))
}
#[allow(
clippy::unused_async,
reason = "async for parity with the other subscribe methods and the SubscriptionSource shape"
)]
pub async fn subscribe_list(&self, def: RedisList) -> Result<RedisListSubscriber, RedisError> {
let pool = self.connected()?;
Ok(RedisListSubscriber::new(
pool,
def.key().to_owned(),
def.is_reliable(),
def.processing_or_default(),
def.block_or_default(),
def.codec_handle(),
))
}
#[must_use]
pub fn pubsub_publisher(&self) -> RedisPubSubPublisher {
RedisPubSubPublisher::new(Arc::clone(&self.pool), PubSubMode::Classic)
}
#[must_use]
pub fn list_publisher(&self) -> RedisListPublisher {
RedisListPublisher::new(Arc::clone(&self.pool))
}
pub async fn shutdown_pool(&self) {
if let Some(pool) = self.pool.get() {
let _ = pool.quit().await;
}
}
}
async fn ensure_group(
pool: &Pool,
key: &str,
group: &str,
start_id: &str,
) -> Result<(), RedisError> {
let result: Result<String, fred::error::Error> =
pool.xgroup_create(key, group, start_id, true).await;
match result {
Ok(_) => Ok(()),
Err(err) if err.details().contains("BUSYGROUP") => Ok(()),
Err(err) => Err(RedisError::subscribe(err)),
}
}
impl Broker for RedisBroker {
type Error = RedisError;
async fn connect(&self) -> Result<(), Self::Error> {
self.pool
.get_or_try_init(|| async {
let config = self.build_config()?;
let pool = Pool::new(config, None, None, None, self.pool_size)
.map_err(|err| RedisError::Connect(Box::new(err)))?;
pool.init()
.await
.map_err(|err| RedisError::Connect(Box::new(err)))?;
Ok(pool)
})
.await?;
Ok(())
}
async fn shutdown(&self) -> Result<(), Self::Error> {
self.shutdown_pool().await;
Ok(())
}
}
#[allow(clippy::use_self)]
impl Subscribe for RedisBroker {
type Subscriber = RedisSubscriber;
async fn subscribe(&self, name: &str) -> Result<Self::Subscriber, Self::Error> {
let group = self.default_group.clone().ok_or_else(|| {
RedisError::InvalidOptions(format!(
"bare-string subscription on `{name}` needs a broker-wide default group: \
call RedisBroker::default_group(name), or subscribe with \
RedisStream::new(name).group(group)"
))
})?;
RedisBroker::subscribe(self, RedisStream::new(name).group(group)).await
}
}
impl DescribeServer for RedisBroker {
fn describe_server(&self) -> ServerSpec {
let host = match &self.topology {
Topology::Standalone(url) => url
.trim_start_matches("rediss://")
.trim_start_matches("redis://")
.to_owned(),
Topology::Cluster(nodes) => nodes.first().cloned().unwrap_or_default(),
Topology::Sentinel { hosts, .. } => hosts.first().cloned().unwrap_or_default(),
Topology::Preconnected => String::new(),
};
ServerSpec::new(host, "redis")
}
}
#[cfg(test)]
mod tests {
use ruststream::{OutgoingMessage, Publisher};
use super::*;
#[tokio::test]
async fn standalone_does_not_connect() {
let broker = RedisBroker::standalone("redis://127.0.0.1:6379");
let publish_err = broker
.publisher()
.publish(OutgoingMessage::new("orders", b"{}".as_slice()))
.await
.unwrap_err();
assert!(matches!(publish_err, RedisError::NotConnected));
let subscribe_err = broker
.subscribe(RedisStream::new("orders").group("g"))
.await
.unwrap_err();
assert!(matches!(subscribe_err, RedisError::NotConnected));
}
#[tokio::test]
async fn bare_string_subscription_needs_default_group() {
let broker = RedisBroker::standalone("redis://127.0.0.1:6379");
let err = Subscribe::subscribe(&broker, "orders").await.unwrap_err();
assert!(matches!(err, RedisError::InvalidOptions(msg) if msg.contains("default group")));
}
#[test]
fn describe_server_reports_redis() {
let broker = RedisBroker::standalone("redis://localhost:6379");
let spec = broker.describe_server();
assert_eq!(spec.protocol, "redis");
assert_eq!(spec.host, "localhost:6379");
}
}