use either::Either;
use crate::sync::{thread, AtomicBool, Mutex};
use smallvec::SmallVec;
use std::collections::{BTreeMap, BTreeSet};
use std::sync::atomic::Ordering;
use std::sync::{Arc, Weak};
use std::{fmt::Debug, mem};
use thread::ThreadId;
#[derive(Debug)]
pub enum SubscriptionError {
CannotEmitEventDueToRecursiveCall,
}
impl std::fmt::Display for Subscription {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SubscriptionError")
}
}
pub struct SubscriberSet<EmitterKey, Callback>(
Arc<Mutex<SubscriberSetState<EmitterKey, Callback>>>,
);
impl<EmitterKey, Callback> Clone for SubscriberSet<EmitterKey, Callback> {
fn clone(&self) -> Self {
SubscriberSet(self.0.clone())
}
}
struct SubscriberSetState<EmitterKey, Callback> {
subscribers: BTreeMap<EmitterKey, Either<BTreeMap<usize, Subscriber<Callback>>, ThreadId>>,
dropped_subscribers: BTreeSet<(EmitterKey, usize)>,
next_subscriber_id: usize,
}
struct Subscriber<Callback> {
active: Arc<AtomicBool>,
callback: Callback,
_sub: InnerSubscription,
}
impl<EmitterKey, Callback> SubscriberSet<EmitterKey, Callback>
where
EmitterKey: 'static + Ord + Clone + Debug + Send + Sync,
Callback: 'static + Send + Sync,
{
pub fn new() -> Self {
Self(Arc::new(Mutex::new(SubscriberSetState {
subscribers: Default::default(),
dropped_subscribers: Default::default(),
next_subscriber_id: 0,
})))
}
pub fn insert(
&self,
emitter_key: EmitterKey,
callback: Callback,
) -> (Subscription, impl FnOnce()) {
let active = Arc::new(AtomicBool::new(false));
let mut lock = self.0.lock();
let subscriber_id = post_inc(&mut lock.next_subscriber_id);
let this = Arc::downgrade(&self.0);
let emitter_key_1 = emitter_key.clone();
let inner_sub = InnerSubscription {
unsubscribe: Arc::new(Mutex::new(Some(Box::new(move || {
let Some(this) = this.upgrade() else {
return;
};
let mut lock = this.lock();
let Some(subscribers) = lock.subscribers.get_mut(&emitter_key) else {
return;
};
if let Either::Left(subscribers) = subscribers {
subscribers.remove(&subscriber_id);
if subscribers.is_empty() {
lock.subscribers.remove(&emitter_key);
}
return;
}
lock.dropped_subscribers
.insert((emitter_key, subscriber_id));
})))),
};
let subscription = Subscription {
unsubscribe: Arc::downgrade(&inner_sub.unsubscribe),
};
lock.subscribers
.entry(emitter_key_1)
.or_insert_with(|| Either::Left(BTreeMap::new()))
.as_mut()
.unwrap_left()
.insert(
subscriber_id,
Subscriber {
active: active.clone(),
callback,
_sub: inner_sub,
},
);
(subscription, move || active.store(true, Ordering::Relaxed))
}
#[allow(unused)]
pub fn remove(&self, emitter: &EmitterKey) -> impl IntoIterator<Item = Callback> {
let mut lock = self.0.lock();
let subscribers = lock.subscribers.remove(emitter);
subscribers
.and_then(|x| x.left().map(|s| s.into_values()))
.into_iter()
.flatten()
.filter_map(|subscriber| {
if subscriber.active.load(Ordering::Relaxed) {
Some(subscriber.callback)
} else {
None
}
})
}
pub fn is_recursive_calling(&self, emitter: &EmitterKey) -> bool {
if let Some(Either::Right(thread_id)) = self.0.lock().subscribers.get(emitter) {
*thread_id == thread::current().id()
} else {
false
}
}
pub fn retain(
&self,
emitter: &EmitterKey,
f: &mut dyn FnMut(&mut Callback) -> bool,
) -> Result<(), SubscriptionError> {
let mut subscribers = {
let inner = loop {
let mut subscriber_set_state = self.0.lock();
let Some(set) = subscriber_set_state.subscribers.get_mut(emitter) else {
return Ok(());
};
match set {
Either::Left(_) => {
break std::mem::replace(set, Either::Right(thread::current().id()))
.unwrap_left();
}
Either::Right(lock_thread) => {
if thread::current().id() == *lock_thread {
return Err(SubscriptionError::CannotEmitEventDueToRecursiveCall);
} else {
drop(subscriber_set_state);
#[cfg(loom)]
loom::thread::yield_now();
#[cfg(not(loom))]
std::thread::sleep(std::time::Duration::from_millis(10));
}
}
}
};
inner
};
subscribers.retain(|_, subscriber| {
if subscriber.active.load(Ordering::Relaxed) {
f(&mut subscriber.callback)
} else {
true
}
});
let mut lock = self.0.lock();
if let Some(Either::Left(new_subscribers)) = lock.subscribers.remove(emitter) {
subscribers.extend(new_subscribers);
}
for (dropped_emitter, dropped_subscription_id) in mem::take(&mut lock.dropped_subscribers) {
if *emitter == dropped_emitter {
subscribers.remove(&dropped_subscription_id);
} else {
lock.dropped_subscribers
.insert((dropped_emitter, dropped_subscription_id));
}
}
lock.subscribers
.insert(emitter.clone(), Either::Left(subscribers));
Ok(())
}
pub fn is_empty(&self) -> bool {
self.0.lock().subscribers.is_empty()
}
pub fn may_include(&self, emitter: &EmitterKey) -> bool {
self.0.lock().subscribers.contains_key(emitter)
}
}
impl<EmitterKey, Callback> Default for SubscriberSet<EmitterKey, Callback>
where
EmitterKey: 'static + Ord + Clone + Debug + Send + Sync,
Callback: 'static + Send + Sync,
{
fn default() -> Self {
Self::new()
}
}
impl<EmitterKey, Callback> std::fmt::Debug for SubscriberSet<EmitterKey, Callback>
where
EmitterKey: 'static + Ord + Clone + Debug + Send + Sync,
Callback: 'static + Send + Sync,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let lock = self.0.lock();
f.debug_struct("SubscriberSet")
.field("subscriber_count", &lock.subscribers.len())
.field("dropped_subscribers_count", &lock.dropped_subscribers.len())
.field("next_subscriber_id", &lock.next_subscriber_id)
.finish()
}
}
fn post_inc(next_subscriber_id: &mut usize) -> usize {
let ans = *next_subscriber_id;
*next_subscriber_id += 1;
ans
}
type Callback = Box<dyn FnOnce() + 'static + Send + Sync>;
#[must_use]
pub struct Subscription {
unsubscribe: Weak<Mutex<Option<Callback>>>,
}
impl std::fmt::Debug for Subscription {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Subscription").finish()
}
}
impl Subscription {
pub fn detach(self) {
if let Some(unsubscribe) = self.unsubscribe.upgrade() {
unsubscribe.lock().take();
}
}
#[inline]
pub fn unsubscribe(self) {
drop(self)
}
}
impl Drop for Subscription {
fn drop(&mut self) {
if let Some(unsubscribe) = self.unsubscribe.upgrade() {
let unsubscribe = unsubscribe.lock().take();
if let Some(unsubscribe) = unsubscribe {
unsubscribe();
}
}
}
}
struct InnerSubscription {
unsubscribe: Arc<Mutex<Option<Callback>>>,
}
impl Drop for InnerSubscription {
fn drop(&mut self) {
self.unsubscribe.lock().take();
}
}
#[derive(Clone)]
pub struct SubscriberSetWithQueue<EmitterKey, Callback, Payload> {
subscriber_set: SubscriberSet<EmitterKey, Callback>,
queue: Arc<Mutex<BTreeMap<EmitterKey, Vec<Payload>>>>,
}
impl<EmitterKey, Callback, Payload> Debug
for SubscriberSetWithQueue<EmitterKey, Callback, Payload>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SubscriberSetWithQueue").finish()
}
}
impl<EmitterKey, Callback, Payload> Default
for SubscriberSetWithQueue<EmitterKey, Callback, Payload>
where
EmitterKey: 'static + Ord + Clone + Debug + Send + Sync,
Callback: 'static + Send + Sync + for<'a> FnMut(&'a Payload) -> bool,
Payload: Send + Sync + Debug,
{
fn default() -> Self {
Self::new()
}
}
pub struct WeakSubscriberSetWithQueue<EmitterKey, Callback, Payload> {
subscriber_set: Weak<Mutex<SubscriberSetState<EmitterKey, Callback>>>,
queue: Weak<Mutex<BTreeMap<EmitterKey, Vec<Payload>>>>,
}
impl<EmitterKey, Callback, Payload> Debug
for WeakSubscriberSetWithQueue<EmitterKey, Callback, Payload>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WeakSubscriberSetWithQueue").finish()
}
}
impl<EmitterKey, Callback, Payload> Clone
for WeakSubscriberSetWithQueue<EmitterKey, Callback, Payload>
{
fn clone(&self) -> Self {
Self {
subscriber_set: self.subscriber_set.clone(),
queue: self.queue.clone(),
}
}
}
impl<EmitterKey, Callback, Payload> WeakSubscriberSetWithQueue<EmitterKey, Callback, Payload> {
pub fn upgrade(self) -> Option<SubscriberSetWithQueue<EmitterKey, Callback, Payload>> {
Some(SubscriberSetWithQueue {
subscriber_set: SubscriberSet(self.subscriber_set.upgrade()?),
queue: self.queue.upgrade()?,
})
}
}
impl<EmitterKey, Callback, Payload> SubscriberSetWithQueue<EmitterKey, Callback, Payload>
where
EmitterKey: 'static + Ord + Clone + Debug + Send + Sync,
Callback: 'static + Send + Sync + for<'a> FnMut(&'a Payload) -> bool,
Payload: Send + Sync + Debug,
{
pub fn new() -> Self {
Self {
subscriber_set: SubscriberSet::new(),
queue: Arc::new(Mutex::new(Default::default())),
}
}
pub fn downgrade(&self) -> WeakSubscriberSetWithQueue<EmitterKey, Callback, Payload> {
WeakSubscriberSetWithQueue {
subscriber_set: Arc::downgrade(&self.subscriber_set.0),
queue: Arc::downgrade(&self.queue),
}
}
pub fn inner(&self) -> &SubscriberSet<EmitterKey, Callback> {
&self.subscriber_set
}
pub fn emit(&self, key: &EmitterKey, payload: Payload) {
let mut pending_events: SmallVec<[Payload; 1]> = SmallVec::new();
pending_events.push(payload);
while let Some(payload) = pending_events.pop() {
let result = self
.subscriber_set
.retain(key, &mut |callback| (callback)(&payload));
match result {
Ok(_) => {
let mut queue = self.queue.lock();
if let Some(new_pending_events) = queue.remove(key) {
pending_events.extend(new_pending_events);
}
}
Err(SubscriptionError::CannotEmitEventDueToRecursiveCall) => {
let mut queue = self.queue.lock();
queue.entry(key.clone()).or_default().push(payload);
}
}
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_inner_subscription_drop() {
let subscriber_set = SubscriberSet::<i32, Box<dyn Fn(&i32) -> bool + Send + Sync>>::new();
let (subscription, activate) = subscriber_set.insert(1, Box::new(move |_: &i32| true));
activate();
drop(subscriber_set);
assert!(subscription.unsubscribe.upgrade().is_none());
}
#[test]
fn test_inner_subscription_drop_2() {
let subscriber_set = SubscriberSet::<i32, Box<dyn Fn(&i32) -> bool + Send + Sync>>::new();
let (subscription, activate) = subscriber_set.insert(1, Box::new(move |_: &i32| false));
activate();
subscriber_set
.retain(&1, &mut |callback| callback(&1))
.unwrap();
assert!(subscription.unsubscribe.upgrade().is_none());
}
}