use std::any::Any;
use std::sync::{Arc, RwLock};
use crate::cmd::Cmd;
pub type RenderCallback = Arc<dyn Fn() + Send + Sync>;
pub type EffectCallback = Box<dyn FnOnce() -> Option<Box<dyn FnOnce() + Send>> + Send>;
#[derive(Clone)]
pub struct HookStorage {
pub value: Arc<RwLock<Box<dyn Any + Send + Sync>>>,
}
impl HookStorage {
pub fn new<T: Send + Sync + 'static>(value: T) -> Self {
Self {
value: Arc::new(RwLock::new(Box::new(value))),
}
}
pub fn get<T: Clone + Send + Sync + 'static>(&self) -> Option<T> {
self.value.read().ok()?.downcast_ref::<T>().cloned()
}
pub fn set<T: Send + Sync + 'static>(&self, value: T) {
if let Ok(mut guard) = self.value.write() {
*guard = Box::new(value);
}
}
}
pub struct Effect {
pub callback: EffectCallback,
pub slot: usize,
}
pub struct HookContext {
hooks: Vec<HookStorage>,
hook_index: usize,
effects: Vec<Effect>,
cleanups: Vec<Option<Box<dyn FnOnce() + Send>>>,
render_callback: Option<RenderCallback>,
is_rendering: bool,
cmd_queue: Vec<Cmd>,
hook_types: Vec<std::any::TypeId>,
first_render_complete: bool,
}
impl HookContext {
pub fn new() -> Self {
Self {
hooks: Vec::new(),
hook_index: 0,
effects: Vec::new(),
cleanups: Vec::new(),
render_callback: None,
is_rendering: false,
cmd_queue: Vec::new(),
hook_types: Vec::new(),
first_render_complete: false,
}
}
pub fn set_render_callback(&mut self, callback: RenderCallback) {
self.render_callback = Some(callback);
}
pub fn get_render_callback(&self) -> Option<RenderCallback> {
self.render_callback.clone()
}
pub fn begin_render(&mut self) {
self.hook_index = 0;
self.effects.clear();
self.is_rendering = true;
}
pub fn end_render(&mut self) {
self.is_rendering = false;
self.first_render_complete = true;
}
pub fn use_hook<T: Clone + Send + Sync + 'static, F: FnOnce() -> T>(
&mut self,
init: F,
) -> HookStorage {
self.use_hook_with_index(init).0
}
pub fn use_hook_with_index<T: Clone + Send + Sync + 'static, F: FnOnce() -> T>(
&mut self,
init: F,
) -> (HookStorage, usize) {
let index = self.hook_index;
self.hook_index += 1;
let storage = if index >= self.hooks.len() {
self.hook_types.push(std::any::TypeId::of::<T>());
let storage = HookStorage::new(init());
self.hooks.push(storage.clone());
storage
} else {
if self.first_render_complete {
let expected = self.hook_types[index];
let actual = std::any::TypeId::of::<T>();
if expected != actual {
panic!(
"Hook order violation at index {}! \
Hooks must be called in the same order on every render. \
This usually happens when hooks are called conditionally. \
Move conditional logic inside the hook or use separate components.",
index
);
}
}
self.hooks[index].clone()
};
(storage, index)
}
pub fn add_effect(&mut self, effect: Effect) {
self.effects.push(effect);
}
pub fn run_effects(&mut self) {
let effects = std::mem::take(&mut self.effects);
for effect in effects {
if effect.slot >= self.cleanups.len() {
self.cleanups.resize_with(effect.slot + 1, || None);
}
if let Some(cleanup_fn) = self.cleanups[effect.slot].take() {
cleanup_fn();
}
self.cleanups[effect.slot] = (effect.callback)();
}
}
pub fn request_render(&self) {
if let Some(callback) = &self.render_callback {
callback();
}
}
pub fn queue_cmd(&mut self, cmd: Cmd) {
self.cmd_queue.push(cmd);
}
pub fn take_cmds(&mut self) -> Vec<Cmd> {
std::mem::take(&mut self.cmd_queue)
}
}
impl Default for HookContext {
fn default() -> Self {
Self::new()
}
}
impl Drop for HookContext {
fn drop(&mut self) {
for cleanup_fn in self.cleanups.drain(..).flatten() {
cleanup_fn();
}
}
}
thread_local! {
static CURRENT_CONTEXT: std::cell::RefCell<Option<Arc<RwLock<HookContext>>>> = const { std::cell::RefCell::new(None) };
}
pub fn current_context() -> Option<Arc<RwLock<HookContext>>> {
CURRENT_CONTEXT.with(|ctx| ctx.borrow().clone())
}
pub fn with_hooks<F, R>(ctx: Arc<RwLock<HookContext>>, f: F) -> R
where
F: FnOnce() -> R,
{
CURRENT_CONTEXT.with(|current| {
*current.borrow_mut() = Some(ctx.clone());
});
if let Ok(mut guard) = ctx.write() {
guard.begin_render();
}
let result = f();
if let Ok(mut guard) = ctx.write() {
guard.end_render();
guard.run_effects();
}
CURRENT_CONTEXT.with(|current| {
*current.borrow_mut() = None;
});
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_context_creation() {
let ctx = HookContext::new();
assert_eq!(ctx.hook_index, 0);
assert!(ctx.hooks.is_empty());
}
#[test]
fn test_use_hook() {
let mut ctx = HookContext::new();
ctx.begin_render();
let hook1 = ctx.use_hook(|| 42i32);
let hook2 = ctx.use_hook(|| "hello".to_string());
assert_eq!(hook1.get::<i32>(), Some(42));
assert_eq!(hook2.get::<String>(), Some("hello".to_string()));
assert_eq!(ctx.hook_index, 2);
}
#[test]
fn test_hook_persistence() {
let mut ctx = HookContext::new();
ctx.begin_render();
let hook = ctx.use_hook(|| 1i32);
assert_eq!(hook.get::<i32>(), Some(1));
hook.set(2i32);
ctx.end_render();
ctx.begin_render();
let hook = ctx.use_hook(|| 999i32); assert_eq!(hook.get::<i32>(), Some(2)); ctx.end_render();
}
#[test]
fn test_with_hooks() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
let result = with_hooks(ctx.clone(), || {
let ctx = current_context().unwrap();
let hook = ctx.write().unwrap().use_hook(|| 42i32);
hook.get::<i32>().unwrap()
});
assert_eq!(result, 42);
}
#[test]
#[should_panic(expected = "Hook order violation")]
fn test_hook_order_violation() {
let ctx = Arc::new(RwLock::new(HookContext::new()));
with_hooks(ctx.clone(), || {
let ctx = current_context().unwrap();
let mut guard = ctx.write().unwrap();
let _ = guard.use_hook(|| 42i32);
let _ = guard.use_hook(|| "hello".to_string());
});
with_hooks(ctx.clone(), || {
let ctx = current_context().unwrap();
let mut guard = ctx.write().unwrap();
let _ = guard.use_hook(|| "wrong type".to_string());
});
}
}