use std::collections::HashSet;
use std::io;
use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, TryRecvError, channel};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use crate::store::Inner;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PubsubFrame {
Subscribe {
channel: Vec<u8>,
count: usize,
},
Psubscribe {
pattern: Vec<u8>,
count: usize,
},
Unsubscribe {
channel: Option<Vec<u8>>,
count: usize,
},
Punsubscribe {
pattern: Option<Vec<u8>>,
count: usize,
},
Message {
channel: Vec<u8>,
payload: Vec<u8>,
},
Pmessage {
pattern: Vec<u8>,
channel: Vec<u8>,
payload: Vec<u8>,
},
}
pub(crate) use crate::pubsub_bus::PubsubBus;
#[allow(missing_debug_implementations)]
pub struct Subscription {
inner: Arc<Mutex<Inner>>,
_guard: Arc<crate::store::DropGuard>,
receiver: Mutex<Receiver<PubsubFrame>>,
sender: Mutex<Sender<PubsubFrame>>,
id: u64,
channels: HashSet<Vec<u8>>,
patterns: HashSet<Vec<u8>>,
}
impl Subscription {
pub(crate) fn new(inner: Arc<Mutex<Inner>>, guard: Arc<crate::store::DropGuard>) -> Self {
let (sender, receiver) = channel();
let id = inner
.lock()
.unwrap_or_else(|p| p.into_inner())
.bus
.alloc_id();
Self {
inner,
_guard: guard,
receiver: Mutex::new(receiver),
sender: Mutex::new(sender),
id,
channels: HashSet::new(),
patterns: HashSet::new(),
}
}
fn sender_clone(&self) -> Sender<PubsubFrame> {
self.sender
.lock()
.unwrap_or_else(|p| p.into_inner())
.clone()
}
pub fn subscribe(&mut self, channels: &[&[u8]]) {
let s = self.sender_clone();
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
for ch in channels {
let owned = ch.to_vec();
let added = g.bus.add_channel(self.id, &s, owned.clone());
if added {
self.channels.insert(owned.clone());
}
let count = g.bus.count_for(self.id);
let _ = s.send(PubsubFrame::Subscribe {
channel: owned,
count,
});
}
}
pub fn psubscribe(&mut self, patterns: &[&[u8]]) {
let s = self.sender_clone();
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
for pat in patterns {
let owned = pat.to_vec();
let added = g.bus.add_pattern(self.id, &s, owned.clone());
if added {
self.patterns.insert(owned.clone());
}
let count = g.bus.count_for(self.id);
let _ = s.send(PubsubFrame::Psubscribe {
pattern: owned,
count,
});
}
}
pub fn unsubscribe(&mut self, channels: &[&[u8]]) {
if channels.is_empty() {
self.drain_channel_subs();
return;
}
let s = self.sender_clone();
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
for ch in channels {
let owned = ch.to_vec();
let _ = g.bus.remove_channel(self.id, &owned);
self.channels.remove(&owned);
let count = g.bus.count_for(self.id);
let _ = s.send(PubsubFrame::Unsubscribe {
channel: Some(owned),
count,
});
}
}
pub fn punsubscribe(&mut self, patterns: &[&[u8]]) {
if patterns.is_empty() {
self.drain_pattern_subs();
return;
}
let s = self.sender_clone();
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
for pat in patterns {
let owned = pat.to_vec();
let _ = g.bus.remove_pattern(self.id, &owned);
self.patterns.remove(&owned);
let count = g.bus.count_for(self.id);
let _ = s.send(PubsubFrame::Punsubscribe {
pattern: Some(owned),
count,
});
}
}
fn drain_channel_subs(&mut self) {
let s = self.sender_clone();
let owned: Vec<Vec<u8>> = self.channels.drain().collect();
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
if owned.is_empty() {
let count = g.bus.count_for(self.id);
let _ = s.send(PubsubFrame::Unsubscribe { channel: None, count });
return;
}
for ch in owned {
let _ = g.bus.remove_channel(self.id, &ch);
let count = g.bus.count_for(self.id);
let _ = s.send(PubsubFrame::Unsubscribe {
channel: Some(ch),
count,
});
}
}
fn drain_pattern_subs(&mut self) {
let s = self.sender_clone();
let owned: Vec<Vec<u8>> = self.patterns.drain().collect();
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
if owned.is_empty() {
let count = g.bus.count_for(self.id);
let _ = s.send(PubsubFrame::Punsubscribe { pattern: None, count });
return;
}
for p in owned {
let _ = g.bus.remove_pattern(self.id, &p);
let count = g.bus.count_for(self.id);
let _ = s.send(PubsubFrame::Punsubscribe {
pattern: Some(p),
count,
});
}
}
pub fn recv(&self) -> io::Result<PubsubFrame> {
let g = self.receiver.lock().unwrap_or_else(|p| p.into_inner());
g.recv()
.map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
}
pub fn recv_timeout(&self, dur: Duration) -> io::Result<PubsubFrame> {
let g = self.receiver.lock().unwrap_or_else(|p| p.into_inner());
g.recv_timeout(dur).map_err(|e| match e {
RecvTimeoutError::Timeout => io::Error::from(io::ErrorKind::TimedOut),
RecvTimeoutError::Disconnected => {
io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed")
}
})
}
pub fn try_recv(&self) -> io::Result<Option<PubsubFrame>> {
let g = match self.receiver.try_lock() {
Ok(g) => g,
Err(_) => return Ok(None),
};
match g.try_recv() {
Ok(f) => Ok(Some(f)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
}
}
}
}
impl std::fmt::Debug for Subscription {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Subscription")
.field("id", &self.id)
.field("channels", &self.channels.len())
.field("patterns", &self.patterns.len())
.finish_non_exhaustive()
}
}
impl Drop for Subscription {
fn drop(&mut self) {
if let Ok(mut g) = self.inner.lock() {
g.bus.remove_all_for(self.id);
} else if let Ok(mut g) = self.inner.clear_poison_and_lock() {
g.bus.remove_all_for(self.id);
}
}
}
trait LockExt<'a, T> {
fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>>;
}
impl<'a, T> LockExt<'a, T> for Mutex<T> {
fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>> {
self.clear_poison();
self.lock()
}
}
#[cfg(test)]
#[path = "pubsub_tests.rs"]
mod tests;