use core::cell::RefCell;
extern crate alloc;
use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::rc::Rc;
use super::runtime::{EffectTiming, NodeId, NodeType, Observer, try_with_runtime, with_runtime};
type MemoFn<T> = Box<dyn FnMut() -> T + 'static>;
#[derive(Clone)]
struct MemoState<T: Clone> {
value: T,
dirty: bool,
}
thread_local! {
static MEMO_FUNCTIONS: RefCell<BTreeMap<NodeId, Box<dyn core::any::Any>>> = RefCell::new(BTreeMap::new());
}
thread_local! {
static MEMO_VALUES: RefCell<BTreeMap<NodeId, Box<dyn core::any::Any>>> = RefCell::new(BTreeMap::new());
}
thread_local! {
static MEMO_DIRTY: RefCell<BTreeMap<NodeId, bool>> = const { RefCell::new(BTreeMap::new()) };
}
pub(crate) fn mark_memo_dirty_by_id(memo_id: NodeId) {
MEMO_DIRTY.with(|m| {
m.borrow_mut().insert(memo_id, true);
});
with_runtime(|rt| rt.notify_signal_change(memo_id));
}
fn is_memo_dirty_externally(memo_id: NodeId) -> bool {
MEMO_DIRTY.with(|m| m.borrow().get(&memo_id).copied().unwrap_or(false))
}
fn clear_memo_dirty(memo_id: NodeId) {
MEMO_DIRTY.with(|m| {
m.borrow_mut().remove(&memo_id);
});
}
#[derive(Clone)]
pub struct Memo<T: Clone + 'static> {
id: NodeId,
disposed: Rc<RefCell<bool>>,
_phantom: core::marker::PhantomData<T>,
#[allow(dead_code)]
deps_notifier: Option<Rc<super::effect::Effect>>,
}
impl<T: Clone + 'static> Memo<T> {
pub fn new<F>(f: F) -> Self
where
F: FnMut() -> T + 'static,
{
let id = NodeId::new();
let disposed = Rc::new(RefCell::new(false));
MEMO_FUNCTIONS.with(|storage| {
let mut storage = storage.borrow_mut();
let boxed: Box<dyn core::any::Any> = Box::new(Box::new(f) as MemoFn<T>);
storage.insert(id, boxed);
});
let initial_value = Self::compute_value(id);
MEMO_VALUES.with(|storage| {
let mut storage = storage.borrow_mut();
let state = MemoState {
value: initial_value,
dirty: false,
};
let boxed: Box<dyn core::any::Any> = Box::new(state);
storage.insert(id, boxed);
});
Self {
id,
disposed,
_phantom: core::marker::PhantomData,
deps_notifier: None,
}
}
#[allow(dead_code)]
pub fn new_with_deps<F>(mut f: F, deps: super::deps::Deps) -> Self
where
F: FnMut() -> T + 'static,
{
let id = NodeId::new();
let disposed = Rc::new(RefCell::new(false));
let f_wrapped = move || super::runtime::run_without_observer(&mut f);
MEMO_FUNCTIONS.with(|storage| {
let boxed: Box<dyn core::any::Any> = Box::new(Box::new(f_wrapped) as MemoFn<T>);
storage.borrow_mut().insert(id, boxed);
});
let initial_value = Self::compute_value(id);
MEMO_VALUES.with(|storage| {
let state = MemoState {
value: initial_value,
dirty: false,
};
let boxed: Box<dyn core::any::Any> = Box::new(state);
storage.borrow_mut().insert(id, boxed);
});
let deps_notifier = if deps.as_slice().is_empty() {
None
} else {
let memo_id = id;
let notifier = super::effect::Effect::new_with_deps_and_timing::<_, fn()>(
move || {
mark_memo_dirty_by_id(memo_id);
None
},
deps,
EffectTiming::Layout,
);
Some(Rc::new(notifier))
};
clear_memo_dirty(id);
Self {
id,
disposed,
_phantom: core::marker::PhantomData,
deps_notifier,
}
}
fn compute_value(memo_id: NodeId) -> T {
with_runtime(|rt| {
rt.clear_dependencies(memo_id);
rt.push_observer(Observer {
id: memo_id,
node_type: NodeType::Memo,
timing: EffectTiming::default(), cleanup: None,
});
});
struct MemoFnGuard {
memo_id: NodeId,
memo_fn_box: Option<Box<dyn core::any::Any>>,
}
impl Drop for MemoFnGuard {
fn drop(&mut self) {
if let Some(f) = self.memo_fn_box.take() {
MEMO_FUNCTIONS.with(|storage| {
storage.borrow_mut().insert(self.memo_id, f);
});
}
}
}
let mut guard = MemoFnGuard {
memo_id,
memo_fn_box: MEMO_FUNCTIONS.with(|storage| storage.borrow_mut().remove(&memo_id)),
};
let result = if let Some(ref mut boxed) = guard.memo_fn_box
&& let Some(memo_fn) = boxed.downcast_mut::<MemoFn<T>>()
{
memo_fn()
} else {
panic!("Memo function not found - this should never happen")
};
with_runtime(|rt| {
rt.pop_observer();
});
result
}
pub fn get(&self) -> T {
if *self.disposed.borrow() {
panic!("Attempted to access a disposed Memo");
}
with_runtime(|rt| rt.track_dependency(self.id));
let needs_recompute = MEMO_VALUES.with(|storage| {
let storage = storage.borrow();
if let Some(boxed) = storage.get(&self.id)
&& let Some(state) = boxed.downcast_ref::<MemoState<T>>()
{
return state.dirty;
}
true
}) || is_memo_dirty_externally(self.id);
if needs_recompute {
let new_value = Self::compute_value(self.id);
MEMO_VALUES.with(|storage| {
let mut storage = storage.borrow_mut();
if let Some(boxed) = storage.get_mut(&self.id)
&& let Some(state) = boxed.downcast_mut::<MemoState<T>>()
{
state.value = new_value.clone();
state.dirty = false;
}
});
clear_memo_dirty(self.id);
new_value
} else {
MEMO_VALUES.with(|storage| {
let storage = storage.borrow();
if let Some(boxed) = storage.get(&self.id)
&& let Some(state) = boxed.downcast_ref::<MemoState<T>>()
{
return state.value.clone();
}
panic!("Memo value not found - this should never happen");
})
}
}
pub fn get_untracked(&self) -> T {
if *self.disposed.borrow() {
panic!("Attempted to access a disposed Memo");
}
let needs_recompute = MEMO_VALUES.with(|storage| {
let storage = storage.borrow();
if let Some(boxed) = storage.get(&self.id)
&& let Some(state) = boxed.downcast_ref::<MemoState<T>>()
{
return state.dirty;
}
true
}) || is_memo_dirty_externally(self.id);
if needs_recompute {
let new_value = Self::compute_value(self.id);
MEMO_VALUES.with(|storage| {
let mut storage = storage.borrow_mut();
if let Some(boxed) = storage.get_mut(&self.id)
&& let Some(state) = boxed.downcast_mut::<MemoState<T>>()
{
state.value = new_value.clone();
state.dirty = false;
}
});
clear_memo_dirty(self.id);
new_value
} else {
MEMO_VALUES.with(|storage| {
let storage = storage.borrow();
if let Some(boxed) = storage.get(&self.id)
&& let Some(state) = boxed.downcast_ref::<MemoState<T>>()
{
return state.value.clone();
}
panic!("Memo value not found - this should never happen");
})
}
}
pub fn mark_dirty(&self) {
MEMO_VALUES.with(|storage| {
let mut storage = storage.borrow_mut();
if let Some(boxed) = storage.get_mut(&self.id)
&& let Some(state) = boxed.downcast_mut::<MemoState<T>>()
{
state.dirty = true;
}
});
with_runtime(|rt| rt.notify_signal_change(self.id));
}
pub fn id(&self) -> NodeId {
self.id
}
pub fn dispose(&self) {
*self.disposed.borrow_mut() = true;
let _ = try_with_runtime(|rt| rt.remove_node(self.id));
let _ = MEMO_FUNCTIONS.try_with(|storage| {
storage.borrow_mut().remove(&self.id);
});
let _ = MEMO_VALUES.try_with(|storage| {
storage.borrow_mut().remove(&self.id);
});
let _ = MEMO_DIRTY.try_with(|storage| {
storage.borrow_mut().remove(&self.id);
});
}
}
impl<T: Clone + 'static> Drop for Memo<T> {
fn drop(&mut self) {
if Rc::strong_count(&self.disposed) == 1 {
self.dispose();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reactive::Signal;
use serial_test::serial;
#[test]
#[serial]
fn test_memo_creation() {
let memo = Memo::new(|| 42);
assert_eq!(memo.get(), 42);
}
#[test]
#[serial]
fn test_memo_caching() {
let compute_count = Rc::new(RefCell::new(0));
let compute_count_clone = compute_count.clone();
let memo = Memo::new(move || {
*compute_count_clone.borrow_mut() += 1;
42
});
assert_eq!(memo.get(), 42);
assert_eq!(*compute_count.borrow(), 1);
assert_eq!(memo.get(), 42);
assert_eq!(*compute_count.borrow(), 1); }
#[test]
#[serial]
fn test_memo_with_signal_dependency() {
let signal = Signal::new(5);
let signal_clone = signal.clone();
let memo = Memo::new(move || signal_clone.get() * 2);
assert_eq!(memo.get(), 10);
signal.set(10);
memo.mark_dirty();
assert_eq!(memo.get(), 20);
}
#[test]
#[serial]
fn test_memo_clone() {
let memo1 = Memo::new(|| 42);
let memo2 = memo1.clone();
assert_eq!(memo1.get(), 42);
assert_eq!(memo2.get(), 42);
}
#[test]
#[serial]
fn test_memo_dependency_tracking() {
let signal = Signal::new(1);
let signal_clone = signal.clone();
let memo = Memo::new(move || signal_clone.get() + 10);
with_runtime(|rt| {
let observer_id = NodeId::new();
rt.push_observer(Observer {
id: observer_id,
node_type: NodeType::Effect,
timing: EffectTiming::default(), cleanup: None,
});
let _ = memo.get();
rt.pop_observer();
let graph = rt.dependency_graph.borrow();
let memo_node = graph.get(&memo.id()).unwrap();
assert!(memo_node.subscribers.contains(&observer_id));
});
}
#[rstest::rstest]
#[serial]
fn test_memo_creates_effect_during_computation() {
let effect_ran = Rc::new(RefCell::new(false));
let effect_ran_clone = effect_ran.clone();
let memo = Memo::new(move || {
use crate::reactive::Effect;
let ran = effect_ran_clone.clone();
let _effect = Effect::new(move || {
*ran.borrow_mut() = true;
});
42
});
assert_eq!(memo.get(), 42);
assert!(*effect_ran.borrow());
}
#[test]
#[serial]
fn test_memo_get_untracked() {
let signal = Signal::new(5);
let signal_clone = signal.clone();
let memo = Memo::new(move || signal_clone.get() * 2);
with_runtime(|rt| {
let observer_id = NodeId::new();
rt.push_observer(Observer {
id: observer_id,
node_type: NodeType::Effect,
timing: EffectTiming::default(), cleanup: None,
});
let _ = memo.get_untracked();
rt.pop_observer();
let graph = rt.dependency_graph.borrow();
if let Some(memo_node) = graph.get(&memo.id()) {
assert!(!memo_node.subscribers.contains(&observer_id));
}
});
}
#[test]
#[serial]
fn new_with_deps_recomputes_only_on_listed_dep() {
use core::cell::Cell;
let listed = Signal::new(2_i32);
let unlisted = Signal::new(100_i32);
let computations = Rc::new(Cell::new(0_i32));
let computations_for_memo = computations.clone();
let listed_for_memo = listed.clone();
let unlisted_for_memo = unlisted.clone();
let deps = crate::reactive::deps::Deps::from_signals(&[listed.id()]);
let memo = Memo::new_with_deps(
move || {
computations_for_memo.set(computations_for_memo.get() + 1);
listed_for_memo.get() * 10 + unlisted_for_memo.get()
},
deps,
);
let _ = memo.get();
let before_changes = computations.get();
unlisted.set(200);
let _ = memo.get();
let after_unlisted = computations.get();
listed.set(3);
let _ = memo.get();
let after_listed = computations.get();
assert_eq!(
after_unlisted, before_changes,
"memo MUST NOT recompute when unlisted Signal changes (Option A)"
);
assert_eq!(
after_listed,
before_changes + 1,
"memo MUST recompute exactly once when listed dep changes"
);
}
}