use super::{PubSubBackend, Subscription};
use crate::error::PubSubError;
use async_trait::async_trait;
use redis::Client;
use tokio::sync::oneshot;
use tokio_stream::StreamExt;
const BASE_DELAY_MS: u64 = 100;
const MAX_DELAY_MS: u64 = 5_000;
const MAX_EXPONENT: u32 = 6;
const JITTER_RANGE: u64 = 50;
#[derive(Clone, Debug)]
pub struct BackoffConfig {
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub max_exponent: u32,
pub jitter_range: u64,
pub max_retries: u32,
}
impl Default for BackoffConfig {
fn default() -> Self {
Self {
base_delay_ms: BASE_DELAY_MS,
max_delay_ms: MAX_DELAY_MS,
max_exponent: MAX_EXPONENT,
jitter_range: JITTER_RANGE,
max_retries: 10,
}
}
}
#[derive(Clone)]
pub struct RedisPubSub {
client: Client,
backoff_config: BackoffConfig,
}
impl RedisPubSub {
pub fn new(addr: &str) -> Result<Self, PubSubError> {
Ok(Self {
client: Client::open(addr)?,
backoff_config: BackoffConfig::default(),
})
}
pub fn with_config(addr: &str, config: BackoffConfig) -> Result<Self, PubSubError> {
Ok(Self {
client: Client::open(addr)?,
backoff_config: config,
})
}
}
#[async_trait]
impl PubSubBackend for RedisPubSub {
async fn publish_bytes(&self, topic: &str, payload: Vec<u8>) -> Result<(), PubSubError> {
let mut conn: redis::aio::MultiplexedConnection =
self.client.get_multiplexed_tokio_connection().await?;
redis::AsyncCommands::publish::<&str, Vec<u8>, ()>(&mut conn, topic, payload).await?;
Ok(())
}
async fn subscribe_bytes(
&self,
topic: &str,
handler: Box<dyn Fn(Vec<u8>) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> + Send>,
) -> Result<Subscription, PubSubError> {
let client = self.client.clone();
let topic = topic.to_string();
let cfg = self.backoff_config.clone();
let (stop_tx, mut stop_rx) = oneshot::channel::<()>();
let handle = tokio::spawn(async move {
let mut attempt: u32 = 0;
loop {
if stop_rx.try_recv().is_ok() {
tracing::debug!("Redis subscription for '{topic}' stopped gracefully");
break;
}
match client.get_multiplexed_tokio_connection().await {
Ok(mut conn) => {
let ping_result: redis::RedisResult<String> =
redis::AsyncCommands::ping(&mut conn).await;
if ping_result.is_err() {
tracing::error!("Redis health check failed: {:?}", ping_result.err());
attempt += 1;
} else {
match client.get_async_pubsub().await {
Ok(mut pubsub) => {
if let Err(e) = pubsub.subscribe(&topic).await {
tracing::error!("Redis subscribe failed for '{topic}': {e}");
attempt += 1;
} else {
attempt = 0; tracing::info!("Successfully subscribed to Redis topic '{topic}'");
let mut stream = pubsub.on_message();
let should_stop = loop {
tokio::select! {
msg_opt = stream.next() => {
match msg_opt {
Some(msg) => {
match msg.get_payload::<Vec<u8>>() {
Ok(payload) => handler(payload).await,
Err(e) => tracing::error!("Redis payload decode error on '{topic}': {e}"),
}
}
None => {
tracing::warn!("Redis pubsub stream ended for '{topic}', will reconnect…");
break false;
}
}
}
_ = &mut stop_rx => {
tracing::debug!("Gracefully stopping Redis subscription for '{topic}'");
break true;
}
}
};
drop(stream);
if should_stop {
let _ = pubsub.unsubscribe(&topic).await;
return;
}
}
}
Err(e) => {
tracing::error!("Redis pubsub connection error: {e}. Retrying…");
attempt += 1;
}
}
}
}
Err(e) => {
tracing::error!("Redis connection error: {e}. Retrying…");
attempt += 1;
}
}
if attempt >= cfg.max_retries {
tracing::error!(
"Redis pubsub: exceeded max retries ({}) for topic '{}', giving up.",
cfg.max_retries,
topic
);
break;
}
let exp = attempt.min(cfg.max_exponent);
let base_ms = cfg.base_delay_ms.saturating_mul(2u64.saturating_pow(exp));
let jitter_ms = (attempt as u64 % cfg.jitter_range) + 1;
let delay = std::time::Duration::from_millis((base_ms + jitter_ms).min(cfg.max_delay_ms));
tokio::time::sleep(delay).await;
}
});
Ok(Subscription::new(handle, stop_tx))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pubsub::PubSubExt;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)]
struct TestMsg {
val: String,
}
#[tokio::test]
#[ignore]
async fn redis_pubsub_works() {
let backend = RedisPubSub::new("redis://127.0.0.1/").unwrap();
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
backend
.subscribe("topic", move |data| {
let tx = tx.clone();
async move {
let msg: TestMsg = serde_json::from_slice(&data).unwrap();
tx.send(msg).await.ok();
}
})
.await
.unwrap();
backend
.publish(
"topic",
&TestMsg {
val: "hello".into(),
},
)
.await
.unwrap();
let received = tokio::time::timeout(Duration::from_secs(1), rx.recv())
.await
.unwrap()
.unwrap();
assert_eq!(
received,
TestMsg {
val: "hello".into()
}
);
}
}