use std::cell::RefCell;
use std::fmt;
use std::sync::atomic::{AtomicU64, Ordering};
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct SignalId(u64);
impl fmt::Display for SignalId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Signal({})", self.0)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct SubscriberId(u64);
impl fmt::Display for SubscriberId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Subscriber({})", self.0)
}
}
static SIGNAL_COUNTER: AtomicU64 = AtomicU64::new(1);
static SUBSCRIBER_COUNTER: AtomicU64 = AtomicU64::new(1);
pub fn next_signal_id() -> SignalId {
SignalId(SIGNAL_COUNTER.fetch_add(1, Ordering::Relaxed))
}
pub fn next_subscriber_id() -> SubscriberId {
SubscriberId(SUBSCRIBER_COUNTER.fetch_add(1, Ordering::Relaxed))
}
struct TrackingScope {
_subscriber_id: SubscriberId,
dependencies: Vec<SignalId>,
}
thread_local! {
static TRACKING: RefCell<Option<TrackingScope>> = const { RefCell::new(None) };
}
pub fn start_tracking(id: SubscriberId) {
TRACKING.with(|t| {
*t.borrow_mut() = Some(TrackingScope {
_subscriber_id: id,
dependencies: Vec::new(),
});
});
}
pub fn stop_tracking() -> Vec<SignalId> {
TRACKING.with(|t| {
t.borrow_mut()
.take()
.map(|scope| scope.dependencies)
.unwrap_or_default()
})
}
pub fn record_read(signal_id: SignalId) {
TRACKING.with(|t| {
if let Some(scope) = t.borrow_mut().as_mut() {
scope.dependencies.push(signal_id);
}
});
}
pub fn is_tracking() -> bool {
TRACKING.with(|t| t.borrow().is_some())
}
pub fn synthetic_signal_id(sub_id: SubscriberId) -> SignalId {
SignalId(sub_id.0 | (1 << 63))
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn signal_id_uniqueness() {
let a = next_signal_id();
let b = next_signal_id();
assert_ne!(a, b);
}
#[test]
fn subscriber_id_uniqueness() {
let a = next_subscriber_id();
let b = next_subscriber_id();
assert_ne!(a, b);
}
#[test]
fn signal_id_display() {
let id = SignalId(42);
assert_eq!(format!("{id}"), "Signal(42)");
}
#[test]
fn subscriber_id_display() {
let id = SubscriberId(99);
assert_eq!(format!("{id}"), "Subscriber(99)");
}
#[test]
fn start_stop_tracking_roundtrip() {
let sub = next_subscriber_id();
start_tracking(sub);
assert!(is_tracking());
let deps = stop_tracking();
assert!(!is_tracking());
assert!(deps.is_empty());
}
#[test]
fn record_read_outside_tracking_is_noop() {
record_read(SignalId(1));
assert!(!is_tracking());
}
#[test]
fn record_read_inside_tracking() {
let sub = next_subscriber_id();
start_tracking(sub);
let sig_a = SignalId(100);
let sig_b = SignalId(200);
record_read(sig_a);
record_read(sig_b);
let deps = stop_tracking();
assert_eq!(deps.len(), 2);
assert_eq!(deps[0], sig_a);
assert_eq!(deps[1], sig_b);
}
#[test]
fn nested_start_overwrites_previous() {
let sub_a = next_subscriber_id();
let sub_b = next_subscriber_id();
start_tracking(sub_a);
record_read(SignalId(1));
start_tracking(sub_b);
record_read(SignalId(2));
let deps = stop_tracking();
assert_eq!(deps.len(), 1);
assert_eq!(deps[0], SignalId(2));
}
#[test]
fn stop_tracking_when_not_active() {
let deps = stop_tracking();
assert!(deps.is_empty());
}
}