use core::cell::RefCell;
extern crate alloc;
use alloc::boxed::Box;
use alloc::collections::BTreeMap;
use alloc::rc::Rc;
use super::runtime::{EffectTiming, NodeId, NodeType, Observer, try_with_runtime, with_runtime};
type MemoFn<T> = Box<dyn FnMut() -> T + 'static>;
#[derive(Clone)]
struct MemoState<T: Clone> {
value: T,
dirty: bool,
}
thread_local! {
static MEMO_FUNCTIONS: RefCell<BTreeMap<NodeId, Box<dyn core::any::Any>>> = RefCell::new(BTreeMap::new());
}
thread_local! {
static MEMO_VALUES: RefCell<BTreeMap<NodeId, Box<dyn core::any::Any>>> = RefCell::new(BTreeMap::new());
}
#[derive(Clone)]
pub struct Memo<T: Clone + 'static> {
id: NodeId,
disposed: Rc<RefCell<bool>>,
_phantom: core::marker::PhantomData<T>,
}
impl<T: Clone + 'static> Memo<T> {
pub fn new<F>(mut f: F) -> Self
where
F: FnMut() -> T + 'static,
{
let id = NodeId::new();
let disposed = Rc::new(RefCell::new(false));
let disposed_clone = disposed.clone();
MEMO_FUNCTIONS.with(|storage| {
let mut storage = storage.borrow_mut();
let boxed: Box<dyn core::any::Any> = Box::new(Box::new(move || {
if !*disposed_clone.borrow() {
f()
} else {
panic!("Attempted to compute a disposed Memo");
}
}) as MemoFn<T>);
storage.insert(id, boxed);
});
let initial_value = Self::compute_value(id);
MEMO_VALUES.with(|storage| {
let mut storage = storage.borrow_mut();
let state = MemoState {
value: initial_value,
dirty: false,
};
let boxed: Box<dyn core::any::Any> = Box::new(state);
storage.insert(id, boxed);
});
Self {
id,
disposed,
_phantom: core::marker::PhantomData,
}
}
fn compute_value(memo_id: NodeId) -> T {
with_runtime(|rt| {
rt.clear_dependencies(memo_id);
rt.push_observer(Observer {
id: memo_id,
node_type: NodeType::Memo,
timing: EffectTiming::default(), cleanup: None,
});
});
struct MemoFnGuard {
memo_id: NodeId,
memo_fn_box: Option<Box<dyn core::any::Any>>,
}
impl Drop for MemoFnGuard {
fn drop(&mut self) {
if let Some(f) = self.memo_fn_box.take() {
MEMO_FUNCTIONS.with(|storage| {
storage.borrow_mut().insert(self.memo_id, f);
});
}
}
}
let mut guard = MemoFnGuard {
memo_id,
memo_fn_box: MEMO_FUNCTIONS.with(|storage| storage.borrow_mut().remove(&memo_id)),
};
let result = if let Some(ref mut boxed) = guard.memo_fn_box
&& let Some(memo_fn) = boxed.downcast_mut::<MemoFn<T>>()
{
memo_fn()
} else {
panic!("Memo function not found - this should never happen")
};
with_runtime(|rt| {
rt.pop_observer();
});
result
}
pub fn get(&self) -> T {
if *self.disposed.borrow() {
panic!("Attempted to access a disposed Memo");
}
with_runtime(|rt| rt.track_dependency(self.id));
let needs_recompute = MEMO_VALUES.with(|storage| {
let storage = storage.borrow();
if let Some(boxed) = storage.get(&self.id)
&& let Some(state) = boxed.downcast_ref::<MemoState<T>>()
{
return state.dirty;
}
true
});
if needs_recompute {
let new_value = Self::compute_value(self.id);
MEMO_VALUES.with(|storage| {
let mut storage = storage.borrow_mut();
if let Some(boxed) = storage.get_mut(&self.id)
&& let Some(state) = boxed.downcast_mut::<MemoState<T>>()
{
state.value = new_value.clone();
state.dirty = false;
}
});
new_value
} else {
MEMO_VALUES.with(|storage| {
let storage = storage.borrow();
if let Some(boxed) = storage.get(&self.id)
&& let Some(state) = boxed.downcast_ref::<MemoState<T>>()
{
return state.value.clone();
}
panic!("Memo value not found - this should never happen");
})
}
}
pub fn get_untracked(&self) -> T {
if *self.disposed.borrow() {
panic!("Attempted to access a disposed Memo");
}
let needs_recompute = MEMO_VALUES.with(|storage| {
let storage = storage.borrow();
if let Some(boxed) = storage.get(&self.id)
&& let Some(state) = boxed.downcast_ref::<MemoState<T>>()
{
return state.dirty;
}
true
});
if needs_recompute {
let new_value = Self::compute_value(self.id);
MEMO_VALUES.with(|storage| {
let mut storage = storage.borrow_mut();
if let Some(boxed) = storage.get_mut(&self.id)
&& let Some(state) = boxed.downcast_mut::<MemoState<T>>()
{
state.value = new_value.clone();
state.dirty = false;
}
});
new_value
} else {
MEMO_VALUES.with(|storage| {
let storage = storage.borrow();
if let Some(boxed) = storage.get(&self.id)
&& let Some(state) = boxed.downcast_ref::<MemoState<T>>()
{
return state.value.clone();
}
panic!("Memo value not found - this should never happen");
})
}
}
pub fn mark_dirty(&self) {
MEMO_VALUES.with(|storage| {
let mut storage = storage.borrow_mut();
if let Some(boxed) = storage.get_mut(&self.id)
&& let Some(state) = boxed.downcast_mut::<MemoState<T>>()
{
state.dirty = true;
}
});
with_runtime(|rt| rt.notify_signal_change(self.id));
}
pub fn id(&self) -> NodeId {
self.id
}
pub fn dispose(&self) {
*self.disposed.borrow_mut() = true;
let _ = try_with_runtime(|rt| rt.remove_node(self.id));
let _ = MEMO_FUNCTIONS.try_with(|storage| {
storage.borrow_mut().remove(&self.id);
});
let _ = MEMO_VALUES.try_with(|storage| {
storage.borrow_mut().remove(&self.id);
});
}
}
impl<T: Clone + 'static> Drop for Memo<T> {
fn drop(&mut self) {
self.dispose();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reactive::Signal;
use serial_test::serial;
#[test]
#[serial]
fn test_memo_creation() {
let memo = Memo::new(|| 42);
assert_eq!(memo.get(), 42);
}
#[test]
#[serial]
fn test_memo_caching() {
let compute_count = Rc::new(RefCell::new(0));
let compute_count_clone = compute_count.clone();
let memo = Memo::new(move || {
*compute_count_clone.borrow_mut() += 1;
42
});
assert_eq!(memo.get(), 42);
assert_eq!(*compute_count.borrow(), 1);
assert_eq!(memo.get(), 42);
assert_eq!(*compute_count.borrow(), 1); }
#[test]
#[serial]
fn test_memo_with_signal_dependency() {
let signal = Signal::new(5);
let signal_clone = signal.clone();
let memo = Memo::new(move || signal_clone.get() * 2);
assert_eq!(memo.get(), 10);
signal.set(10);
memo.mark_dirty();
assert_eq!(memo.get(), 20);
}
#[test]
#[serial]
fn test_memo_clone() {
let memo1 = Memo::new(|| 42);
let memo2 = memo1.clone();
assert_eq!(memo1.get(), 42);
assert_eq!(memo2.get(), 42);
}
#[test]
#[serial]
fn test_memo_dependency_tracking() {
let signal = Signal::new(1);
let signal_clone = signal.clone();
let memo = Memo::new(move || signal_clone.get() + 10);
with_runtime(|rt| {
let observer_id = NodeId::new();
rt.push_observer(Observer {
id: observer_id,
node_type: NodeType::Effect,
timing: EffectTiming::default(), cleanup: None,
});
let _ = memo.get();
rt.pop_observer();
let graph = rt.dependency_graph.borrow();
let memo_node = graph.get(&memo.id()).unwrap();
assert!(memo_node.subscribers.contains(&observer_id));
});
}
#[rstest::rstest]
#[serial]
fn test_memo_creates_effect_during_computation() {
let effect_ran = Rc::new(RefCell::new(false));
let effect_ran_clone = effect_ran.clone();
let memo = Memo::new(move || {
use crate::reactive::Effect;
let ran = effect_ran_clone.clone();
let _effect = Effect::new(move || {
*ran.borrow_mut() = true;
});
42
});
assert_eq!(memo.get(), 42);
assert!(*effect_ran.borrow());
}
#[test]
#[serial]
fn test_memo_get_untracked() {
let signal = Signal::new(5);
let signal_clone = signal.clone();
let memo = Memo::new(move || signal_clone.get() * 2);
with_runtime(|rt| {
let observer_id = NodeId::new();
rt.push_observer(Observer {
id: observer_id,
node_type: NodeType::Effect,
timing: EffectTiming::default(), cleanup: None,
});
let _ = memo.get_untracked();
rt.pop_observer();
let graph = rt.dependency_graph.borrow();
if let Some(memo_node) = graph.get(&memo.id()) {
assert!(!memo_node.subscribers.contains(&observer_id));
}
});
}
}