1use std::collections::{HashMap, HashSet};
16use std::io;
17use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, TryRecvError, channel};
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use kevy_store::glob_match;
22
23use crate::store::Inner;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum PubsubFrame {
28 Subscribe {
30 channel: Vec<u8>,
32 count: usize,
34 },
35 Psubscribe {
37 pattern: Vec<u8>,
39 count: usize,
41 },
42 Unsubscribe {
44 channel: Option<Vec<u8>>,
46 count: usize,
48 },
49 Punsubscribe {
51 pattern: Option<Vec<u8>>,
53 count: usize,
55 },
56 Message {
58 channel: Vec<u8>,
60 payload: Vec<u8>,
62 },
63 Pmessage {
66 pattern: Vec<u8>,
68 channel: Vec<u8>,
70 payload: Vec<u8>,
72 },
73}
74
75struct BusEntry {
77 id: u64,
78 sender: Sender<PubsubFrame>,
79}
80
81pub(crate) struct PubsubBus {
83 next_id: u64,
84 channels: HashMap<Vec<u8>, Vec<BusEntry>>,
85 patterns: Vec<(Vec<u8>, BusEntry)>,
86}
87
88impl PubsubBus {
89 pub(crate) fn new() -> Self {
90 Self {
91 next_id: 1,
92 channels: HashMap::new(),
93 patterns: Vec::new(),
94 }
95 }
96
97 fn alloc_id(&mut self) -> u64 {
98 let id = self.next_id;
99 self.next_id = id.wrapping_add(1).max(1);
100 id
101 }
102
103 fn count_for(&self, id: u64) -> usize {
105 let chans = self
106 .channels
107 .values()
108 .filter(|v| v.iter().any(|e| e.id == id))
109 .count();
110 let pats = self.patterns.iter().filter(|(_, e)| e.id == id).count();
111 chans + pats
112 }
113
114 pub(crate) fn collect_delivery(
118 &self,
119 channel: &[u8],
120 payload: &[u8],
121 ) -> Vec<(PubsubFrame, Sender<PubsubFrame>)> {
122 let mut plans = Vec::new();
123 if let Some(subs) = self.channels.get(channel) {
124 for e in subs {
125 plans.push((
126 PubsubFrame::Message {
127 channel: channel.to_vec(),
128 payload: payload.to_vec(),
129 },
130 e.sender.clone(),
131 ));
132 }
133 }
134 for (pat, e) in &self.patterns {
135 if glob_match(pat, channel) {
136 plans.push((
137 PubsubFrame::Pmessage {
138 pattern: pat.clone(),
139 channel: channel.to_vec(),
140 payload: payload.to_vec(),
141 },
142 e.sender.clone(),
143 ));
144 }
145 }
146 plans
147 }
148
149 fn add_channel(&mut self, id: u64, sender: &Sender<PubsubFrame>, channel: Vec<u8>) -> bool {
150 let subs = self.channels.entry(channel).or_default();
151 if subs.iter().any(|e| e.id == id) {
152 return false;
153 }
154 subs.push(BusEntry {
155 id,
156 sender: sender.clone(),
157 });
158 true
159 }
160
161 fn add_pattern(&mut self, id: u64, sender: &Sender<PubsubFrame>, pattern: Vec<u8>) -> bool {
162 if self
163 .patterns
164 .iter()
165 .any(|(p, e)| p == &pattern && e.id == id)
166 {
167 return false;
168 }
169 self.patterns.push((
170 pattern,
171 BusEntry {
172 id,
173 sender: sender.clone(),
174 },
175 ));
176 true
177 }
178
179 fn remove_channel(&mut self, id: u64, channel: &[u8]) -> bool {
180 if let Some(subs) = self.channels.get_mut(channel) {
181 let before = subs.len();
182 subs.retain(|e| e.id != id);
183 let removed = subs.len() < before;
184 if subs.is_empty() {
185 self.channels.remove(channel);
186 }
187 removed
188 } else {
189 false
190 }
191 }
192
193 fn remove_pattern(&mut self, id: u64, pattern: &[u8]) -> bool {
194 let before = self.patterns.len();
195 self.patterns.retain(|(p, e)| !(p == pattern && e.id == id));
196 self.patterns.len() < before
197 }
198
199 fn remove_all_for(&mut self, id: u64) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
200 let mut chans = Vec::new();
201 let mut pats = Vec::new();
202 self.channels.retain(|name, subs| {
203 let had = subs.iter().any(|e| e.id == id);
204 if had {
205 chans.push(name.clone());
206 }
207 subs.retain(|e| e.id != id);
208 !subs.is_empty()
209 });
210 self.patterns.retain(|(p, e)| {
211 if e.id == id {
212 pats.push(p.clone());
213 false
214 } else {
215 true
216 }
217 });
218 (chans, pats)
219 }
220}
221
222#[allow(missing_debug_implementations)]
229pub struct Subscription {
230 inner: Arc<Mutex<Inner>>,
231 _guard: Arc<crate::store::DropGuard>,
235 receiver: Receiver<PubsubFrame>,
236 sender: Sender<PubsubFrame>,
237 id: u64,
238 channels: HashSet<Vec<u8>>,
239 patterns: HashSet<Vec<u8>>,
240}
241
242impl Subscription {
243 pub(crate) fn new(inner: Arc<Mutex<Inner>>, guard: Arc<crate::store::DropGuard>) -> Self {
244 let (sender, receiver) = channel();
245 let id = inner
246 .lock()
247 .unwrap_or_else(|p| p.into_inner())
248 .bus
249 .alloc_id();
250 Self {
251 inner,
252 _guard: guard,
253 receiver,
254 sender,
255 id,
256 channels: HashSet::new(),
257 patterns: HashSet::new(),
258 }
259 }
260
261 pub fn subscribe(&mut self, channels: &[&[u8]]) {
264 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
265 for ch in channels {
266 let owned = ch.to_vec();
267 let added = g.bus.add_channel(self.id, &self.sender, owned.clone());
268 if added {
269 self.channels.insert(owned.clone());
270 }
271 let count = g.bus.count_for(self.id);
272 let _ = self.sender.send(PubsubFrame::Subscribe {
273 channel: owned,
274 count,
275 });
276 }
277 }
278
279 pub fn psubscribe(&mut self, patterns: &[&[u8]]) {
282 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
283 for pat in patterns {
284 let owned = pat.to_vec();
285 let added = g.bus.add_pattern(self.id, &self.sender, owned.clone());
286 if added {
287 self.patterns.insert(owned.clone());
288 }
289 let count = g.bus.count_for(self.id);
290 let _ = self.sender.send(PubsubFrame::Psubscribe {
291 pattern: owned,
292 count,
293 });
294 }
295 }
296
297 pub fn unsubscribe(&mut self, channels: &[&[u8]]) {
302 if channels.is_empty() {
303 self.drain_channel_subs();
304 return;
305 }
306 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
307 for ch in channels {
308 let owned = ch.to_vec();
309 let _ = g.bus.remove_channel(self.id, &owned);
310 self.channels.remove(&owned);
311 let count = g.bus.count_for(self.id);
312 let _ = self.sender.send(PubsubFrame::Unsubscribe {
313 channel: Some(owned),
314 count,
315 });
316 }
317 }
318
319 pub fn punsubscribe(&mut self, patterns: &[&[u8]]) {
321 if patterns.is_empty() {
322 self.drain_pattern_subs();
323 return;
324 }
325 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
326 for pat in patterns {
327 let owned = pat.to_vec();
328 let _ = g.bus.remove_pattern(self.id, &owned);
329 self.patterns.remove(&owned);
330 let count = g.bus.count_for(self.id);
331 let _ = self.sender.send(PubsubFrame::Punsubscribe {
332 pattern: Some(owned),
333 count,
334 });
335 }
336 }
337
338 fn drain_channel_subs(&mut self) {
339 let owned: Vec<Vec<u8>> = self.channels.drain().collect();
340 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
341 if owned.is_empty() {
342 let count = g.bus.count_for(self.id);
343 let _ = self
344 .sender
345 .send(PubsubFrame::Unsubscribe { channel: None, count });
346 return;
347 }
348 for ch in owned {
349 let _ = g.bus.remove_channel(self.id, &ch);
350 let count = g.bus.count_for(self.id);
351 let _ = self.sender.send(PubsubFrame::Unsubscribe {
352 channel: Some(ch),
353 count,
354 });
355 }
356 }
357
358 fn drain_pattern_subs(&mut self) {
359 let owned: Vec<Vec<u8>> = self.patterns.drain().collect();
360 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
361 if owned.is_empty() {
362 let count = g.bus.count_for(self.id);
363 let _ = self
364 .sender
365 .send(PubsubFrame::Punsubscribe { pattern: None, count });
366 return;
367 }
368 for p in owned {
369 let _ = g.bus.remove_pattern(self.id, &p);
370 let count = g.bus.count_for(self.id);
371 let _ = self.sender.send(PubsubFrame::Punsubscribe {
372 pattern: Some(p),
373 count,
374 });
375 }
376 }
377
378 pub fn recv(&self) -> io::Result<PubsubFrame> {
381 self.receiver
382 .recv()
383 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
384 }
385
386 pub fn recv_timeout(&self, dur: Duration) -> io::Result<PubsubFrame> {
389 self.receiver.recv_timeout(dur).map_err(|e| match e {
390 RecvTimeoutError::Timeout => io::Error::from(io::ErrorKind::TimedOut),
391 RecvTimeoutError::Disconnected => {
392 io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed")
393 }
394 })
395 }
396
397 pub fn try_recv(&self) -> io::Result<Option<PubsubFrame>> {
400 match self.receiver.try_recv() {
401 Ok(f) => Ok(Some(f)),
402 Err(TryRecvError::Empty) => Ok(None),
403 Err(TryRecvError::Disconnected) => {
404 Err(io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
405 }
406 }
407 }
408}
409
410impl std::fmt::Debug for Subscription {
411 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412 f.debug_struct("Subscription")
413 .field("id", &self.id)
414 .field("channels", &self.channels.len())
415 .field("patterns", &self.patterns.len())
416 .finish_non_exhaustive()
417 }
418}
419
420impl Drop for Subscription {
421 fn drop(&mut self) {
422 if let Ok(mut g) = self.inner.lock() {
425 g.bus.remove_all_for(self.id);
426 } else if let Ok(mut g) = self.inner.clear_poison_and_lock() {
427 g.bus.remove_all_for(self.id);
432 }
433 }
434}
435
436trait LockExt<'a, T> {
440 fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>>;
441}
442
443impl<'a, T> LockExt<'a, T> for Mutex<T> {
444 fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>> {
445 self.clear_poison();
446 self.lock()
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 use super::*;
453 use crate::{Config, Store};
454
455 fn store() -> Store {
456 Store::open(Config::default().with_ttl_reaper_manual()).unwrap()
457 }
458
459 #[test]
460 fn publish_to_no_subscribers_returns_zero() {
461 let s = store();
462 assert_eq!(s.publish(b"chan", b"hi"), 0);
463 }
464
465 #[test]
466 fn subscribe_ack_then_message_delivered() {
467 let s = store();
468 let sub = s.subscribe(&[b"news"]);
469 assert_eq!(
471 sub.recv().unwrap(),
472 PubsubFrame::Subscribe {
473 channel: b"news".to_vec(),
474 count: 1,
475 }
476 );
477 assert_eq!(s.publish(b"news", b"hello"), 1);
479 assert_eq!(
480 sub.recv().unwrap(),
481 PubsubFrame::Message {
482 channel: b"news".to_vec(),
483 payload: b"hello".to_vec(),
484 }
485 );
486 }
487
488 #[test]
489 fn store_clone_publishes_reach_other_clones_subscribers() {
490 let s1 = store();
491 let s2 = s1.clone();
492 let sub = s1.subscribe(&[b"x"]);
493 let _ = sub.recv().unwrap(); assert_eq!(s2.publish(b"x", b"v"), 1);
495 assert_eq!(
496 sub.recv().unwrap(),
497 PubsubFrame::Message {
498 channel: b"x".to_vec(),
499 payload: b"v".to_vec(),
500 }
501 );
502 }
503
504 #[test]
505 fn psubscribe_glob_match_delivers_pmessage() {
506 let s = store();
507 let sub = s.psubscribe(&[b"news.*"]);
508 let _ = sub.recv().unwrap(); assert_eq!(s.publish(b"news.tech", b"breaking"), 1);
510 assert_eq!(
511 sub.recv().unwrap(),
512 PubsubFrame::Pmessage {
513 pattern: b"news.*".to_vec(),
514 channel: b"news.tech".to_vec(),
515 payload: b"breaking".to_vec(),
516 }
517 );
518 assert_eq!(s.publish(b"weather", b"sunny"), 0);
520 assert!(sub.try_recv().unwrap().is_none());
521 }
522
523 #[test]
524 fn duplicate_subscribe_does_not_duplicate_delivery() {
525 let s = store();
526 let mut sub = s.subscribe(&[b"x"]);
527 sub.subscribe(&[b"x"]); let a1 = sub.recv().unwrap();
530 let a2 = sub.recv().unwrap();
531 assert!(matches!(a1, PubsubFrame::Subscribe { count: 1, .. }));
532 assert!(matches!(a2, PubsubFrame::Subscribe { count: 1, .. }));
533 assert_eq!(s.publish(b"x", b"v"), 1);
535 let _ = sub.recv().unwrap();
536 assert!(sub.try_recv().unwrap().is_none());
537 }
538
539 #[test]
540 fn unsubscribe_removes_then_no_more_messages() {
541 let s = store();
542 let mut sub = s.subscribe(&[b"x"]);
543 let _ = sub.recv().unwrap();
544 sub.unsubscribe(&[b"x"]);
545 assert!(matches!(
547 sub.recv().unwrap(),
548 PubsubFrame::Unsubscribe {
549 channel: Some(_),
550 count: 0
551 }
552 ));
553 assert_eq!(s.publish(b"x", b"v"), 0);
555 }
556
557 #[test]
558 fn unsubscribe_all_with_empty_args_drains_every_channel() {
559 let s = store();
560 let mut sub = s.subscribe(&[b"a", b"b"]);
561 let _ = sub.recv().unwrap();
562 let _ = sub.recv().unwrap();
563 sub.unsubscribe(&[]);
564 for _ in 0..2 {
566 assert!(matches!(
567 sub.recv().unwrap(),
568 PubsubFrame::Unsubscribe {
569 channel: Some(_),
570 ..
571 }
572 ));
573 }
574 assert_eq!(s.publish(b"a", b"x"), 0);
576 assert_eq!(s.publish(b"b", b"x"), 0);
577 }
578
579 #[test]
580 fn unsubscribe_when_no_subs_held_emits_nil_channel_ack() {
581 let s = store();
582 let mut sub = s.subscribe(&[]); sub.unsubscribe(&[]);
584 assert!(matches!(
585 sub.recv().unwrap(),
586 PubsubFrame::Unsubscribe {
587 channel: None,
588 count: 0
589 }
590 ));
591 }
592
593 #[test]
594 fn drop_subscriber_unregisters() {
595 let s = store();
596 let sub = s.subscribe(&[b"x"]);
597 let _ = sub.recv().unwrap();
598 assert_eq!(s.publish(b"x", b"v"), 1);
599 let _ = sub.recv().unwrap();
600 drop(sub);
601 assert_eq!(s.publish(b"x", b"v"), 0);
602 }
603
604 #[test]
605 fn recv_timeout_returns_timeout_when_empty() {
606 let s = store();
607 let sub = s.subscribe(&[b"x"]);
608 let _ = sub.recv_timeout(Duration::from_millis(100)).unwrap();
610 let err = sub
611 .recv_timeout(Duration::from_millis(50))
612 .unwrap_err();
613 assert_eq!(err.kind(), io::ErrorKind::TimedOut);
614 }
615}