use std::collections::{HashMap, HashSet};
use std::io;
use std::sync::mpsc::{Receiver, RecvTimeoutError, Sender, TryRecvError, channel};
use std::sync::{Arc, Mutex};
use std::time::Duration;
use kevy_store::glob_match;
use crate::store::Inner;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PubsubFrame {
Subscribe {
channel: Vec<u8>,
count: usize,
},
Psubscribe {
pattern: Vec<u8>,
count: usize,
},
Unsubscribe {
channel: Option<Vec<u8>>,
count: usize,
},
Punsubscribe {
pattern: Option<Vec<u8>>,
count: usize,
},
Message {
channel: Vec<u8>,
payload: Vec<u8>,
},
Pmessage {
pattern: Vec<u8>,
channel: Vec<u8>,
payload: Vec<u8>,
},
}
struct BusEntry {
id: u64,
sender: Sender<PubsubFrame>,
}
pub(crate) struct PubsubBus {
next_id: u64,
channels: HashMap<Vec<u8>, Vec<BusEntry>>,
patterns: Vec<(Vec<u8>, BusEntry)>,
}
impl PubsubBus {
pub(crate) fn new() -> Self {
Self {
next_id: 1,
channels: HashMap::new(),
patterns: Vec::new(),
}
}
fn alloc_id(&mut self) -> u64 {
let id = self.next_id;
self.next_id = id.wrapping_add(1).max(1);
id
}
fn count_for(&self, id: u64) -> usize {
let chans = self
.channels
.values()
.filter(|v| v.iter().any(|e| e.id == id))
.count();
let pats = self.patterns.iter().filter(|(_, e)| e.id == id).count();
chans + pats
}
pub(crate) fn collect_delivery(
&self,
channel: &[u8],
payload: &[u8],
) -> Vec<(PubsubFrame, Sender<PubsubFrame>)> {
let mut plans = Vec::new();
if let Some(subs) = self.channels.get(channel) {
for e in subs {
plans.push((
PubsubFrame::Message {
channel: channel.to_vec(),
payload: payload.to_vec(),
},
e.sender.clone(),
));
}
}
for (pat, e) in &self.patterns {
if glob_match(pat, channel) {
plans.push((
PubsubFrame::Pmessage {
pattern: pat.clone(),
channel: channel.to_vec(),
payload: payload.to_vec(),
},
e.sender.clone(),
));
}
}
plans
}
fn add_channel(&mut self, id: u64, sender: &Sender<PubsubFrame>, channel: Vec<u8>) -> bool {
let subs = self.channels.entry(channel).or_default();
if subs.iter().any(|e| e.id == id) {
return false;
}
subs.push(BusEntry {
id,
sender: sender.clone(),
});
true
}
fn add_pattern(&mut self, id: u64, sender: &Sender<PubsubFrame>, pattern: Vec<u8>) -> bool {
if self
.patterns
.iter()
.any(|(p, e)| p == &pattern && e.id == id)
{
return false;
}
self.patterns.push((
pattern,
BusEntry {
id,
sender: sender.clone(),
},
));
true
}
fn remove_channel(&mut self, id: u64, channel: &[u8]) -> bool {
if let Some(subs) = self.channels.get_mut(channel) {
let before = subs.len();
subs.retain(|e| e.id != id);
let removed = subs.len() < before;
if subs.is_empty() {
self.channels.remove(channel);
}
removed
} else {
false
}
}
fn remove_pattern(&mut self, id: u64, pattern: &[u8]) -> bool {
let before = self.patterns.len();
self.patterns.retain(|(p, e)| !(p == pattern && e.id == id));
self.patterns.len() < before
}
fn remove_all_for(&mut self, id: u64) -> (Vec<Vec<u8>>, Vec<Vec<u8>>) {
let mut chans = Vec::new();
let mut pats = Vec::new();
self.channels.retain(|name, subs| {
let had = subs.iter().any(|e| e.id == id);
if had {
chans.push(name.clone());
}
subs.retain(|e| e.id != id);
!subs.is_empty()
});
self.patterns.retain(|(p, e)| {
if e.id == id {
pats.push(p.clone());
false
} else {
true
}
});
(chans, pats)
}
}
#[allow(missing_debug_implementations)]
pub struct Subscription {
inner: Arc<Mutex<Inner>>,
_guard: Arc<crate::store::DropGuard>,
receiver: Receiver<PubsubFrame>,
sender: Sender<PubsubFrame>,
id: u64,
channels: HashSet<Vec<u8>>,
patterns: HashSet<Vec<u8>>,
}
impl Subscription {
pub(crate) fn new(inner: Arc<Mutex<Inner>>, guard: Arc<crate::store::DropGuard>) -> Self {
let (sender, receiver) = channel();
let id = inner
.lock()
.unwrap_or_else(|p| p.into_inner())
.bus
.alloc_id();
Self {
inner,
_guard: guard,
receiver,
sender,
id,
channels: HashSet::new(),
patterns: HashSet::new(),
}
}
pub fn subscribe(&mut self, channels: &[&[u8]]) {
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
for ch in channels {
let owned = ch.to_vec();
let added = g.bus.add_channel(self.id, &self.sender, owned.clone());
if added {
self.channels.insert(owned.clone());
}
let count = g.bus.count_for(self.id);
let _ = self.sender.send(PubsubFrame::Subscribe {
channel: owned,
count,
});
}
}
pub fn psubscribe(&mut self, patterns: &[&[u8]]) {
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
for pat in patterns {
let owned = pat.to_vec();
let added = g.bus.add_pattern(self.id, &self.sender, owned.clone());
if added {
self.patterns.insert(owned.clone());
}
let count = g.bus.count_for(self.id);
let _ = self.sender.send(PubsubFrame::Psubscribe {
pattern: owned,
count,
});
}
}
pub fn unsubscribe(&mut self, channels: &[&[u8]]) {
if channels.is_empty() {
self.drain_channel_subs();
return;
}
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
for ch in channels {
let owned = ch.to_vec();
let _ = g.bus.remove_channel(self.id, &owned);
self.channels.remove(&owned);
let count = g.bus.count_for(self.id);
let _ = self.sender.send(PubsubFrame::Unsubscribe {
channel: Some(owned),
count,
});
}
}
pub fn punsubscribe(&mut self, patterns: &[&[u8]]) {
if patterns.is_empty() {
self.drain_pattern_subs();
return;
}
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
for pat in patterns {
let owned = pat.to_vec();
let _ = g.bus.remove_pattern(self.id, &owned);
self.patterns.remove(&owned);
let count = g.bus.count_for(self.id);
let _ = self.sender.send(PubsubFrame::Punsubscribe {
pattern: Some(owned),
count,
});
}
}
fn drain_channel_subs(&mut self) {
let owned: Vec<Vec<u8>> = self.channels.drain().collect();
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
if owned.is_empty() {
let count = g.bus.count_for(self.id);
let _ = self
.sender
.send(PubsubFrame::Unsubscribe { channel: None, count });
return;
}
for ch in owned {
let _ = g.bus.remove_channel(self.id, &ch);
let count = g.bus.count_for(self.id);
let _ = self.sender.send(PubsubFrame::Unsubscribe {
channel: Some(ch),
count,
});
}
}
fn drain_pattern_subs(&mut self) {
let owned: Vec<Vec<u8>> = self.patterns.drain().collect();
let mut g = self.inner.lock().unwrap_or_else(|p| p.into_inner());
if owned.is_empty() {
let count = g.bus.count_for(self.id);
let _ = self
.sender
.send(PubsubFrame::Punsubscribe { pattern: None, count });
return;
}
for p in owned {
let _ = g.bus.remove_pattern(self.id, &p);
let count = g.bus.count_for(self.id);
let _ = self.sender.send(PubsubFrame::Punsubscribe {
pattern: Some(p),
count,
});
}
}
pub fn recv(&self) -> io::Result<PubsubFrame> {
self.receiver
.recv()
.map_err(|_| io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
}
pub fn recv_timeout(&self, dur: Duration) -> io::Result<PubsubFrame> {
self.receiver.recv_timeout(dur).map_err(|e| match e {
RecvTimeoutError::Timeout => io::Error::from(io::ErrorKind::TimedOut),
RecvTimeoutError::Disconnected => {
io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed")
}
})
}
pub fn try_recv(&self) -> io::Result<Option<PubsubFrame>> {
match self.receiver.try_recv() {
Ok(f) => Ok(Some(f)),
Err(TryRecvError::Empty) => Ok(None),
Err(TryRecvError::Disconnected) => {
Err(io::Error::new(io::ErrorKind::UnexpectedEof, "bus closed"))
}
}
}
}
impl std::fmt::Debug for Subscription {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Subscription")
.field("id", &self.id)
.field("channels", &self.channels.len())
.field("patterns", &self.patterns.len())
.finish_non_exhaustive()
}
}
impl Drop for Subscription {
fn drop(&mut self) {
if let Ok(mut g) = self.inner.lock() {
g.bus.remove_all_for(self.id);
} else if let Ok(mut g) = self.inner.clear_poison_and_lock() {
g.bus.remove_all_for(self.id);
}
}
}
trait LockExt<'a, T> {
fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>>;
}
impl<'a, T> LockExt<'a, T> for Mutex<T> {
fn clear_poison_and_lock(&'a self) -> std::sync::LockResult<std::sync::MutexGuard<'a, T>> {
self.clear_poison();
self.lock()
}
}
#[cfg(test)]
#[path = "pubsub_tests.rs"]
mod tests;