1use std::collections::{HashMap, HashSet};
16use std::io;
17use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, TryRecvError, channel};
18use std::sync::{Arc, Mutex};
19use std::time::Duration;
20
21use kevy_store::glob_match;
22
23use crate::store::Inner;
24
25#[derive(Debug, Clone, PartialEq, Eq)]
27pub enum PubsubFrame {
28 Subscribe {
30 channel: Vec<u8>,
32 count: usize,
34 },
35 Psubscribe {
37 pattern: Vec<u8>,
39 count: usize,
41 },
42 Unsubscribe {
44 channel: Option<Vec<u8>>,
46 count: usize,
48 },
49 Punsubscribe {
51 pattern: Option<Vec<u8>>,
53 count: usize,
55 },
56 Message {
58 channel: Vec<u8>,
60 payload: Vec<u8>,
62 },
63 Pmessage {
66 pattern: Vec<u8>,
68 channel: Vec<u8>,
70 payload: Vec<u8>,
72 },
73}
74
75struct BusEntry {
77 id: u64,
78 sender: Sender<PubsubFrame>,
79}
80
81pub(crate) struct PubsubBus {
83 next_id: u64,
84 channels: HashMap<Vec<u8>, Vec<BusEntry>>,
85 patterns: Vec<(Vec<u8>, BusEntry)>,
86}
87
88impl PubsubBus {
89 pub(crate) fn new() -> Self {
90 Self {
91 next_id: 1,
92 channels: HashMap::new(),
93 patterns: Vec::new(),
94 }
95 }
96
97 fn alloc_id(&mut self) -> u64 {
98 let id = self.next_id;
99 self.next_id = id.wrapping_add(1).max(1);
100 id
101 }
102
103 fn count_for(&self, id: u64) -> usize {
105 let chans = self
106 .channels
107 .values()
108 .filter(|v| v.iter().any(|e| e.id == id))
109 .count();
110 let pats = self.patterns.iter().filter(|(_, e)| e.id == id).count();
111 chans + pats
112 }
113
114 pub(crate) fn collect_delivery(
118 &self,
119 channel: &[u8],
120 payload: &[u8],
121 ) -> Vec<(PubsubFrame, Sender<PubsubFrame>)> {
122 let mut plans = Vec::new();
123 if let Some(subs) = self.channels.get(channel) {
124 for e in subs {
125 plans.push((
126 PubsubFrame::Message {
127 channel: channel.to_vec(),
128 payload: payload.to_vec(),
129 },
130 e.sender.clone(),
131 ));
132 }
133 }
134 for (pat, e) in &self.patterns {
135 if glob_match(pat, channel) {
136 plans.push((
137 PubsubFrame::Pmessage {
138 pattern: pat.clone(),
139 channel: channel.to_vec(),
140 payload: payload.to_vec(),
141 },
142 e.sender.clone(),
143 ));
144 }
145 }
146 plans
147 }
148
149 fn add_channel(&mut self, id: u64, sender: &Sender<PubsubFrame>, channel: Vec<u8>) -> bool {
150 let subs = self.channels.entry(channel).or_default();
151 if subs.iter().any(|e| e.id == id) {
152 return false;
153 }
154 subs.push(BusEntry {
155 id,
156 sender: sender.clone(),
157 });
158 true
159 }
160
161 fn add_pattern(&mut self, id: u64, sender: &Sender<PubsubFrame>, pattern: Vec<u8>) -> bool {
162 if self
163 .patterns
164 .iter()
165 .any(|(p, e)| p == &pattern && e.id == id)
166 {
167 return false;
168 }
169 self.patterns.push((
170 pattern,
171 BusEntry {
172 id,
173 sender: sender.clone(),
174 },
175 ));
176 true
177 }
178
179 fn remove_channel(&mut self, id: u64, channel: &[u8]) -> bool {
180 if let Some(subs) = self.channels.get_mut(channel) {
181 let before = subs.len();
182 subs.retain(|e| e.id != id);
183 let removed = subs.len() < before;
184 if subs.is_empty() {
185 self.channels.remove(channel);
186 }
187 removed
188 } else {
189 false
190 }
191 }
192
193 fn remove_pattern(&mut self, id: u64, pattern: &[u8]) -> bool {
194 let before = self.patterns.len();
195 self.patterns.retain(|(p, e)| !(p == pattern && e.id == id));
196 self.patterns.len() < before
197 }
198
199 fn remove_all_for(&mut self, id: u64) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
200 let mut chans = Vec::new();
201 let mut pats = Vec::new();
202 self.channels.retain(|name, subs| {
203 let had = subs.iter().any(|e| e.id == id);
204 if had {
205 chans.push(name.clone());
206 }
207 subs.retain(|e| e.id != id);
208 !subs.is_empty()
209 });
210 self.patterns.retain(|(p, e)| {
211 if e.id == id {
212 pats.push(p.clone());
213 false
214 } else {
215 true
216 }
217 });
218 (chans, pats)
219 }
220}
221
222#[allow(missing_debug_implementations)]
243pub struct Subscription {
244 inner: Arc<Mutex<Inner>>,
245 _guard: Arc<crate::store::DropGuard>,
249 receiver: Mutex<Receiver<PubsubFrame>>,
255 sender: Mutex<Sender<PubsubFrame>>,
259 id: u64,
260 channels: HashSet<Vec<u8>>,
261 patterns: HashSet<Vec<u8>>,
262}
263
264impl Subscription {
265 pub(crate) fn new(inner: Arc<Mutex<Inner>>, guard: Arc<crate::store::DropGuard>) -> Self {
266 let (sender, receiver) = channel();
267 let id = inner
268 .lock()
269 .unwrap_or_else(|p| p.into_inner())
270 .bus
271 .alloc_id();
272 Self {
273 inner,
274 _guard: guard,
275 receiver: Mutex::new(receiver),
276 sender: Mutex::new(sender),
277 id,
278 channels: HashSet::new(),
279 patterns: HashSet::new(),
280 }
281 }
282
283 fn sender_clone(&self) -> Sender<PubsubFrame> {
287 self.sender
288 .lock()
289 .unwrap_or_else(|p| p.into_inner())
290 .clone()
291 }
292
293 pub fn subscribe(&mut self, channels: &[&[u8]]) {
296 let s = self.sender_clone();
297 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
298 for ch in channels {
299 let owned = ch.to_vec();
300 let added = g.bus.add_channel(self.id, &s, owned.clone());
301 if added {
302 self.channels.insert(owned.clone());
303 }
304 let count = g.bus.count_for(self.id);
305 let _ = s.send(PubsubFrame::Subscribe {
306 channel: owned,
307 count,
308 });
309 }
310 }
311
312 pub fn psubscribe(&mut self, patterns: &[&[u8]]) {
315 let s = self.sender_clone();
316 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
317 for pat in patterns {
318 let owned = pat.to_vec();
319 let added = g.bus.add_pattern(self.id, &s, owned.clone());
320 if added {
321 self.patterns.insert(owned.clone());
322 }
323 let count = g.bus.count_for(self.id);
324 let _ = s.send(PubsubFrame::Psubscribe {
325 pattern: owned,
326 count,
327 });
328 }
329 }
330
331 pub fn unsubscribe(&mut self, channels: &[&[u8]]) {
336 if channels.is_empty() {
337 self.drain_channel_subs();
338 return;
339 }
340 let s = self.sender_clone();
341 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
342 for ch in channels {
343 let owned = ch.to_vec();
344 let _ = g.bus.remove_channel(self.id, &owned);
345 self.channels.remove(&owned);
346 let count = g.bus.count_for(self.id);
347 let _ = s.send(PubsubFrame::Unsubscribe {
348 channel: Some(owned),
349 count,
350 });
351 }
352 }
353
354 pub fn punsubscribe(&mut self, patterns: &[&[u8]]) {
356 if patterns.is_empty() {
357 self.drain_pattern_subs();
358 return;
359 }
360 let s = self.sender_clone();
361 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
362 for pat in patterns {
363 let owned = pat.to_vec();
364 let _ = g.bus.remove_pattern(self.id, &owned);
365 self.patterns.remove(&owned);
366 let count = g.bus.count_for(self.id);
367 let _ = s.send(PubsubFrame::Punsubscribe {
368 pattern: Some(owned),
369 count,
370 });
371 }
372 }
373
374 fn drain_channel_subs(&mut self) {
375 let s = self.sender_clone();
376 let owned: Vec<Vec<u8>> = self.channels.drain().collect();
377 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
378 if owned.is_empty() {
379 let count = g.bus.count_for(self.id);
380 let _ = s.send(PubsubFrame::Unsubscribe { channel: None, count });
381 return;
382 }
383 for ch in owned {
384 let _ = g.bus.remove_channel(self.id, &ch);
385 let count = g.bus.count_for(self.id);
386 let _ = s.send(PubsubFrame::Unsubscribe {
387 channel: Some(ch),
388 count,
389 });
390 }
391 }
392
393 fn drain_pattern_subs(&mut self) {
394 let s = self.sender_clone();
395 let owned: Vec<Vec<u8>> = self.patterns.drain().collect();
396 let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
397 if owned.is_empty() {
398 let count = g.bus.count_for(self.id);
399 let _ = s.send(PubsubFrame::Punsubscribe { pattern: None, count });
400 return;
401 }
402 for p in owned {
403 let _ = g.bus.remove_pattern(self.id, &p);
404 let count = g.bus.count_for(self.id);
405 let _ = s.send(PubsubFrame::Punsubscribe {
406 pattern: Some(p),
407 count,
408 });
409 }
410 }
411
412 pub fn recv(&self) -> io::Result<PubsubFrame> {
420 let g = self.receiver.lock().unwrap_or_else(|p| p.into_inner());
421 g.recv()
422 .map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
423 }
424
425 pub fn recv_timeout(&self, dur: Duration) -> io::Result<PubsubFrame> {
428 let g = self.receiver.lock().unwrap_or_else(|p| p.into_inner());
429 g.recv_timeout(dur).map_err(|e| match e {
430 RecvTimeoutError::Timeout => io::Error::from(io::ErrorKind::TimedOut),
431 RecvTimeoutError::Disconnected => {
432 io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed")
433 }
434 })
435 }
436
437 pub fn try_recv(&self) -> io::Result<Option<PubsubFrame>> {
445 let g = match self.receiver.try_lock() {
446 Ok(g) => g,
447 Err(_) => return Ok(None),
448 };
449 match g.try_recv() {
450 Ok(f) => Ok(Some(f)),
451 Err(TryRecvError::Empty) => Ok(None),
452 Err(TryRecvError::Disconnected) => {
453 Err(io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
454 }
455 }
456 }
457}
458
459impl std::fmt::Debug for Subscription {
460 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
461 f.debug_struct("Subscription")
462 .field("id", &self.id)
463 .field("channels", &self.channels.len())
464 .field("patterns", &self.patterns.len())
465 .finish_non_exhaustive()
466 }
467}
468
469impl Drop for Subscription {
470 fn drop(&mut self) {
471 if let Ok(mut g) = self.inner.lock() {
474 g.bus.remove_all_for(self.id);
475 } else if let Ok(mut g) = self.inner.clear_poison_and_lock() {
476 g.bus.remove_all_for(self.id);
481 }
482 }
483}
484
485trait LockExt<'a, T> {
489 fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>>;
490}
491
492impl<'a, T> LockExt<'a, T> for Mutex<T> {
493 fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>> {
494 self.clear_poison();
495 self.lock()
496 }
497}
498
499#[cfg(test)]
500#[path = "pubsub_tests.rs"]
501mod tests;