use std::any::Any;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use std::thread::ThreadId;
use crate::fiber::FiberId;
use crate::fiber_tree::FiberTree;
thread_local! {
static STATE_BATCH: RefCell<StateBatch> = RefCell::new(StateBatch::new());
}
static CROSS_THREAD_UPDATES: Mutex<Vec<CrossThreadUpdate>> = Mutex::new(Vec::new());
static MAIN_THREAD_ID: std::sync::RwLock<Option<ThreadId>> = std::sync::RwLock::new(None);
pub struct CrossThreadUpdate {
pub fiber_id: FiberId,
pub hook_index: usize,
pub update: CrossThreadUpdateKind,
}
pub enum CrossThreadUpdateKind {
Value(Box<dyn Any + Send>),
Updater(Arc<Mutex<Option<StateUpdaterFn>>>),
ValueIfChanged {
value: Box<dyn Any + Send>,
type_id: std::any::TypeId,
},
UpdaterIfChanged {
updater: Arc<Mutex<Option<StateUpdaterFn>>>,
},
}
pub type StateUpdaterFn = Box<dyn FnOnce(&dyn Any) -> Box<dyn Any + Send> + Send>;
pub type EqualityCheckFn = Box<dyn FnOnce(&dyn Any, &dyn Any) -> bool + Send>;
pub struct StateUpdate {
pub hook_index: usize,
pub update: StateUpdateKind,
}
pub enum StateUpdateKind {
Value(Box<dyn Any + Send>),
Updater(StateUpdaterFn),
ValueIfChanged {
value: Box<dyn Any + Send>,
eq_check: EqualityCheckFn,
},
UpdaterIfChanged {
updater: StateUpdaterFn,
eq_check: EqualityCheckFn,
},
}
pub struct StateBatch {
updates: HashMap<FiberId, Vec<StateUpdate>>,
batching: bool,
dirty_fibers: HashSet<FiberId>,
}
impl StateBatch {
pub fn new() -> Self {
Self {
updates: HashMap::new(),
batching: false,
dirty_fibers: HashSet::new(),
}
}
pub fn begin_batch(&mut self) {
self.batching = true;
}
pub fn end_batch(&mut self, tree: &mut FiberTree) -> HashSet<FiberId> {
self.batching = false;
let mut actually_dirty: HashSet<FiberId> = HashSet::new();
for (fiber_id, updates) in self.updates.drain() {
if let Some(fiber) = tree.get_mut(fiber_id) {
let mut fiber_changed = false;
for update in updates {
if update.hook_index >= fiber.hooks.len() {
fiber
.hooks
.resize_with(update.hook_index + 1, || Box::new(()));
}
match update.update {
StateUpdateKind::Value(value) => {
fiber.hooks[update.hook_index] = value;
fiber_changed = true;
}
StateUpdateKind::Updater(updater) => {
let current = &*fiber.hooks[update.hook_index];
let new_value = updater(current);
fiber.hooks[update.hook_index] = new_value;
fiber_changed = true;
}
StateUpdateKind::ValueIfChanged { value, eq_check } => {
let current = &*fiber.hooks[update.hook_index];
if !eq_check(current, &*value) {
fiber.hooks[update.hook_index] = value;
fiber_changed = true;
}
}
StateUpdateKind::UpdaterIfChanged { updater, eq_check } => {
let current = &*fiber.hooks[update.hook_index];
let new_value = updater(current);
if !eq_check(&*fiber.hooks[update.hook_index], &*new_value) {
fiber.hooks[update.hook_index] = new_value;
fiber_changed = true;
}
}
}
}
if fiber_changed {
fiber.dirty = true;
actually_dirty.insert(fiber_id);
}
}
}
self.dirty_fibers.clear();
actually_dirty
}
pub fn queue_update(&mut self, fiber_id: FiberId, update: StateUpdate) {
self.updates.entry(fiber_id).or_default().push(update);
self.dirty_fibers.insert(fiber_id);
}
pub fn is_batching(&self) -> bool {
self.batching
}
pub fn take_updates(&mut self, fiber_id: FiberId) -> Vec<StateUpdate> {
self.updates.remove(&fiber_id).unwrap_or_default()
}
pub fn has_pending_updates(&self) -> bool {
!self.updates.is_empty()
}
pub fn dirty_fiber_count(&self) -> usize {
self.dirty_fibers.len()
}
pub fn is_fiber_dirty(&self, fiber_id: FiberId) -> bool {
self.dirty_fibers.contains(&fiber_id)
}
pub fn clear(&mut self) {
self.updates.clear();
self.dirty_fibers.clear();
self.batching = false;
}
}
impl Default for StateBatch {
fn default() -> Self {
Self::new()
}
}
pub fn init_main_thread() {
if let Ok(mut guard) = MAIN_THREAD_ID.write() {
*guard = Some(std::thread::current().id());
}
}
pub fn is_main_thread() -> bool {
if let Ok(guard) = MAIN_THREAD_ID.read() {
match *guard {
Some(id) => id == std::thread::current().id(),
None => true, }
} else {
true }
}
#[cfg(test)]
pub fn reset_main_thread() {
if let Ok(mut guard) = MAIN_THREAD_ID.write() {
*guard = None;
}
}
pub fn queue_cross_thread_update(update: CrossThreadUpdate) {
if let Ok(mut queue) = CROSS_THREAD_UPDATES.lock() {
queue.push(update);
} else {
tracing::error!("Cross-thread update queue mutex is poisoned, update dropped");
}
}
pub fn drain_cross_thread_updates() {
if let Ok(mut queue) = CROSS_THREAD_UPDATES.lock() {
for update in queue.drain(..) {
let state_update = StateUpdate {
hook_index: update.hook_index,
update: match update.update {
CrossThreadUpdateKind::Value(v) => StateUpdateKind::Value(v),
CrossThreadUpdateKind::Updater(arc_mutex) => {
StateUpdateKind::Updater(Box::new(move |any| {
if let Ok(mut guard) = arc_mutex.lock() {
if let Some(f) = guard.take() {
f(any)
} else {
panic!("Cross-thread updater called more than once");
}
} else {
panic!("Cross-thread updater mutex poisoned");
}
}))
}
CrossThreadUpdateKind::ValueIfChanged { value, type_id } => {
StateUpdateKind::ValueIfChanged {
value,
eq_check: Box::new(move |old, new| {
reconstruct_equality_check(old, new, type_id)
}),
}
}
CrossThreadUpdateKind::UpdaterIfChanged { updater } => {
StateUpdateKind::UpdaterIfChanged {
updater: Box::new(move |any| {
if let Ok(mut guard) = updater.lock() {
if let Some(f) = guard.take() {
f(any)
} else {
panic!("Cross-thread updater called more than once");
}
} else {
panic!("Cross-thread updater mutex poisoned");
}
}),
eq_check: Box::new(move |old, new| {
let type_id = old.type_id();
reconstruct_equality_check(old, new, type_id)
}),
}
}
},
};
STATE_BATCH.with(|batch| {
batch
.borrow_mut()
.queue_update(update.fiber_id, state_update);
});
}
} else {
tracing::error!("Cross-thread update queue mutex is poisoned, updates not drained");
}
}
fn reconstruct_equality_check(old: &dyn Any, new: &dyn Any, type_id: std::any::TypeId) -> bool {
if old.type_id() != type_id || new.type_id() != type_id {
return false;
}
if type_id == std::any::TypeId::of::<i32>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<i32>(), new.downcast_ref::<i32>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<i64>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<i64>(), new.downcast_ref::<i64>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<u32>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<u32>(), new.downcast_ref::<u32>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<u64>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<u64>(), new.downcast_ref::<u64>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<usize>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<usize>(), new.downcast_ref::<usize>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<f32>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<f32>(), new.downcast_ref::<f32>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<f64>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<f64>(), new.downcast_ref::<f64>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<bool>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<bool>(), new.downcast_ref::<bool>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<String>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<String>(), new.downcast_ref::<String>())
{
return old_val == new_val;
}
}
if type_id == std::any::TypeId::of::<&str>() {
if let (Some(old_val), Some(new_val)) =
(old.downcast_ref::<&str>(), new.downcast_ref::<&str>())
{
return old_val == new_val;
}
}
false
}
pub fn has_cross_thread_updates() -> bool {
CROSS_THREAD_UPDATES
.lock()
.map(|queue| !queue.is_empty())
.unwrap_or(false)
}
#[cfg(test)]
pub fn clear_cross_thread_updates() {
if let Ok(mut queue) = CROSS_THREAD_UPDATES.lock() {
queue.clear();
}
}
#[cfg(test)]
pub fn test_simulate_reentrant_update(
fiber_id: FiberId,
hook_index: usize,
value: Box<dyn Any + Send>,
) {
STATE_BATCH.with(|batch| {
let _guard = batch.borrow_mut();
queue_update(
fiber_id,
StateUpdate {
hook_index,
update: StateUpdateKind::Value(value),
},
);
});
}
pub fn begin_batch() {
STATE_BATCH.with(|batch| {
batch.borrow_mut().begin_batch();
});
}
pub fn end_batch_with_tree(tree: &mut FiberTree) -> HashSet<FiberId> {
STATE_BATCH.with(|batch| batch.borrow_mut().end_batch(tree))
}
pub fn end_batch() -> HashSet<FiberId> {
crate::fiber_tree::with_fiber_tree_mut(|tree| {
STATE_BATCH.with(|batch| batch.borrow_mut().end_batch(tree))
})
.unwrap_or_default()
}
pub fn queue_update(fiber_id: FiberId, update: StateUpdate) {
#[cfg(debug_assertions)]
{
if crate::runtime::is_in_render_phase() {
tracing::warn!(
fiber_id = ?fiber_id,
hook_index = update.hook_index,
"State update queued during render phase! State updates should be triggered by events or effects, not during render. \
This can lead to infinite render loops and performance issues."
);
}
}
if !is_main_thread() {
queue_update_to_cross_thread(fiber_id, update);
return;
}
let fallback_update: std::cell::Cell<Option<StateUpdate>> = std::cell::Cell::new(None);
let queued = STATE_BATCH.with(|batch| {
match batch.try_borrow_mut() {
Ok(mut b) => {
b.queue_update(fiber_id, update);
true
}
Err(_) => {
fallback_update.set(Some(update));
false
}
}
});
if !queued {
if let Some(update) = fallback_update.take() {
queue_update_to_cross_thread(fiber_id, update);
}
}
}
fn queue_update_to_cross_thread(fiber_id: FiberId, update: StateUpdate) {
let cross_update = CrossThreadUpdate {
fiber_id,
hook_index: update.hook_index,
update: match update.update {
StateUpdateKind::Value(v) => CrossThreadUpdateKind::Value(v),
StateUpdateKind::Updater(f) => {
CrossThreadUpdateKind::Updater(Arc::new(Mutex::new(Some(f))))
}
StateUpdateKind::ValueIfChanged { value, eq_check: _ } => {
let type_id = (*value).type_id();
CrossThreadUpdateKind::ValueIfChanged { value, type_id }
}
StateUpdateKind::UpdaterIfChanged {
updater,
eq_check: _,
} => {
CrossThreadUpdateKind::UpdaterIfChanged {
updater: Arc::new(Mutex::new(Some(updater))),
}
}
},
};
queue_cross_thread_update(cross_update);
}
pub fn is_batching() -> bool {
STATE_BATCH.with(|batch| batch.borrow().is_batching())
}
pub fn with_state_batch<R, F: FnOnce(&StateBatch) -> R>(f: F) -> R {
STATE_BATCH.with(|batch| f(&batch.borrow()))
}
pub fn with_state_batch_mut<R, F: FnOnce(&mut StateBatch) -> R>(f: F) -> R {
STATE_BATCH.with(|batch| f(&mut batch.borrow_mut()))
}
pub fn clear_state_batch() {
STATE_BATCH.with(|batch| {
batch.borrow_mut().clear();
});
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_state_batch_creation() {
let batch = StateBatch::new();
assert!(!batch.is_batching());
assert!(!batch.has_pending_updates());
assert_eq!(batch.dirty_fiber_count(), 0);
}
#[test]
fn test_begin_and_end_batch() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
assert!(!batch.is_batching());
batch.begin_batch();
assert!(batch.is_batching());
let dirty = batch.end_batch(&mut tree);
assert!(!batch.is_batching());
assert!(dirty.is_empty());
}
#[test]
fn test_queue_update_marks_fiber_dirty() {
let mut batch = StateBatch::new();
let fiber_id = FiberId(1);
assert!(!batch.is_fiber_dirty(fiber_id));
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
assert!(batch.is_fiber_dirty(fiber_id));
assert!(batch.has_pending_updates());
assert_eq!(batch.dirty_fiber_count(), 1);
}
#[test]
fn test_multiple_updates_same_fiber() {
let mut batch = StateBatch::new();
let fiber_id = FiberId(1);
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(1i32)),
},
);
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(2i32)),
},
);
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 1,
update: StateUpdateKind::Value(Box::new("hello".to_string())),
},
);
assert_eq!(batch.dirty_fiber_count(), 1);
assert!(batch.has_pending_updates());
}
#[test]
fn test_multiple_fibers_dirty() {
let mut batch = StateBatch::new();
let fiber1 = FiberId(1);
let fiber2 = FiberId(2);
let fiber3 = FiberId(3);
batch.queue_update(
fiber1,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(1i32)),
},
);
batch.queue_update(
fiber2,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(2i32)),
},
);
batch.queue_update(
fiber3,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(3i32)),
},
);
assert_eq!(batch.dirty_fiber_count(), 3);
assert!(batch.is_fiber_dirty(fiber1));
assert!(batch.is_fiber_dirty(fiber2));
assert!(batch.is_fiber_dirty(fiber3));
}
#[test]
fn test_end_batch_applies_value_updates() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
let dirty = batch.end_batch(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(42));
}
#[test]
fn test_end_batch_applies_functional_updates() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 10i32);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(|current| {
let val = current.downcast_ref::<i32>().unwrap();
Box::new(val + 5)
})),
},
);
let dirty = batch.end_batch(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(15));
}
#[test]
fn test_chained_functional_updates() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
batch.begin_batch();
for _ in 0..5 {
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(|current| {
let val = current.downcast_ref::<i32>().unwrap();
Box::new(val + 1)
})),
},
);
}
batch.end_batch(&mut tree);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(5));
}
#[test]
fn test_take_updates() {
let mut batch = StateBatch::new();
let fiber_id = FiberId(1);
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(1i32)),
},
);
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 1,
update: StateUpdateKind::Value(Box::new(2i32)),
},
);
let updates = batch.take_updates(fiber_id);
assert_eq!(updates.len(), 2);
let updates_again = batch.take_updates(fiber_id);
assert!(updates_again.is_empty());
}
#[test]
fn test_clear_batch() {
let mut batch = StateBatch::new();
let fiber_id = FiberId(1);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
assert!(batch.is_batching());
assert!(batch.has_pending_updates());
batch.clear();
assert!(!batch.is_batching());
assert!(!batch.has_pending_updates());
assert_eq!(batch.dirty_fiber_count(), 0);
}
#[test]
fn test_thread_local_begin_batch() {
clear_state_batch();
assert!(!is_batching());
begin_batch();
assert!(is_batching());
clear_state_batch();
}
#[test]
#[serial]
fn test_thread_local_queue_and_end_batch() {
clear_state_batch();
clear_cross_thread_updates();
reset_main_thread();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
begin_batch();
with_state_batch_mut(|batch| {
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(100i32)),
},
);
});
let dirty = end_batch_with_tree(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(100));
clear_state_batch();
}
#[test]
#[serial]
fn test_with_state_batch() {
clear_state_batch();
clear_cross_thread_updates();
reset_main_thread();
let fiber_id = FiberId(1);
with_state_batch_mut(|batch| {
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
});
let has_updates = with_state_batch(|batch| batch.has_pending_updates());
assert!(has_updates);
let is_dirty = with_state_batch(|batch| batch.is_fiber_dirty(fiber_id));
assert!(is_dirty);
clear_state_batch();
}
#[test]
#[serial]
fn test_with_state_batch_mut() {
clear_state_batch();
with_state_batch_mut(|batch| {
batch.begin_batch();
});
assert!(is_batching());
clear_state_batch();
}
#[test]
fn test_end_batch_marks_fiber_dirty_in_tree() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.mark_clean(fiber_id);
assert!(!tree.get(fiber_id).unwrap().dirty);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
batch.end_batch(&mut tree);
assert!(tree.get(fiber_id).unwrap().dirty);
}
#[test]
fn test_update_nonexistent_fiber() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let nonexistent_fiber = FiberId(999);
batch.begin_batch();
batch.queue_update(
nonexistent_fiber,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
let dirty = batch.end_batch(&mut tree);
assert!(!dirty.contains(&nonexistent_fiber));
}
#[test]
fn test_default_impl() {
let batch: StateBatch = Default::default();
assert!(!batch.is_batching());
assert!(!batch.has_pending_updates());
}
#[test]
fn test_value_if_changed_skips_equal_values() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 42i32);
tree.mark_clean(fiber_id);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::ValueIfChanged {
value: Box::new(42i32), eq_check: Box::new(|old, new| {
let old = old.downcast_ref::<i32>().unwrap();
let new = new.downcast_ref::<i32>().unwrap();
old == new
}),
},
},
);
let dirty = batch.end_batch(&mut tree);
assert!(!dirty.contains(&fiber_id));
assert!(!tree.get(fiber_id).unwrap().dirty);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(42));
}
#[test]
fn test_value_if_changed_updates_different_values() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 42i32);
tree.mark_clean(fiber_id);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::ValueIfChanged {
value: Box::new(100i32), eq_check: Box::new(|old, new| {
let old = old.downcast_ref::<i32>().unwrap();
let new = new.downcast_ref::<i32>().unwrap();
old == new
}),
},
},
);
let dirty = batch.end_batch(&mut tree);
assert!(dirty.contains(&fiber_id));
assert!(tree.get(fiber_id).unwrap().dirty);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(100));
}
#[test]
fn test_updater_if_changed_skips_equal_results() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 5i32);
tree.mark_clean(fiber_id);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::UpdaterIfChanged {
updater: Box::new(|any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new((*n).max(3)) }),
eq_check: Box::new(|old, new| {
let old = old.downcast_ref::<i32>().unwrap();
let new = new.downcast_ref::<i32>().unwrap();
old == new
}),
},
},
);
let dirty = batch.end_batch(&mut tree);
assert!(!dirty.contains(&fiber_id));
assert!(!tree.get(fiber_id).unwrap().dirty);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(5));
}
#[test]
fn test_updater_if_changed_updates_different_results() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 5i32);
tree.mark_clean(fiber_id);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::UpdaterIfChanged {
updater: Box::new(|any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new((*n).max(10)) }),
eq_check: Box::new(|old, new| {
let old = old.downcast_ref::<i32>().unwrap();
let new = new.downcast_ref::<i32>().unwrap();
old == new
}),
},
},
);
let dirty = batch.end_batch(&mut tree);
assert!(dirty.contains(&fiber_id));
assert!(tree.get(fiber_id).unwrap().dirty);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(10));
}
#[test]
fn test_mixed_updates_with_equality_check() {
let mut batch = StateBatch::new();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 10i32);
tree.mark_clean(fiber_id);
batch.begin_batch();
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(20i32)),
},
);
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::ValueIfChanged {
value: Box::new(20i32), eq_check: Box::new(|old, new| {
let old = old.downcast_ref::<i32>().unwrap();
let new = new.downcast_ref::<i32>().unwrap();
old == new
}),
},
},
);
let dirty = batch.end_batch(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(20));
}
#[test]
#[serial]
fn test_is_main_thread_without_init() {
reset_main_thread();
assert!(is_main_thread());
}
#[test]
#[serial]
fn test_init_main_thread_and_is_main_thread() {
reset_main_thread();
init_main_thread();
assert!(is_main_thread());
reset_main_thread();
}
#[test]
#[serial]
fn test_queue_cross_thread_update_adds_to_queue() {
reset_main_thread();
clear_cross_thread_updates();
let fiber_id = FiberId(42);
let update = CrossThreadUpdate {
fiber_id,
hook_index: 0,
update: CrossThreadUpdateKind::Value(Box::new(123i32)),
};
queue_cross_thread_update(update);
assert!(has_cross_thread_updates());
clear_cross_thread_updates();
}
#[test]
#[serial]
fn test_drain_cross_thread_updates_moves_to_local_batch() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let fiber_id = FiberId(42);
let update = CrossThreadUpdate {
fiber_id,
hook_index: 0,
update: CrossThreadUpdateKind::Value(Box::new(999i32)),
};
queue_cross_thread_update(update);
assert!(has_cross_thread_updates());
drain_cross_thread_updates();
assert!(!has_cross_thread_updates());
let has_updates = with_state_batch(|batch| batch.has_pending_updates());
assert!(has_updates);
let is_dirty = with_state_batch(|batch| batch.is_fiber_dirty(fiber_id));
assert!(is_dirty);
clear_state_batch();
}
#[test]
fn test_drain_cross_thread_updates_applies_value() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
let update = CrossThreadUpdate {
fiber_id,
hook_index: 0,
update: CrossThreadUpdateKind::Value(Box::new(42i32)),
};
queue_cross_thread_update(update);
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(42));
clear_state_batch();
}
#[test]
fn test_drain_cross_thread_updates_applies_updater() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 10i32);
let updater: StateUpdaterFn = Box::new(|any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + 5)
});
let update = CrossThreadUpdate {
fiber_id,
hook_index: 0,
update: CrossThreadUpdateKind::Updater(Arc::new(Mutex::new(Some(updater)))),
};
queue_cross_thread_update(update);
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(15));
clear_state_batch();
}
#[test]
fn test_cross_thread_updates_preserve_order() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
for i in 1..=5 {
let updater: StateUpdaterFn = Box::new(move |any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + i)
});
let update = CrossThreadUpdate {
fiber_id,
hook_index: 0,
update: CrossThreadUpdateKind::Updater(Arc::new(Mutex::new(Some(updater)))),
};
queue_cross_thread_update(update);
}
begin_batch();
drain_cross_thread_updates();
end_batch_with_tree(&mut tree);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(15));
clear_state_batch();
}
#[test]
fn test_queue_update_routes_to_local_on_main_thread() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let fiber_id = FiberId(1);
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
let has_local = with_state_batch(|batch| batch.has_pending_updates());
assert!(has_local);
assert!(!has_cross_thread_updates());
clear_state_batch();
}
#[test]
fn test_queue_update_fallback_to_cross_thread_on_reentrant_call() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let fiber_id = FiberId(1);
STATE_BATCH.with(|batch| {
let _guard = batch.borrow_mut();
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
assert!(has_cross_thread_updates());
});
let has_local = with_state_batch(|batch| batch.has_pending_updates());
assert!(!has_local);
assert!(has_cross_thread_updates());
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn test_queue_update_fallback_applies_correctly_after_drain() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
STATE_BATCH.with(|batch| {
let _guard = batch.borrow_mut();
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(100i32)),
},
);
});
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(100));
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn test_queue_update_fallback_with_functional_updater() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 10i32);
STATE_BATCH.with(|batch| {
let _guard = batch.borrow_mut();
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(|any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n * 2)
})),
},
);
});
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(20));
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn test_queue_update_fallback_preserves_order_with_multiple_updates() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
STATE_BATCH.with(|batch| {
let _guard = batch.borrow_mut();
for i in 1..=3 {
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(move |any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + i)
})),
},
);
}
});
begin_batch();
drain_cross_thread_updates();
end_batch_with_tree(&mut tree);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(6));
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn test_queue_update_mixed_local_and_fallback() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(10i32)),
},
);
STATE_BATCH.with(|batch| {
let _guard = batch.borrow_mut();
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(|any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + 5)
})),
},
);
});
let has_local = with_state_batch(|batch| batch.has_pending_updates());
assert!(has_local);
assert!(has_cross_thread_updates());
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
assert!(dirty.contains(&fiber_id));
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(15));
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn test_queue_update_fallback_with_value_if_changed() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 42i32);
tree.mark_clean(fiber_id);
STATE_BATCH.with(|batch| {
let _guard = batch.borrow_mut();
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::ValueIfChanged {
value: Box::new(42i32), eq_check: Box::new(|old, new| {
let old = old.downcast_ref::<i32>().unwrap();
let new = new.downcast_ref::<i32>().unwrap();
old == new
}),
},
},
);
});
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
assert!(!dirty.contains(&fiber_id));
assert!(!tree.get(fiber_id).unwrap().dirty);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(42));
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn test_try_borrow_mut_success_path() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let fiber_id = FiberId(1);
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(42i32)),
},
);
let has_local = with_state_batch(|batch| batch.has_pending_updates());
assert!(has_local);
assert!(!has_cross_thread_updates());
clear_state_batch();
}
#[test]
fn test_background_thread_always_uses_cross_thread_queue() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
init_main_thread();
let fiber_id = FiberId(1);
let handle = std::thread::spawn(move || {
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(99i32)),
},
);
});
handle.join().unwrap();
assert!(has_cross_thread_updates());
let has_local = with_state_batch(|batch| batch.has_pending_updates());
assert!(!has_local);
clear_state_batch();
clear_cross_thread_updates();
reset_main_thread();
}
#[test]
fn test_multiple_background_threads_queue_updates() {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
init_main_thread();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
let handles: Vec<_> = (1..=5)
.map(|i| {
std::thread::spawn(move || {
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(move |any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + i)
})),
},
);
})
})
.collect();
for handle in handles {
handle.join().unwrap();
}
assert!(has_cross_thread_updates());
begin_batch();
drain_cross_thread_updates();
end_batch_with_tree(&mut tree);
assert_eq!(tree.get(fiber_id).unwrap().get_hook::<i32>(0), Some(15));
clear_state_batch();
clear_cross_thread_updates();
reset_main_thread();
}
}
#[cfg(test)]
mod property_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_cross_thread_updates_applied_after_drain(
values in prop::collection::vec(any::<i32>(), 1..10)
) {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
init_main_thread();
let barrier = Arc::new(std::sync::Barrier::new(values.len()));
let handles: Vec<_> = values.iter().map(|&val| {
let barrier = Arc::clone(&barrier);
std::thread::spawn(move || {
barrier.wait(); queue_cross_thread_update(CrossThreadUpdate {
fiber_id,
hook_index: 0,
update: CrossThreadUpdateKind::Value(Box::new(val)),
});
})
}).collect();
for handle in handles {
handle.join().unwrap();
}
prop_assert!(has_cross_thread_updates(),
"Cross-thread queue should have updates after background thread queuing");
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
prop_assert!(dirty.contains(&fiber_id),
"Fiber should be marked dirty after cross-thread updates");
prop_assert!(!has_cross_thread_updates(),
"Cross-thread queue should be empty after drain");
let final_value = tree.get(fiber_id).unwrap().get_hook::<i32>(0);
prop_assert!(final_value.is_some(),
"Fiber should have a value after updates");
clear_state_batch();
clear_cross_thread_updates();
reset_main_thread();
}
#[test]
fn prop_update_ordering_preserved(
increments in prop::collection::vec(1i32..10, 1..20)
) {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
begin_batch();
for inc in &increments {
let inc_val = *inc;
with_state_batch_mut(|batch| {
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(move |any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + inc_val)
})),
},
);
});
}
let dirty = end_batch_with_tree(&mut tree);
prop_assert!(dirty.contains(&fiber_id),
"Fiber should be marked dirty after updates");
let expected: i32 = increments.iter().sum();
let actual = tree.get(fiber_id).unwrap().get_hook::<i32>(0);
prop_assert_eq!(actual, Some(expected),
"Final value should be sum of all increments: expected {}, got {:?}", expected, actual);
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn prop_main_thread_uses_local_batch(
values in prop::collection::vec(any::<i32>(), 1..10)
) {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let fiber_id = FiberId(1);
for &val in &values {
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(val)),
},
);
}
let has_local = with_state_batch(|batch| batch.has_pending_updates());
prop_assert!(has_local,
"Updates from main thread should be in local batch");
prop_assert!(!has_cross_thread_updates(),
"Cross-thread queue should be empty for main thread updates");
let dirty_count = with_state_batch(|batch| batch.dirty_fiber_count());
prop_assert_eq!(dirty_count, 1,
"Only one fiber should be dirty regardless of update count");
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn prop_concurrent_access_safe(
thread_count in 2usize..10,
updates_per_thread in 1usize..5
) {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
init_main_thread();
let barrier = Arc::new(std::sync::Barrier::new(thread_count));
let handles: Vec<_> = (0..thread_count).map(|_| {
let barrier = Arc::clone(&barrier);
std::thread::spawn(move || {
barrier.wait(); for _ in 0..updates_per_thread {
queue_cross_thread_update(CrossThreadUpdate {
fiber_id,
hook_index: 0,
update: CrossThreadUpdateKind::Updater(Arc::new(Mutex::new(Some(
Box::new(move |any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + 1)
})
)))),
});
}
})
}).collect();
for handle in handles {
handle.join().expect("Thread should not panic");
}
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
prop_assert!(dirty.contains(&fiber_id),
"Fiber should be marked dirty after concurrent updates");
let expected = (thread_count * updates_per_thread) as i32;
let actual = tree.get(fiber_id).unwrap().get_hook::<i32>(0);
prop_assert_eq!(actual, Some(expected),
"All {} updates should be applied, got {:?}", expected, actual);
clear_state_batch();
clear_cross_thread_updates();
reset_main_thread();
}
#[test]
fn prop_state_batch_atomicity(
fiber_count in 1usize..10,
updates_per_fiber in 1usize..10
) {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_ids: Vec<_> = (0..fiber_count)
.map(|_| {
let id = tree.mount(None, None);
tree.get_mut(id).unwrap().set_hook(0, 0i32);
id
})
.collect();
begin_batch();
for &fiber_id in &fiber_ids {
for i in 0..updates_per_fiber {
with_state_batch_mut(|batch| {
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(move |any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + (i as i32 + 1))
})),
},
);
});
}
}
for &fiber_id in &fiber_ids {
let value = tree.get(fiber_id).unwrap().get_hook::<i32>(0);
prop_assert_eq!(value, Some(0),
"Updates should not be applied until end_batch");
}
let dirty = end_batch_with_tree(&mut tree);
prop_assert_eq!(dirty.len(), fiber_count,
"All {} fibers should be marked dirty", fiber_count);
for &fiber_id in &fiber_ids {
prop_assert!(dirty.contains(&fiber_id),
"Fiber {:?} should be in dirty set", fiber_id);
}
let expected_sum: i32 = (1..=updates_per_fiber as i32).sum();
for &fiber_id in &fiber_ids {
let value = tree.get(fiber_id).unwrap().get_hook::<i32>(0);
prop_assert_eq!(value, Some(expected_sum),
"All updates should be applied atomically, expected {}, got {:?}",
expected_sum, value);
}
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn prop_functional_updater_chaining(
operations in prop::collection::vec((any::<bool>(), 1i32..10), 1..20)
) {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
begin_batch();
for &(is_add, value) in &operations {
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Updater(Box::new(move |any| {
let n = any.downcast_ref::<i32>().unwrap();
if is_add {
Box::new(n.saturating_add(value))
} else {
Box::new(n.saturating_mul(value))
}
})),
},
);
}
end_batch_with_tree(&mut tree);
let mut expected = 0i32;
for &(is_add, value) in &operations {
if is_add {
expected = expected.saturating_add(value);
} else {
expected = expected.saturating_mul(value);
}
}
let actual = tree.get(fiber_id).unwrap().get_hook::<i32>(0);
prop_assert_eq!(actual, Some(expected),
"Functional updaters should chain correctly: expected {}, got {:?}",
expected, actual);
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn prop_equality_check_optimization(
initial_value in any::<i32>(),
same_value_updates in 1usize..10,
different_value in any::<i32>()
) {
prop_assume!(initial_value != different_value);
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, initial_value);
tree.mark_clean(fiber_id);
begin_batch();
for _ in 0..same_value_updates {
with_state_batch_mut(|batch| {
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::ValueIfChanged {
value: Box::new(initial_value),
eq_check: Box::new(|old, new| {
let old = old.downcast_ref::<i32>().unwrap();
let new = new.downcast_ref::<i32>().unwrap();
old == new
}),
},
},
);
});
}
let dirty = end_batch_with_tree(&mut tree);
prop_assert!(!dirty.contains(&fiber_id),
"Fiber should not be dirty when ValueIfChanged updates have equal values");
prop_assert!(!tree.get(fiber_id).unwrap().dirty,
"Fiber dirty flag should be false");
tree.mark_clean(fiber_id);
begin_batch();
with_state_batch_mut(|batch| {
batch.queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::ValueIfChanged {
value: Box::new(different_value),
eq_check: Box::new(|old, new| {
let old = old.downcast_ref::<i32>().unwrap();
let new = new.downcast_ref::<i32>().unwrap();
old == new
}),
},
},
);
});
let dirty = end_batch_with_tree(&mut tree);
prop_assert!(dirty.contains(&fiber_id),
"Fiber should be dirty when ValueIfChanged has different value");
prop_assert_eq!(
tree.get(fiber_id).unwrap().get_hook::<i32>(0),
Some(different_value),
"Value should be updated to different_value"
);
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn prop_reentrant_fallback_safe(
values in prop::collection::vec(any::<i32>(), 1..5)
) {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let fiber_id = FiberId(1);
STATE_BATCH.with(|batch| {
let _guard = batch.borrow_mut();
for &val in &values {
queue_update(
fiber_id,
StateUpdate {
hook_index: 0,
update: StateUpdateKind::Value(Box::new(val)),
},
);
}
});
prop_assert!(has_cross_thread_updates(),
"Re-entrant updates should fall back to cross-thread queue");
let has_local = with_state_batch(|batch| batch.has_pending_updates());
prop_assert!(!has_local,
"Local batch should be empty after re-entrant fallback");
clear_state_batch();
clear_cross_thread_updates();
}
#[test]
fn prop_cross_thread_update_delivery(
update_count in 1usize..20,
thread_count in 1usize..5
) {
reset_main_thread();
clear_state_batch();
clear_cross_thread_updates();
let mut tree = FiberTree::new();
let fiber_id = tree.mount(None, None);
tree.get_mut(fiber_id).unwrap().set_hook(0, 0i32);
init_main_thread();
let barrier = Arc::new(std::sync::Barrier::new(thread_count));
let handles: Vec<_> = (0..thread_count).map(|_| {
let barrier = Arc::clone(&barrier);
std::thread::spawn(move || {
barrier.wait(); for _ in 0..update_count {
queue_cross_thread_update(CrossThreadUpdate {
fiber_id,
hook_index: 0,
update: CrossThreadUpdateKind::Updater(Arc::new(Mutex::new(Some(
Box::new(move |any| {
let n = any.downcast_ref::<i32>().unwrap();
Box::new(n + 1)
})
)))),
});
}
})
}).collect();
for handle in handles {
handle.join().expect("Background thread should not panic");
}
prop_assert!(has_cross_thread_updates(),
"Cross-thread queue should have updates after background threads queue them");
begin_batch();
drain_cross_thread_updates();
let dirty = end_batch_with_tree(&mut tree);
prop_assert!(dirty.contains(&fiber_id),
"Fiber should be marked dirty after cross-thread updates are drained");
let expected = (update_count * thread_count) as i32;
let actual = tree.get(fiber_id).unwrap().get_hook::<i32>(0);
prop_assert_eq!(actual, Some(expected),
"All {} updates should be applied after drain, got {:?}", expected, actual);
prop_assert!(!has_cross_thread_updates(),
"Cross-thread queue should be empty after drain");
clear_state_batch();
clear_cross_thread_updates();
reset_main_thread();
}
}
}