use crate::hooks::context::current_context;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
fn compute_deps_hash<D: Hash>(deps: &D) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
deps.hash(&mut hasher);
hasher.finish()
}
#[derive(Clone)]
struct MemoStorage<T> {
value: Arc<RwLock<T>>,
deps_hash: u64,
}
pub fn use_memo<T, D, F>(compute: F, deps: D) -> T
where
T: Clone + Send + Sync + 'static,
D: Hash,
F: FnOnce() -> T,
{
let Some(ctx) = current_context() else {
return compute();
};
let Ok(mut ctx_ref) = ctx.write() else {
return compute();
};
let new_hash = compute_deps_hash(&deps);
let storage = ctx_ref.use_hook(|| Option::<MemoStorage<T>>::None);
if let Some(existing) = storage.get::<Option<MemoStorage<T>>>() {
if let Some(mut memo) = existing {
if memo.deps_hash != new_hash {
let recomputed = compute();
if let Ok(mut value) = memo.value.write() {
*value = recomputed.clone();
} else {
memo.value = Arc::new(RwLock::new(recomputed.clone()));
}
memo.deps_hash = new_hash;
storage.set(Some(memo.clone()));
return recomputed;
}
if let Ok(value) = memo.value.read() {
value.clone()
} else {
compute()
}
} else {
let value = compute();
storage.set(Some(MemoStorage {
value: Arc::new(RwLock::new(value.clone())),
deps_hash: new_hash,
}));
value
}
} else {
compute()
}
}
#[derive(Clone)]
pub struct MemoizedCallback<F> {
callback: Arc<F>,
}
impl<F> MemoizedCallback<F> {
fn new(callback: F) -> Self {
Self {
callback: Arc::new(callback),
}
}
pub fn get(&self) -> &F {
&self.callback
}
}
#[derive(Clone)]
struct CallbackStorage<F> {
callback: MemoizedCallback<F>,
deps_hash: u64,
}
pub fn use_callback<F, D>(callback: F, deps: D) -> MemoizedCallback<F>
where
F: Clone + Send + Sync + 'static,
D: Hash,
{
let Some(ctx) = current_context() else {
return MemoizedCallback::new(callback);
};
let Ok(mut ctx_ref) = ctx.write() else {
return MemoizedCallback::new(callback);
};
let new_hash = compute_deps_hash(&deps);
let hash_storage = ctx_ref.use_hook(|| new_hash);
let stored_hash = hash_storage.get::<u64>().unwrap_or(0);
let storage = ctx_ref.use_hook(|| CallbackStorage {
callback: MemoizedCallback::new(callback.clone()),
deps_hash: new_hash,
});
if let Some(mut cb_storage) = storage.get::<CallbackStorage<F>>() {
if stored_hash != new_hash {
cb_storage.callback = MemoizedCallback::new(callback);
cb_storage.deps_hash = new_hash;
storage.set(cb_storage.clone());
hash_storage.set(new_hash);
}
cb_storage.callback
} else {
let new_storage = CallbackStorage {
callback: MemoizedCallback::new(callback),
deps_hash: new_hash,
};
storage.set(new_storage.clone());
hash_storage.set(new_hash);
new_storage.callback
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hooks::context::{HookContext, with_hooks};
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_use_memo_basic() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let result = with_hooks(ctx.clone(), || use_memo(|| 42, "deps1"));
assert_eq!(result, 42);
let result = with_hooks(ctx.clone(), || use_memo(|| 99, "deps1"));
assert_eq!(result, 42); }
#[test]
fn test_use_memo_with_tuple_deps() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let result = with_hooks(ctx.clone(), || use_memo(|| vec![1, 2, 3], (1, "a", true)));
assert_eq!(result, vec![1, 2, 3]);
let result = with_hooks(ctx.clone(), || use_memo(|| vec![4, 5, 6], (1, "a", true)));
assert_eq!(result, vec![1, 2, 3]); }
#[test]
fn test_use_memo_recomputes_when_deps_change() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let compute_calls = Arc::new(AtomicUsize::new(0));
let calls = compute_calls.clone();
let first = with_hooks(ctx.clone(), || {
use_memo(
|| {
calls.fetch_add(1, Ordering::SeqCst);
10
},
1,
)
});
assert_eq!(first, 10);
let calls = compute_calls.clone();
let second = with_hooks(ctx.clone(), || {
use_memo(
|| {
calls.fetch_add(1, Ordering::SeqCst);
99
},
1,
)
});
assert_eq!(second, 10);
let calls = compute_calls.clone();
let third = with_hooks(ctx.clone(), || {
use_memo(
|| {
calls.fetch_add(1, Ordering::SeqCst);
99
},
2,
)
});
assert_eq!(third, 99);
assert_eq!(compute_calls.load(Ordering::SeqCst), 2);
}
#[test]
fn test_use_callback_basic() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let multiply_by_2 = |x: i32| x * 2;
let cb = with_hooks(ctx.clone(), || use_callback(multiply_by_2, "cb_deps1"));
assert_eq!((cb.get())(5), 10);
let cb2 = with_hooks(ctx.clone(), || use_callback(multiply_by_2, "cb_deps1"));
assert_eq!((cb2.get())(5), 10);
}
#[test]
fn test_use_callback_deps_change() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
fn multiply(x: i32) -> i32 {
x * 2
}
let cb = with_hooks(ctx.clone(), || use_callback(multiply as fn(i32) -> i32, 1));
assert_eq!((cb.get())(5), 10);
let cb2 = with_hooks(ctx.clone(), || use_callback(multiply as fn(i32) -> i32, 2));
assert_eq!((cb2.get())(5), 10);
}
#[test]
fn test_use_callback_with_closure() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let multiplier = 10;
let cb = with_hooks(ctx.clone(), || {
use_callback(move |x: i32| x * multiplier, multiplier)
});
assert_eq!((cb.get())(5), 50);
}
#[test]
fn test_use_callback_same_fn_type() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
fn double(x: i32) -> i32 {
x * 2
}
let cb = with_hooks(ctx.clone(), || {
use_callback(double as fn(i32) -> i32, "fn_deps1")
});
assert_eq!((cb.get())(5), 10);
let cb2 = with_hooks(ctx.clone(), || {
use_callback(double as fn(i32) -> i32, "fn_deps1")
});
assert_eq!((cb2.get())(5), 10);
let cb3 = with_hooks(ctx.clone(), || {
use_callback(double as fn(i32) -> i32, "fn_deps2")
});
assert_eq!((cb3.get())(5), 10);
}
#[test]
fn test_use_callback_without_context_does_not_panic() {
let cb = use_callback(|x: i32| x + 1, "no_ctx");
assert_eq!((cb.get())(41), 42);
}
}