use std::cell::{Cell, RefCell};
use std::collections::HashSet;
use std::fmt;
use std::rc::Rc;
use crate::observer::{ObserverState, OBSERVER};
use crate::signal::Signal;
type CleanupFn = Box<dyn FnOnce()>;
type SubscriptionList = Rc<RefCell<Vec<(SignalKey, CleanupFn)>>>;
#[derive(Eq, PartialEq, Hash, Clone, Copy)]
pub(crate) struct SignalKey {
addr: usize,
}
impl SignalKey {
pub(crate) fn new<T>(sig: &Signal<T>) -> Self {
Self {
addr: sig.state_addr(),
}
}
}
pub struct Memo<T> {
signal: Signal<T>,
dirty: Rc<Cell<bool>>,
compute: Rc<dyn Fn() -> T>,
subscriptions: SubscriptionList,
computing: Rc<Cell<bool>>,
compute_count: Rc<Cell<u64>>,
}
impl<T: Clone + 'static> Memo<T> {
#[must_use]
pub fn new(compute: impl Fn() -> T + 'static) -> Self {
let compute: Rc<dyn Fn() -> T> = Rc::new(compute);
let dirty = Rc::new(Cell::new(true));
let subscriptions: SubscriptionList = Rc::new(RefCell::new(Vec::new()));
let computing = Rc::new(Cell::new(false));
let compute_count = Rc::new(Cell::new(0));
let holder: Rc<RefCell<Option<Signal<T>>>> = Rc::new(RefCell::new(None));
let (value, _, _) = run_compute(
&compute,
&dirty,
&subscriptions,
&holder,
&computing,
&HashSet::new(),
);
let signal = Signal::new(value);
*holder.borrow_mut() = Some(signal.clone());
let memo = Self {
signal,
dirty,
compute,
subscriptions,
computing,
compute_count,
};
memo.dirty.set(false);
memo.compute_count.set(1);
memo
}
#[must_use]
pub fn read(&self) -> T {
if self.dirty.get() {
self.recompute();
}
self.signal.read()
}
#[must_use]
pub fn with<U>(&self, f: impl FnOnce(&T) -> U) -> U {
if self.dirty.get() {
self.recompute();
}
self.signal.with(f)
}
pub async fn changed(&self) -> T {
self.signal.changed().await;
self.read()
}
#[must_use]
pub fn is_dirty(&self) -> bool {
self.dirty.get()
}
#[must_use]
pub fn compute_count(&self) -> u64 {
self.compute_count.get()
}
fn recompute(&self) {
if self.computing.get() {
return;
}
self.computing.set(true);
let old_keys: HashSet<SignalKey> = self
.subscriptions
.borrow()
.iter()
.map(|(k, _)| *k)
.collect();
let new_subs: SubscriptionList = Rc::new(RefCell::new(Vec::new()));
let holder = Rc::new(RefCell::new(Some(self.signal.clone())));
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
run_compute(
&self.compute,
&self.dirty,
&new_subs,
&holder,
&self.computing,
&old_keys,
)
}));
match result {
Ok((new_value, _all_seen, re_read_keys)) => {
let new_keys: HashSet<SignalKey> =
new_subs.borrow().iter().map(|(k, _)| *k).collect();
let effective_read: HashSet<SignalKey> =
new_keys.union(&re_read_keys).copied().collect();
let mut old = self.subscriptions.borrow_mut();
let old_subs: Vec<(SignalKey, CleanupFn)> = std::mem::take(&mut *old);
let mut keep = Vec::with_capacity(old_subs.len().max(effective_read.len()));
for (key, cleanup) in old_subs {
if effective_read.contains(&key) {
keep.push((key, cleanup)); } else {
cleanup(); }
}
for (key, cleanup) in new_subs.borrow_mut().drain(..) {
if old_keys.contains(&key) {
cleanup();
} else {
keep.push((key, cleanup));
}
}
*old = keep;
drop(old);
self.signal.set(new_value);
self.dirty.set(false);
self.compute_count
.set(self.compute_count.get().wrapping_add(1));
}
Err(payload) => {
for (_, cleanup) in new_subs.borrow_mut().drain(..) {
cleanup();
}
self.computing.set(false);
std::panic::resume_unwind(payload);
}
}
self.computing.set(false);
}
}
struct ObserverGuard {
prev: Option<ObserverState>,
}
impl Drop for ObserverGuard {
fn drop(&mut self) {
OBSERVER.with(|cell| {
*cell.borrow_mut() = self.prev.take();
});
}
}
fn run_compute<T: Clone + 'static>(
compute: &Rc<dyn Fn() -> T>,
dirty: &Rc<Cell<bool>>,
subscriptions: &SubscriptionList,
signal_holder: &Rc<RefCell<Option<Signal<T>>>>,
computing: &Rc<Cell<bool>>,
pre_seen: &HashSet<SignalKey>,
) -> (T, HashSet<SignalKey>, HashSet<SignalKey>) {
let dirty2 = Rc::clone(dirty);
let subs = Rc::clone(subscriptions);
let holder = Rc::clone(signal_holder);
let computing2 = Rc::clone(computing);
let seen: Rc<RefCell<HashSet<SignalKey>>> = Rc::new(RefCell::new(pre_seen.clone()));
let seen2 = Rc::clone(&seen);
let re_read: Rc<RefCell<HashSet<SignalKey>>> = Rc::new(RefCell::new(HashSet::new()));
let re_read2 = Rc::clone(&re_read);
let observer = ObserverState {
dirty_callback: Rc::new(move || {
if !computing2.get() {
dirty2.set(true);
if let Some(ref sig) = *holder.borrow() {
sig.bump_version();
}
}
}),
on_subscribe: Rc::new(move |key: SignalKey, cleanup: Box<dyn FnOnce()>| {
subs.borrow_mut().push((key, cleanup));
}),
seen: seen2,
re_read: re_read2,
};
let prev = OBSERVER.with(|cell| cell.borrow_mut().take());
let _guard = ObserverGuard { prev };
OBSERVER.with(|cell| {
*cell.borrow_mut() = Some(observer);
});
let value = compute();
let read_keys = seen.borrow().clone();
let re_read_keys = re_read.borrow().clone();
(value, read_keys, re_read_keys)
}
impl<T> Drop for Memo<T> {
fn drop(&mut self) {
if Rc::strong_count(&self.subscriptions) == 1 {
for (_, cleanup) in self.subscriptions.borrow_mut().drain(..) {
cleanup();
}
}
}
}
impl<T> Clone for Memo<T> {
fn clone(&self) -> Self {
Self {
signal: self.signal.clone(),
dirty: Rc::clone(&self.dirty),
compute: Rc::clone(&self.compute),
subscriptions: Rc::clone(&self.subscriptions),
computing: Rc::clone(&self.computing),
compute_count: Rc::clone(&self.compute_count),
}
}
}
impl<T: fmt::Debug + 'static> fmt::Debug for Memo<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let subs = self.subscriptions.borrow().len();
self.signal.with(|value| {
f.debug_struct("Memo")
.field("value", value)
.field("dirty", &self.dirty.get())
.field("subs", &subs)
.finish_non_exhaustive()
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Signal;
#[test]
fn memo_initial_value() {
let a = Signal::new(2);
let b = Signal::new(3);
let sum = Memo::new(move || a.read() + b.read());
assert_eq!(sum.read(), 5);
}
#[test]
fn memo_lazy_recompute() {
let a = Signal::new(1);
let a2 = a.clone();
let compute_count = Rc::new(Cell::new(0u32));
let cc = Rc::clone(&compute_count);
let doubled = Memo::new(move || {
cc.set(cc.get() + 1);
a2.read() * 2
});
assert_eq!(compute_count.get(), 1);
assert_eq!(doubled.read(), 2);
assert_eq!(compute_count.get(), 1);
a.set(10);
assert_eq!(compute_count.get(), 1); assert_eq!(doubled.read(), 20);
assert_eq!(compute_count.get(), 2); }
#[test]
fn memo_no_recompute_when_clean() {
let a = Signal::new(5);
let memo = Memo::new(move || a.read() * 2);
assert_eq!(memo.read(), 10);
assert_eq!(memo.read(), 10); assert_eq!(memo.read(), 10);
}
#[test]
fn memo_multiple_sources() {
let x = Signal::new(1);
let y = Signal::new(2);
let z = Signal::new(3);
let x2 = x.clone();
let y2 = y.clone();
let z2 = z.clone();
let sum = Memo::new(move || x2.read() + y2.read() + z2.read());
assert_eq!(sum.read(), 6);
x.set(10);
assert_eq!(sum.read(), 15);
y.set(20);
assert_eq!(sum.read(), 33);
z.set(30);
assert_eq!(sum.read(), 60);
}
#[test]
fn memo_changed_future() {
let a = Signal::new(1);
let a2 = a.clone();
let doubled = Memo::new(move || a2.read() * 2);
let _fut = doubled.changed();
a.set(5);
assert_eq!(doubled.read(), 10);
}
#[test]
fn memo_drop_cleans_up() {
let sig = Signal::new(0i32);
{
let s = sig.clone();
let _memo = Memo::new(move || s.read() * 2);
assert!(sig.debug_count_waiters() > 0);
}
assert_eq!(sig.debug_count_waiters(), 0);
}
#[test]
fn memo_with_tracking() {
let a = Signal::new(7);
let memo = Memo::new(move || a.read() * 3);
let result = memo.with(|v| *v);
assert_eq!(result, 21);
}
#[test]
fn memo_dependency_set_changes() {
let use_b = Signal::new(false);
let a = Signal::new(1);
let b = Signal::new(100);
let a2 = a.clone();
let b2 = b.clone();
let use_b2 = use_b.clone();
let memo = Memo::new(move || if use_b2.read() { b2.read() } else { a2.read() });
assert_eq!(memo.read(), 1);
b.set(200);
assert_eq!(memo.read(), 1);
use_b.set(true);
assert_eq!(memo.read(), 200);
b.set(300);
assert_eq!(memo.read(), 300);
a.set(999);
assert!(!memo.is_dirty());
assert_eq!(memo.read(), 300);
}
#[test]
fn memo_nested() {
let base = Signal::new(2);
let doubled = Memo::new({
let b = base.clone();
move || b.read() * 2
});
let quadrupled = Memo::new({
let d = doubled.clone();
move || d.read() * 2
});
assert_eq!(quadrupled.read(), 8);
base.set(5);
assert_eq!(quadrupled.read(), 20);
}
#[test]
fn memo_stress_many_changes() {
let sig = Signal::new(0i32);
let memo = Memo::new({
let s = sig.clone();
move || s.read() * 2
});
for i in 1..=1000 {
sig.set(i);
assert_eq!(memo.read(), i * 2);
}
}
#[test]
fn memo_clone_shares_state() {
let a = Signal::new(10);
let a2 = a.clone();
let m1 = Memo::new(move || a2.read() * 2);
let m2 = m1.clone();
assert_eq!(m2.read(), 20);
a.set(15);
assert_eq!(m1.read(), 30);
assert_eq!(m2.read(), 30);
}
#[test]
fn memo_is_dirty_flag() {
let a = Signal::new(1);
let a2 = a.clone();
let memo = Memo::new(move || a2.read() * 2);
assert!(!memo.is_dirty());
a.set(5);
assert!(memo.is_dirty());
assert_eq!(memo.read(), 10);
assert!(!memo.is_dirty());
}
#[test]
fn memo_compute_count_increments() {
let a = Signal::new(1);
let a2 = a.clone();
let memo = Memo::new(move || a2.read() * 2);
assert_eq!(memo.compute_count(), 1);
let _ = memo.read();
assert_eq!(memo.compute_count(), 1);
a.set(10);
let _ = memo.read();
assert_eq!(memo.compute_count(), 2);
let clone = memo.clone();
assert_eq!(clone.compute_count(), 2);
a.set(20);
let _ = clone.read();
assert_eq!(memo.compute_count(), 3);
}
#[test]
fn memo_panics_during_compute_then_recovers() {
let sig = Signal::new(1);
let should_panic = Rc::new(Cell::new(false));
let sp = Rc::clone(&should_panic);
let s = sig.clone();
let memo = Memo::new(move || {
assert!(!sp.get(), "intentional memo panic");
s.read() * 2
});
assert_eq!(memo.read(), 2);
should_panic.set(true);
sig.set(99); let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| memo.read()));
assert!(result.is_err());
should_panic.set(false);
sig.set(5);
assert_eq!(memo.read(), 10);
}
#[test]
fn memo_self_read_in_compute_does_not_panic() {
let sig = Signal::new(1);
let s = sig.clone();
let memo = Memo::new(move || s.read() * 2);
let result = memo.with(|v| *v);
assert_eq!(result, 2);
}
#[test]
fn memo_with_signalmap_tracks_dependency() {
let source = Signal::new(42);
let sm = source.map(|v: &i32| *v);
let sm2 = sm.clone();
let memo = Memo::new(move || sm2.with(|v| *v));
assert_eq!(memo.read(), 42);
source.set(99);
assert!(memo.is_dirty());
assert_eq!(memo.read(), 99);
}
#[test]
fn memo_clone_drop_does_not_disconnect_siblings() {
let sig = Signal::new(0i32);
let s = sig.clone();
let m1 = Memo::new(move || s.read() * 2);
let m2 = m1.clone();
assert_eq!(m1.read(), 0);
assert_eq!(m2.read(), 0);
drop(m2);
sig.set(10);
assert!(
m1.is_dirty(),
"m1 should still be subscribed after clone drop"
);
assert_eq!(m1.read(), 20);
}
}