#![cfg(feature = "topic")]
use super::core::{SpmcTopicDispatcher, SubscriberList};
use super::mailbox;
use crate::error::{RecvError, SendError};
use crate::spmc::topic::async_impl::{AsyncTopicReceiver, AsyncTopicSender};
use crate::{CloseError, RecvErrorTimeout, TryRecvError};
use std::borrow::Borrow;
use std::collections::HashSet;
use std::hash::Hash;
use std::mem;
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Weak,
};
use std::time::Duration;
use papaya::Equivalent;
use parking_lot::Mutex;
#[derive(Debug)]
pub struct TopicSender<K, T>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
T: Send + Clone + 'static,
{
pub(crate) dispatcher: Arc<SpmcTopicDispatcher<K, T>>,
pub(crate) closed: AtomicBool,
}
impl<K, T> TopicSender<K, T>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
T: Send + Clone + 'static,
{
pub fn send(&self, topic: K, value: T) -> Result<(), SendError> {
if self.closed.load(Ordering::Relaxed) || self.is_closed() {
return Err(SendError::Closed);
}
let pinned_map = self.dispatcher.subscriptions.pin();
if let Some(list_arc) = pinned_map.get(&topic) {
let mailboxes_snapshot = list_arc.reader.enter();
let subscribers = mailboxes_snapshot.clone();
drop(mailboxes_snapshot);
for mailbox_weak in subscribers.iter() {
if let Some(mailbox_strong) = mailbox_weak.upgrade() {
mailbox_strong.deliver((topic.clone(), value.clone()));
}
}
}
Ok(())
}
pub fn is_closed(&self) -> bool {
self.dispatcher.receiver_count.load(Ordering::Relaxed) == 0
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
let pinned_map = self.dispatcher.subscriptions.pin();
for (_topic, list_arc) in pinned_map.iter() {
let subscribers_snapshot = list_arc.reader.enter();
for mailbox_weak in subscribers_snapshot.iter() {
if let Some(mailbox_strong) = mailbox_weak.upgrade() {
mailbox_strong.disconnect();
}
}
}
}
pub fn to_async(self) -> AsyncTopicSender<K, T> {
let dispatcher = unsafe { std::ptr::read(&self.dispatcher) };
let closed = unsafe { std::ptr::read(&self.closed) };
mem::forget(self);
AsyncTopicSender { dispatcher, closed }
}
}
impl<K, T> Clone for TopicSender<K, T>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
T: Send + Clone + 'static,
{
fn clone(&self) -> Self {
Self {
dispatcher: self.dispatcher.clone(),
closed: AtomicBool::new(false),
}
}
}
impl<K, T> Drop for TopicSender<K, T>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
T: Send + Clone + 'static,
{
fn drop(&mut self) {
let _ = self.close();
}
}
#[derive(Debug)]
pub struct TopicReceiver<K, T>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
T: Send + Clone + 'static,
{
pub(crate) dispatcher: Weak<SpmcTopicDispatcher<K, T>>,
pub(crate) consumer: mailbox::MailboxConsumer<(K, T)>,
pub(crate) producer_mailbox: Arc<mailbox::MailboxProducer<(K, T)>>,
pub(crate) subscriptions: Arc<Mutex<HashSet<K>>>,
pub(crate) closed: AtomicBool,
}
impl<K, T> TopicReceiver<K, T>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
T: Send + Clone + 'static,
{
pub fn recv(&self) -> Result<(K, T), RecvError> {
self.consumer.recv_sync()
}
pub fn try_recv(&self) -> Result<(K, T), TryRecvError> {
self.consumer.try_recv()
}
pub fn recv_timeout(&self, timeout: Duration) -> Result<(K, T), RecvErrorTimeout> {
if self.closed.load(Ordering::Relaxed) {
return self
.consumer
.try_recv()
.map_err(|_| RecvErrorTimeout::Disconnected);
}
self.consumer.recv_timeout_sync(timeout)
}
pub fn subscribe(&self, topic: K) {
let mut subs = self.subscriptions.lock();
if !subs.insert(topic.clone()) {
return; }
drop(subs);
if let Some(dispatcher) = self.dispatcher.upgrade() {
let list_arc = dispatcher
.subscriptions
.pin()
.get_or_insert_with(topic, || Arc::new(SubscriberList::new()))
.clone();
list_arc.writer.modify(|list| {
list.retain(|w| w.upgrade().is_some());
if !list
.iter()
.any(|w| w.ptr_eq(&Arc::downgrade(&self.producer_mailbox)))
{
list.push(Arc::downgrade(&self.producer_mailbox));
}
});
}
}
pub fn unsubscribe<Q: ?Sized>(&self, topic: &Q)
where
K: Borrow<Q> + Equivalent<Q>,
Q: Hash + Eq,
{
let mut subs = self.subscriptions.lock();
if !subs.remove(topic) {
return; }
drop(subs);
if let Some(dispatcher) = self.dispatcher.upgrade() {
if let Some(list_arc) = dispatcher.subscriptions.pin().get(topic) {
list_arc.writer.modify(|list| {
list.retain(|w| {
w.upgrade()
.map_or(false, |s| !Arc::ptr_eq(&s, &self.producer_mailbox))
});
});
}
}
}
pub fn is_closed(&self) -> bool {
self.closed.load(Ordering::Relaxed)
|| (self.dispatcher.upgrade().is_none() && self.consumer.is_empty())
}
pub fn close(&self) -> Result<(), CloseError> {
if self
.closed
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.close_internal();
Ok(())
} else {
Err(CloseError)
}
}
fn close_internal(&self) {
if let Some(dispatcher) = self.dispatcher.upgrade() {
let topics_to_unsubscribe: Vec<K> = self.subscriptions.lock().drain().collect();
for topic in topics_to_unsubscribe {
self.unsubscribe(&topic);
}
dispatcher.receiver_count.fetch_sub(1, Ordering::Relaxed);
}
}
pub fn capacity(&self) -> usize {
return self.consumer.capacity();
}
pub fn is_empty(&self) -> bool {
return self.consumer.is_empty();
}
pub fn to_async(self) -> AsyncTopicReceiver<K, T> {
let dispatcher = unsafe { std::ptr::read(&self.dispatcher) };
let consumer = unsafe { std::ptr::read(&self.consumer) };
let producer_mailbox = unsafe { std::ptr::read(&self.producer_mailbox) };
let subscriptions = unsafe { std::ptr::read(&self.subscriptions) };
mem::forget(self);
AsyncTopicReceiver {
dispatcher,
consumer,
producer_mailbox,
subscriptions,
closed: AtomicBool::new(false),
}
}
}
impl<K, T> Clone for TopicReceiver<K, T>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
T: Send + Clone + 'static,
{
fn clone(&self) -> Self {
if let Some(dispatcher) = self.dispatcher.upgrade() {
dispatcher.receiver_count.fetch_add(1, Ordering::Relaxed);
let mailbox_capacity = self.consumer.capacity();
let (p, c) = mailbox::channel(mailbox_capacity);
let topics_to_subscribe: Vec<K> = self.subscriptions.lock().iter().cloned().collect();
let new_receiver = Self {
dispatcher: self.dispatcher.clone(),
consumer: c,
producer_mailbox: Arc::new(p),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
closed: AtomicBool::new(false),
};
for topic in topics_to_subscribe {
new_receiver.subscribe(topic);
}
new_receiver
} else {
let (p, c) = mailbox::channel(0);
Self {
dispatcher: Weak::new(),
consumer: c,
producer_mailbox: Arc::new(p),
subscriptions: Arc::new(Mutex::new(HashSet::new())),
closed: AtomicBool::new(true),
}
}
}
}
impl<K, T> Drop for TopicReceiver<K, T>
where
K: Eq + Hash + Clone + Send + Sync + 'static,
T: Send + Clone + 'static,
{
fn drop(&mut self) {
if !self.closed.swap(true, Ordering::AcqRel) {
self.close_internal();
}
}
}