Skip to main content

kevy_embedded/
pubsub.rs

1//! In-process pub/sub bus for embedded `Store`.
2//!
3//! Mirrors the Redis/kevy server pub/sub semantics inside a single process:
4//! `Store::publish` walks the channel + pattern subscriber tables and
5//! enqueues a [`PubsubFrame`] onto each matching [`Subscription`]'s
6//! `std::sync::mpsc` channel. Each `Subscription` drains its own queue via
7//! [`Subscription::recv`] / [`Subscription::recv_timeout`] /
8//! [`Subscription::try_recv`].
9//!
10//! The bus lives inside `Inner` and is reached only under the embedded
11//! mutex; per-publish we clone the matching senders out, drop the lock,
12//! then `send()` — so a slow receiver can't stall publishes on unrelated
13//! channels.
14
15use std::collections::{HashMap, HashSet};
16use std::io;
17use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, TryRecvError, channel};
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use kevy_store::glob_match;
22
23use crate::store::Inner;
24
25/// One pub/sub event delivered to a [`Subscription`].
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum PubsubFrame {
28    /// Ack: `SUBSCRIBE` succeeded on `channel`.
29    Subscribe {
30        /// Channel that was just subscribed.
31        channel: Vec<u8>,
32        /// Total channels + patterns this subscription holds after the op.
33        count: usize,
34    },
35    /// Ack: `PSUBSCRIBE` succeeded on `pattern`.
36    Psubscribe {
37        /// Pattern that was just subscribed.
38        pattern: Vec<u8>,
39        /// Total channels + patterns this subscription holds after the op.
40        count: usize,
41    },
42    /// Ack: `UNSUBSCRIBE` removed `channel` (or "all", when `None`).
43    Unsubscribe {
44        /// Channel that was just unsubscribed (`None` = "all").
45        channel: Option<Vec<u8>>,
46        /// Total channels + patterns still held after the op.
47        count: usize,
48    },
49    /// Ack: `PUNSUBSCRIBE` removed `pattern` (or "all", when `None`).
50    Punsubscribe {
51        /// Pattern that was just unsubscribed (`None` = "all").
52        pattern: Option<Vec<u8>>,
53        /// Total channels + patterns still held after the op.
54        count: usize,
55    },
56    /// A `PUBLISH` reached a channel this subscription holds directly.
57    Message {
58        /// Channel the publish was made to.
59        channel: Vec<u8>,
60        /// Raw payload bytes.
61        payload: Vec<u8>,
62    },
63    /// A `PUBLISH` reached a channel matching one of this subscription's
64    /// patterns.
65    Pmessage {
66        /// Pattern the channel matched.
67        pattern: Vec<u8>,
68        /// Channel the publish was made to.
69        channel: Vec<u8>,
70        /// Raw payload bytes.
71        payload: Vec<u8>,
72    },
73}
74
75/// Internal entry in the bus tables.
76struct BusEntry {
77    id: u64,
78    sender: Sender<PubsubFrame>,
79}
80
81/// The pub/sub registry, owned by `Inner`.
82pub(crate) struct PubsubBus {
83    next_id: u64,
84    channels: HashMap<Vec<u8>, Vec<BusEntry>>,
85    patterns: Vec<(Vec<u8>, BusEntry)>,
86}
87
88impl PubsubBus {
89    pub(crate) fn new() -> Self {
90        Self {
91            next_id: 1,
92            channels: HashMap::new(),
93            patterns: Vec::new(),
94        }
95    }
96
97    fn alloc_id(&mut self) -> u64 {
98        let id = self.next_id;
99        self.next_id = id.wrapping_add(1).max(1);
100        id
101    }
102
103    /// Total channels + patterns the given subscription id is bound to.
104    fn count_for(&self, id: u64) -> usize {
105        let chans = self
106            .channels
107            .values()
108            .filter(|v| v.iter().any(|e| e.id == id))
109            .count();
110        let pats = self.patterns.iter().filter(|(_, e)| e.id == id).count();
111        chans + pats
112    }
113
114    /// Build the per-publish delivery plan: a list of (frame, sender)
115    /// pairs. Caller drops the bus lock before invoking `send()` so a
116    /// slow receiver can't stall unrelated traffic.
117    pub(crate) fn collect_delivery(
118        &self,
119        channel: &[u8],
120        payload: &[u8],
121    ) -> Vec<(PubsubFrame, Sender<PubsubFrame>)> {
122        let mut plans = Vec::new();
123        if let Some(subs) = self.channels.get(channel) {
124            for e in subs {
125                plans.push((
126                    PubsubFrame::Message {
127                        channel: channel.to_vec(),
128                        payload: payload.to_vec(),
129                    },
130                    e.sender.clone(),
131                ));
132            }
133        }
134        for (pat, e) in &self.patterns {
135            if glob_match(pat, channel) {
136                plans.push((
137                    PubsubFrame::Pmessage {
138                        pattern: pat.clone(),
139                        channel: channel.to_vec(),
140                        payload: payload.to_vec(),
141                    },
142                    e.sender.clone(),
143                ));
144            }
145        }
146        plans
147    }
148
149    fn add_channel(&mut self, id: u64, sender: &Sender<PubsubFrame>, channel: Vec<u8>) -> bool {
150        let subs = self.channels.entry(channel).or_default();
151        if subs.iter().any(|e| e.id == id) {
152            return false;
153        }
154        subs.push(BusEntry {
155            id,
156            sender: sender.clone(),
157        });
158        true
159    }
160
161    fn add_pattern(&mut self, id: u64, sender: &Sender<PubsubFrame>, pattern: Vec<u8>) -> bool {
162        if self
163            .patterns
164            .iter()
165            .any(|(p, e)| p == &pattern && e.id == id)
166        {
167            return false;
168        }
169        self.patterns.push((
170            pattern,
171            BusEntry {
172                id,
173                sender: sender.clone(),
174            },
175        ));
176        true
177    }
178
179    fn remove_channel(&mut self, id: u64, channel: &[u8]) -> bool {
180        if let Some(subs) = self.channels.get_mut(channel) {
181            let before = subs.len();
182            subs.retain(|e| e.id != id);
183            let removed = subs.len() < before;
184            if subs.is_empty() {
185                self.channels.remove(channel);
186            }
187            removed
188        } else {
189            false
190        }
191    }
192
193    fn remove_pattern(&mut self, id: u64, pattern: &[u8]) -> bool {
194        let before = self.patterns.len();
195        self.patterns.retain(|(p, e)| !(p == pattern && e.id == id));
196        self.patterns.len() < before
197    }
198
199    fn remove_all_for(&mut self, id: u64) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
200        let mut chans = Vec::new();
201        let mut pats = Vec::new();
202        self.channels.retain(|name, subs| {
203            let had = subs.iter().any(|e| e.id == id);
204            if had {
205                chans.push(name.clone());
206            }
207            subs.retain(|e| e.id != id);
208            !subs.is_empty()
209        });
210        self.patterns.retain(|(p, e)| {
211            if e.id == id {
212                pats.push(p.clone());
213                false
214            } else {
215                true
216            }
217        });
218        (chans, pats)
219    }
220}
221
222/// A handle to one subscription — owns the receive end of the bus channel.
223///
224/// Drop unsubscribes from everything automatically. While the handle is
225/// alive, [`recv`](Self::recv) / [`recv_timeout`](Self::recv_timeout) /
226/// [`try_recv`](Self::try_recv) drain queued [`PubsubFrame`]s in arrival
227/// order.
228#[allow(missing_debug_implementations)]
229pub struct Subscription {
230    inner: Arc<Mutex<Inner>>,
231    // Keeps the AOF/reaper alive as long as a Subscription does — so
232    // dropping every `Store` clone while a subscriber is still active
233    // leaves the keyspace intact until the subscriber also goes away.
234    _guard: Arc<crate::store::DropGuard>,
235    receiver: Receiver<PubsubFrame>,
236    sender: Sender<PubsubFrame>,
237    id: u64,
238    channels: HashSet<Vec<u8>>,
239    patterns: HashSet<Vec<u8>>,
240}
241
242impl Subscription {
243    pub(crate) fn new(inner: Arc<Mutex<Inner>>, guard: Arc<crate::store::DropGuard>) -> Self {
244        let (sender, receiver) = channel();
245        let id = inner
246            .lock()
247            .unwrap_or_else(|p| p.into_inner())
248            .bus
249            .alloc_id();
250        Self {
251            inner,
252            _guard: guard,
253            receiver,
254            sender,
255            id,
256            channels: HashSet::new(),
257            patterns: HashSet::new(),
258        }
259    }
260
261    /// `SUBSCRIBE channel [channel ...]`. Per-channel `Subscribe` acks are
262    /// enqueued onto the receive queue in order.
263    pub fn subscribe(&mut self, channels: &[&[u8]]) {
264        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
265        for ch in channels {
266            let owned = ch.to_vec();
267            let added = g.bus.add_channel(self.id, &self.sender, owned.clone());
268            if added {
269                self.channels.insert(owned.clone());
270            }
271            let count = g.bus.count_for(self.id);
272            let _ = self.sender.send(PubsubFrame::Subscribe {
273                channel: owned,
274                count,
275            });
276        }
277    }
278
279    /// `PSUBSCRIBE pattern [pattern ...]`. Patterns use Redis glob syntax
280    /// (`*`, `?`, `[abc]`).
281    pub fn psubscribe(&mut self, patterns: &[&[u8]]) {
282        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
283        for pat in patterns {
284            let owned = pat.to_vec();
285            let added = g.bus.add_pattern(self.id, &self.sender, owned.clone());
286            if added {
287                self.patterns.insert(owned.clone());
288            }
289            let count = g.bus.count_for(self.id);
290            let _ = self.sender.send(PubsubFrame::Psubscribe {
291                pattern: owned,
292                count,
293            });
294        }
295    }
296
297    /// `UNSUBSCRIBE [channel ...]`. Empty `channels` removes every channel
298    /// subscription this handle holds (matching the Redis wire shape:
299    /// individual ack frames for each channel that was actually removed,
300    /// or a single `Unsubscribe { channel: None }` if none were held).
301    pub fn unsubscribe(&mut self, channels: &[&[u8]]) {
302        if channels.is_empty() {
303            self.drain_channel_subs();
304            return;
305        }
306        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
307        for ch in channels {
308            let owned = ch.to_vec();
309            let _ = g.bus.remove_channel(self.id, &owned);
310            self.channels.remove(&owned);
311            let count = g.bus.count_for(self.id);
312            let _ = self.sender.send(PubsubFrame::Unsubscribe {
313                channel: Some(owned),
314                count,
315            });
316        }
317    }
318
319    /// `PUNSUBSCRIBE [pattern ...]`. Empty `patterns` removes every pattern.
320    pub fn punsubscribe(&mut self, patterns: &[&[u8]]) {
321        if patterns.is_empty() {
322            self.drain_pattern_subs();
323            return;
324        }
325        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
326        for pat in patterns {
327            let owned = pat.to_vec();
328            let _ = g.bus.remove_pattern(self.id, &owned);
329            self.patterns.remove(&owned);
330            let count = g.bus.count_for(self.id);
331            let _ = self.sender.send(PubsubFrame::Punsubscribe {
332                pattern: Some(owned),
333                count,
334            });
335        }
336    }
337
338    fn drain_channel_subs(&mut self) {
339        let owned: Vec<Vec<u8>> = self.channels.drain().collect();
340        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
341        if owned.is_empty() {
342            let count = g.bus.count_for(self.id);
343            let _ = self
344                .sender
345                .send(PubsubFrame::Unsubscribe { channel: None, count });
346            return;
347        }
348        for ch in owned {
349            let _ = g.bus.remove_channel(self.id, &ch);
350            let count = g.bus.count_for(self.id);
351            let _ = self.sender.send(PubsubFrame::Unsubscribe {
352                channel: Some(ch),
353                count,
354            });
355        }
356    }
357
358    fn drain_pattern_subs(&mut self) {
359        let owned: Vec<Vec<u8>> = self.patterns.drain().collect();
360        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
361        if owned.is_empty() {
362            let count = g.bus.count_for(self.id);
363            let _ = self
364                .sender
365                .send(PubsubFrame::Punsubscribe { pattern: None, count });
366            return;
367        }
368        for p in owned {
369            let _ = g.bus.remove_pattern(self.id, &p);
370            let count = g.bus.count_for(self.id);
371            let _ = self.sender.send(PubsubFrame::Punsubscribe {
372                pattern: Some(p),
373                count,
374            });
375        }
376    }
377
378    /// Block until one frame is queued. `Err(io::ErrorKind::UnexpectedEof)`
379    /// once the underlying bus tears down (last `Store` clone dropped).
380    pub fn recv(&self) -> io::Result<PubsubFrame> {
381        self.receiver
382            .recv()
383            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
384    }
385
386    /// Bounded blocking recv. `Err(io::ErrorKind::TimedOut)` when `dur`
387    /// elapses; `Err(io::ErrorKind::UnexpectedEof)` when the bus is gone.
388    pub fn recv_timeout(&self, dur: Duration) -> io::Result<PubsubFrame> {
389        self.receiver.recv_timeout(dur).map_err(|e| match e {
390            RecvTimeoutError::Timeout => io::Error::from(io::ErrorKind::TimedOut),
391            RecvTimeoutError::Disconnected => {
392                io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed")
393            }
394        })
395    }
396
397    /// Non-blocking recv. `Ok(None)` if the queue is empty;
398    /// `Err(UnexpectedEof)` when the bus is gone.
399    pub fn try_recv(&self) -> io::Result<Option<PubsubFrame>> {
400        match self.receiver.try_recv() {
401            Ok(f) => Ok(Some(f)),
402            Err(TryRecvError::Empty) => Ok(None),
403            Err(TryRecvError::Disconnected) => {
404                Err(io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
405            }
406        }
407    }
408}
409
410impl std::fmt::Debug for Subscription {
411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412        f.debug_struct("Subscription")
413            .field("id", &self.id)
414            .field("channels", &self.channels.len())
415            .field("patterns", &self.patterns.len())
416            .finish_non_exhaustive()
417    }
418}
419
420impl Drop for Subscription {
421    fn drop(&mut self) {
422        // Best-effort cleanup. If the underlying Inner is poisoned we still
423        // remove our entries; the AtomicBool / send stuff doesn't care.
424        if let Ok(mut g) = self.inner.lock() {
425            g.bus.remove_all_for(self.id);
426        } else if let Ok(mut g) = self.inner.clear_poison_and_lock() {
427            // Mutex::clear_poison + reacquire is stable since Rust 1.77; we
428            // pin rust-version=1.95 so this is available. The `else` branch
429            // above is unreachable in practice given we always recover from
430            // poison ourselves; left here so the cleanup is total.
431            g.bus.remove_all_for(self.id);
432        }
433    }
434}
435
436/// Tiny helper trait so `Drop` can recover from poison without
437/// pulling in the explicit `poison.into_inner()` dance. Local to the
438/// module; not part of the public API.
439trait LockExt<'a, T> {
440    fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>>;
441}
442
443impl<'a, T> LockExt<'a, T> for Mutex<T> {
444    fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>> {
445        self.clear_poison();
446        self.lock()
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453    use crate::{Config, Store};
454
455    fn store() -> Store {
456        Store::open(Config::default().with_ttl_reaper_manual()).unwrap()
457    }
458
459    #[test]
460    fn publish_to_no_subscribers_returns_zero() {
461        let s = store();
462        assert_eq!(s.publish(b"chan", b"hi"), 0);
463    }
464
465    #[test]
466    fn subscribe_ack_then_message_delivered() {
467        let s = store();
468        let sub = s.subscribe(&[b"news"]);
469        // Drain the SUBSCRIBE ack.
470        assert_eq!(
471            sub.recv().unwrap(),
472            PubsubFrame::Subscribe {
473                channel: b"news".to_vec(),
474                count: 1,
475            }
476        );
477        // Same store handle (or a clone) can publish.
478        assert_eq!(s.publish(b"news", b"hello"), 1);
479        assert_eq!(
480            sub.recv().unwrap(),
481            PubsubFrame::Message {
482                channel: b"news".to_vec(),
483                payload: b"hello".to_vec(),
484            }
485        );
486    }
487
488    #[test]
489    fn store_clone_publishes_reach_other_clones_subscribers() {
490        let s1 = store();
491        let s2 = s1.clone();
492        let sub = s1.subscribe(&[b"x"]);
493        let _ = sub.recv().unwrap(); // ack
494        assert_eq!(s2.publish(b"x", b"v"), 1);
495        assert_eq!(
496            sub.recv().unwrap(),
497            PubsubFrame::Message {
498                channel: b"x".to_vec(),
499                payload: b"v".to_vec(),
500            }
501        );
502    }
503
504    #[test]
505    fn psubscribe_glob_match_delivers_pmessage() {
506        let s = store();
507        let sub = s.psubscribe(&[b"news.*"]);
508        let _ = sub.recv().unwrap(); // psubscribe ack
509        assert_eq!(s.publish(b"news.tech", b"breaking"), 1);
510        assert_eq!(
511            sub.recv().unwrap(),
512            PubsubFrame::Pmessage {
513                pattern: b"news.*".to_vec(),
514                channel: b"news.tech".to_vec(),
515                payload: b"breaking".to_vec(),
516            }
517        );
518        // Non-matching publish does not reach the subscriber.
519        assert_eq!(s.publish(b"weather", b"sunny"), 0);
520        assert!(sub.try_recv().unwrap().is_none());
521    }
522
523    #[test]
524    fn duplicate_subscribe_does_not_duplicate_delivery() {
525        let s = store();
526        let mut sub = s.subscribe(&[b"x"]);
527        sub.subscribe(&[b"x"]); // second call to same channel: no-op
528        // Drain the two acks (one from subscribe(), one from the second call).
529        let a1 = sub.recv().unwrap();
530        let a2 = sub.recv().unwrap();
531        assert!(matches!(a1, PubsubFrame::Subscribe { count: 1, .. }));
532        assert!(matches!(a2, PubsubFrame::Subscribe { count: 1, .. }));
533        // Single delivery, despite "double subscribe".
534        assert_eq!(s.publish(b"x", b"v"), 1);
535        let _ = sub.recv().unwrap();
536        assert!(sub.try_recv().unwrap().is_none());
537    }
538
539    #[test]
540    fn unsubscribe_removes_then_no_more_messages() {
541        let s = store();
542        let mut sub = s.subscribe(&[b"x"]);
543        let _ = sub.recv().unwrap();
544        sub.unsubscribe(&[b"x"]);
545        // Drain the unsubscribe ack.
546        assert!(matches!(
547            sub.recv().unwrap(),
548            PubsubFrame::Unsubscribe {
549                channel: Some(_),
550                count: 0
551            }
552        ));
553        // Publishes no longer reach us.
554        assert_eq!(s.publish(b"x", b"v"), 0);
555    }
556
557    #[test]
558    fn unsubscribe_all_with_empty_args_drains_every_channel() {
559        let s = store();
560        let mut sub = s.subscribe(&[b"a", b"b"]);
561        let _ = sub.recv().unwrap();
562        let _ = sub.recv().unwrap();
563        sub.unsubscribe(&[]);
564        // Two unsubscribe acks, one per removed channel.
565        for _ in 0..2 {
566            assert!(matches!(
567                sub.recv().unwrap(),
568                PubsubFrame::Unsubscribe {
569                    channel: Some(_),
570                    ..
571                }
572            ));
573        }
574        // Publishes go nowhere now.
575        assert_eq!(s.publish(b"a", b"x"), 0);
576        assert_eq!(s.publish(b"b", b"x"), 0);
577    }
578
579    #[test]
580    fn unsubscribe_when_no_subs_held_emits_nil_channel_ack() {
581        let s = store();
582        let mut sub = s.subscribe(&[]); // empty start
583        sub.unsubscribe(&[]);
584        assert!(matches!(
585            sub.recv().unwrap(),
586            PubsubFrame::Unsubscribe {
587                channel: None,
588                count: 0
589            }
590        ));
591    }
592
593    #[test]
594    fn drop_subscriber_unregisters() {
595        let s = store();
596        let sub = s.subscribe(&[b"x"]);
597        let _ = sub.recv().unwrap();
598        assert_eq!(s.publish(b"x", b"v"), 1);
599        let _ = sub.recv().unwrap();
600        drop(sub);
601        assert_eq!(s.publish(b"x", b"v"), 0);
602    }
603
604    #[test]
605    fn recv_timeout_returns_timeout_when_empty() {
606        let s = store();
607        let sub = s.subscribe(&[b"x"]);
608        // Drain the ack first.
609        let _ = sub.recv_timeout(Duration::from_millis(100)).unwrap();
610        let err = sub
611            .recv_timeout(Duration::from_millis(50))
612            .unwrap_err();
613        assert_eq!(err.kind(), io::ErrorKind::TimedOut);
614    }
615}