mod common;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
use common::{RecordedEvent, Recorder, TestBinding, TestRuntime, TestValue};
use graphrefly_core::{
BindingBoundary, Core, DepBatch, EqualsMode, FnId, FnResult, HandleId, Message, NodeId, Sink,
};
#[test]
fn fn_can_reenter_core_emit_during_invoke_fn_runs_nested_wave() {
let rt = TestRuntime::new();
let s_in = rt.state(Some(TestValue::Int(0)));
let s_side = rt.state(Some(TestValue::Int(0)));
let core = rt.core.clone();
let s_side_id = s_side.id;
let binding = rt.binding.clone();
let d = rt.derived(&[s_in.id], move |deps| {
if let TestValue::Int(n) = deps[0] {
let h = binding.intern(TestValue::Int(n * 10));
core.emit(s_side_id, h);
}
Some(deps[0].clone())
});
let rec_d = rt.subscribe_recorder(d);
let rec_side = rt.subscribe_recorder(s_side.id);
s_in.set(TestValue::Int(7));
assert_eq!(
rec_d.data_values(),
vec![TestValue::Int(0), TestValue::Int(7)]
);
assert_eq!(
rec_side.data_values(),
vec![TestValue::Int(0), TestValue::Int(70)]
);
}
#[test]
fn fn_can_reenter_core_pause_resume_during_invoke_fn() {
let rt = TestRuntime::new();
let s_in = rt.state(Some(TestValue::Int(0)));
let s_other = rt.state(Some(TestValue::Int(100)));
let core = rt.core.clone();
let s_other_id = s_other.id;
let pause_lock = core.alloc_lock_id();
let _d = rt.derived(&[s_in.id], move |_deps| {
core.pause(s_other_id, pause_lock).expect("pause");
let report = core.resume(s_other_id, pause_lock).expect("resume");
if let Some(r) = report {
assert_eq!(r.replayed, 0);
assert_eq!(r.dropped, 0);
}
None
});
let _rec_d = rt.subscribe_recorder(_d);
s_in.set(TestValue::Int(1));
}
#[test]
fn fn_can_reenter_core_invalidate_during_invoke_fn() {
let rt = TestRuntime::new();
let s_in = rt.state(Some(TestValue::Int(0)));
let s_other = rt.state(Some(TestValue::Int(100)));
let s_other_id = s_other.id;
let rec_other = rt.subscribe_recorder(s_other_id);
let core = rt.core.clone();
let _d = rt.derived(&[s_in.id], move |deps| {
core.invalidate(s_other_id);
Some(deps[0].clone())
});
let _rec_d = rt.subscribe_recorder(_d);
s_in.set(TestValue::Int(1));
assert!(rec_other.snapshot().contains(&RecordedEvent::Invalidate));
}
#[test]
fn custom_equals_can_reenter_core_during_emission() {
let rt = TestRuntime::new();
let s_in = rt.state(Some(TestValue::Int(0)));
let probe_state = rt.state(Some(TestValue::Int(99)));
let core = rt.core.clone();
let probe_id = probe_state.id;
let d = rt.derived_with_equals(
&[s_in.id],
|deps| Some(deps[0].clone()),
move |_a, _b| {
let _ = core.cache_of(probe_id);
false
},
);
let rec = rt.subscribe_recorder(d);
s_in.set(TestValue::Int(1));
s_in.set(TestValue::Int(2));
assert_eq!(
rec.data_values(),
vec![TestValue::Int(0), TestValue::Int(1), TestValue::Int(2)]
);
}
#[test]
fn handshake_tier_split_sentinel_state_one_call() {
let rt = TestRuntime::new();
let s = rt.state(None);
let rec = rt.subscribe_recorder(s.id);
assert_eq!(rec.call_count(), 1, "sentinel handshake = 1 sink call");
assert_eq!(rec.call_boundaries(), vec![1]);
assert_eq!(rec.snapshot(), vec![RecordedEvent::Start]);
}
#[test]
fn handshake_tier_split_cached_state_two_calls() {
let rt = TestRuntime::new();
let s = rt.state(Some(TestValue::Int(42)));
let rec = rt.subscribe_recorder(s.id);
assert_eq!(rec.call_count(), 2, "cached handshake = 2 sink calls");
assert_eq!(rec.call_boundaries(), vec![1, 1]);
assert_eq!(
rec.snapshot(),
vec![
RecordedEvent::Start,
RecordedEvent::Data(TestValue::Int(42))
]
);
}
#[test]
fn handshake_tier_split_terminated_state_three_calls() {
let rt = TestRuntime::new();
let s = rt.state(Some(TestValue::Int(7)));
rt.core.complete(s.id);
let rec = rt.subscribe_recorder(s.id);
assert_eq!(
rec.call_count(),
3,
"cached + complete handshake = 3 sink calls"
);
assert_eq!(rec.call_boundaries(), vec![1, 1, 1]);
assert_eq!(
rec.snapshot(),
vec![
RecordedEvent::Start,
RecordedEvent::Data(TestValue::Int(7)),
RecordedEvent::Complete,
]
);
}
#[test]
fn handshake_tier_split_torndown_node_four_calls() {
let rt = TestRuntime::new();
let s = rt.state(Some(TestValue::Int(5)));
rt.core.teardown(s.id);
let rec = rt.subscribe_recorder(s.id);
assert_eq!(rec.call_count(), 4, "torn-down handshake = 4 sink calls");
assert_eq!(rec.call_boundaries(), vec![1, 1, 1, 1]);
assert_eq!(
rec.snapshot(),
vec![
RecordedEvent::Start,
RecordedEvent::Data(TestValue::Int(5)),
RecordedEvent::Complete,
RecordedEvent::Teardown,
]
);
}
#[test]
fn handshake_tier_split_error_terminated_three_calls() {
let rt = TestRuntime::new();
let s = rt.state(Some(TestValue::Int(1)));
let err_h = rt.binding.intern(TestValue::Str("boom".into()));
rt.core.error(s.id, err_h);
let rec = rt.subscribe_recorder(s.id);
assert_eq!(rec.call_count(), 3);
assert_eq!(rec.call_boundaries(), vec![1, 1, 1]);
assert_eq!(
rec.snapshot(),
vec![
RecordedEvent::Start,
RecordedEvent::Data(TestValue::Int(1)),
RecordedEvent::Error(TestValue::Str("boom".into())),
]
);
}
#[test]
fn late_subscriber_installed_after_first_queue_notify_does_not_double_receive_data() {
let rt = TestRuntime::new();
let s = rt.state(Some(TestValue::Int(0)));
let rec_existing = rt.subscribe_recorder(s.id);
let late_recorder: Arc<Mutex<Option<Recorder>>> = Arc::new(Mutex::new(None));
let late_rec = Recorder::new();
let late_sink = late_rec.sink(rt.binding.clone());
*late_recorder.lock().unwrap() = Some(late_rec);
let core = rt.core.clone();
let s_id = s.id;
let binding = rt.binding.clone();
let late_recorder_for_fn = late_recorder.clone();
let late_sink_holder: Arc<Mutex<Option<Sink>>> = Arc::new(Mutex::new(Some(late_sink)));
let late_sink_for_fn = late_sink_holder.clone();
let trigger = rt.state(Some(TestValue::Int(0)));
let _d = rt.derived(&[trigger.id], move |deps| {
if let TestValue::Int(n) = deps[0] {
if n > 0 {
let h = binding.intern(TestValue::Int(n * 100));
core.emit(s_id, h);
if let Some(sink) = late_sink_for_fn.lock().unwrap().take() {
let sub = core.subscribe(s_id, sink);
if let Some(rec) = late_recorder_for_fn.lock().unwrap().as_ref() {
rec.attach(sub);
}
}
}
}
None
});
let _rec_d = rt.subscribe_recorder(_d);
trigger.set(TestValue::Int(7));
let existing_events = rec_existing.snapshot();
let existing_data: Vec<i64> = existing_events
.iter()
.filter_map(|e| match e {
RecordedEvent::Data(TestValue::Int(n)) => Some(*n),
_ => None,
})
.collect();
assert_eq!(
existing_data,
vec![0, 700],
"existing subscriber: handshake Data(0) + wave Data(700)"
);
let late = late_recorder.lock().unwrap();
let late_rec = late.as_ref().expect("late recorder");
let late_events = late_rec.snapshot();
assert_eq!(
late_events,
vec![
RecordedEvent::Start,
RecordedEvent::Data(TestValue::Int(700)),
],
"late subscriber duplicated Data: {late_events:?}"
);
assert_eq!(
late_rec.call_count(),
2,
"late subscriber: handshake split into [Start] + [Data(700)] = 2 calls"
);
}
#[test]
fn concurrent_emit_blocks_until_in_flight_wave_completes() {
let rt = TestRuntime::new();
let s_a = rt.state(None);
let s_b = rt.state(Some(TestValue::Int(0)));
let (tx, rx) = std::sync::mpsc::channel::<()>();
let rx = Arc::new(Mutex::new(Some(rx)));
let rx_for_fn = rx.clone();
let fn_entered = Arc::new(AtomicU64::new(0));
let fn_entered_for_fn = fn_entered.clone();
let _d = rt.derived(&[s_a.id], move |deps| {
fn_entered_for_fn.fetch_add(1, Ordering::SeqCst);
let recv = rx_for_fn.lock().unwrap().take();
if let Some(rx) = recv {
let _ = rx.recv();
}
Some(deps[0].clone())
});
let _rec_d = rt.subscribe_recorder(_d);
let core_a = rt.core.clone();
let s_a_id = s_a.id;
let binding_a = rt.binding.clone();
let thread_a = thread::spawn(move || {
let h = binding_a.intern(TestValue::Int(1));
core_a.emit(s_a_id, h);
});
let mut waited_ms = 0u64;
while fn_entered.load(Ordering::SeqCst) == 0 {
thread::sleep(Duration::from_millis(1));
waited_ms += 1;
assert!(
waited_ms < 5_000,
"thread A's fn never entered — emit may have deadlocked at wave_owner"
);
}
let core_b = rt.core.clone();
let s_b_id = s_b.id;
let binding_b = rt.binding.clone();
let thread_b = thread::spawn(move || {
let h = binding_b.intern(TestValue::Int(2));
core_b.emit(s_b_id, h);
});
thread::sleep(Duration::from_millis(100));
assert!(
!thread_b.is_finished(),
"Thread B's cross-thread emit should be blocked on wave_owner \
while Thread A's wave is in flight; instead it finished early"
);
tx.send(()).expect("send to unblock fn");
let join_with_timeout = |handle: thread::JoinHandle<()>, secs: u64| {
let start = std::time::Instant::now();
loop {
if handle.is_finished() {
handle.join().expect("thread panicked");
return;
}
if start.elapsed().as_secs() > secs {
panic!("thread did not finish within {secs}s — likely deadlock");
}
thread::sleep(Duration::from_millis(5));
}
};
join_with_timeout(thread_a, 5);
join_with_timeout(thread_b, 5);
}
#[test]
fn lock_released_refactor_does_not_leak_handles_under_basic_emit() {
let rt = TestRuntime::new();
let s = rt.state(Some(TestValue::Int(0)));
let d = rt.derived(&[s.id], |deps| Some(deps[0].clone()));
let _rec_d = rt.subscribe_recorder(d);
for i in 1..=10 {
s.set(TestValue::Int(i));
}
drop(_rec_d);
drop(s);
let live_now = rt.binding.live_handles();
assert!(
live_now <= 2,
"expected <= 2 live handles after drop (derived cache + maybe a transient), got {}",
live_now
);
drop(rt);
}
struct ReentrantBinding {
inner: Arc<TestBinding>,
core_slot: Mutex<Option<Core>>,
probe_node: Mutex<Option<NodeId>>,
}
impl ReentrantBinding {
fn new() -> Arc<Self> {
Arc::new(Self {
inner: TestBinding::new(),
core_slot: Mutex::new(None),
probe_node: Mutex::new(None),
})
}
}
impl BindingBoundary for ReentrantBinding {
fn invoke_fn(&self, node_id: NodeId, fn_id: FnId, dep_data: &[DepBatch]) -> FnResult {
if let (Some(core), Some(probe)) = (
self.core_slot.lock().unwrap().as_ref(),
*self.probe_node.lock().unwrap(),
) {
let _ = core.cache_of(probe);
}
self.inner.invoke_fn(node_id, fn_id, dep_data)
}
fn custom_equals(&self, equals_handle: FnId, a: HandleId, b: HandleId) -> bool {
self.inner.custom_equals(equals_handle, a, b)
}
fn release_handle(&self, handle: HandleId) {
self.inner.release_handle(handle);
}
fn retain_handle(&self, handle: HandleId) {
self.inner.retain_handle(handle);
}
}
#[test]
fn invoke_fn_can_call_core_cache_of_directly() {
let binding = ReentrantBinding::new();
let core = Core::new(binding.clone() as Arc<dyn BindingBoundary>);
*binding.core_slot.lock().unwrap() = Some(core.clone());
let probe_h = binding.inner.intern(TestValue::Int(42));
let probe_id = core.register_state(probe_h, false).unwrap();
*binding.probe_node.lock().unwrap() = Some(probe_id);
let in_h = binding.inner.intern(TestValue::Int(1));
let in_id = core.register_state(in_h, false).unwrap();
let fn_id = binding.inner.register_fn(|deps| Some(deps[0].clone()));
let d = core
.register_derived(&[in_id], fn_id, EqualsMode::Identity, false)
.unwrap();
let rec_events: Arc<Mutex<Vec<Message>>> = Arc::new(Mutex::new(Vec::new()));
let rec_events_for_sink = rec_events.clone();
let sink: Sink = Arc::new(move |msgs: &[Message]| {
rec_events_for_sink.lock().unwrap().extend_from_slice(msgs);
});
let _sub = core.subscribe(d, sink);
let events = rec_events.lock().unwrap();
assert!(events.iter().any(|m| matches!(m, Message::Start)));
assert!(events.iter().any(|m| matches!(m, Message::Data(_))));
}