#![allow(dead_code)]
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use graphrefly_core::{
BindingBoundary, CleanupTrigger, Core, DepBatch, EqualsMode, FnId, FnResult, HandleId, Message,
NodeId, Sink, Subscription,
};
#[derive(Clone, Debug)]
pub enum TestValue {
Int(i64),
Str(String),
Object(Arc<TestObject>),
Null,
}
#[derive(Debug)]
pub struct TestObject {
pub label: String,
pub x: i64,
}
impl TestValue {
fn primitive_key(&self) -> Option<PrimitiveKey> {
match self {
Self::Int(n) => Some(PrimitiveKey::Int(*n)),
Self::Str(s) => Some(PrimitiveKey::Str(s.clone())),
Self::Null => Some(PrimitiveKey::Null),
Self::Object(_) => None,
}
}
fn object_ptr(&self) -> Option<usize> {
match self {
Self::Object(arc) => Some(Arc::as_ptr(arc) as usize),
_ => None,
}
}
}
#[derive(Clone, Eq, PartialEq, Hash)]
enum PrimitiveKey {
Int(i64),
Str(String),
Null,
}
type FnImpl = Arc<dyn Fn(&[TestValue]) -> Option<TestValue> + Send + Sync>;
type DynamicFnImpl =
Arc<dyn Fn(&[TestValue]) -> (Option<TestValue>, Option<Vec<usize>>) + Send + Sync>;
type RawFnImpl = Arc<dyn Fn(&[DepBatch]) -> FnResult + Send + Sync>;
type EqualsImpl = Arc<dyn Fn(&TestValue, &TestValue) -> bool + Send + Sync>;
pub type CleanupClosure = Arc<dyn Fn() + Send + Sync>;
#[derive(Clone, Default)]
pub struct TestNodeFnCleanup {
pub on_rerun: Option<CleanupClosure>,
pub on_deactivation: Option<CleanupClosure>,
pub on_invalidate: Option<CleanupClosure>,
}
#[derive(Default)]
pub struct NodeCtxState {
pub store: HashMap<String, TestValue>,
pub current_cleanup: Option<TestNodeFnCleanup>,
}
struct RegistryInner {
next_handle: u64,
next_fn_id: u64,
values: HashMap<HandleId, TestValue>,
refcounts: HashMap<HandleId, u64>,
primitive_index: HashMap<PrimitiveKey, HandleId>,
object_index: HashMap<usize, HandleId>,
fns: HashMap<FnId, FnImpl>,
dynamic_fns: HashMap<FnId, DynamicFnImpl>,
raw_fns: HashMap<FnId, RawFnImpl>,
custom_equals: HashMap<FnId, EqualsImpl>,
}
pub struct TestBinding {
inner: Mutex<RegistryInner>,
node_ctx: Mutex<HashMap<NodeId, NodeCtxState>>,
cleanup_calls: Mutex<Vec<(NodeId, CleanupTrigger)>>,
wipe_calls: Mutex<Vec<NodeId>>,
propagate_cleanup_panics: Mutex<bool>,
cleanup_panics: Mutex<Vec<(NodeId, CleanupTrigger)>>,
}
impl TestBinding {
pub fn new() -> Arc<Self> {
Arc::new(Self {
inner: Mutex::new(RegistryInner {
next_handle: 1,
next_fn_id: 1,
values: HashMap::new(),
refcounts: HashMap::new(),
primitive_index: HashMap::new(),
object_index: HashMap::new(),
fns: HashMap::new(),
dynamic_fns: HashMap::new(),
raw_fns: HashMap::new(),
custom_equals: HashMap::new(),
}),
node_ctx: Mutex::new(HashMap::new()),
cleanup_calls: Mutex::new(Vec::new()),
wipe_calls: Mutex::new(Vec::new()),
propagate_cleanup_panics: Mutex::new(false),
cleanup_panics: Mutex::new(Vec::new()),
})
}
pub fn intern(&self, value: TestValue) -> HandleId {
let mut inner = self.inner.lock().expect("registry lock");
let existing = if let Some(key) = value.primitive_key() {
inner.primitive_index.get(&key).copied()
} else if let Some(ptr) = value.object_ptr() {
inner.object_index.get(&ptr).copied()
} else {
None
};
if let Some(h) = existing {
*inner.refcounts.entry(h).or_insert(0) += 1;
return h;
}
let h = HandleId::new(inner.next_handle);
inner.next_handle += 1;
if let Some(key) = value.primitive_key() {
inner.primitive_index.insert(key, h);
}
if let Some(ptr) = value.object_ptr() {
inner.object_index.insert(ptr, h);
}
inner.values.insert(h, value);
inner.refcounts.insert(h, 1);
h
}
pub fn deref(&self, handle: HandleId) -> TestValue {
self.inner
.lock()
.expect("registry lock")
.values
.get(&handle)
.cloned()
.unwrap_or_else(|| panic!("handle {handle:?} not in registry"))
}
pub fn live_handles(&self) -> usize {
self.inner
.lock()
.expect("registry lock")
.refcounts
.values()
.filter(|&&n| n > 0)
.count()
}
pub fn refcount_of(&self, handle: HandleId) -> u64 {
self.inner
.lock()
.expect("registry lock")
.refcounts
.get(&handle)
.copied()
.unwrap_or(0)
}
pub fn register_fn<F>(&self, f: F) -> FnId
where
F: Fn(&[TestValue]) -> Option<TestValue> + Send + Sync + 'static,
{
let mut inner = self.inner.lock().expect("registry lock");
let id = FnId::new(inner.next_fn_id);
inner.next_fn_id += 1;
inner.fns.insert(id, Arc::new(f));
id
}
pub fn register_dynamic_fn<F>(&self, f: F) -> FnId
where
F: Fn(&[TestValue]) -> (Option<TestValue>, Option<Vec<usize>>) + Send + Sync + 'static,
{
let mut inner = self.inner.lock().expect("registry lock");
let id = FnId::new(inner.next_fn_id);
inner.next_fn_id += 1;
inner.dynamic_fns.insert(id, Arc::new(f));
id
}
pub fn register_custom_equals<F>(&self, f: F) -> FnId
where
F: Fn(&TestValue, &TestValue) -> bool + Send + Sync + 'static,
{
let mut inner = self.inner.lock().expect("registry lock");
let id = FnId::new(inner.next_fn_id);
inner.next_fn_id += 1;
inner.custom_equals.insert(id, Arc::new(f));
id
}
pub fn register_raw_fn<F>(&self, f: F) -> FnId
where
F: Fn(&[DepBatch]) -> FnResult + Send + Sync + 'static,
{
let mut inner = self.inner.lock().expect("registry lock");
let id = FnId::new(inner.next_fn_id);
inner.next_fn_id += 1;
inner.raw_fns.insert(id, Arc::new(f));
id
}
pub fn register_cleanup(&self, node_id: NodeId, spec: TestNodeFnCleanup) {
let mut ctx = self.node_ctx.lock().expect("node_ctx lock");
let entry = ctx.entry(node_id).or_default();
entry.current_cleanup = Some(spec);
}
pub fn store_get(&self, node_id: NodeId, key: &str) -> Option<TestValue> {
self.node_ctx
.lock()
.expect("node_ctx lock")
.get(&node_id)
.and_then(|s| s.store.get(key).cloned())
}
pub fn store_set(&self, node_id: NodeId, key: &str, value: TestValue) {
self.node_ctx
.lock()
.expect("node_ctx lock")
.entry(node_id)
.or_default()
.store
.insert(key.to_string(), value);
}
pub fn has_ctx(&self, node_id: NodeId) -> bool {
self.node_ctx
.lock()
.expect("node_ctx lock")
.contains_key(&node_id)
}
pub fn cleanup_calls(&self) -> Vec<(NodeId, CleanupTrigger)> {
self.cleanup_calls
.lock()
.expect("cleanup_calls lock")
.clone()
}
pub fn cleanup_calls_for(&self, trigger: CleanupTrigger) -> Vec<NodeId> {
self.cleanup_calls()
.into_iter()
.filter_map(|(n, t)| (t == trigger).then_some(n))
.collect()
}
pub fn wipe_calls(&self) -> Vec<NodeId> {
self.wipe_calls.lock().expect("wipe_calls lock").clone()
}
pub fn set_propagate_cleanup_panics(&self, on: bool) {
*self
.propagate_cleanup_panics
.lock()
.expect("propagate_cleanup_panics lock") = on;
}
pub fn cleanup_panics(&self) -> Vec<(NodeId, CleanupTrigger)> {
self.cleanup_panics
.lock()
.expect("cleanup_panics lock")
.clone()
}
}
impl BindingBoundary for TestBinding {
fn invoke_fn(&self, _node_id: NodeId, fn_id: FnId, dep_data: &[DepBatch]) -> FnResult {
let raw_f = self
.inner
.lock()
.expect("registry lock")
.raw_fns
.get(&fn_id)
.cloned();
if let Some(f) = raw_f {
return f(dep_data);
}
let dep_values: Vec<TestValue> =
dep_data.iter().map(|db| self.deref(db.latest())).collect();
let dyn_f = self
.inner
.lock()
.expect("registry lock")
.dynamic_fns
.get(&fn_id)
.cloned();
if let Some(f) = dyn_f {
let (value, tracked) = f(&dep_values);
return match value {
Some(v) => {
let handle = self.intern(v);
FnResult::Data { handle, tracked }
}
None => FnResult::Noop { tracked },
};
}
let f = self
.inner
.lock()
.expect("registry lock")
.fns
.get(&fn_id)
.cloned()
.unwrap_or_else(|| panic!("unknown fn_id {fn_id:?}"));
match f(&dep_values) {
Some(value) => {
let handle = self.intern(value);
FnResult::Data {
handle,
tracked: None,
}
}
None => FnResult::Noop { tracked: None },
}
}
fn custom_equals(&self, equals_handle: FnId, a: HandleId, b: HandleId) -> bool {
let f = self
.inner
.lock()
.expect("registry lock")
.custom_equals
.get(&equals_handle)
.cloned()
.unwrap_or_else(|| panic!("unknown custom-equals fn_id {equals_handle:?}"));
let av = self.deref(a);
let bv = self.deref(b);
f(&av, &bv)
}
fn release_handle(&self, handle: HandleId) {
let mut inner = self.inner.lock().expect("registry lock");
let count = inner.refcounts.entry(handle).or_insert(0);
if *count > 0 {
*count -= 1;
}
if *count == 0 {
if let Some(value) = inner.values.remove(&handle) {
if let Some(key) = value.primitive_key() {
if inner.primitive_index.get(&key).copied() == Some(handle) {
inner.primitive_index.remove(&key);
}
}
if let Some(ptr) = value.object_ptr() {
if inner.object_index.get(&ptr).copied() == Some(handle) {
inner.object_index.remove(&ptr);
}
}
}
inner.refcounts.remove(&handle);
}
}
fn retain_handle(&self, handle: HandleId) {
let mut inner = self.inner.lock().expect("registry lock");
*inner.refcounts.entry(handle).or_insert(0) += 1;
}
fn cleanup_for(&self, node_id: NodeId, trigger: CleanupTrigger) {
self.cleanup_calls
.lock()
.expect("cleanup_calls lock")
.push((node_id, trigger));
let closure: Option<CleanupClosure> = {
let mut ctx = self.node_ctx.lock().expect("node_ctx lock");
let entry = ctx.get_mut(&node_id);
entry.and_then(|state| {
let cleanup = state.current_cleanup.as_ref()?;
let c = match trigger {
CleanupTrigger::OnRerun => cleanup.on_rerun.clone(),
CleanupTrigger::OnDeactivation => cleanup.on_deactivation.clone(),
CleanupTrigger::OnInvalidate => cleanup.on_invalidate.clone(),
};
if matches!(trigger, CleanupTrigger::OnDeactivation) {
state.current_cleanup = None;
}
c
})
};
let Some(closure) = closure else {
return;
};
let propagate = *self
.propagate_cleanup_panics
.lock()
.expect("propagate_cleanup_panics lock");
if propagate {
closure();
} else {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| closure()));
if result.is_err() {
self.cleanup_panics
.lock()
.expect("cleanup_panics lock")
.push((node_id, trigger));
}
}
}
fn wipe_ctx(&self, node_id: NodeId) {
self.wipe_calls
.lock()
.expect("wipe_calls lock")
.push(node_id);
self.node_ctx
.lock()
.expect("node_ctx lock")
.remove(&node_id);
}
}
pub struct TestRuntime {
pub binding: Arc<TestBinding>,
pub core: Core,
}
impl TestRuntime {
pub fn new() -> Self {
let binding = TestBinding::new();
let core = Core::new(binding.clone() as Arc<dyn BindingBoundary>);
Self { binding, core }
}
pub fn state(&self, initial: Option<TestValue>) -> StateHandle {
let initial_handle = match initial {
Some(v) => self.binding.intern(v),
None => HandleId::new(0), };
let id = self.core.register_state(initial_handle, false).unwrap();
StateHandle {
id,
binding: self.binding.clone(),
core: self.core.clone(),
}
}
pub fn derived<F>(&self, deps: &[NodeId], f: F) -> NodeId
where
F: Fn(&[TestValue]) -> Option<TestValue> + Send + Sync + 'static,
{
let fn_id = self.binding.register_fn(f);
self.core
.register_derived(deps, fn_id, EqualsMode::Identity, false)
.unwrap()
}
pub fn derived_with_equals<F, E>(&self, deps: &[NodeId], f: F, equals: E) -> NodeId
where
F: Fn(&[TestValue]) -> Option<TestValue> + Send + Sync + 'static,
E: Fn(&TestValue, &TestValue) -> bool + Send + Sync + 'static,
{
let fn_id = self.binding.register_fn(f);
let eq_id = self.binding.register_custom_equals(equals);
self.core
.register_derived(deps, fn_id, EqualsMode::Custom(eq_id), false)
.unwrap()
}
pub fn dynamic<F>(&self, deps: &[NodeId], f: F) -> NodeId
where
F: Fn(&[TestValue]) -> (Option<TestValue>, Option<Vec<usize>>) + Send + Sync + 'static,
{
let fn_id = self.binding.register_dynamic_fn(f);
self.core
.register_dynamic(deps, fn_id, EqualsMode::Identity, false)
.unwrap()
}
pub fn subscribe_recorder(&self, node_id: NodeId) -> Recorder {
let recorder = Recorder::new();
let sink: Sink = recorder.sink(self.binding.clone());
let sub = self.core.subscribe(node_id, sink);
recorder.attach(sub);
recorder
}
pub fn cache_value(&self, node_id: NodeId) -> Option<TestValue> {
let h = self.core.cache_of(node_id);
if h == HandleId::new(0) {
None
} else {
Some(self.binding.deref(h))
}
}
}
pub struct StateHandle {
pub id: NodeId,
pub binding: Arc<TestBinding>,
pub core: Core,
}
impl StateHandle {
pub fn set(&self, value: TestValue) {
let handle = self.binding.intern(value);
self.core.emit(self.id, handle);
}
pub fn current(&self) -> Option<TestValue> {
let h = self.core.cache_of(self.id);
if h == HandleId::new(0) {
None
} else {
Some(self.binding.deref(h))
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub enum RecordedEvent {
Start,
Dirty,
Data(TestValue),
Resolved,
Invalidate,
Pause(graphrefly_core::LockId),
Resume(graphrefly_core::LockId),
Complete,
Error(TestValue),
Teardown,
}
pub struct Recorder {
events: Arc<Mutex<Vec<RecordedEvent>>>,
call_boundaries: Arc<Mutex<Vec<usize>>>,
call_count: Arc<AtomicU64>,
subscription: Mutex<Option<Subscription>>,
}
impl Recorder {
pub fn new() -> Self {
Self {
events: Arc::new(Mutex::new(Vec::new())),
call_boundaries: Arc::new(Mutex::new(Vec::new())),
call_count: Arc::new(AtomicU64::new(0)),
subscription: Mutex::new(None),
}
}
pub fn sink(&self, binding: Arc<TestBinding>) -> Sink {
let events = self.events.clone();
let call_boundaries = self.call_boundaries.clone();
let call_count = self.call_count.clone();
Arc::new(move |msgs: &[Message]| {
let mut guard = events.lock().expect("recorder lock");
for msg in msgs {
let recorded = match msg {
Message::Start => RecordedEvent::Start,
Message::Dirty => RecordedEvent::Dirty,
Message::Resolved => RecordedEvent::Resolved,
Message::Data(h) => RecordedEvent::Data(binding.deref(*h)),
Message::Invalidate => RecordedEvent::Invalidate,
Message::Pause(l) => RecordedEvent::Pause(*l),
Message::Resume(l) => RecordedEvent::Resume(*l),
Message::Complete => RecordedEvent::Complete,
Message::Error(h) => RecordedEvent::Error(binding.deref(*h)),
Message::Teardown => RecordedEvent::Teardown,
};
guard.push(recorded);
}
call_boundaries
.lock()
.expect("recorder lock")
.push(msgs.len());
call_count.fetch_add(1, Ordering::SeqCst);
})
}
#[must_use]
pub fn call_count(&self) -> u64 {
self.call_count.load(Ordering::SeqCst)
}
#[must_use]
pub fn call_boundaries(&self) -> Vec<usize> {
self.call_boundaries.lock().expect("recorder lock").clone()
}
pub fn attach(&self, sub: Subscription) {
*self.subscription.lock().expect("recorder lock") = Some(sub);
}
pub fn unsubscribe(&self) {
*self.subscription.lock().expect("recorder lock") = None;
}
pub fn snapshot(&self) -> Vec<RecordedEvent> {
self.events.lock().expect("recorder lock").clone()
}
pub fn data_values(&self) -> Vec<TestValue> {
self.snapshot()
.into_iter()
.filter_map(|e| match e {
RecordedEvent::Data(v) => Some(v),
_ => None,
})
.collect()
}
pub fn count(&self, predicate: impl Fn(&RecordedEvent) -> bool) -> usize {
self.snapshot().iter().filter(|e| predicate(e)).count()
}
}
impl PartialEq for TestValue {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::Int(a), Self::Int(b)) => a == b,
(Self::Str(a), Self::Str(b)) => a == b,
(Self::Object(a), Self::Object(b)) => Arc::ptr_eq(a, b),
(Self::Null, Self::Null) => true,
_ => false,
}
}
}