use crate::hooks::context::current_context;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, RwLock};
fn compute_deps_hash<D: Hash>(deps: &D) -> u64 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
deps.hash(&mut hasher);
hasher.finish()
}
#[derive(Clone)]
struct MemoStorage<T> {
value: Arc<RwLock<T>>,
}
pub fn use_memo<T, D, F>(compute: F, deps: D) -> T
where
T: Clone + Send + Sync + 'static,
D: Hash,
F: FnOnce() -> T,
{
let ctx = current_context().expect("use_memo must be called within a component");
let mut ctx_ref = ctx.write().unwrap();
let new_hash = compute_deps_hash(&deps);
let hash_storage = ctx_ref.use_hook(|| new_hash);
let stored_hash = hash_storage.get::<u64>().unwrap_or(0);
let storage = ctx_ref.use_hook(|| {
let value = compute();
MemoStorage {
value: Arc::new(RwLock::new(value)),
}
});
if let Some(memo) = storage.get::<MemoStorage<T>>() {
if stored_hash != new_hash {
hash_storage.set(new_hash);
}
memo.value.read().unwrap().clone()
} else {
panic!("use_memo: storage type mismatch")
}
}
#[derive(Clone)]
pub struct MemoizedCallback<F> {
callback: Arc<F>,
}
impl<F> MemoizedCallback<F> {
fn new(callback: F) -> Self {
Self {
callback: Arc::new(callback),
}
}
pub fn get(&self) -> &F {
&self.callback
}
}
#[derive(Clone)]
struct CallbackStorage<F> {
callback: MemoizedCallback<F>,
deps_hash: u64,
}
pub fn use_callback<F, D>(callback: F, deps: D) -> MemoizedCallback<F>
where
F: Clone + Send + Sync + 'static,
D: Hash,
{
let ctx = current_context().expect("use_callback must be called within a component");
let mut ctx_ref = ctx.write().unwrap();
let new_hash = compute_deps_hash(&deps);
let hash_storage = ctx_ref.use_hook(|| new_hash);
let stored_hash = hash_storage.get::<u64>().unwrap_or(0);
let storage = ctx_ref.use_hook(|| CallbackStorage {
callback: MemoizedCallback::new(callback.clone()),
deps_hash: new_hash,
});
if let Some(mut cb_storage) = storage.get::<CallbackStorage<F>>() {
if stored_hash != new_hash {
cb_storage.callback = MemoizedCallback::new(callback);
cb_storage.deps_hash = new_hash;
storage.set(cb_storage.clone());
hash_storage.set(new_hash);
}
cb_storage.callback
} else {
let new_storage = CallbackStorage {
callback: MemoizedCallback::new(callback),
deps_hash: new_hash,
};
storage.set(new_storage.clone());
hash_storage.set(new_hash);
new_storage.callback
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::hooks::context::{HookContext, with_hooks};
#[test]
fn test_use_memo_basic() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let result = with_hooks(ctx.clone(), || use_memo(|| 42, "deps1"));
assert_eq!(result, 42);
let result = with_hooks(ctx.clone(), || use_memo(|| 99, "deps1"));
assert_eq!(result, 42); }
#[test]
fn test_use_memo_with_tuple_deps() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let result = with_hooks(ctx.clone(), || use_memo(|| vec![1, 2, 3], (1, "a", true)));
assert_eq!(result, vec![1, 2, 3]);
let result = with_hooks(ctx.clone(), || use_memo(|| vec![4, 5, 6], (1, "a", true)));
assert_eq!(result, vec![1, 2, 3]); }
#[test]
fn test_use_callback_basic() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let multiply_by_2 = |x: i32| x * 2;
let cb = with_hooks(ctx.clone(), || use_callback(multiply_by_2, "cb_deps1"));
assert_eq!((cb.get())(5), 10);
let cb2 = with_hooks(ctx.clone(), || use_callback(multiply_by_2, "cb_deps1"));
assert_eq!((cb2.get())(5), 10);
}
#[test]
fn test_use_callback_deps_change() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let multiply_by_2 = |x: i32| x * 2;
let multiply_by_3 = |x: i32| x * 3;
let cb = with_hooks(ctx.clone(), || use_callback(multiply_by_2, "change_deps1"));
assert_eq!((cb.get())(5), 10);
let cb2 = with_hooks(ctx.clone(), || use_callback(multiply_by_3, "change_deps2"));
assert_eq!((cb2.get())(5), 15);
}
#[test]
fn test_use_callback_with_closure() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let multiplier = 10;
let cb = with_hooks(ctx.clone(), || {
use_callback(move |x: i32| x * multiplier, multiplier)
});
assert_eq!((cb.get())(5), 50);
}
#[test]
fn test_use_callback_same_fn_type() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
fn double(x: i32) -> i32 {
x * 2
}
fn triple(x: i32) -> i32 {
x * 3
}
let cb = with_hooks(ctx.clone(), || use_callback(double, "fn_deps1"));
assert_eq!((cb.get())(5), 10);
let cb2 = with_hooks(ctx.clone(), || use_callback(double, "fn_deps1"));
assert_eq!((cb2.get())(5), 10);
let cb3 = with_hooks(ctx.clone(), || use_callback(triple, "fn_deps2"));
assert_eq!((cb3.get())(5), 15);
}
}