Skip to main content

greentic_start/notifier/
redis.rs

1//! Redis pub/sub backplane for the WebChat WS notifier.
2//!
3//! See docs/superpowers/specs/2026-05-01-webchat-ws-redis-backplane-design.md.
4
5use anyhow::Context as _;
6use redis::aio::ConnectionManager;
7use redis::{AsyncCommands as _, Client};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::RwLock;
12use uuid::Uuid;
13
14use crate::notifier::{
15    ActivityNotifier, EventStream, InMemoryNotifier, NotifierError, NotifyEvent,
16};
17
18const DEFAULT_CHANNEL: &str = "greentic:webchat:notify";
19const BOOT_CONNECT_TIMEOUT: Duration = Duration::from_secs(2);
20
21/// Wire payload exchanged over the global pub/sub channel.
22///
23/// `instance_id` is the per-process UUID used for self-echo suppression.
24/// `version` allows future forward-compatible payload changes.
25#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
26pub struct Wire {
27    pub tenant_id: String,
28    pub conversation_id: String,
29    pub new_watermark: u64,
30    pub version: u8,
31    pub instance_id: Uuid,
32}
33
34enum SubState {
35    Connected,
36    Reconnecting { attempt: u32 },
37}
38
39/// Redis pub/sub backplane wrapping an `InMemoryNotifier` for local fan-out.
40///
41/// `build` fails fast if Redis is unreachable at startup. Once running,
42/// publish-to-Redis is fire-and-forget (local fan-out happens first). A
43/// background SUB task handles reconnects with exponential backoff and is
44/// supervised against panics.
45pub struct RedisNotifier {
46    inner: Arc<InMemoryNotifier>,
47    self_id: Uuid,
48    channel: String,
49    pub_conn: ConnectionManager,
50    #[allow(dead_code)]
51    sub_state: Arc<RwLock<SubState>>,
52    _sub_task: tokio::task::JoinHandle<()>,
53}
54
55#[async_trait::async_trait]
56impl ActivityNotifier for RedisNotifier {
57    async fn publish(&self, event: NotifyEvent) {
58        // Local first — never block on Redis health.
59        self.inner.publish(event.clone()).await;
60
61        // Mirror to Redis fire-and-forget.
62        let payload = match serde_json::to_vec(&Wire {
63            tenant_id: event.tenant_id,
64            conversation_id: event.conversation_id,
65            new_watermark: event.new_watermark,
66            version: 1,
67            instance_id: self.self_id,
68        }) {
69            Ok(p) => p,
70            Err(err) => {
71                tracing::warn!(target: "notifier_redis", ?err, "redis_encode_err");
72                return;
73            }
74        };
75        let mut pub_conn = self.pub_conn.clone();
76        let channel = self.channel.clone();
77        tokio::spawn(async move {
78            if let Err(err) = pub_conn.publish::<_, _, ()>(&channel, payload).await {
79                tracing::debug!(target: "notifier_redis", ?err, "redis_publish_dropped");
80            }
81        });
82    }
83
84    async fn subscribe(
85        &self,
86        tenant_id: &str,
87        conversation_id: &str,
88    ) -> Result<EventStream, NotifierError> {
89        // No Redis call per subscribe — delegate to the in-memory broadcast.
90        self.inner.subscribe(tenant_id, conversation_id).await
91    }
92}
93
94impl RedisNotifier {
95    /// Open PUB and SUB connections to Redis, verify connectivity, spawn the
96    /// background SUB loop, and return a reference-counted handle.
97    ///
98    /// Fails immediately if the URL is invalid or either connection cannot be
99    /// established (strict startup).
100    pub async fn build(
101        url: &str,
102        channel: Option<String>,
103        capacity: usize,
104    ) -> anyhow::Result<Arc<Self>> {
105        let channel = channel.unwrap_or_else(|| DEFAULT_CHANNEL.to_string());
106        let inner = Arc::new(InMemoryNotifier::new(capacity));
107        let self_id = Uuid::new_v4();
108
109        let client = Client::open(url).with_context(|| format!("invalid redis url: {url}"))?;
110
111        // Open the PUB connection (ConnectionManager auto-reconnects on use).
112        let pub_conn =
113            tokio::time::timeout(BOOT_CONNECT_TIMEOUT, ConnectionManager::new(client.clone()))
114                .await
115                .with_context(|| format!("timed out opening redis PUB connection to {url}"))?
116                .with_context(|| format!("failed to open redis PUB connection to {url}"))?;
117
118        // Verify SUB connectivity once at boot by opening and immediately
119        // dropping a probe connection.
120        {
121            let probe =
122                tokio::time::timeout(BOOT_CONNECT_TIMEOUT, subscribe_once(&client, &channel))
123                    .await
124                    .with_context(|| format!("timed out opening redis SUB connection to {url}"))?
125                    .with_context(|| format!("failed to open redis SUB connection to {url}"))?;
126            drop(probe);
127        }
128
129        let sub_state = Arc::new(RwLock::new(SubState::Connected));
130
131        // Arc::new_cyclic lets the background task hold a Weak<Self> so it
132        // can detect when the parent is dropped and exit cleanly.
133        let notifier = Arc::new_cyclic(|weak: &std::sync::Weak<Self>| {
134            let weak_clone = weak.clone();
135            let inner_clone = inner.clone();
136            let channel_clone = channel.clone();
137            let sub_state_clone = sub_state.clone();
138            let client_clone = client.clone();
139            let self_id_copy = self_id;
140
141            // Supervisor wrapper: catch panics inside the loop and restart.
142            // Without this, a panic in `background_sub_loop` would silently
143            // kill cross-replica delivery for the lifetime of the process.
144            let task = tokio::spawn(async move {
145                loop {
146                    let inv = std::panic::AssertUnwindSafe(background_sub_loop(
147                        weak_clone.clone(),
148                        inner_clone.clone(),
149                        self_id_copy,
150                        client_clone.clone(),
151                        channel_clone.clone(),
152                        sub_state_clone.clone(),
153                    ));
154                    match futures_util::FutureExt::catch_unwind(inv).await {
155                        Ok(()) => return, // clean exit (parent dropped)
156                        Err(_panic) => {
157                            tracing::error!(
158                                target: "notifier_redis",
159                                "background loop panicked; restarting after 500ms"
160                            );
161                            *sub_state_clone.write().await = SubState::Reconnecting { attempt: 0 };
162                            tokio::time::sleep(Duration::from_millis(500)).await;
163                            if weak_clone.upgrade().is_none() {
164                                return;
165                            }
166                        }
167                    }
168                }
169            });
170
171            Self {
172                inner,
173                self_id,
174                channel,
175                pub_conn,
176                sub_state,
177                _sub_task: task,
178            }
179        });
180
181        Ok(notifier)
182    }
183}
184
185/// Open a fresh async pub/sub connection and subscribe to `channel`.
186async fn subscribe_once(client: &Client, channel: &str) -> anyhow::Result<redis::aio::PubSub> {
187    let mut pubsub = client.get_async_pubsub().await?;
188    pubsub.subscribe(channel).await?;
189    Ok(pubsub)
190}
191
192/// Long-running background task: subscribe, drain messages, reconnect on drop.
193///
194/// Exits cleanly when the parent `Arc<RedisNotifier>` is dropped
195/// (`notifier_weak.upgrade()` returns `None`).
196async fn background_sub_loop(
197    notifier_weak: std::sync::Weak<RedisNotifier>,
198    inner: Arc<InMemoryNotifier>,
199    self_id: Uuid,
200    client: Client,
201    channel: String,
202    sub_state: Arc<RwLock<SubState>>,
203) {
204    use futures_util::StreamExt as _;
205
206    loop {
207        // (Re)subscribe with bounded backoff.
208        let mut sub = loop {
209            if notifier_weak.upgrade().is_none() {
210                return; // parent dropped
211            }
212            match subscribe_once(&client, &channel).await {
213                Ok(s) => {
214                    *sub_state.write().await = SubState::Connected;
215                    tracing::info!(target: "notifier_redis", "redis_reconnect_ok");
216                    break s;
217                }
218                Err(err) => {
219                    let attempt = match *sub_state.read().await {
220                        SubState::Reconnecting { attempt } => attempt,
221                        SubState::Connected => 0,
222                    };
223                    tracing::debug!(
224                        target: "notifier_redis",
225                        ?err,
226                        attempt,
227                        "redis_reconnect_fail"
228                    );
229                    *sub_state.write().await = SubState::Reconnecting {
230                        attempt: attempt + 1,
231                    };
232                    tokio::time::sleep(backoff_with_jitter(attempt)).await;
233                }
234            }
235        };
236
237        // Drain messages until the connection ends.
238        while let Some(msg) = sub.on_message().next().await {
239            let payload: Vec<u8> = msg.get_payload().unwrap_or_default();
240            process_incoming(&payload, self_id, inner.as_ref()).await;
241        }
242
243        // Stream ended = disconnect; go back to the (re)subscribe arm.
244        *sub_state.write().await = SubState::Reconnecting { attempt: 0 };
245        tracing::warn!(target: "notifier_redis", "redis_disconnected");
246    }
247}
248
249/// Exponential backoff with ±20% jitter.
250fn backoff_with_jitter(attempt: u32) -> Duration {
251    use rand::RngExt as _;
252
253    let base_ms: u64 = match attempt {
254        0 => 100,
255        1 => 250,
256        2 => 500,
257        3 => 1_000,
258        4 => 2_000,
259        _ => 5_000,
260    };
261    // rand 0.10: rand::rng() returns ThreadRng; random_range replaces gen_range.
262    let jitter: f64 = rand::rng().random_range(-20i32..=20i32) as f64 / 100.0;
263    let ms = (base_ms as f64) * (1.0 + jitter);
264    Duration::from_millis(ms.max(1.0) as u64)
265}
266
267/// Decode a payload received over the Redis SUB stream and dispatch it
268/// to the inner notifier, dropping self-echoes and unknown versions.
269///
270/// Extracted as a free function so unit tests can exercise it without
271/// spinning up a Redis connection.
272pub(crate) async fn process_incoming(payload: &[u8], self_id: Uuid, inner: &dyn ActivityNotifier) {
273    let wire: Wire = match serde_json::from_slice(payload) {
274        Ok(w) => w,
275        Err(err) => {
276            tracing::debug!(target: "notifier_redis", ?err, "redis_decode_err");
277            return;
278        }
279    };
280    if wire.instance_id == self_id {
281        return; // self-echo
282    }
283    if wire.version != 1 {
284        tracing::warn!(
285            target: "notifier_redis",
286            version = wire.version,
287            "redis_unknown_version"
288        );
289        return;
290    }
291    inner
292        .publish(NotifyEvent {
293            tenant_id: wire.tenant_id,
294            conversation_id: wire.conversation_id,
295            new_watermark: wire.new_watermark,
296        })
297        .await;
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use crate::notifier::{ActivityNotifier, EventStream, NotifierError, NotifyEvent};
304    use async_trait::async_trait;
305    use std::sync::Mutex;
306
307    #[test]
308    fn wire_payload_roundtrip() {
309        let original = Wire {
310            tenant_id: "tenant-a".into(),
311            conversation_id: "conv-1".into(),
312            new_watermark: 42,
313            version: 1,
314            instance_id: Uuid::new_v4(),
315        };
316        let bytes = serde_json::to_vec(&original).expect("encode");
317        let decoded: Wire = serde_json::from_slice(&bytes).expect("decode");
318        assert_eq!(original, decoded);
319    }
320
321    /// Test double that records every publish call.
322    struct RecordingNotifier {
323        published: Mutex<Vec<NotifyEvent>>,
324    }
325
326    impl RecordingNotifier {
327        fn new() -> Self {
328            Self {
329                published: Mutex::new(vec![]),
330            }
331        }
332        fn count(&self) -> usize {
333            self.published.lock().unwrap().len()
334        }
335    }
336
337    #[async_trait]
338    impl ActivityNotifier for RecordingNotifier {
339        async fn publish(&self, event: NotifyEvent) {
340            self.published.lock().unwrap().push(event);
341        }
342        async fn subscribe(
343            &self,
344            _tenant: &str,
345            _conv: &str,
346        ) -> Result<EventStream, NotifierError> {
347            unreachable!("not used in dispatch tests")
348        }
349    }
350
351    fn make_payload(instance_id: Uuid, version: u8) -> Vec<u8> {
352        serde_json::to_vec(&Wire {
353            tenant_id: "t".into(),
354            conversation_id: "c".into(),
355            new_watermark: 7,
356            version,
357            instance_id,
358        })
359        .unwrap()
360    }
361
362    #[tokio::test]
363    async fn loop_suppression_drops_self_publish() {
364        let inner = RecordingNotifier::new();
365        let self_id = Uuid::new_v4();
366        let payload = make_payload(self_id, 1);
367        process_incoming(&payload, self_id, &inner).await;
368        assert_eq!(inner.count(), 0, "self-echo must be dropped");
369    }
370
371    #[tokio::test]
372    async fn loop_suppression_accepts_other_replica() {
373        let inner = RecordingNotifier::new();
374        let self_id = Uuid::new_v4();
375        let other = Uuid::new_v4();
376        let payload = make_payload(other, 1);
377        process_incoming(&payload, self_id, &inner).await;
378        assert_eq!(inner.count(), 1);
379    }
380
381    #[tokio::test]
382    async fn dispatch_drops_unknown_version() {
383        let inner = RecordingNotifier::new();
384        let self_id = Uuid::new_v4();
385        let other = Uuid::new_v4();
386        let payload = make_payload(other, 99);
387        process_incoming(&payload, self_id, &inner).await;
388        assert_eq!(inner.count(), 0);
389    }
390
391    #[tokio::test]
392    async fn dispatch_drops_malformed_payload() {
393        let inner = RecordingNotifier::new();
394        let self_id = Uuid::new_v4();
395        process_incoming(b"not-json{{", self_id, &inner).await;
396        assert_eq!(inner.count(), 0);
397    }
398}