use std::{
any::Any,
cell::RefCell,
collections::{HashMap, HashSet, VecDeque},
marker::PhantomData,
sync::{Arc, RwLock},
};
type AnyValue = Box<dyn Any + Send + Sync>;
type ArcThunk = Arc<dyn Fn() -> AnyValue + Send + Sync>;
thread_local! {
static COMPUTATION_STACK: RefCell<Vec<NodeId>> = const { RefCell::new(Vec::new()) };
}
fn stack_push(id: NodeId) {
COMPUTATION_STACK.with(|s| s.borrow_mut().push(id));
}
fn stack_pop() {
COMPUTATION_STACK.with(|s| {
s.borrow_mut().pop();
});
}
fn stack_top() -> Option<NodeId> {
COMPUTATION_STACK.with(|s| s.borrow().last().copied())
}
fn stack_contains(id: NodeId) -> bool {
COMPUTATION_STACK.with(|s| s.borrow().contains(&id))
}
#[derive(Debug, Clone, PartialEq)]
pub enum ReactiveError {
Cycle,
DependencyCycle,
TypeMismatch,
}
impl std::fmt::Display for ReactiveError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReactiveError::Cycle => {
write!(f, "reactive cycle: a computed node reads its own value")
}
ReactiveError::DependencyCycle => {
write!(f, "reactive dependency cycle detected on edge insertion")
}
ReactiveError::TypeMismatch => write!(
f,
"reactive type mismatch: stored type does not match requested type"
),
}
}
}
impl std::error::Error for ReactiveError {}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
pub struct NodeId(u64);
enum NodeKind {
Signal {
value: AnyValue,
},
Computed {
thunk: ArcThunk,
cached: Option<AnyValue>,
dirty: bool,
},
}
struct RuntimeInner {
nodes: HashMap<NodeId, NodeKind>,
deps: HashMap<NodeId, Vec<NodeId>>,
next_id: u64,
}
impl RuntimeInner {
fn new() -> Self {
Self {
nodes: HashMap::new(),
deps: HashMap::new(),
next_id: 0,
}
}
fn alloc_id(&mut self) -> NodeId {
let id = NodeId(self.next_id);
self.next_id += 1;
id
}
fn reachable(&self, start: NodeId, target: NodeId) -> bool {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(start);
while let Some(current) = queue.pop_front() {
if current == target {
return true;
}
if !visited.insert(current) {
continue;
}
if let Some(dependents) = self.deps.get(¤t) {
for &dep in dependents {
queue.push_back(dep);
}
}
}
false
}
fn try_add_dependency(&mut self, source: NodeId, caller: NodeId) -> Result<(), ReactiveError> {
if self.reachable(caller, source) {
return Err(ReactiveError::DependencyCycle);
}
let dependents = self.deps.entry(source).or_default();
if !dependents.contains(&caller) {
dependents.push(caller);
}
Ok(())
}
fn mark_dirty_transitive(&mut self, id: NodeId) {
let mut visited = HashSet::new();
let mut queue = VecDeque::new();
if let Some(dependents) = self.deps.get(&id) {
for &dep in dependents {
queue.push_back(dep);
}
}
while let Some(current) = queue.pop_front() {
if !visited.insert(current) {
continue;
}
if let Some(NodeKind::Computed { dirty, .. }) = self.nodes.get_mut(¤t) {
*dirty = true;
}
if let Some(dependents) = self.deps.get(¤t) {
for &dep in dependents {
queue.push_back(dep);
}
}
}
}
}
#[derive(Clone)]
pub struct ReactiveRuntime {
inner: Arc<RwLock<RuntimeInner>>,
}
impl Default for ReactiveRuntime {
fn default() -> Self {
Self::new()
}
}
impl ReactiveRuntime {
pub fn new() -> Self {
Self {
inner: Arc::new(RwLock::new(RuntimeInner::new())),
}
}
pub fn signal<T: Send + Sync + Clone + 'static>(&self, initial: T) -> Signal<T> {
let mut inner = self
.inner
.write()
.expect("ReactiveRuntime::signal: RwLock poisoned");
let id = inner.alloc_id();
inner.nodes.insert(
id,
NodeKind::Signal {
value: Box::new(initial),
},
);
drop(inner);
Signal {
runtime: Arc::clone(&self.inner),
id,
_phantom: PhantomData,
}
}
pub fn computed<T: Send + Sync + Clone + 'static>(
&self,
f: impl Fn() -> T + Send + Sync + 'static,
) -> Result<Computed<T>, ReactiveError> {
let thunk: ArcThunk = Arc::new(move || Box::new(f()) as AnyValue);
let mut inner = self
.inner
.write()
.expect("ReactiveRuntime::computed: RwLock poisoned");
let id = inner.alloc_id();
inner.nodes.insert(
id,
NodeKind::Computed {
thunk,
cached: None,
dirty: true,
},
);
drop(inner);
Ok(Computed {
runtime: Arc::clone(&self.inner),
id,
_phantom: PhantomData,
})
}
}
pub struct Signal<T: Send + Sync + Clone + 'static> {
runtime: Arc<RwLock<RuntimeInner>>,
id: NodeId,
_phantom: PhantomData<T>,
}
impl<T: Send + Sync + Clone + 'static> Clone for Signal<T> {
fn clone(&self) -> Self {
Self {
runtime: Arc::clone(&self.runtime),
id: self.id,
_phantom: PhantomData,
}
}
}
impl<T: Send + Sync + Clone + 'static> Signal<T> {
pub fn get(&self) -> T {
let value = {
let inner = self.runtime.read().expect("Signal::get: RwLock poisoned");
match inner.nodes.get(&self.id) {
Some(NodeKind::Signal { value }) => value
.downcast_ref::<T>()
.expect("Signal<T> type invariant: stored type must match T")
.clone(),
_ => panic!("Signal<T> node not found or wrong kind"),
}
};
if let Some(caller) = stack_top() {
let mut inner = self
.runtime
.write()
.expect("Signal::get dep-reg: RwLock poisoned");
let _ = inner.try_add_dependency(self.id, caller);
}
value
}
pub fn set(&self, value: T) {
let mut inner = self.runtime.write().expect("Signal::set: RwLock poisoned");
match inner.nodes.get_mut(&self.id) {
Some(NodeKind::Signal { value: stored }) => {
*stored = Box::new(value);
}
_ => panic!("Signal<T> node not found or wrong kind"),
}
inner.mark_dirty_transitive(self.id);
}
}
pub struct Computed<T: Send + Sync + Clone + 'static> {
runtime: Arc<RwLock<RuntimeInner>>,
id: NodeId,
_phantom: PhantomData<T>,
}
impl<T: Send + Sync + Clone + 'static> Clone for Computed<T> {
fn clone(&self) -> Self {
Self {
runtime: Arc::clone(&self.runtime),
id: self.id,
_phantom: PhantomData,
}
}
}
impl<T: Send + Sync + Clone + 'static> Computed<T> {
pub fn get(&self) -> Result<T, ReactiveError> {
if stack_contains(self.id) {
return Err(ReactiveError::Cycle);
}
if let Some(caller) = stack_top() {
let mut inner = self
.runtime
.write()
.expect("Computed::get dep-reg: RwLock poisoned");
inner.try_add_dependency(self.id, caller)?;
}
let is_dirty = {
let inner = self
.runtime
.read()
.expect("Computed::get dirty-check: RwLock poisoned");
match inner.nodes.get(&self.id) {
Some(NodeKind::Computed { dirty, .. }) => *dirty,
_ => return Err(ReactiveError::TypeMismatch),
}
};
if !is_dirty {
let inner = self
.runtime
.read()
.expect("Computed::get cached-read: RwLock poisoned");
return match inner.nodes.get(&self.id) {
Some(NodeKind::Computed {
cached: Some(v), ..
}) => v
.downcast_ref::<T>()
.cloned()
.ok_or(ReactiveError::TypeMismatch),
_ => Err(ReactiveError::TypeMismatch),
};
}
let thunk: ArcThunk = {
let inner = self
.runtime
.read()
.expect("Computed::get thunk-clone: RwLock poisoned");
match inner.nodes.get(&self.id) {
Some(NodeKind::Computed { thunk, .. }) => Arc::clone(thunk),
_ => return Err(ReactiveError::TypeMismatch),
}
};
stack_push(self.id);
let new_value: AnyValue = thunk();
stack_pop();
{
let mut inner = self
.runtime
.write()
.expect("Computed::get store: RwLock poisoned");
match inner.nodes.get_mut(&self.id) {
Some(NodeKind::Computed { cached, dirty, .. }) => {
*cached = Some(new_value);
*dirty = false;
}
_ => return Err(ReactiveError::TypeMismatch),
}
}
let inner = self
.runtime
.read()
.expect("Computed::get final-read: RwLock poisoned");
match inner.nodes.get(&self.id) {
Some(NodeKind::Computed {
cached: Some(v), ..
}) => v
.downcast_ref::<T>()
.cloned()
.ok_or(ReactiveError::TypeMismatch),
_ => Err(ReactiveError::TypeMismatch),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_signal_get_set() {
let rt = ReactiveRuntime::new();
let s = rt.signal(42i32);
assert_eq!(s.get(), 42);
s.set(99);
assert_eq!(s.get(), 99);
}
#[test]
fn test_computed_derives_from_signal() {
let rt = ReactiveRuntime::new();
let s = rt.signal(10i32);
let sc = s.clone();
let c = rt.computed(move || sc.get() * 2).expect("no cycle");
assert_eq!(c.get(), Ok(20));
s.set(5);
assert_eq!(c.get(), Ok(10));
}
#[test]
fn test_chain_propagation() {
let rt = ReactiveRuntime::new();
let a = rt.signal(1i32);
let ac = a.clone();
let b = rt.computed(move || ac.get() * 2).expect("b ok");
let bc = b.clone();
let c = rt.computed(move || bc.get().expect("b") + 1).expect("c ok");
assert_eq!(c.get(), Ok(3));
a.set(10);
assert_eq!(c.get(), Ok(21));
}
#[test]
fn test_cycle_detection_no_false_positive_diamond() {
let rt = ReactiveRuntime::new();
let a = rt.signal(2i32);
let ac1 = a.clone();
let ac2 = a.clone();
let b = rt.computed(move || ac1.get() * 3).expect("b ok");
let c = rt.computed(move || ac2.get() + 10).expect("c ok");
let bc = b.clone();
let cc = c.clone();
let d = rt
.computed(move || bc.get().expect("b") + cc.get().expect("c"))
.expect("diamond: no cycle");
assert_eq!(d.get(), Ok(18));
a.set(5);
assert_eq!(d.get(), Ok(30));
}
#[test]
fn test_cycle_detection_dep_graph_dfs() {
let rt = ReactiveRuntime::new();
let a = rt.signal(1i32);
let ac = a.clone();
let b = rt.computed(move || ac.get() + 1).expect("b ok");
let _ = b.get();
let result = {
let mut inner = rt.inner.write().unwrap();
inner.try_add_dependency(b.id, a.id)
};
assert_eq!(result, Err(ReactiveError::DependencyCycle));
}
#[test]
fn test_diamond_recomputes_correctly() {
let rt = ReactiveRuntime::new();
let a = rt.signal(1i32);
let ac1 = a.clone();
let ac2 = a.clone();
let b = rt.computed(move || ac1.get() * 2).expect("b");
let c = rt.computed(move || ac2.get() + 5).expect("c");
let bc = b.clone();
let cc = c.clone();
let d = rt
.computed(move || bc.get().expect("b") + cc.get().expect("c"))
.expect("d");
assert_eq!(d.get(), Ok(8));
a.set(3);
assert_eq!(d.get(), Ok(14));
a.set(0);
assert_eq!(d.get(), Ok(5));
}
#[test]
fn test_send_sync_bounds() {
fn _assert_send<T: Send>() {}
fn _assert_sync<T: Sync>() {}
fn _check(rt: ReactiveRuntime) {
let _: &dyn Send = &rt;
let _: &dyn Sync = &rt;
}
_assert_send::<ReactiveRuntime>();
_assert_sync::<ReactiveRuntime>();
_assert_send::<Signal<i32>>();
_assert_sync::<Signal<i32>>();
_assert_send::<Computed<i32>>();
_assert_sync::<Computed<i32>>();
}
#[test]
fn test_no_deadlock_nested_computed() {
use std::time::{Duration, Instant};
let rt = ReactiveRuntime::new();
let x = rt.signal(7i32);
let xc = x.clone();
let comp_a = rt.computed(move || xc.get() * 3).expect("a ok");
let ac = comp_a.clone();
let comp_b = rt.computed(move || ac.get().expect("a") + 1).expect("b ok");
let start = Instant::now();
let result = comp_b.get();
let elapsed = start.elapsed();
assert_eq!(result, Ok(22));
assert!(
elapsed < Duration::from_secs(1),
"get() should not deadlock (elapsed: {elapsed:?})",
);
}
}