use std::{
fmt::Debug,
marker::PhantomData,
panic::{AssertUnwindSafe, Location, catch_unwind},
sync::{
Arc, Weak,
atomic::{AtomicBool, Ordering},
},
time::Duration,
};
use arc_swap::ArcSwap;
use dashmap::DashMap;
use uuid::Uuid;
use crate::{
metrics::CellMetrics,
signal::Signal,
subscription::SubscriptionGuard,
traits::{CellValue, DepNode, Gettable, Mutable, Watchable},
};
#[derive(Debug, Clone)]
pub struct SlowSubscriberAlert {
pub subscriber_id: Uuid,
pub duration_ns: u64,
pub threshold_ns: u64,
}
type SlowSubscriberCallback = Arc<dyn Fn(SlowSubscriberAlert) + Send + Sync>;
#[derive(Debug, Clone)]
pub struct CellMutable;
#[derive(Debug, Clone)]
pub struct CellImmutable;
pub(crate) struct CellInner<T> {
pub(crate) id: Uuid,
pub(crate) subscribers: DashMap<Uuid, Box<Subscriber<T>>>,
pub(crate) value: ArcSwap<T>,
pub(crate) name: ArcSwap<Option<Arc<str>>>,
pub(crate) owned: DashMap<Uuid, SubscriptionGuard>,
pub(crate) completed: AtomicBool,
pub(crate) errored: AtomicBool,
pub(crate) error: ArcSwap<Option<Arc<anyhow::Error>>>,
pub(crate) metrics: Option<Arc<CellMetrics>>,
pub(crate) slow_subscriber_threshold_ns: ArcSwap<Option<u64>>,
pub(crate) slow_subscriber_callback: ArcSwap<Option<SlowSubscriberCallback>>,
#[allow(dead_code)]
pub(crate) caller: &'static Location<'static>,
}
pub struct Cell<T, M> {
pub(crate) inner: Arc<CellInner<T>>,
_marker: PhantomData<M>,
}
pub struct WeakCell<T, M> {
inner: Weak<CellInner<T>>,
_marker: PhantomData<M>,
}
impl<T, M> WeakCell<T, M> {
pub fn upgrade(&self) -> Option<Cell<T, M>> {
self.inner.upgrade().map(|inner| Cell {
inner,
_marker: PhantomData,
})
}
}
impl<T, M> Clone for WeakCell<T, M> {
fn clone(&self) -> Self {
WeakCell {
inner: self.inner.clone(),
_marker: PhantomData,
}
}
}
pub(crate) type SubscriberCallback<T> = Arc<dyn Fn(&Signal<T>) + Send + Sync>;
pub(crate) struct Subscriber<T> {
pub(crate) callback: SubscriberCallback<T>,
}
impl<T> Subscriber<T> {
pub(crate) fn new(callback: impl Fn(&Signal<T>) + Send + Sync + 'static) -> Self {
Self {
callback: Arc::new(callback),
}
}
}
impl<T: CellValue> Cell<T, CellMutable> {
#[track_caller]
pub fn new(initial_value: T) -> Self {
let inner = Arc::new(CellInner {
id: Uuid::new_v4(),
subscribers: DashMap::new(),
value: ArcSwap::from_pointee(initial_value),
name: ArcSwap::from_pointee(None),
owned: DashMap::new(),
completed: AtomicBool::new(false),
errored: AtomicBool::new(false),
error: ArcSwap::from_pointee(None),
metrics: default_metrics(),
slow_subscriber_threshold_ns: ArcSwap::from_pointee(None),
slow_subscriber_callback: ArcSwap::from_pointee(None),
caller: Location::caller(),
});
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
crate::registry::registry().register(inner.id, Arc::downgrade(&inner) as Weak<dyn DepNode>);
#[cfg(feature = "trace")]
crate::tracing::register_cell(inner.id, Some(Location::caller().to_string()));
Self {
inner,
_marker: PhantomData,
}
}
#[track_caller]
pub fn with_metrics(initial_value: T) -> Self {
let inner = Arc::new(CellInner {
id: Uuid::new_v4(),
subscribers: DashMap::new(),
value: ArcSwap::from_pointee(initial_value),
name: ArcSwap::from_pointee(None),
owned: DashMap::new(),
completed: AtomicBool::new(false),
errored: AtomicBool::new(false),
error: ArcSwap::from_pointee(None),
metrics: Some(Arc::new(CellMetrics::new())),
slow_subscriber_threshold_ns: ArcSwap::from_pointee(None),
slow_subscriber_callback: ArcSwap::from_pointee(None),
caller: Location::caller(),
});
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
crate::registry::registry().register(inner.id, Arc::downgrade(&inner) as Weak<dyn DepNode>);
#[cfg(feature = "trace")]
crate::tracing::register_cell(inner.id, Some(Location::caller().to_string()));
Self {
inner,
_marker: PhantomData,
}
}
pub fn on_slow_subscriber<F>(&self, threshold: Duration, callback: F)
where
F: Fn(SlowSubscriberAlert) + Send + Sync + 'static,
{
self.inner
.slow_subscriber_threshold_ns
.store(Arc::new(Some(threshold.as_nanos() as u64)));
self.inner
.slow_subscriber_callback
.store(Arc::new(Some(Arc::new(callback))));
}
pub fn lock(self) -> Cell<T, CellImmutable> {
Cell {
inner: self.inner,
_marker: PhantomData,
}
}
pub fn with_name(self, name: impl Into<Arc<str>>) -> Self {
let name = name.into();
self.inner.name.store(Arc::new(Some(name.clone())));
#[cfg(feature = "trace")]
crate::tracing::update_name(self.inner.id, name.to_string());
self
}
pub fn is_backed_up(&self) -> bool {
self.is_backed_up_threshold(std::time::Duration::from_millis(1))
}
pub fn is_backed_up_threshold(&self, threshold: std::time::Duration) -> bool {
self.inner
.metrics
.as_ref()
.map(|m| m.last_notify_time_ns() > threshold.as_nanos() as u64)
.unwrap_or(false)
}
pub fn try_set(&self, value: T) -> Result<(), T> {
if self.is_backed_up() {
Err(value)
} else {
self.set(value);
Ok(())
}
}
pub fn try_set_threshold(&self, value: T, threshold: std::time::Duration) -> Result<(), T> {
if self.is_backed_up_threshold(threshold) {
Err(value)
} else {
self.set(value);
Ok(())
}
}
}
impl<T, M> Clone for Cell<T, M> {
fn clone(&self) -> Self {
Cell {
inner: Arc::clone(&self.inner),
_marker: PhantomData,
}
}
}
impl<T, M> Cell<T, M> {
pub fn downgrade(&self) -> WeakCell<T, M> {
WeakCell {
inner: Arc::downgrade(&self.inner),
_marker: PhantomData,
}
}
pub fn metrics(&self) -> Option<&CellMetrics> {
self.inner.metrics.as_ref().map(|m| m.as_ref())
}
pub fn own(&self, guard: SubscriptionGuard) {
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
crate::registry::registry().mark_owned(guard.source().id(), self.inner.id);
self.inner.owned.insert(Uuid::new_v4(), guard);
#[cfg(feature = "trace")]
crate::tracing::update_owned_count(self.inner.id, self.inner.owned.len());
}
pub fn own_keyed(&self, key: Uuid, guard: SubscriptionGuard) {
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
{
if let Some((_, old_guard)) = self.inner.owned.remove(&key) {
crate::registry::registry().unmark_owned(old_guard.source().id());
}
crate::registry::registry().mark_owned(guard.source().id(), self.inner.id);
}
self.inner.owned.insert(key, guard);
#[cfg(feature = "trace")]
crate::tracing::update_owned_count(self.inner.id, self.inner.owned.len());
}
}
impl<T: Send + Sync, M: Send + Sync> DepNode for Cell<T, M> {
fn id(&self) -> Uuid {
self.inner.id
}
fn name(&self) -> Option<String> {
(**self.inner.name.load()).as_ref().map(|s| s.to_string())
}
fn deps(&self) -> Vec<Arc<dyn DepNode>> {
let mut seen = std::collections::HashSet::new();
self.inner
.owned
.iter()
.filter_map(|entry| {
let source = entry.value().source();
let id = source.id();
if seen.insert(id) {
Some(Arc::clone(source))
} else {
None
}
})
.collect()
}
fn subscriber_count(&self) -> usize {
self.inner.subscribers.len()
}
fn owned_count(&self) -> usize {
self.inner.owned.len()
}
}
impl<T: CellValue> Cell<T, CellImmutable> {
pub fn with_name(self, name: impl Into<Arc<str>>) -> Self {
let name = name.into();
self.inner.name.store(Arc::new(Some(name.clone())));
#[cfg(feature = "trace")]
crate::tracing::update_name(self.inner.id, name.to_string());
self
}
}
impl<T: CellValue, M: Send + Sync + 'static> Cell<T, M> {
#[doc(hidden)]
pub fn notify(&self, signal: Signal<T>) {
if self.inner.completed.load(Ordering::SeqCst) || self.inner.errored.load(Ordering::SeqCst)
{
return;
}
let notify_start = self
.inner
.metrics
.as_ref()
.map(|_| std::time::Instant::now());
match &signal {
Signal::Value(arc_value) => {
self.inner.value.store(arc_value.clone());
}
Signal::Complete => {
self.inner.completed.store(true, Ordering::SeqCst);
}
Signal::Error(err) => {
self.inner.errored.store(true, Ordering::SeqCst);
self.inner.error.store(Arc::new(Some(err.clone())));
}
}
let callbacks: Vec<_> = self
.inner
.subscribers
.iter()
.map(|entry| (*entry.key(), Arc::clone(&entry.value().callback)))
.collect();
let slow_threshold = **self.inner.slow_subscriber_threshold_ns.load();
let slow_callback = (**self.inner.slow_subscriber_callback.load()).clone();
let metrics = &self.inner.metrics;
for (subscriber_id, callback) in &callbacks {
let sub_start = metrics.as_ref().map(|_| std::time::Instant::now());
#[cfg(panic = "unwind")]
let _ = catch_unwind(AssertUnwindSafe(|| {
callback(&signal);
}));
#[cfg(panic = "abort")]
callback(&signal);
if let (Some(m), Some(start)) = (metrics, sub_start) {
let elapsed = start.elapsed().as_nanos() as u64;
m.update_slowest_subscriber(elapsed);
if let (Some(threshold), Some(cb)) = (&slow_threshold, &slow_callback)
&& elapsed > *threshold
{
let alert = SlowSubscriberAlert {
subscriber_id: *subscriber_id,
duration_ns: elapsed,
threshold_ns: *threshold,
};
#[cfg(panic = "unwind")]
let _ = catch_unwind(AssertUnwindSafe(|| {
cb(alert);
}));
#[cfg(panic = "abort")]
cb(alert);
}
}
}
if let (Some(metrics), Some(start)) = (&self.inner.metrics, notify_start) {
let duration_ns = start.elapsed().as_nanos() as u64;
metrics.record_notify(duration_ns);
#[cfg(feature = "trace")]
crate::tracing::record_notify(
self.inner.id,
duration_ns,
self.inner.subscribers.len(),
self.inner.owned.len(),
metrics.slowest_subscriber_ns(),
);
}
}
}
impl<T: CellValue, U: Send + Sync + 'static> Gettable<T> for Cell<T, U> {
fn get(&self) -> T {
(**self.inner.value.load()).clone()
}
}
impl<T: CellValue, U: Send + Sync + 'static> Watchable<T> for Cell<T, U> {
fn subscribe(
&self,
callback: impl Fn(&Signal<T>) + Send + Sync + 'static,
) -> SubscriptionGuard {
callback(&Signal::Value(self.inner.value.load_full()));
if self.is_complete() {
callback(&Signal::Complete);
} else if self.is_error()
&& let Some(err) = self.error()
{
callback(&Signal::Error(err));
}
let id = Uuid::new_v4();
self.inner
.subscribers
.insert(id, Box::new(Subscriber::new(callback)));
if let Some(metrics) = &self.inner.metrics {
metrics.record_subscriber_added();
}
#[cfg(feature = "trace")]
crate::tracing::update_subscriber_count(self.inner.id, self.inner.subscribers.len());
let source: Arc<dyn DepNode> = Arc::new(self.clone());
let cell = self.clone();
let metrics = self.inner.metrics.clone();
SubscriptionGuard::new(id, source, move || {
cell.inner.subscribers.remove(&id);
if let Some(m) = &metrics {
m.record_subscriber_removed();
}
#[cfg(feature = "trace")]
crate::tracing::update_subscriber_count(cell.inner.id, cell.inner.subscribers.len());
})
}
fn unsubscribe(&self, id: Uuid) {
if self.inner.subscribers.remove(&id).is_some() {
if let Some(metrics) = &self.inner.metrics {
metrics.record_subscriber_removed();
}
#[cfg(feature = "trace")]
crate::tracing::update_subscriber_count(self.inner.id, self.inner.subscribers.len());
}
}
fn is_complete(&self) -> bool {
self.inner.completed.load(Ordering::SeqCst)
}
fn is_error(&self) -> bool {
self.inner.errored.load(Ordering::SeqCst)
}
fn error(&self) -> Option<Arc<anyhow::Error>> {
(**self.inner.error.load()).clone()
}
}
impl<T: CellValue> Mutable<T> for Cell<T, CellMutable> {
fn set(&self, value: T) {
self.notify(Signal::value(value)); }
fn complete(&self) {
self.notify(Signal::Complete);
}
fn fail(&self, error: impl Into<anyhow::Error>) {
self.notify(Signal::error(error));
}
}
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
impl<T: CellValue> DepNode for CellInner<T> {
fn id(&self) -> Uuid {
self.id
}
fn name(&self) -> Option<String> {
(**self.name.load()).as_ref().map(|s| s.to_string())
}
fn deps(&self) -> Vec<Arc<dyn DepNode>> {
let mut seen = std::collections::HashSet::new();
self.owned
.iter()
.filter_map(|entry| {
let source = entry.value().source();
let id = source.id();
if seen.insert(id) {
Some(Arc::clone(source))
} else {
None
}
})
.collect()
}
fn subscriber_count(&self) -> usize {
self.subscribers.len()
}
fn owned_count(&self) -> usize {
self.owned.len()
}
fn value_debug(&self) -> Option<String> {
Some(format!("{:?}", &**self.value.load()))
}
fn caller(&self) -> Option<&'static Location<'static>> {
Some(self.caller)
}
}
impl<T> Drop for CellInner<T> {
fn drop(&mut self) {
#[cfg(feature = "trace")]
crate::tracing::deregister_cell(&self.id);
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
crate::registry::registry().deregister(&self.id);
}
}
#[cfg(feature = "trace")]
fn default_metrics() -> Option<Arc<CellMetrics>> {
Some(Arc::new(CellMetrics::new()))
}
#[cfg(not(feature = "trace"))]
fn default_metrics() -> Option<Arc<CellMetrics>> {
None
}