use anyhow::Context as _;
use redis::aio::ConnectionManager;
use redis::{AsyncCommands as _, Client};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::notifier::{
ActivityNotifier, EventStream, InMemoryNotifier, NotifierError, NotifyEvent,
};
const DEFAULT_CHANNEL: &str = "greentic:webchat:notify";
const BOOT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Wire {
pub tenant_id: String,
pub conversation_id: String,
pub new_watermark: u64,
pub version: u8,
pub instance_id: Uuid,
}
enum SubState {
Connected,
Reconnecting { attempt: u32 },
}
pub struct RedisNotifier {
inner: Arc<InMemoryNotifier>,
self_id: Uuid,
channel: String,
pub_conn: ConnectionManager,
#[allow(dead_code)]
sub_state: Arc<RwLock<SubState>>,
_sub_task: tokio::task::JoinHandle<()>,
}
#[async_trait::async_trait]
impl ActivityNotifier for RedisNotifier {
async fn publish(&self, event: NotifyEvent) {
self.inner.publish(event.clone()).await;
let payload = match serde_json::to_vec(&Wire {
tenant_id: event.tenant_id,
conversation_id: event.conversation_id,
new_watermark: event.new_watermark,
version: 1,
instance_id: self.self_id,
}) {
Ok(p) => p,
Err(err) => {
tracing::warn!(target: "notifier_redis", ?err, "redis_encode_err");
return;
}
};
let mut pub_conn = self.pub_conn.clone();
let channel = self.channel.clone();
tokio::spawn(async move {
if let Err(err) = pub_conn.publish::<_, _, ()>(&channel, payload).await {
tracing::debug!(target: "notifier_redis", ?err, "redis_publish_dropped");
}
});
}
async fn subscribe(
&self,
tenant_id: &str,
conversation_id: &str,
) -> Result<EventStream, NotifierError> {
self.inner.subscribe(tenant_id, conversation_id).await
}
}
impl RedisNotifier {
pub async fn build(
url: &str,
channel: Option<String>,
capacity: usize,
) -> anyhow::Result<Arc<Self>> {
let channel = channel.unwrap_or_else(|| DEFAULT_CHANNEL.to_string());
let inner = Arc::new(InMemoryNotifier::new(capacity));
let self_id = Uuid::new_v4();
let client = Client::open(url).with_context(|| format!("invalid redis url: {url}"))?;
let pub_conn =
tokio::time::timeout(BOOT_CONNECT_TIMEOUT, ConnectionManager::new(client.clone()))
.await
.with_context(|| format!("timed out opening redis PUB connection to {url}"))?
.with_context(|| format!("failed to open redis PUB connection to {url}"))?;
{
let probe =
tokio::time::timeout(BOOT_CONNECT_TIMEOUT, subscribe_once(&client, &channel))
.await
.with_context(|| format!("timed out opening redis SUB connection to {url}"))?
.with_context(|| format!("failed to open redis SUB connection to {url}"))?;
drop(probe);
}
let sub_state = Arc::new(RwLock::new(SubState::Connected));
let notifier = Arc::new_cyclic(|weak: &std::sync::Weak<Self>| {
let weak_clone = weak.clone();
let inner_clone = inner.clone();
let channel_clone = channel.clone();
let sub_state_clone = sub_state.clone();
let client_clone = client.clone();
let self_id_copy = self_id;
let task = tokio::spawn(async move {
loop {
let inv = std::panic::AssertUnwindSafe(background_sub_loop(
weak_clone.clone(),
inner_clone.clone(),
self_id_copy,
client_clone.clone(),
channel_clone.clone(),
sub_state_clone.clone(),
));
match futures_util::FutureExt::catch_unwind(inv).await {
Ok(()) => return, Err(_panic) => {
tracing::error!(
target: "notifier_redis",
"background loop panicked; restarting after 500ms"
);
*sub_state_clone.write().await = SubState::Reconnecting { attempt: 0 };
tokio::time::sleep(Duration::from_millis(500)).await;
if weak_clone.upgrade().is_none() {
return;
}
}
}
}
});
Self {
inner,
self_id,
channel,
pub_conn,
sub_state,
_sub_task: task,
}
});
Ok(notifier)
}
}
async fn subscribe_once(client: &Client, channel: &str) -> anyhow::Result<redis::aio::PubSub> {
let mut pubsub = client.get_async_pubsub().await?;
pubsub.subscribe(channel).await?;
Ok(pubsub)
}
async fn background_sub_loop(
notifier_weak: std::sync::Weak<RedisNotifier>,
inner: Arc<InMemoryNotifier>,
self_id: Uuid,
client: Client,
channel: String,
sub_state: Arc<RwLock<SubState>>,
) {
use futures_util::StreamExt as _;
loop {
let mut sub = loop {
if notifier_weak.upgrade().is_none() {
return; }
match subscribe_once(&client, &channel).await {
Ok(s) => {
*sub_state.write().await = SubState::Connected;
tracing::info!(target: "notifier_redis", "redis_reconnect_ok");
break s;
}
Err(err) => {
let attempt = match *sub_state.read().await {
SubState::Reconnecting { attempt } => attempt,
SubState::Connected => 0,
};
tracing::debug!(
target: "notifier_redis",
?err,
attempt,
"redis_reconnect_fail"
);
*sub_state.write().await = SubState::Reconnecting {
attempt: attempt + 1,
};
tokio::time::sleep(backoff_with_jitter(attempt)).await;
}
}
};
while let Some(msg) = sub.on_message().next().await {
let payload: Vec<u8> = msg.get_payload().unwrap_or_default();
process_incoming(&payload, self_id, inner.as_ref()).await;
}
*sub_state.write().await = SubState::Reconnecting { attempt: 0 };
tracing::warn!(target: "notifier_redis", "redis_disconnected");
}
}
fn backoff_with_jitter(attempt: u32) -> Duration {
use rand::RngExt as _;
let base_ms: u64 = match attempt {
0 => 100,
1 => 250,
2 => 500,
3 => 1_000,
4 => 2_000,
_ => 5_000,
};
let jitter: f64 = rand::rng().random_range(-20i32..=20i32) as f64 / 100.0;
let ms = (base_ms as f64) * (1.0 + jitter);
Duration::from_millis(ms.max(1.0) as u64)
}
pub(crate) async fn process_incoming(payload: &[u8], self_id: Uuid, inner: &dyn ActivityNotifier) {
let wire: Wire = match serde_json::from_slice(payload) {
Ok(w) => w,
Err(err) => {
tracing::debug!(target: "notifier_redis", ?err, "redis_decode_err");
return;
}
};
if wire.instance_id == self_id {
return; }
if wire.version != 1 {
tracing::warn!(
target: "notifier_redis",
version = wire.version,
"redis_unknown_version"
);
return;
}
inner
.publish(NotifyEvent {
tenant_id: wire.tenant_id,
conversation_id: wire.conversation_id,
new_watermark: wire.new_watermark,
})
.await;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::notifier::{ActivityNotifier, EventStream, NotifierError, NotifyEvent};
use async_trait::async_trait;
use std::sync::Mutex;
#[test]
fn wire_payload_roundtrip() {
let original = Wire {
tenant_id: "tenant-a".into(),
conversation_id: "conv-1".into(),
new_watermark: 42,
version: 1,
instance_id: Uuid::new_v4(),
};
let bytes = serde_json::to_vec(&original).expect("encode");
let decoded: Wire = serde_json::from_slice(&bytes).expect("decode");
assert_eq!(original, decoded);
}
struct RecordingNotifier {
published: Mutex<Vec<NotifyEvent>>,
}
impl RecordingNotifier {
fn new() -> Self {
Self {
published: Mutex::new(vec![]),
}
}
fn count(&self) -> usize {
self.published.lock().unwrap().len()
}
}
#[async_trait]
impl ActivityNotifier for RecordingNotifier {
async fn publish(&self, event: NotifyEvent) {
self.published.lock().unwrap().push(event);
}
async fn subscribe(
&self,
_tenant: &str,
_conv: &str,
) -> Result<EventStream, NotifierError> {
unreachable!("not used in dispatch tests")
}
}
fn make_payload(instance_id: Uuid, version: u8) -> Vec<u8> {
serde_json::to_vec(&Wire {
tenant_id: "t".into(),
conversation_id: "c".into(),
new_watermark: 7,
version,
instance_id,
})
.unwrap()
}
#[tokio::test]
async fn loop_suppression_drops_self_publish() {
let inner = RecordingNotifier::new();
let self_id = Uuid::new_v4();
let payload = make_payload(self_id, 1);
process_incoming(&payload, self_id, &inner).await;
assert_eq!(inner.count(), 0, "self-echo must be dropped");
}
#[tokio::test]
async fn loop_suppression_accepts_other_replica() {
let inner = RecordingNotifier::new();
let self_id = Uuid::new_v4();
let other = Uuid::new_v4();
let payload = make_payload(other, 1);
process_incoming(&payload, self_id, &inner).await;
assert_eq!(inner.count(), 1);
}
#[tokio::test]
async fn dispatch_drops_unknown_version() {
let inner = RecordingNotifier::new();
let self_id = Uuid::new_v4();
let other = Uuid::new_v4();
let payload = make_payload(other, 99);
process_incoming(&payload, self_id, &inner).await;
assert_eq!(inner.count(), 0);
}
#[tokio::test]
async fn dispatch_drops_malformed_payload() {
let inner = RecordingNotifier::new();
let self_id = Uuid::new_v4();
process_incoming(b"not-json{{", self_id, &inner).await;
assert_eq!(inner.count(), 0);
}
}