use std::{
marker::PhantomData,
panic::Location,
sync::{
Arc, Weak,
atomic::{AtomicBool, Ordering},
},
};
use dashmap::DashMap;
use uuid::Uuid;
use crate::{
cell::{Cell, CellImmutable, CellMutable, Subscriber, SubscriberCallback},
pipeline::{Empty, MaterializeEmpty, Pipeline, PipelineInstall},
signal::Signal,
subscription::SubscriptionGuard,
traits::{CellValue, DepNode, Gettable, Mutable, Watchable},
};
pub(crate) struct SourceInner<T> {
pub(crate) id: Uuid,
pub(crate) subscribers: parking_lot::Mutex<Arc<Vec<(Uuid, Arc<Subscriber<T>>)>>>,
pub(crate) owned: DashMap<Uuid, SubscriptionGuard>,
pub(crate) completed: AtomicBool,
#[allow(dead_code)]
pub(crate) caller: &'static Location<'static>,
}
pub struct Source<T> {
pub(crate) inner: Arc<SourceInner<T>>,
_marker: PhantomData<T>,
}
pub struct WeakSource<T> {
inner: Weak<SourceInner<T>>,
_marker: PhantomData<T>,
}
impl<T> WeakSource<T> {
pub fn upgrade(&self) -> Option<Source<T>> {
self.inner.upgrade().map(|inner| Source {
inner,
_marker: PhantomData,
})
}
}
impl<T> Clone for WeakSource<T> {
fn clone(&self) -> Self {
WeakSource {
inner: self.inner.clone(),
_marker: PhantomData,
}
}
}
impl<T> Clone for Source<T> {
fn clone(&self) -> Self {
Source {
inner: Arc::clone(&self.inner),
_marker: PhantomData,
}
}
}
impl<T: CellValue> Source<T> {
#[track_caller]
pub fn new() -> Self {
let inner = Arc::new(SourceInner {
id: Uuid::new_v4(),
subscribers: parking_lot::Mutex::new(Arc::new(Vec::new())),
owned: DashMap::new(),
completed: AtomicBool::new(false),
caller: Location::caller(),
});
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
crate::registry::registry().register(inner.id, Arc::downgrade(&inner) as Weak<dyn DepNode>);
#[cfg(feature = "trace")]
crate::tracing::register_cell(inner.id, Some(Location::caller().to_string()));
Self {
inner,
_marker: PhantomData,
}
}
pub fn downgrade(&self) -> WeakSource<T> {
WeakSource {
inner: Arc::downgrade(&self.inner),
_marker: PhantomData,
}
}
pub fn own(&self, guard: SubscriptionGuard) {
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
crate::registry::registry().mark_owned(guard.source().id(), self.inner.id);
self.inner.owned.insert(Uuid::new_v4(), guard);
}
pub fn is_complete(&self) -> bool {
self.inner.completed.load(Ordering::SeqCst)
}
#[allow(unused_variables)]
pub fn with_name(self, name: impl Into<std::sync::Arc<str>>) -> Self {
#[cfg(feature = "trace")]
crate::tracing::update_name(self.inner.id, name.into().to_string());
self
}
#[track_caller]
pub fn subscribe(
&self,
callback: impl Fn(&Signal<T>) + Send + Sync + 'static,
) -> SubscriptionGuard {
let id = Uuid::new_v4();
let cb: SubscriberCallback<T> = Arc::new(callback);
let sub = Arc::new(Subscriber { callback: cb });
let _old = {
let mut guard = self.inner.subscribers.lock();
let mut next: Vec<(Uuid, Arc<Subscriber<T>>)> = (**guard).clone();
next.push((id, sub));
std::mem::replace(&mut *guard, Arc::new(next))
};
let source: Arc<dyn DepNode> = Arc::new(self.clone());
let inner = self.inner.clone();
SubscriptionGuard::new(id, source, move || {
let _old = {
let mut guard = inner.subscribers.lock();
let prev_len = guard.len();
let next: Vec<(Uuid, Arc<Subscriber<T>>)> = (**guard)
.iter()
.filter(|(i, _)| *i != id)
.cloned()
.collect();
if next.len() != prev_len {
Some(std::mem::replace(&mut *guard, Arc::new(next)))
} else {
None
}
};
let _ = _old;
})
}
pub fn emit(&self, value: T) {
if self.inner.completed.load(Ordering::SeqCst) {
return;
}
let arc_value = Arc::new(value);
let signal = Signal::Value(arc_value);
let subs = Arc::clone(&*self.inner.subscribers.lock());
for (_id, sub) in subs.iter() {
(sub.callback)(&signal);
}
}
pub fn complete(&self) {
if self.inner.completed.swap(true, Ordering::SeqCst) {
return;
}
let signal: Signal<T> = Signal::Complete;
let subs = Arc::clone(&*self.inner.subscribers.lock());
for (_id, sub) in subs.iter() {
(sub.callback)(&signal);
}
}
}
impl<T: CellValue> Default for Source<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: CellValue> PipelineInstall<T> for Source<T> {
fn install(&self, callback: Arc<dyn Fn(&Signal<T>) + Send + Sync>) -> SubscriptionGuard {
self.subscribe(move |signal| callback(signal))
}
}
#[allow(private_bounds)]
impl<T: CellValue> Pipeline<T, Empty> for Source<T> {}
impl<T: CellValue> MaterializeEmpty<T> for Source<T> {}
pub trait SampleOnSourceExt<T: CellValue>: Watchable<T> + Gettable<T> {
#[track_caller]
fn sample_on<U>(&self, notifier: &Source<U>) -> Cell<T, CellImmutable>
where
U: CellValue,
{
let source = self.clone();
let cell = Cell::<T, CellMutable>::new(source.get());
let cell_for_cb = cell.clone();
let guard = notifier.subscribe(move |signal| {
if let Signal::Value(_) = signal {
cell_for_cb.set(source.get());
}
});
cell.own(guard);
cell.lock()
}
}
impl<T: CellValue, W: Watchable<T> + Gettable<T>> SampleOnSourceExt<T> for W {}
impl<T: Send + Sync> DepNode for Source<T> {
fn id(&self) -> Uuid {
self.inner.id
}
fn name(&self) -> Option<String> {
None
}
fn deps(&self) -> Vec<Arc<dyn DepNode>> {
Vec::new()
}
fn subscriber_count(&self) -> usize {
Arc::clone(&*self.inner.subscribers.lock()).len()
}
fn owned_count(&self) -> usize {
self.inner.owned.len()
}
}
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
impl<T: Send + Sync> DepNode for SourceInner<T> {
fn id(&self) -> Uuid {
self.id
}
fn name(&self) -> Option<String> {
None
}
fn deps(&self) -> Vec<Arc<dyn DepNode>> {
Vec::new()
}
fn subscriber_count(&self) -> usize {
Arc::clone(&*self.subscribers.lock()).len()
}
fn owned_count(&self) -> usize {
self.owned.len()
}
}
impl<T> Drop for SourceInner<T> {
fn drop(&mut self) {
#[cfg(feature = "trace")]
crate::tracing::deregister_cell(&self.id);
#[cfg(all(feature = "inspector", not(target_arch = "wasm32")))]
crate::registry::registry().deregister(&self.id);
}
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, atomic::Ordering};
use super::*;
#[test]
fn emit_fires_subscribers() {
let src: Source<u64> = Source::new();
let counter = Arc::new(std::sync::atomic::AtomicU64::new(0));
let counter_clone = counter.clone();
let _guard = src.subscribe(move |signal| {
if let Signal::Value(v) = signal {
counter_clone.fetch_add(**v, Ordering::SeqCst);
}
});
src.emit(1);
src.emit(2);
src.emit(3);
assert_eq!(counter.load(Ordering::SeqCst), 6);
}
#[test]
fn no_initial_emission_on_subscribe() {
let src: Source<u64> = Source::new();
src.emit(42);
let saw_anything = Arc::new(std::sync::atomic::AtomicBool::new(false));
let saw_clone = saw_anything.clone();
let _guard = src.subscribe(move |signal| {
if let Signal::Value(_) = signal {
saw_clone.store(true, Ordering::SeqCst);
}
});
assert!(!saw_anything.load(Ordering::SeqCst));
src.emit(7);
assert!(saw_anything.load(Ordering::SeqCst));
}
#[test]
fn unsubscribe_via_guard_drop() {
let src: Source<u64> = Source::new();
let counter = Arc::new(std::sync::atomic::AtomicU64::new(0));
let counter_clone = counter.clone();
let guard = src.subscribe(move |signal| {
if let Signal::Value(v) = signal {
counter_clone.fetch_add(**v, Ordering::SeqCst);
}
});
src.emit(5);
drop(guard);
src.emit(10);
assert_eq!(counter.load(Ordering::SeqCst), 5);
}
#[test]
fn complete_blocks_further_emits() {
let src: Source<u64> = Source::new();
let counter = Arc::new(std::sync::atomic::AtomicU64::new(0));
let counter_clone = counter.clone();
let _guard = src.subscribe(move |signal| match signal {
Signal::Value(v) => {
counter_clone.fetch_add(**v, Ordering::SeqCst);
}
Signal::Complete => {
counter_clone.fetch_add(1000, Ordering::SeqCst);
}
_ => {}
});
src.emit(1);
src.complete();
src.emit(99); assert_eq!(counter.load(Ordering::SeqCst), 1 + 1000);
}
}