use std::fmt::{Debug, Formatter};
use std::sync::Arc;
use bytes::Bytes;
use fred::clients::Client;
use fred::interfaces::{ClientLike, PubsubInterface};
use fred::types::Message;
use futures::Stream;
use futures::stream::unfold;
use ruststream::codec::Codec;
use ruststream::{
AckError, Headers, IncomingMessage, OutgoingMessage, Partitioned, Publisher, SubscriptionSource,
};
use tokio::sync::OnceCell;
use tokio::sync::broadcast::{Receiver, error::RecvError};
use crate::envelope::{SharedEnvelope, frame, unframe};
use crate::{RedisBroker, error::RedisError, message::PARTITION_KEY_HEADER};
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub enum PubSubMode {
#[default]
Classic,
Sharded,
}
#[derive(Clone)]
#[must_use]
pub struct RedisPubSub {
channel: String,
mode: PubSubMode,
pattern: bool,
codec: Option<SharedEnvelope>,
}
impl Debug for RedisPubSub {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisPubSub")
.field("channel", &self.channel)
.field("mode", &self.mode)
.field("pattern", &self.pattern)
.field("codec", &self.codec.is_some())
.finish()
}
}
impl RedisPubSub {
pub fn new(channel: impl Into<String>) -> Self {
Self {
channel: channel.into(),
mode: PubSubMode::default(),
pattern: false,
codec: None,
}
}
pub const fn mode(mut self, mode: PubSubMode) -> Self {
self.mode = mode;
self
}
pub const fn pattern(mut self) -> Self {
self.pattern = true;
self
}
pub fn codec(mut self, codec: impl Codec + 'static) -> Self {
self.codec = Some(Arc::new(codec));
self
}
#[must_use]
pub fn channel(&self) -> &str {
&self.channel
}
pub(crate) const fn delivery_mode(&self) -> PubSubMode {
self.mode
}
pub(crate) const fn is_pattern(&self) -> bool {
self.pattern
}
pub(crate) fn codec_handle(&self) -> Option<SharedEnvelope> {
self.codec.clone()
}
pub(crate) fn validate(&self) -> Result<(), RedisError> {
if self.pattern && matches!(self.mode, PubSubMode::Sharded) {
return Err(RedisError::InvalidOptions(
"pattern subscriptions are classic-only; sharded pub/sub has no PSUBSCRIBE"
.to_owned(),
));
}
Ok(())
}
}
impl SubscriptionSource<RedisBroker> for RedisPubSub {
type Subscriber = RedisPubSubSubscriber;
fn name(&self) -> &str {
self.channel()
}
async fn subscribe(self, broker: &RedisBroker) -> Result<Self::Subscriber, RedisError> {
broker.subscribe_pubsub(self).await
}
}
pub struct RedisPubSubSubscriber {
client: Client,
rx: Receiver<Message>,
codec: Option<SharedEnvelope>,
}
impl Debug for RedisPubSubSubscriber {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisPubSubSubscriber")
.finish_non_exhaustive()
}
}
impl RedisPubSubSubscriber {
pub(crate) fn new(
client: Client,
rx: Receiver<Message>,
codec: Option<SharedEnvelope>,
) -> Self {
Self { client, rx, codec }
}
}
impl Drop for RedisPubSubSubscriber {
fn drop(&mut self) {
let client = self.client.clone();
tokio::spawn(async move {
let _ = client.quit().await;
});
}
}
fn to_message(msg: &Message, codec: Option<&SharedEnvelope>) -> RedisPubSubMessage {
let raw = msg.value.as_bytes().unwrap_or(&[]);
let (payload, headers) = unframe(codec, raw);
RedisPubSubMessage {
channel: msg.channel.to_string(),
payload,
headers,
}
}
impl ruststream::Subscriber for RedisPubSubSubscriber {
type Message = RedisPubSubMessage;
type Error = RedisError;
fn stream(&mut self) -> impl Stream<Item = Result<Self::Message, Self::Error>> + Send + '_ {
let codec = self.codec.clone();
unfold((&mut self.rx, codec), |(rx, codec)| async move {
loop {
match rx.recv().await {
Ok(msg) => {
let message = to_message(&msg, codec.as_ref());
return Some((Ok(message), (rx, codec)));
}
Err(RecvError::Lagged(_)) => {}
Err(RecvError::Closed) => return None,
}
}
})
}
}
pub struct RedisPubSubMessage {
channel: String,
payload: Bytes,
headers: Headers,
}
impl Debug for RedisPubSubMessage {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisPubSubMessage")
.field("channel", &self.channel)
.field("payload_len", &self.payload.len())
.finish_non_exhaustive()
}
}
impl RedisPubSubMessage {
#[must_use]
pub fn channel(&self) -> &str {
&self.channel
}
}
impl IncomingMessage for RedisPubSubMessage {
fn payload(&self) -> &[u8] {
&self.payload
}
fn headers(&self) -> &Headers {
&self.headers
}
async fn ack(self) -> Result<(), AckError> {
Err(AckError::Unsupported)
}
async fn nack(self, _requeue: bool) -> Result<(), AckError> {
Err(AckError::Unsupported)
}
}
impl Partitioned for RedisPubSubMessage {
fn partition_key(&self) -> Option<&[u8]> {
self.headers().get(PARTITION_KEY_HEADER)
}
}
#[derive(Clone)]
pub struct RedisPubSubPublisher {
pool: Arc<OnceCell<fred::clients::Pool>>,
mode: PubSubMode,
codec: Option<SharedEnvelope>,
}
impl Debug for RedisPubSubPublisher {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisPubSubPublisher")
.field("mode", &self.mode)
.field("codec", &self.codec.is_some())
.finish_non_exhaustive()
}
}
impl RedisPubSubPublisher {
pub(crate) fn new(pool: Arc<OnceCell<fred::clients::Pool>>, mode: PubSubMode) -> Self {
Self {
pool,
mode,
codec: None,
}
}
#[must_use]
pub const fn mode(mut self, mode: PubSubMode) -> Self {
self.mode = mode;
self
}
#[must_use]
pub fn codec(mut self, codec: impl Codec + 'static) -> Self {
self.codec = Some(Arc::new(codec));
self
}
}
impl Publisher for RedisPubSubPublisher {
type Error = RedisError;
async fn publish(&self, msg: OutgoingMessage<'_>) -> Result<(), Self::Error> {
let pool = self.pool.get().cloned().ok_or(RedisError::NotConnected)?;
let client = pool.next();
let channel = msg.name().to_owned();
let body = frame(self.codec.as_ref(), msg.payload(), msg.headers());
let _: i64 = match self.mode {
PubSubMode::Classic => client.publish(channel, body).await,
PubSubMode::Sharded => client.spublish(channel, body).await,
}
.map_err(RedisError::publish)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pattern_with_sharded_is_rejected() {
let err = RedisPubSub::new("e.*")
.mode(PubSubMode::Sharded)
.pattern()
.validate()
.unwrap_err();
assert!(matches!(err, RedisError::InvalidOptions(msg) if msg.contains("classic-only")));
}
#[test]
fn classic_pattern_validates() {
RedisPubSub::new("e.*").pattern().validate().expect("ok");
}
}