#![allow(clippy::type_complexity)]
use slotmap::{new_key_type, SlotMap};
use smallvec::SmallVec;
use std::any::Any;
use std::cell::{Cell, RefCell};
use std::collections::VecDeque;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
new_key_type! {
pub struct SignalId;
pub struct DerivedId;
pub struct EffectId;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum SubscriberId {
Derived(DerivedId),
Effect(EffectId),
}
#[derive(Debug)]
pub struct Signal<T> {
id: SignalId,
_marker: std::marker::PhantomData<T>,
}
impl<T> Clone for Signal<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for Signal<T> {}
impl<T> Signal<T> {
pub fn id(&self) -> SignalId {
self.id
}
pub fn from_id(id: SignalId) -> Self {
Signal {
id,
_marker: std::marker::PhantomData,
}
}
}
impl SignalId {
pub fn to_raw(&self) -> u64 {
use slotmap::Key;
self.data().as_ffi()
}
pub fn from_raw(raw: u64) -> Self {
slotmap::KeyData::from_ffi(raw).into()
}
}
#[derive(Debug)]
pub struct Derived<T> {
id: DerivedId,
_marker: std::marker::PhantomData<T>,
}
impl<T> Clone for Derived<T> {
fn clone(&self) -> Self {
*self
}
}
impl<T> Copy for Derived<T> {}
impl<T> Derived<T> {
pub fn id(&self) -> DerivedId {
self.id
}
}
#[derive(Debug, Clone, Copy)]
pub struct Effect {
id: EffectId,
}
impl Effect {
pub fn id(&self) -> EffectId {
self.id
}
}
struct SignalNode {
value: Box<dyn Any + Send>,
version: u64,
subscribers: SmallVec<[SubscriberId; 4]>,
}
struct DerivedNode {
value: Option<Box<dyn Any + Send>>,
cached_version: u64,
compute: Box<dyn Fn(&ReactiveGraph) -> Box<dyn Any + Send> + Send>,
dependencies: SmallVec<[SignalId; 4]>,
subscribers: SmallVec<[SubscriberId; 4]>,
dirty: Cell<bool>,
depth: u32,
}
struct EffectNode {
run: Box<dyn FnMut(&ReactiveGraph) + Send>,
dependencies: SmallVec<[SignalId; 4]>,
dirty: Cell<bool>,
depth: u32,
}
pub struct ReactiveGraph {
signals: SlotMap<SignalId, SignalNode>,
derived: SlotMap<DerivedId, DerivedNode>,
effects: SlotMap<EffectId, EffectNode>,
pending_effects: RefCell<VecDeque<EffectId>>,
batch_depth: Cell<u32>,
tracking: RefCell<Option<Vec<SignalId>>>,
global_version: Cell<u64>,
}
impl ReactiveGraph {
pub fn new() -> Self {
Self {
signals: SlotMap::with_key(),
derived: SlotMap::with_key(),
effects: SlotMap::with_key(),
pending_effects: RefCell::new(VecDeque::new()),
batch_depth: Cell::new(0),
tracking: RefCell::new(None),
global_version: Cell::new(0),
}
}
pub fn create_signal<T: Send + 'static>(&mut self, initial: T) -> Signal<T> {
let id = self.signals.insert(SignalNode {
value: Box::new(initial),
version: 0,
subscribers: SmallVec::new(),
});
Signal {
id,
_marker: std::marker::PhantomData,
}
}
pub fn get<T: Clone + 'static>(&self, signal: Signal<T>) -> Option<T> {
if let Some(ref mut deps) = *self.tracking.borrow_mut() {
if !deps.contains(&signal.id) {
deps.push(signal.id);
}
}
self.signals
.get(signal.id)
.and_then(|node| node.value.downcast_ref::<T>().cloned())
}
pub fn get_untracked<T: Clone + 'static>(&self, signal: Signal<T>) -> Option<T> {
self.signals
.get(signal.id)
.and_then(|node| node.value.downcast_ref::<T>().cloned())
}
pub fn set<T: Send + 'static>(&mut self, signal: Signal<T>, value: T) {
if let Some(node) = self.signals.get_mut(signal.id) {
node.value = Box::new(value);
node.version += 1;
self.global_version.set(self.global_version.get() + 1);
let subscribers: SmallVec<[SubscriberId; 4]> = node.subscribers.clone();
for sub in subscribers {
self.mark_dirty(sub);
}
if self.batch_depth.get() == 0 {
self.flush_effects();
}
}
}
pub fn update<T: Clone + Send + 'static, F: FnOnce(T) -> T>(
&mut self,
signal: Signal<T>,
f: F,
) {
if let Some(current) = self.get_untracked(signal) {
self.set(signal, f(current));
}
}
pub fn signal_version(&self, id: SignalId) -> Option<u64> {
self.signals.get(id).map(|n| n.version)
}
pub fn create_derived<T, F>(&mut self, compute: F) -> Derived<T>
where
T: Clone + Send + 'static,
F: Fn(&ReactiveGraph) -> T + Send + 'static,
{
let compute_boxed =
move |graph: &ReactiveGraph| -> Box<dyn Any + Send> { Box::new(compute(graph)) };
let id = self.derived.insert(DerivedNode {
value: None,
cached_version: 0,
compute: Box::new(compute_boxed),
dependencies: SmallVec::new(),
subscribers: SmallVec::new(),
dirty: Cell::new(true), depth: 0,
});
Derived {
id,
_marker: std::marker::PhantomData,
}
}
pub fn get_derived<T: Clone + 'static>(&mut self, derived: Derived<T>) -> Option<T> {
let node = self.derived.get(derived.id)?;
if !node.dirty.get() {
if let Some(ref cached) = node.value {
return cached.downcast_ref::<T>().cloned();
}
}
self.tracking.replace(Some(Vec::new()));
let compute: *const Box<dyn Fn(&ReactiveGraph) -> Box<dyn Any + Send> + Send> = {
let node = self.derived.get(derived.id)?;
node.dirty.set(false);
&node.compute as *const _
};
let value = unsafe { (*compute)(self) };
let deps = self.tracking.take().unwrap_or_default();
if let Some(node) = self.derived.get_mut(derived.id) {
for &dep_id in &node.dependencies {
if let Some(sig) = self.signals.get_mut(dep_id) {
sig.subscribers
.retain(|s| *s != SubscriberId::Derived(derived.id));
}
}
for &dep_id in &deps {
if let Some(sig) = self.signals.get_mut(dep_id) {
let sub = SubscriberId::Derived(derived.id);
if !sig.subscribers.contains(&sub) {
sig.subscribers.push(sub);
}
}
}
let max_dep_depth = deps
.iter()
.filter_map(|&id| self.signals.get(id))
.map(|_| 0u32) .max()
.unwrap_or(0);
node.dependencies = deps.into_iter().collect();
node.depth = max_dep_depth + 1;
node.cached_version = self.global_version.get();
let result = value.downcast_ref::<T>().cloned();
node.value = Some(value);
result
} else {
None
}
}
pub fn create_effect<F>(&mut self, run: F) -> Effect
where
F: FnMut(&ReactiveGraph) + Send + 'static,
{
let id = self.effects.insert(EffectNode {
run: Box::new(run),
dependencies: SmallVec::new(),
dirty: Cell::new(true), depth: 0,
});
self.pending_effects.borrow_mut().push_back(id);
if self.batch_depth.get() == 0 {
self.flush_effects();
}
Effect { id }
}
pub fn dispose_effect(&mut self, effect: Effect) {
if let Some(node) = self.effects.remove(effect.id) {
for &dep_id in &node.dependencies {
if let Some(sig) = self.signals.get_mut(dep_id) {
sig.subscribers
.retain(|s| *s != SubscriberId::Effect(effect.id));
}
}
}
}
pub fn batch_start(&self) {
self.batch_depth.set(self.batch_depth.get() + 1);
}
pub fn batch_end(&mut self) {
let depth = self.batch_depth.get();
if depth > 0 {
self.batch_depth.set(depth - 1);
if depth == 1 {
self.flush_effects();
}
}
}
pub fn batch<F, R>(&mut self, f: F) -> R
where
F: FnOnce(&mut Self) -> R,
{
self.batch_start();
let result = f(self);
self.batch_end();
result
}
fn mark_dirty(&mut self, sub: SubscriberId) {
match sub {
SubscriberId::Derived(id) => {
if let Some(node) = self.derived.get(id) {
if !node.dirty.get() {
node.dirty.set(true);
let subscribers: SmallVec<[SubscriberId; 4]> = node.subscribers.clone();
for sub in subscribers {
self.mark_dirty(sub);
}
}
}
}
SubscriberId::Effect(id) => {
if let Some(node) = self.effects.get(id) {
if !node.dirty.get() {
node.dirty.set(true);
self.pending_effects.borrow_mut().push_back(id);
}
}
}
}
}
fn flush_effects(&mut self) {
let mut effects: Vec<EffectId> = self.pending_effects.borrow_mut().drain(..).collect();
effects.sort_by_key(|id| self.effects.get(*id).map(|n| n.depth).unwrap_or(0));
for effect_id in effects {
self.run_effect(effect_id);
}
}
fn run_effect(&mut self, effect_id: EffectId) {
let should_run = self
.effects
.get(effect_id)
.map(|n| n.dirty.get())
.unwrap_or(false);
if !should_run {
return;
}
self.tracking.replace(Some(Vec::new()));
let run_ptr: *mut Box<dyn FnMut(&ReactiveGraph) + Send> = {
if let Some(node) = self.effects.get_mut(effect_id) {
node.dirty.set(false);
&mut node.run as *mut _
} else {
return;
}
};
unsafe {
(*run_ptr)(self);
}
let deps = self.tracking.take().unwrap_or_default();
if let Some(node) = self.effects.get_mut(effect_id) {
for &dep_id in &node.dependencies {
if let Some(sig) = self.signals.get_mut(dep_id) {
sig.subscribers
.retain(|s| *s != SubscriberId::Effect(effect_id));
}
}
for &dep_id in &deps {
if let Some(sig) = self.signals.get_mut(dep_id) {
let sub = SubscriberId::Effect(effect_id);
if !sig.subscribers.contains(&sub) {
sig.subscribers.push(sub);
}
}
}
node.dependencies = deps.into_iter().collect();
}
}
pub fn stats(&self) -> ReactiveStats {
ReactiveStats {
signal_count: self.signals.len(),
derived_count: self.derived.len(),
effect_count: self.effects.len(),
pending_effects: self.pending_effects.borrow().len(),
global_version: self.global_version.get(),
}
}
}
impl Default for ReactiveGraph {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct ReactiveStats {
pub signal_count: usize,
pub derived_count: usize,
pub effect_count: usize,
pub pending_effects: usize,
pub global_version: u64,
}
pub type SharedReactiveGraph = Arc<Mutex<ReactiveGraph>>;
pub type DirtyFlag = Arc<AtomicBool>;
pub type StatefulDepsCallback = Arc<dyn Fn(&[SignalId]) + Send + Sync>;
#[derive(Clone)]
pub struct State<T> {
signal: Signal<T>,
reactive: SharedReactiveGraph,
dirty_flag: DirtyFlag,
stateful_deps_callback: Option<StatefulDepsCallback>,
}
impl<T: Clone + Send + 'static> State<T> {
pub fn new(signal: Signal<T>, reactive: SharedReactiveGraph, dirty_flag: DirtyFlag) -> Self {
Self {
signal,
reactive,
dirty_flag,
stateful_deps_callback: None,
}
}
pub fn with_stateful_callback(
signal: Signal<T>,
reactive: SharedReactiveGraph,
dirty_flag: DirtyFlag,
callback: StatefulDepsCallback,
) -> Self {
Self {
signal,
reactive,
dirty_flag,
stateful_deps_callback: Some(callback),
}
}
pub fn get(&self) -> T
where
T: Default,
{
self.reactive
.lock()
.unwrap()
.get(self.signal)
.unwrap_or_default()
}
pub fn try_get(&self) -> Option<T> {
self.reactive.lock().unwrap().get(self.signal)
}
pub fn set(&self, value: T) {
self.reactive.lock().unwrap().set(self.signal, value);
if let Some(ref callback) = self.stateful_deps_callback {
callback(&[self.signal.id()]);
}
}
pub fn set_rebuild(&self, value: T) {
self.reactive.lock().unwrap().set(self.signal, value);
self.dirty_flag.store(true, Ordering::SeqCst);
}
pub fn update(&self, f: impl FnOnce(T) -> T) {
self.reactive.lock().unwrap().update(self.signal, f);
if let Some(ref callback) = self.stateful_deps_callback {
callback(&[self.signal.id()]);
}
}
pub fn update_rebuild(&self, f: impl FnOnce(T) -> T) {
self.reactive.lock().unwrap().update(self.signal, f);
self.dirty_flag.store(true, Ordering::SeqCst);
}
pub fn signal(&self) -> Signal<T> {
self.signal
}
pub fn signal_id(&self) -> SignalId {
self.signal.id()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
#[test]
fn test_signal_create_get_set() {
let mut graph = ReactiveGraph::new();
let count = graph.create_signal(0i32);
assert_eq!(graph.get(count), Some(0));
graph.set(count, 42);
assert_eq!(graph.get(count), Some(42));
}
#[test]
fn test_signal_update() {
let mut graph = ReactiveGraph::new();
let count = graph.create_signal(10i32);
graph.update(count, |x| x + 5);
assert_eq!(graph.get(count), Some(15));
}
#[test]
fn test_derived_basic() {
let mut graph = ReactiveGraph::new();
let count = graph.create_signal(5i32);
let doubled = graph.create_derived(move |g| g.get(count).unwrap_or(0) * 2);
assert_eq!(graph.get_derived(doubled), Some(10));
graph.set(count, 7);
assert_eq!(graph.get_derived(doubled), Some(14));
}
#[test]
fn test_derived_caching() {
let mut graph = ReactiveGraph::new();
let compute_count = Arc::new(Mutex::new(0));
let count = graph.create_signal(5i32);
let compute_count_clone = compute_count.clone();
let doubled = graph.create_derived(move |g| {
*compute_count_clone.lock().unwrap() += 1;
g.get(count).unwrap_or(0) * 2
});
assert_eq!(graph.get_derived(doubled), Some(10));
assert_eq!(*compute_count.lock().unwrap(), 1);
assert_eq!(graph.get_derived(doubled), Some(10));
assert_eq!(*compute_count.lock().unwrap(), 1);
graph.set(count, 7);
assert_eq!(graph.get_derived(doubled), Some(14));
assert_eq!(*compute_count.lock().unwrap(), 2);
}
#[test]
fn test_effect_runs_on_change() {
let mut graph = ReactiveGraph::new();
let effect_runs = Arc::new(Mutex::new(Vec::new()));
let count = graph.create_signal(0i32);
let effect_runs_clone = effect_runs.clone();
let _effect = graph.create_effect(move |g| {
let val = g.get(count).unwrap_or(0);
effect_runs_clone.lock().unwrap().push(val);
});
assert_eq!(*effect_runs.lock().unwrap(), vec![0]);
graph.set(count, 1);
assert_eq!(*effect_runs.lock().unwrap(), vec![0, 1]);
graph.set(count, 2);
assert_eq!(*effect_runs.lock().unwrap(), vec![0, 1, 2]);
}
#[test]
fn test_batching() {
let mut graph = ReactiveGraph::new();
let effect_runs = Arc::new(Mutex::new(0));
let a = graph.create_signal(1i32);
let b = graph.create_signal(2i32);
let effect_runs_clone = effect_runs.clone();
let _effect = graph.create_effect(move |g| {
let _a = g.get(a);
let _b = g.get(b);
*effect_runs_clone.lock().unwrap() += 1;
});
assert_eq!(*effect_runs.lock().unwrap(), 1);
*effect_runs.lock().unwrap() = 0;
graph.set(a, 10);
graph.set(b, 20);
assert_eq!(*effect_runs.lock().unwrap(), 2);
*effect_runs.lock().unwrap() = 0;
graph.batch(|g| {
g.set(a, 100);
g.set(b, 200);
});
assert_eq!(*effect_runs.lock().unwrap(), 1);
}
#[test]
fn test_dispose_effect() {
let mut graph = ReactiveGraph::new();
let effect_runs = Arc::new(Mutex::new(0));
let count = graph.create_signal(0i32);
let effect_runs_clone = effect_runs.clone();
let effect = graph.create_effect(move |g| {
let _val = g.get(count);
*effect_runs_clone.lock().unwrap() += 1;
});
assert_eq!(*effect_runs.lock().unwrap(), 1);
graph.set(count, 1);
assert_eq!(*effect_runs.lock().unwrap(), 2);
graph.dispose_effect(effect);
graph.set(count, 2);
assert_eq!(*effect_runs.lock().unwrap(), 2);
}
#[test]
fn test_multiple_signals() {
let mut graph = ReactiveGraph::new();
let a = graph.create_signal(1i32);
let b = graph.create_signal(2i32);
let c = graph.create_signal(3i32);
let sum = graph.create_derived(move |g| {
g.get(a).unwrap_or(0) + g.get(b).unwrap_or(0) + g.get(c).unwrap_or(0)
});
assert_eq!(graph.get_derived(sum), Some(6));
graph.set(b, 10);
assert_eq!(graph.get_derived(sum), Some(14));
}
#[test]
fn test_stats() {
let mut graph = ReactiveGraph::new();
let _s1 = graph.create_signal(1);
let _s2 = graph.create_signal(2);
let _d1 = graph.create_derived(|_| 0);
let stats = graph.stats();
assert_eq!(stats.signal_count, 2);
assert_eq!(stats.derived_count, 1);
}
}