use super::SignalId;
use std::cell::{Cell, RefCell};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
const MAX_NOTIFY_DEPTH: usize = 100;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct SubscriberId(u64);
impl SubscriberId {
pub fn new() -> Self {
use std::sync::atomic::{AtomicU64, Ordering};
static COUNTER: AtomicU64 = AtomicU64::new(0);
Self(COUNTER.fetch_add(1, Ordering::Relaxed))
}
}
impl Default for SubscriberId {
fn default() -> Self {
Self::new()
}
}
pub type SubscriberCallback = Arc<dyn Fn() + Send + Sync>;
#[derive(Clone)]
pub struct Subscriber {
pub id: SubscriberId,
pub callback: SubscriberCallback,
}
impl Subscriber {
pub fn new(callback: impl Fn() + Send + Sync + 'static) -> Self {
Self {
id: SubscriberId::new(),
callback: Arc::new(callback),
}
}
pub fn notify(&self) {
(self.callback)();
}
}
pub struct DependencyTracker {
subscriber_stack: Vec<Subscriber>,
dependencies: HashMap<SignalId, HashSet<SubscriberId>>,
subscribers: HashMap<SubscriberId, SubscriberCallback>,
subscriber_deps: HashMap<SubscriberId, HashSet<SignalId>>,
}
impl DependencyTracker {
pub fn new() -> Self {
Self {
subscriber_stack: Vec::new(),
dependencies: HashMap::new(),
subscribers: HashMap::new(),
subscriber_deps: HashMap::new(),
}
}
pub fn start_tracking(&mut self, subscriber: Subscriber) {
self.clear_subscriber_deps(subscriber.id);
self.subscribers
.insert(subscriber.id, subscriber.callback.clone());
self.subscriber_stack.push(subscriber);
}
pub fn stop_tracking(&mut self) -> Option<Subscriber> {
self.subscriber_stack.pop()
}
pub fn current_subscriber(&self) -> Option<&Subscriber> {
self.subscriber_stack.last()
}
pub fn track_read(&mut self, signal_id: SignalId) {
if let Some(subscriber) = self.subscriber_stack.last() {
let sub_id = subscriber.id;
self.dependencies
.entry(signal_id)
.or_default()
.insert(sub_id);
self.subscriber_deps
.entry(sub_id)
.or_default()
.insert(signal_id);
}
}
pub fn notify_subscribers(&self, signal_id: SignalId) {
if let Some(subscriber_ids) = self.dependencies.get(&signal_id) {
let ids: Vec<_> = subscriber_ids.iter().copied().collect();
for id in ids {
if let Some(callback) = self.subscribers.get(&id) {
callback();
}
}
}
}
fn clear_subscriber_deps(&mut self, subscriber_id: SubscriberId) {
if let Some(signal_ids) = self.subscriber_deps.remove(&subscriber_id) {
for signal_id in signal_ids {
if let Some(deps) = self.dependencies.get_mut(&signal_id) {
deps.remove(&subscriber_id);
}
}
}
}
pub fn dispose_subscriber(&mut self, subscriber_id: SubscriberId) {
self.clear_subscriber_deps(subscriber_id);
self.subscribers.remove(&subscriber_id);
}
pub fn is_tracking(&self) -> bool {
!self.subscriber_stack.is_empty()
}
pub fn dependent_count(&self, signal_id: SignalId) -> usize {
self.dependencies.get(&signal_id).map_or(0, |s| s.len())
}
}
impl Default for DependencyTracker {
fn default() -> Self {
Self::new()
}
}
thread_local! {
static TRACKER: RefCell<DependencyTracker> = RefCell::new(DependencyTracker::new());
static NOTIFY_DEPTH: Cell<usize> = const { Cell::new(0) };
}
pub fn with_tracker<R>(f: impl FnOnce(&mut DependencyTracker) -> R) -> R {
TRACKER.with(|tracker| f(&mut tracker.borrow_mut()))
}
pub fn start_tracking(subscriber: Subscriber) {
with_tracker(|t| t.start_tracking(subscriber));
}
pub fn stop_tracking() -> Option<Subscriber> {
with_tracker(|t| t.stop_tracking())
}
pub fn track_read(signal_id: SignalId) {
with_tracker(|t| t.track_read(signal_id));
}
pub fn notify_dependents(signal_id: SignalId) {
let depth = NOTIFY_DEPTH.with(|d| {
let current = d.get();
let new_depth = current + 1;
d.set(new_depth);
new_depth
});
struct DepthGuard;
impl Drop for DepthGuard {
fn drop(&mut self) {
NOTIFY_DEPTH.with(|d| d.set(d.get().saturating_sub(1)));
}
}
let _guard = DepthGuard;
if depth > MAX_NOTIFY_DEPTH {
panic!(
"Maximum reactive update depth ({}) exceeded. \
This usually indicates a circular dependency in your reactive graph.",
MAX_NOTIFY_DEPTH
);
}
let callbacks: Vec<SubscriberCallback> = with_tracker(|t| {
t.dependencies
.get(&signal_id)
.map(|subscriber_ids| {
subscriber_ids
.iter()
.filter_map(|id| t.subscribers.get(id).cloned())
.collect()
})
.unwrap_or_default()
});
for callback in callbacks {
callback();
}
}
pub fn dispose_subscriber(subscriber_id: SubscriberId) {
with_tracker(|t| t.dispose_subscriber(subscriber_id));
}
pub fn is_tracking() -> bool {
with_tracker(|t| t.is_tracking())
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_subscriber_id_unique() {
let id1 = SubscriberId::new();
let id2 = SubscriberId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_tracker_basic_tracking() {
let mut tracker = DependencyTracker::new();
let signal_id = SignalId::new();
let called = Arc::new(AtomicUsize::new(0));
let called_clone = called.clone();
let subscriber = Subscriber::new(move || {
called_clone.fetch_add(1, Ordering::SeqCst);
});
tracker.start_tracking(subscriber);
tracker.track_read(signal_id);
tracker.stop_tracking();
assert_eq!(tracker.dependent_count(signal_id), 1);
tracker.notify_subscribers(signal_id);
assert_eq!(called.load(Ordering::SeqCst), 1);
}
#[test]
fn test_tracker_nested_tracking() {
let mut tracker = DependencyTracker::new();
let signal1 = SignalId::new();
let signal2 = SignalId::new();
let sub1 = Subscriber::new(|| {});
let sub2 = Subscriber::new(|| {});
tracker.start_tracking(sub1);
tracker.track_read(signal1);
tracker.start_tracking(sub2);
tracker.track_read(signal2);
tracker.stop_tracking();
tracker.track_read(signal1);
tracker.stop_tracking();
assert_eq!(tracker.dependent_count(signal1), 1);
assert_eq!(tracker.dependent_count(signal2), 1);
}
#[test]
fn test_tracker_retracking_clears_old_deps() {
let mut tracker = DependencyTracker::new();
let signal1 = SignalId::new();
let signal2 = SignalId::new();
let sub_id = SubscriberId::new();
let subscriber = Subscriber {
id: sub_id,
callback: Arc::new(|| {}),
};
tracker.start_tracking(subscriber.clone());
tracker.track_read(signal1);
tracker.stop_tracking();
assert_eq!(tracker.dependent_count(signal1), 1);
assert_eq!(tracker.dependent_count(signal2), 0);
tracker.start_tracking(subscriber);
tracker.track_read(signal2);
tracker.stop_tracking();
assert_eq!(tracker.dependent_count(signal1), 0);
assert_eq!(tracker.dependent_count(signal2), 1);
}
#[test]
fn test_tracker_dispose_subscriber() {
let mut tracker = DependencyTracker::new();
let signal_id = SignalId::new();
let sub_id = SubscriberId::new();
let subscriber = Subscriber {
id: sub_id,
callback: Arc::new(|| {}),
};
tracker.start_tracking(subscriber);
tracker.track_read(signal_id);
tracker.stop_tracking();
assert_eq!(tracker.dependent_count(signal_id), 1);
tracker.dispose_subscriber(sub_id);
assert_eq!(tracker.dependent_count(signal_id), 0);
}
#[test]
#[serial]
fn test_notify_depth_resets_after_normal_notify() {
let signal_id = SignalId::new();
let sub_id = SubscriberId::new();
let called = Arc::new(AtomicUsize::new(0));
let called_clone = called.clone();
let subscriber = Subscriber {
id: sub_id,
callback: Arc::new(move || {
called_clone.fetch_add(1, Ordering::SeqCst);
}),
};
with_tracker(|t| {
t.start_tracking(subscriber);
t.track_read(signal_id);
t.stop_tracking();
});
notify_dependents(signal_id);
assert_eq!(called.load(Ordering::SeqCst), 1);
notify_dependents(signal_id);
assert_eq!(called.load(Ordering::SeqCst), 2);
dispose_subscriber(sub_id);
}
#[test]
#[serial]
#[should_panic(expected = "Maximum reactive update depth")]
fn test_notify_depth_guard_panics_on_circular_dependency() {
let signal_a = SignalId::new();
let sub_id = SubscriberId::new();
let subscriber = Subscriber {
id: sub_id,
callback: Arc::new(move || {
notify_dependents(signal_a);
}),
};
with_tracker(|t| {
t.start_tracking(subscriber);
t.track_read(signal_a);
t.stop_tracking();
});
notify_dependents(signal_a);
}
}