1use anyhow::Context as _;
6use redis::aio::ConnectionManager;
7use redis::{AsyncCommands as _, Client};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14use crate::notifier::{
15 ActivityNotifier, EventStream, InMemoryNotifier, NotifierError, NotifyEvent,
16};
17
18const DEFAULT_CHANNEL: &str = "greentic:webchat:notify";
19const BOOT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
20
21#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
26pub struct Wire {
27 pub tenant_id: String,
28 pub conversation_id: String,
29 pub new_watermark: u64,
30 pub version: u8,
31 pub instance_id: Uuid,
32}
33
34enum SubState {
35 Connected,
36 Reconnecting { attempt: u32 },
37}
38
39pub struct RedisNotifier {
46 inner: Arc<InMemoryNotifier>,
47 self_id: Uuid,
48 channel: String,
49 pub_conn: ConnectionManager,
50 #[allow(dead_code)]
51 sub_state: Arc<RwLock<SubState>>,
52 _sub_task: tokio::task::JoinHandle<()>,
53}
54
55#[async_trait::async_trait]
56impl ActivityNotifier for RedisNotifier {
57 async fn publish(&self, event: NotifyEvent) {
58 self.inner.publish(event.clone()).await;
60
61 let payload = match serde_json::to_vec(&Wire {
63 tenant_id: event.tenant_id,
64 conversation_id: event.conversation_id,
65 new_watermark: event.new_watermark,
66 version: 1,
67 instance_id: self.self_id,
68 }) {
69 Ok(p) => p,
70 Err(err) => {
71 tracing::warn!(target: "notifier_redis", ?err, "redis_encode_err");
72 return;
73 }
74 };
75 let mut pub_conn = self.pub_conn.clone();
76 let channel = self.channel.clone();
77 tokio::spawn(async move {
78 if let Err(err) = pub_conn.publish::<_, _, ()>(&channel, payload).await {
79 tracing::debug!(target: "notifier_redis", ?err, "redis_publish_dropped");
80 }
81 });
82 }
83
84 async fn subscribe(
85 &self,
86 tenant_id: &str,
87 conversation_id: &str,
88 ) -> Result<EventStream, NotifierError> {
89 self.inner.subscribe(tenant_id, conversation_id).await
91 }
92}
93
94impl RedisNotifier {
95 pub async fn build(
101 url: &str,
102 channel: Option<String>,
103 capacity: usize,
104 ) -> anyhow::Result<Arc<Self>> {
105 let channel = channel.unwrap_or_else(|| DEFAULT_CHANNEL.to_string());
106 let inner = Arc::new(InMemoryNotifier::new(capacity));
107 let self_id = Uuid::new_v4();
108
109 let client = Client::open(url).with_context(|| format!("invalid redis url: {url}"))?;
110
111 let pub_conn =
113 tokio::time::timeout(BOOT_CONNECT_TIMEOUT, ConnectionManager::new(client.clone()))
114 .await
115 .with_context(|| format!("timed out opening redis PUB connection to {url}"))?
116 .with_context(|| format!("failed to open redis PUB connection to {url}"))?;
117
118 {
121 let probe =
122 tokio::time::timeout(BOOT_CONNECT_TIMEOUT, subscribe_once(&client, &channel))
123 .await
124 .with_context(|| format!("timed out opening redis SUB connection to {url}"))?
125 .with_context(|| format!("failed to open redis SUB connection to {url}"))?;
126 drop(probe);
127 }
128
129 let sub_state = Arc::new(RwLock::new(SubState::Connected));
130
131 let notifier = Arc::new_cyclic(|weak: &std::sync::Weak<Self>| {
134 let weak_clone = weak.clone();
135 let inner_clone = inner.clone();
136 let channel_clone = channel.clone();
137 let sub_state_clone = sub_state.clone();
138 let client_clone = client.clone();
139 let self_id_copy = self_id;
140
141 let task = tokio::spawn(async move {
145 loop {
146 let inv = std::panic::AssertUnwindSafe(background_sub_loop(
147 weak_clone.clone(),
148 inner_clone.clone(),
149 self_id_copy,
150 client_clone.clone(),
151 channel_clone.clone(),
152 sub_state_clone.clone(),
153 ));
154 match futures_util::FutureExt::catch_unwind(inv).await {
155 Ok(()) => return, Err(_panic) => {
157 tracing::error!(
158 target: "notifier_redis",
159 "background loop panicked; restarting after 500ms"
160 );
161 *sub_state_clone.write().await = SubState::Reconnecting { attempt: 0 };
162 tokio::time::sleep(Duration::from_millis(500)).await;
163 if weak_clone.upgrade().is_none() {
164 return;
165 }
166 }
167 }
168 }
169 });
170
171 Self {
172 inner,
173 self_id,
174 channel,
175 pub_conn,
176 sub_state,
177 _sub_task: task,
178 }
179 });
180
181 Ok(notifier)
182 }
183}
184
185async fn subscribe_once(client: &Client, channel: &str) -> anyhow::Result<redis::aio::PubSub> {
187 let mut pubsub = client.get_async_pubsub().await?;
188 pubsub.subscribe(channel).await?;
189 Ok(pubsub)
190}
191
192async fn background_sub_loop(
197 notifier_weak: std::sync::Weak<RedisNotifier>,
198 inner: Arc<InMemoryNotifier>,
199 self_id: Uuid,
200 client: Client,
201 channel: String,
202 sub_state: Arc<RwLock<SubState>>,
203) {
204 use futures_util::StreamExt as _;
205
206 loop {
207 let mut sub = loop {
209 if notifier_weak.upgrade().is_none() {
210 return; }
212 match subscribe_once(&client, &channel).await {
213 Ok(s) => {
214 *sub_state.write().await = SubState::Connected;
215 tracing::info!(target: "notifier_redis", "redis_reconnect_ok");
216 break s;
217 }
218 Err(err) => {
219 let attempt = match *sub_state.read().await {
220 SubState::Reconnecting { attempt } => attempt,
221 SubState::Connected => 0,
222 };
223 tracing::debug!(
224 target: "notifier_redis",
225 ?err,
226 attempt,
227 "redis_reconnect_fail"
228 );
229 *sub_state.write().await = SubState::Reconnecting {
230 attempt: attempt + 1,
231 };
232 tokio::time::sleep(backoff_with_jitter(attempt)).await;
233 }
234 }
235 };
236
237 while let Some(msg) = sub.on_message().next().await {
239 let payload: Vec<u8> = msg.get_payload().unwrap_or_default();
240 process_incoming(&payload, self_id, inner.as_ref()).await;
241 }
242
243 *sub_state.write().await = SubState::Reconnecting { attempt: 0 };
245 tracing::warn!(target: "notifier_redis", "redis_disconnected");
246 }
247}
248
249fn backoff_with_jitter(attempt: u32) -> Duration {
251 use rand::RngExt as _;
252
253 let base_ms: u64 = match attempt {
254 0 => 100,
255 1 => 250,
256 2 => 500,
257 3 => 1_000,
258 4 => 2_000,
259 _ => 5_000,
260 };
261 let jitter: f64 = rand::rng().random_range(-20i32..=20i32) as f64 / 100.0;
263 let ms = (base_ms as f64) * (1.0 + jitter);
264 Duration::from_millis(ms.max(1.0) as u64)
265}
266
267pub(crate) async fn process_incoming(payload: &[u8], self_id: Uuid, inner: &dyn ActivityNotifier) {
273 let wire: Wire = match serde_json::from_slice(payload) {
274 Ok(w) => w,
275 Err(err) => {
276 tracing::debug!(target: "notifier_redis", ?err, "redis_decode_err");
277 return;
278 }
279 };
280 if wire.instance_id == self_id {
281 return; }
283 if wire.version != 1 {
284 tracing::warn!(
285 target: "notifier_redis",
286 version = wire.version,
287 "redis_unknown_version"
288 );
289 return;
290 }
291 inner
292 .publish(NotifyEvent {
293 tenant_id: wire.tenant_id,
294 conversation_id: wire.conversation_id,
295 new_watermark: wire.new_watermark,
296 })
297 .await;
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::notifier::{ActivityNotifier, EventStream, NotifierError, NotifyEvent};
304 use async_trait::async_trait;
305 use std::sync::Mutex;
306
307 #[test]
308 fn wire_payload_roundtrip() {
309 let original = Wire {
310 tenant_id: "tenant-a".into(),
311 conversation_id: "conv-1".into(),
312 new_watermark: 42,
313 version: 1,
314 instance_id: Uuid::new_v4(),
315 };
316 let bytes = serde_json::to_vec(&original).expect("encode");
317 let decoded: Wire = serde_json::from_slice(&bytes).expect("decode");
318 assert_eq!(original, decoded);
319 }
320
321 struct RecordingNotifier {
323 published: Mutex<Vec<NotifyEvent>>,
324 }
325
326 impl RecordingNotifier {
327 fn new() -> Self {
328 Self {
329 published: Mutex::new(vec![]),
330 }
331 }
332 fn count(&self) -> usize {
333 self.published.lock().unwrap().len()
334 }
335 }
336
337 #[async_trait]
338 impl ActivityNotifier for RecordingNotifier {
339 async fn publish(&self, event: NotifyEvent) {
340 self.published.lock().unwrap().push(event);
341 }
342 async fn subscribe(
343 &self,
344 _tenant: &str,
345 _conv: &str,
346 ) -> Result<EventStream, NotifierError> {
347 unreachable!("not used in dispatch tests")
348 }
349 }
350
351 fn make_payload(instance_id: Uuid, version: u8) -> Vec<u8> {
352 serde_json::to_vec(&Wire {
353 tenant_id: "t".into(),
354 conversation_id: "c".into(),
355 new_watermark: 7,
356 version,
357 instance_id,
358 })
359 .unwrap()
360 }
361
362 #[tokio::test]
363 async fn loop_suppression_drops_self_publish() {
364 let inner = RecordingNotifier::new();
365 let self_id = Uuid::new_v4();
366 let payload = make_payload(self_id, 1);
367 process_incoming(&payload, self_id, &inner).await;
368 assert_eq!(inner.count(), 0, "self-echo must be dropped");
369 }
370
371 #[tokio::test]
372 async fn loop_suppression_accepts_other_replica() {
373 let inner = RecordingNotifier::new();
374 let self_id = Uuid::new_v4();
375 let other = Uuid::new_v4();
376 let payload = make_payload(other, 1);
377 process_incoming(&payload, self_id, &inner).await;
378 assert_eq!(inner.count(), 1);
379 }
380
381 #[tokio::test]
382 async fn dispatch_drops_unknown_version() {
383 let inner = RecordingNotifier::new();
384 let self_id = Uuid::new_v4();
385 let other = Uuid::new_v4();
386 let payload = make_payload(other, 99);
387 process_incoming(&payload, self_id, &inner).await;
388 assert_eq!(inner.count(), 0);
389 }
390
391 #[tokio::test]
392 async fn dispatch_drops_malformed_payload() {
393 let inner = RecordingNotifier::new();
394 let self_id = Uuid::new_v4();
395 process_incoming(b"not-json{{", self_id, &inner).await;
396 assert_eq!(inner.count(), 0);
397 }
398}