use super::tracker::{notify_dependents, track_read};
use super::SignalId;
use crate::utils::lock::{read_or_recover, write_or_recover};
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct SubscriptionId(u64);
impl SubscriptionId {
fn new() -> Self {
static COUNTER: AtomicU64 = AtomicU64::new(0);
Self(COUNTER.fetch_add(1, AtomicOrdering::Relaxed))
}
}
type SubscriberCallback = Arc<dyn Fn() + Send + Sync>;
type Subscribers = Arc<RwLock<HashMap<SubscriptionId, SubscriberCallback>>>;
pub struct Subscription {
id: SubscriptionId,
subscribers: Subscribers,
}
impl Drop for Subscription {
fn drop(&mut self) {
if let Ok(mut subs) = self.subscribers.write() {
subs.remove(&self.id);
}
}
}
pub struct Signal<T> {
id: SignalId,
value: Arc<RwLock<T>>,
subscribers: Subscribers,
}
impl<T: 'static> Signal<T> {
pub fn new(value: T) -> Self {
Self {
id: SignalId::new(),
value: Arc::new(RwLock::new(value)),
subscribers: Arc::new(RwLock::new(HashMap::new())),
}
}
#[inline]
pub fn read(&self) -> RwLockReadGuard<'_, T> {
track_read(self.id);
read_or_recover(&self.value)
}
#[inline]
pub fn borrow(&self) -> RwLockReadGuard<'_, T> {
self.read()
}
#[inline]
pub fn write(&self) -> RwLockWriteGuard<'_, T> {
write_or_recover(&self.value)
}
#[inline]
pub fn borrow_mut(&self) -> RwLockWriteGuard<'_, T> {
self.write()
}
#[inline]
pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
track_read(self.id);
let guard = read_or_recover(&self.value);
f(&*guard)
}
#[inline]
pub fn with_mut<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
let mut guard = write_or_recover(&self.value);
f(&mut *guard)
}
#[inline]
pub fn update_with<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
let result = {
let mut guard = write_or_recover(&self.value);
f(&mut *guard)
};
self.notify();
notify_dependents(self.id);
result
}
pub fn set(&self, value: T) {
{
let mut guard = write_or_recover(&self.value);
*guard = value;
}
self.notify();
notify_dependents(self.id);
}
pub fn update(&self, f: impl FnOnce(&mut T)) {
{
let mut guard = write_or_recover(&self.value);
f(&mut *guard);
}
self.notify();
notify_dependents(self.id);
}
pub fn subscribe(&self, callback: impl Fn() + Send + Sync + 'static) -> Subscription {
let id = SubscriptionId::new();
{
let mut subs = write_or_recover(&self.subscribers);
subs.insert(id, Arc::new(callback));
}
Subscription {
id,
subscribers: Arc::clone(&self.subscribers),
}
}
pub fn notify_change(&self) {
self.notify();
notify_dependents(self.id);
}
fn notify(&self) {
let callbacks: Vec<_> = {
let subs = read_or_recover(&self.subscribers);
subs.values().cloned().collect()
};
for callback in callbacks {
callback();
}
}
pub fn id(&self) -> SignalId {
self.id
}
}
impl<T: Clone + 'static> Signal<T> {
#[inline]
pub fn get(&self) -> T {
track_read(self.id);
let guard = read_or_recover(&self.value);
guard.clone()
}
}
impl<T: 'static> Clone for Signal<T> {
fn clone(&self) -> Self {
Self {
id: self.id,
value: Arc::clone(&self.value),
subscribers: Arc::clone(&self.subscribers),
}
}
}
impl<T: std::fmt::Debug + 'static> std::fmt::Debug for Signal<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.value.try_read() {
Ok(guard) => f
.debug_struct("Signal")
.field("id", &self.id)
.field("value", &*guard)
.finish(),
Err(_) => f
.debug_struct("Signal")
.field("id", &self.id)
.field("value", &"<locked>")
.finish(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_callback_can_drop_other_subscription() {
use std::sync::{Arc, Mutex};
let signal = Signal::new(0);
let sub_holder = Arc::new(Mutex::new(None));
let sub = signal.subscribe({
let signal_clone = signal.clone();
let holder = Arc::clone(&sub_holder);
move || {
if let Ok(mut guard) = holder.lock() {
if let Some(s) = guard.take() {
drop(s);
}
}
signal_clone.set(42);
}
});
if let Ok(mut guard) = sub_holder.lock() {
*guard = Some(sub);
}
signal.set(1);
}
#[test]
fn test_multiple_subscriptions_drop_during_notify() {
let signal = Signal::new(0);
let mut subscriptions = Vec::new();
for i in 0..5 {
let sub = signal.subscribe(move || {
if i == 2 {
}
});
subscriptions.push(sub);
}
signal.set(1);
}
#[test]
fn test_nested_notifications() {
let signal1 = Signal::new(0);
let signal2 = Signal::new(0);
let _sub = signal1.subscribe({
let signal2_clone = signal2.clone();
move || {
signal2_clone.set(1);
}
});
let _sub2 = signal2.subscribe(|| {
});
signal1.set(1);
}
#[test]
fn test_subscription_cleanup() {
use std::sync::atomic::{AtomicUsize, Ordering};
let signal = Signal::new(0);
let call_count = Arc::new(AtomicUsize::new(0));
let sub = signal.subscribe({
let count = Arc::clone(&call_count);
move || {
count.fetch_add(1, Ordering::Relaxed);
}
});
signal.set(1);
assert_eq!(call_count.load(Ordering::Relaxed), 1);
drop(sub);
signal.set(2);
assert_eq!(call_count.load(Ordering::Relaxed), 1);
}
#[test]
fn test_signal_basic() {
let signal = Signal::new(42);
assert_eq!(*signal.read(), 42);
signal.set(100);
assert_eq!(*signal.read(), 100);
signal.update(|v| *v *= 2);
assert_eq!(*signal.read(), 200);
}
#[test]
fn test_signal_with_methods() {
let signal = Signal::new(vec![1, 2, 3]);
let len = signal.with(|v| v.len());
assert_eq!(len, 3);
signal.with_mut(|v| v.push(4));
assert_eq!(*signal.read(), vec![1, 2, 3, 4]);
}
}