Skip to main content

kevy_embedded/
pubsub.rs

1//! In-process pub/sub bus for embedded `Store`.
2//!
3//! Mirrors the Redis/kevy server pub/sub semantics inside a single process:
4//! `Store::publish` walks the channel + pattern subscriber tables and
5//! enqueues a [`PubsubFrame`] onto each matching [`Subscription`]'s
6//! `std::sync::mpsc` channel. Each `Subscription` drains its own queue via
7//! [`Subscription::recv`] / [`Subscription::recv_timeout`] /
8//! [`Subscription::try_recv`].
9//!
10//! The bus lives inside `Inner` and is reached only under the embedded
11//! mutex; per-publish we clone the matching senders out, drop the lock,
12//! then `send()` — so a slow receiver can't stall publishes on unrelated
13//! channels.
14
15use std::collections::{HashMap, HashSet};
16use std::io;
17use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, TryRecvError, channel};
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use kevy_store::glob_match;
22
23use crate::store::Inner;
24
25/// One pub/sub event delivered to a [`Subscription`].
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum PubsubFrame {
28    /// Ack: `SUBSCRIBE` succeeded on `channel`.
29    Subscribe {
30        /// Channel that was just subscribed.
31        channel: Vec<u8>,
32        /// Total channels + patterns this subscription holds after the op.
33        count: usize,
34    },
35    /// Ack: `PSUBSCRIBE` succeeded on `pattern`.
36    Psubscribe {
37        /// Pattern that was just subscribed.
38        pattern: Vec<u8>,
39        /// Total channels + patterns this subscription holds after the op.
40        count: usize,
41    },
42    /// Ack: `UNSUBSCRIBE` removed `channel` (or "all", when `None`).
43    Unsubscribe {
44        /// Channel that was just unsubscribed (`None` = "all").
45        channel: Option<Vec<u8>>,
46        /// Total channels + patterns still held after the op.
47        count: usize,
48    },
49    /// Ack: `PUNSUBSCRIBE` removed `pattern` (or "all", when `None`).
50    Punsubscribe {
51        /// Pattern that was just unsubscribed (`None` = "all").
52        pattern: Option<Vec<u8>>,
53        /// Total channels + patterns still held after the op.
54        count: usize,
55    },
56    /// A `PUBLISH` reached a channel this subscription holds directly.
57    Message {
58        /// Channel the publish was made to.
59        channel: Vec<u8>,
60        /// Raw payload bytes.
61        payload: Vec<u8>,
62    },
63    /// A `PUBLISH` reached a channel matching one of this subscription's
64    /// patterns.
65    Pmessage {
66        /// Pattern the channel matched.
67        pattern: Vec<u8>,
68        /// Channel the publish was made to.
69        channel: Vec<u8>,
70        /// Raw payload bytes.
71        payload: Vec<u8>,
72    },
73}
74
75/// Internal entry in the bus tables.
76struct BusEntry {
77    id: u64,
78    sender: Sender<PubsubFrame>,
79}
80
81/// The pub/sub registry, owned by `Inner`.
82pub(crate) struct PubsubBus {
83    next_id: u64,
84    channels: HashMap<Vec<u8>, Vec<BusEntry>>,
85    patterns: Vec<(Vec<u8>, BusEntry)>,
86}
87
88impl PubsubBus {
89    pub(crate) fn new() -> Self {
90        Self {
91            next_id: 1,
92            channels: HashMap::new(),
93            patterns: Vec::new(),
94        }
95    }
96
97    fn alloc_id(&mut self) -> u64 {
98        let id = self.next_id;
99        self.next_id = id.wrapping_add(1).max(1);
100        id
101    }
102
103    /// Total channels + patterns the given subscription id is bound to.
104    fn count_for(&self, id: u64) -> usize {
105        let chans = self
106            .channels
107            .values()
108            .filter(|v| v.iter().any(|e| e.id == id))
109            .count();
110        let pats = self.patterns.iter().filter(|(_, e)| e.id == id).count();
111        chans + pats
112    }
113
114    /// Build the per-publish delivery plan: a list of (frame, sender)
115    /// pairs. Caller drops the bus lock before invoking `send()` so a
116    /// slow receiver can't stall unrelated traffic.
117    pub(crate) fn collect_delivery(
118        &self,
119        channel: &[u8],
120        payload: &[u8],
121    ) -> Vec<(PubsubFrame, Sender<PubsubFrame>)> {
122        let mut plans = Vec::new();
123        if let Some(subs) = self.channels.get(channel) {
124            for e in subs {
125                plans.push((
126                    PubsubFrame::Message {
127                        channel: channel.to_vec(),
128                        payload: payload.to_vec(),
129                    },
130                    e.sender.clone(),
131                ));
132            }
133        }
134        for (pat, e) in &self.patterns {
135            if glob_match(pat, channel) {
136                plans.push((
137                    PubsubFrame::Pmessage {
138                        pattern: pat.clone(),
139                        channel: channel.to_vec(),
140                        payload: payload.to_vec(),
141                    },
142                    e.sender.clone(),
143                ));
144            }
145        }
146        plans
147    }
148
149    fn add_channel(&mut self, id: u64, sender: &Sender<PubsubFrame>, channel: Vec<u8>) -> bool {
150        let subs = self.channels.entry(channel).or_default();
151        if subs.iter().any(|e| e.id == id) {
152            return false;
153        }
154        subs.push(BusEntry {
155            id,
156            sender: sender.clone(),
157        });
158        true
159    }
160
161    fn add_pattern(&mut self, id: u64, sender: &Sender<PubsubFrame>, pattern: Vec<u8>) -> bool {
162        if self
163            .patterns
164            .iter()
165            .any(|(p, e)| p == &pattern && e.id == id)
166        {
167            return false;
168        }
169        self.patterns.push((
170            pattern,
171            BusEntry {
172                id,
173                sender: sender.clone(),
174            },
175        ));
176        true
177    }
178
179    fn remove_channel(&mut self, id: u64, channel: &[u8]) -> bool {
180        if let Some(subs) = self.channels.get_mut(channel) {
181            let before = subs.len();
182            subs.retain(|e| e.id != id);
183            let removed = subs.len() < before;
184            if subs.is_empty() {
185                self.channels.remove(channel);
186            }
187            removed
188        } else {
189            false
190        }
191    }
192
193    fn remove_pattern(&mut self, id: u64, pattern: &[u8]) -> bool {
194        let before = self.patterns.len();
195        self.patterns.retain(|(p, e)| !(p == pattern && e.id == id));
196        self.patterns.len() < before
197    }
198
199    fn remove_all_for(&mut self, id: u64) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
200        let mut chans = Vec::new();
201        let mut pats = Vec::new();
202        self.channels.retain(|name, subs| {
203            let had = subs.iter().any(|e| e.id == id);
204            if had {
205                chans.push(name.clone());
206            }
207            subs.retain(|e| e.id != id);
208            !subs.is_empty()
209        });
210        self.patterns.retain(|(p, e)| {
211            if e.id == id {
212                pats.push(p.clone());
213                false
214            } else {
215                true
216            }
217        });
218        (chans, pats)
219    }
220}
221
222/// A handle to one subscription — owns the receive end of the bus channel.
223///
224/// Drop unsubscribes from everything automatically. While the handle is
225/// alive, [`recv`](Self::recv) / [`recv_timeout`](Self::recv_timeout) /
226/// [`try_recv`](Self::try_recv) drain queued [`PubsubFrame`]s in arrival
227/// order.
228///
229/// **Threading.** `Subscription` is `Send + Sync` —
230/// `Arc<Subscription>` works, so multiple async tasks (or
231/// `spawn_blocking` jobs) can share one subscription and call `recv`
232/// concurrently. The underlying `std::sync::mpsc::Receiver` is
233/// !Sync, so we wrap it (and the matching ack `Sender`) in a `Mutex`;
234/// concurrent `recv` callers serialise on that lock, with each call
235/// receiving a *different* frame in arrival order (single-consumer
236/// semantics — NOT broadcast fanout). `try_recv` is non-blocking even
237/// under contention: if the lock is held by a blocking `recv`,
238/// `try_recv` returns `Ok(None)` rather than waiting.
239///
240/// If you need broadcast fanout (every subscriber sees every message),
241/// open a separate `Subscription` per consumer — they're cheap.
242#[allow(missing_debug_implementations)]
243pub struct Subscription {
244    inner: Arc<Mutex<Inner>>,
245    // Keeps the AOF/reaper alive as long as a Subscription does — so
246    // dropping every `Store` clone while a subscriber is still active
247    // leaves the keyspace intact until the subscriber also goes away.
248    _guard: Arc<crate::store::DropGuard>,
249    // `Receiver<T>` is `Send + !Sync`; wrap so `Subscription: Sync`.
250    // Hot path (recv) acquires + holds the lock during the blocking
251    // wait — single consumer at a time; concurrent recv callers
252    // serialise and each get a different frame. See type-level
253    // doc-comment for the trade-off.
254    receiver: Mutex<Receiver<PubsubFrame>>,
255    // `Sender<T>` is also !Sync (Send + Clone but cannot be shared by
256    // reference across threads). Wrap so the ack-frame path (called
257    // from subscribe/unsubscribe / Drop) can run from any thread.
258    sender: Mutex<Sender<PubsubFrame>>,
259    id: u64,
260    channels: HashSet<Vec<u8>>,
261    patterns: HashSet<Vec<u8>>,
262}
263
264impl Subscription {
265    pub(crate) fn new(inner: Arc<Mutex<Inner>>, guard: Arc<crate::store::DropGuard>) -> Self {
266        let (sender, receiver) = channel();
267        let id = inner
268            .lock()
269            .unwrap_or_else(|p| p.into_inner())
270            .bus
271            .alloc_id();
272        Self {
273            inner,
274            _guard: guard,
275            receiver: Mutex::new(receiver),
276            sender: Mutex::new(sender),
277            id,
278            channels: HashSet::new(),
279            patterns: HashSet::new(),
280        }
281    }
282
283    /// Clone of the inbound `Sender`. Used both for ack frames (Subscribe /
284    /// Unsubscribe / ...) and to register a sender clone inside
285    /// `PubsubBus`. Calling this acquires the sender lock briefly (~20 ns).
286    fn sender_clone(&self) -> Sender<PubsubFrame> {
287        self.sender
288            .lock()
289            .unwrap_or_else(|p| p.into_inner())
290            .clone()
291    }
292
293    /// `SUBSCRIBE channel [channel ...]`. Per-channel `Subscribe` acks are
294    /// enqueued onto the receive queue in order.
295    pub fn subscribe(&mut self, channels: &[&[u8]]) {
296        let s = self.sender_clone();
297        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
298        for ch in channels {
299            let owned = ch.to_vec();
300            let added = g.bus.add_channel(self.id, &s, owned.clone());
301            if added {
302                self.channels.insert(owned.clone());
303            }
304            let count = g.bus.count_for(self.id);
305            let _ = s.send(PubsubFrame::Subscribe {
306                channel: owned,
307                count,
308            });
309        }
310    }
311
312    /// `PSUBSCRIBE pattern [pattern ...]`. Patterns use Redis glob syntax
313    /// (`*`, `?`, `[abc]`).
314    pub fn psubscribe(&mut self, patterns: &[&[u8]]) {
315        let s = self.sender_clone();
316        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
317        for pat in patterns {
318            let owned = pat.to_vec();
319            let added = g.bus.add_pattern(self.id, &s, owned.clone());
320            if added {
321                self.patterns.insert(owned.clone());
322            }
323            let count = g.bus.count_for(self.id);
324            let _ = s.send(PubsubFrame::Psubscribe {
325                pattern: owned,
326                count,
327            });
328        }
329    }
330
331    /// `UNSUBSCRIBE [channel ...]`. Empty `channels` removes every channel
332    /// subscription this handle holds (matching the Redis wire shape:
333    /// individual ack frames for each channel that was actually removed,
334    /// or a single `Unsubscribe { channel: None }` if none were held).
335    pub fn unsubscribe(&mut self, channels: &[&[u8]]) {
336        if channels.is_empty() {
337            self.drain_channel_subs();
338            return;
339        }
340        let s = self.sender_clone();
341        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
342        for ch in channels {
343            let owned = ch.to_vec();
344            let _ = g.bus.remove_channel(self.id, &owned);
345            self.channels.remove(&owned);
346            let count = g.bus.count_for(self.id);
347            let _ = s.send(PubsubFrame::Unsubscribe {
348                channel: Some(owned),
349                count,
350            });
351        }
352    }
353
354    /// `PUNSUBSCRIBE [pattern ...]`. Empty `patterns` removes every pattern.
355    pub fn punsubscribe(&mut self, patterns: &[&[u8]]) {
356        if patterns.is_empty() {
357            self.drain_pattern_subs();
358            return;
359        }
360        let s = self.sender_clone();
361        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
362        for pat in patterns {
363            let owned = pat.to_vec();
364            let _ = g.bus.remove_pattern(self.id, &owned);
365            self.patterns.remove(&owned);
366            let count = g.bus.count_for(self.id);
367            let _ = s.send(PubsubFrame::Punsubscribe {
368                pattern: Some(owned),
369                count,
370            });
371        }
372    }
373
374    fn drain_channel_subs(&mut self) {
375        let s = self.sender_clone();
376        let owned: Vec<Vec<u8>> = self.channels.drain().collect();
377        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
378        if owned.is_empty() {
379            let count = g.bus.count_for(self.id);
380            let _ = s.send(PubsubFrame::Unsubscribe { channel: None, count });
381            return;
382        }
383        for ch in owned {
384            let _ = g.bus.remove_channel(self.id, &ch);
385            let count = g.bus.count_for(self.id);
386            let _ = s.send(PubsubFrame::Unsubscribe {
387                channel: Some(ch),
388                count,
389            });
390        }
391    }
392
393    fn drain_pattern_subs(&mut self) {
394        let s = self.sender_clone();
395        let owned: Vec<Vec<u8>> = self.patterns.drain().collect();
396        let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
397        if owned.is_empty() {
398            let count = g.bus.count_for(self.id);
399            let _ = s.send(PubsubFrame::Punsubscribe { pattern: None, count });
400            return;
401        }
402        for p in owned {
403            let _ = g.bus.remove_pattern(self.id, &p);
404            let count = g.bus.count_for(self.id);
405            let _ = s.send(PubsubFrame::Punsubscribe {
406                pattern: Some(p),
407                count,
408            });
409        }
410    }
411
412    /// Block until one frame is queued. `Err(io::ErrorKind::UnexpectedEof)`
413    /// once the underlying bus tears down (last `Store` clone dropped).
414    ///
415    /// Acquires the receiver mutex for the entire blocking wait — other
416    /// `recv`/`recv_timeout` callers serialise behind this one. Concurrent
417    /// `try_recv` calls return `Ok(None)` while a `recv` is blocked (no
418    /// wait on the lock); see the type-level doc for the trade-off.
419    pub fn recv(&self) -> io::Result<PubsubFrame> {
420        let g = self.receiver.lock().unwrap_or_else(|p| p.into_inner());
421        g.recv()
422            .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
423    }
424
425    /// Bounded blocking recv. `Err(io::ErrorKind::TimedOut)` when `dur`
426    /// elapses; `Err(io::ErrorKind::UnexpectedEof)` when the bus is gone.
427    pub fn recv_timeout(&self, dur: Duration) -> io::Result<PubsubFrame> {
428        let g = self.receiver.lock().unwrap_or_else(|p| p.into_inner());
429        g.recv_timeout(dur).map_err(|e| match e {
430            RecvTimeoutError::Timeout => io::Error::from(io::ErrorKind::TimedOut),
431            RecvTimeoutError::Disconnected => {
432                io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed")
433            }
434        })
435    }
436
437    /// Non-blocking recv. `Ok(None)` if the queue is empty;
438    /// `Err(UnexpectedEof)` when the bus is gone.
439    ///
440    /// Uses `try_lock` so a concurrent blocking `recv` doesn't make
441    /// `try_recv` itself block — lock contention is reported as `Ok(None)`
442    /// (semantically: "no frame available right now"). Same shape callers
443    /// already handle for an empty queue.
444    pub fn try_recv(&self) -> io::Result<Option<PubsubFrame>> {
445        let g = match self.receiver.try_lock() {
446            Ok(g) => g,
447            Err(_) => return Ok(None),
448        };
449        match g.try_recv() {
450            Ok(f) => Ok(Some(f)),
451            Err(TryRecvError::Empty) => Ok(None),
452            Err(TryRecvError::Disconnected) => {
453                Err(io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
454            }
455        }
456    }
457}
458
459impl std::fmt::Debug for Subscription {
460    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
461        f.debug_struct("Subscription")
462            .field("id", &self.id)
463            .field("channels", &self.channels.len())
464            .field("patterns", &self.patterns.len())
465            .finish_non_exhaustive()
466    }
467}
468
469impl Drop for Subscription {
470    fn drop(&mut self) {
471        // Best-effort cleanup. If the underlying Inner is poisoned we still
472        // remove our entries; the AtomicBool / send stuff doesn't care.
473        if let Ok(mut g) = self.inner.lock() {
474            g.bus.remove_all_for(self.id);
475        } else if let Ok(mut g) = self.inner.clear_poison_and_lock() {
476            // Mutex::clear_poison + reacquire is stable since Rust 1.77; we
477            // pin rust-version=1.95 so this is available. The `else` branch
478            // above is unreachable in practice given we always recover from
479            // poison ourselves; left here so the cleanup is total.
480            g.bus.remove_all_for(self.id);
481        }
482    }
483}
484
485/// Tiny helper trait so `Drop` can recover from poison without
486/// pulling in the explicit `poison.into_inner()` dance. Local to the
487/// module; not part of the public API.
488trait LockExt<'a, T> {
489    fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>>;
490}
491
492impl<'a, T> LockExt<'a, T> for Mutex<T> {
493    fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>> {
494        self.clear_poison();
495        self.lock()
496    }
497}
498
499#[cfg(test)]
500#[path = "pubsub_tests.rs"]
501mod tests;