use crate::reactive::{ReactiveGraph, Signal, SignalId, State};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, OnceLock, RwLock};
static CONTEXT_STATE: OnceLock<BlincContextState> = OnceLock::new();
pub type SharedReactiveGraph = Arc<Mutex<ReactiveGraph>>;
pub type DirtyFlag = Arc<AtomicBool>;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub struct StateKey {
key_hash: u64,
type_id: TypeId,
}
impl StateKey {
pub fn new<T: 'static, K: Hash>(key: &K) -> Self {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
key.hash(&mut hasher);
Self {
key_hash: hasher.finish(),
type_id: TypeId::of::<T>(),
}
}
pub fn from_string<T: 'static>(key: &str) -> Self {
Self::new::<T, _>(&key)
}
}
pub struct HookState {
signals: HashMap<StateKey, u64>,
}
impl Default for HookState {
fn default() -> Self {
Self::new()
}
}
impl HookState {
pub fn new() -> Self {
Self {
signals: HashMap::new(),
}
}
pub fn get(&self, key: &StateKey) -> Option<u64> {
self.signals.get(key).copied()
}
pub fn insert(&mut self, key: StateKey, signal_id: u64) {
self.signals.insert(key, signal_id);
}
}
pub type SharedHookState = Arc<Mutex<HookState>>;
pub type StatefulCallback = Arc<dyn Fn(&[SignalId]) + Send + Sync>;
pub type QueryCallback = Arc<dyn Fn(&str) -> Option<u64> + Send + Sync>;
#[derive(Clone, Copy, Debug, Default, PartialEq)]
pub struct Bounds {
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
}
impl Bounds {
pub fn new(x: f32, y: f32, width: f32, height: f32) -> Self {
Self {
x,
y,
width,
height,
}
}
pub fn contains(&self, px: f32, py: f32) -> bool {
px >= self.x && px < self.x + self.width && py >= self.y && py < self.y + self.height
}
pub fn intersects(&self, other: &Bounds) -> bool {
self.x < other.x + other.width
&& self.x + self.width > other.x
&& self.y < other.y + other.height
&& self.y + self.height > other.y
}
}
pub type BoundsCallback = Arc<dyn Fn(&str) -> Option<Bounds> + Send + Sync>;
pub type FocusCallback = Arc<dyn Fn(Option<&str>) + Send + Sync>;
pub type ScrollCallback = Arc<dyn Fn(&str) + Send + Sync>;
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum MotionAnimationState {
Suspended,
Waiting,
Entering {
progress: f32,
},
Visible,
Exiting {
progress: f32,
},
Removed,
NotFound,
}
impl MotionAnimationState {
pub fn is_animating(&self) -> bool {
matches!(
self,
MotionAnimationState::Suspended
| MotionAnimationState::Waiting
| MotionAnimationState::Entering { .. }
| MotionAnimationState::Exiting { .. }
)
}
pub fn is_settled(&self) -> bool {
matches!(self, MotionAnimationState::Visible)
}
pub fn is_suspended(&self) -> bool {
matches!(self, MotionAnimationState::Suspended)
}
pub fn is_entering(&self) -> bool {
matches!(self, MotionAnimationState::Entering { .. })
}
pub fn is_exiting(&self) -> bool {
matches!(self, MotionAnimationState::Exiting { .. })
}
pub fn progress(&self) -> f32 {
match self {
MotionAnimationState::Suspended => 0.0,
MotionAnimationState::Waiting => 0.0,
MotionAnimationState::Entering { progress } => *progress,
MotionAnimationState::Visible => 1.0,
MotionAnimationState::Exiting { progress } => *progress,
MotionAnimationState::Removed => 1.0,
MotionAnimationState::NotFound => 1.0, }
}
}
pub type MotionStateCallback = Arc<dyn Fn(&str) -> MotionAnimationState + Send + Sync>;
pub type MotionCancelExitCallback = Arc<dyn Fn(&str) + Send + Sync>;
pub type RecordedEventAny = Box<dyn Any + Send>;
pub type RecorderEventCallback = Arc<dyn Fn(RecordedEventAny) + Send + Sync>;
pub type TreeSnapshotAny = Box<dyn Any + Send>;
pub type RecorderSnapshotCallback = Arc<dyn Fn(TreeSnapshotAny) + Send + Sync>;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum UpdateCategory {
Visual,
Layout,
Structural,
}
pub type RecorderUpdateCallback = Arc<dyn Fn(&str, UpdateCategory) + Send + Sync>;
pub type AnyElementRegistry = Arc<dyn Any + Send + Sync>;
pub struct BlincContextState {
reactive: SharedReactiveGraph,
hooks: SharedHookState,
dirty_flag: DirtyFlag,
stateful_callback: Option<StatefulCallback>,
query_callback: RwLock<Option<QueryCallback>>,
bounds_callback: RwLock<Option<BoundsCallback>>,
focus_callback: RwLock<Option<FocusCallback>>,
scroll_callback: RwLock<Option<ScrollCallback>>,
viewport_size: RwLock<(f32, f32)>,
focused_element: RwLock<Option<String>>,
element_registry: RwLock<Option<AnyElementRegistry>>,
motion_state_callback: RwLock<Option<MotionStateCallback>>,
motion_cancel_exit_callback: RwLock<Option<MotionCancelExitCallback>>,
recorder_event_callback: RwLock<Option<RecorderEventCallback>>,
recorder_snapshot_callback: RwLock<Option<RecorderSnapshotCallback>>,
recorder_update_callback: RwLock<Option<RecorderUpdateCallback>>,
pending_custom_passes: Mutex<Vec<Box<dyn std::any::Any + Send>>>,
}
impl BlincContextState {
pub fn init(reactive: SharedReactiveGraph, hooks: SharedHookState, dirty_flag: DirtyFlag) {
let state = BlincContextState {
reactive,
hooks,
dirty_flag,
stateful_callback: None,
query_callback: RwLock::new(None),
bounds_callback: RwLock::new(None),
focus_callback: RwLock::new(None),
scroll_callback: RwLock::new(None),
viewport_size: RwLock::new((0.0, 0.0)),
focused_element: RwLock::new(None),
element_registry: RwLock::new(None),
motion_state_callback: RwLock::new(None),
motion_cancel_exit_callback: RwLock::new(None),
recorder_event_callback: RwLock::new(None),
recorder_snapshot_callback: RwLock::new(None),
recorder_update_callback: RwLock::new(None),
pending_custom_passes: Mutex::new(Vec::new()),
};
if CONTEXT_STATE.set(state).is_err() {
panic!("BlincContextState::init() called more than once");
}
}
pub fn init_with_callback(
reactive: SharedReactiveGraph,
hooks: SharedHookState,
dirty_flag: DirtyFlag,
callback: StatefulCallback,
) {
let state = BlincContextState {
reactive,
hooks,
dirty_flag,
stateful_callback: Some(callback),
query_callback: RwLock::new(None),
bounds_callback: RwLock::new(None),
focus_callback: RwLock::new(None),
scroll_callback: RwLock::new(None),
viewport_size: RwLock::new((0.0, 0.0)),
focused_element: RwLock::new(None),
element_registry: RwLock::new(None),
motion_state_callback: RwLock::new(None),
motion_cancel_exit_callback: RwLock::new(None),
recorder_event_callback: RwLock::new(None),
recorder_snapshot_callback: RwLock::new(None),
recorder_update_callback: RwLock::new(None),
pending_custom_passes: Mutex::new(Vec::new()),
};
if CONTEXT_STATE.set(state).is_err() {
panic!("BlincContextState::init() called more than once");
}
}
pub fn get() -> &'static BlincContextState {
CONTEXT_STATE.get().expect(
"BlincContextState not initialized. Call BlincContextState::init() at app startup.",
)
}
pub fn try_get() -> Option<&'static BlincContextState> {
CONTEXT_STATE.get()
}
pub fn is_initialized() -> bool {
CONTEXT_STATE.get().is_some()
}
pub fn use_state_keyed<T, F>(&self, key: &str, init: F) -> State<T>
where
T: Clone + Send + 'static,
F: FnOnce() -> T,
{
let state_key = StateKey::from_string::<T>(key);
let mut hooks = self.hooks.lock().unwrap();
let signal = if let Some(raw_id) = hooks.get(&state_key) {
let signal_id = SignalId::from_raw(raw_id);
Signal::from_id(signal_id)
} else {
let signal = self.reactive.lock().unwrap().create_signal(init());
let raw_id = signal.id().to_raw();
hooks.insert(state_key, raw_id);
signal
};
if let Some(ref callback) = self.stateful_callback {
State::with_stateful_callback(
signal,
Arc::clone(&self.reactive),
Arc::clone(&self.dirty_flag),
Arc::clone(callback),
)
} else {
State::new(
signal,
Arc::clone(&self.reactive),
Arc::clone(&self.dirty_flag),
)
}
}
pub fn use_signal_keyed<T, F>(&self, key: &str, init: F) -> Signal<T>
where
T: Clone + Send + 'static,
F: FnOnce() -> T,
{
let state_key = StateKey::from_string::<T>(key);
let mut hooks = self.hooks.lock().unwrap();
if let Some(raw_id) = hooks.get(&state_key) {
let signal_id = SignalId::from_raw(raw_id);
Signal::from_id(signal_id)
} else {
let signal = self.reactive.lock().unwrap().create_signal(init());
let raw_id = signal.id().to_raw();
hooks.insert(state_key, raw_id);
signal
}
}
pub fn use_signal<T: Send + 'static>(&self, initial: T) -> Signal<T> {
self.reactive.lock().unwrap().create_signal(initial)
}
pub fn get_signal<T: Clone + 'static>(&self, signal: Signal<T>) -> Option<T> {
self.reactive.lock().unwrap().get(signal)
}
pub fn set_signal<T: Send + 'static>(&self, signal: Signal<T>, value: T) {
self.reactive.lock().unwrap().set(signal, value);
}
pub fn update<T: Clone + Send + 'static, F: FnOnce(T) -> T>(&self, signal: Signal<T>, f: F) {
let mut graph = self.reactive.lock().unwrap();
if let Some(current) = graph.get(signal) {
graph.set(signal, f(current));
}
}
pub fn reactive(&self) -> &SharedReactiveGraph {
&self.reactive
}
pub fn hooks(&self) -> &SharedHookState {
&self.hooks
}
pub fn dirty_flag(&self) -> &DirtyFlag {
&self.dirty_flag
}
pub fn request_rebuild(&self) {
self.dirty_flag.store(true, Ordering::SeqCst);
}
pub fn notify_stateful_deps(&self, signal_ids: &[SignalId]) {
if let Some(ref callback) = self.stateful_callback {
callback(signal_ids);
}
}
pub fn set_query_callback(&self, callback: QueryCallback) {
*self.query_callback.write().unwrap() = Some(callback);
}
pub fn query(&self, id: &str) -> Option<u64> {
self.query_callback
.read()
.unwrap()
.as_ref()
.and_then(|cb| cb(id))
}
pub fn set_bounds_callback(&self, callback: BoundsCallback) {
*self.bounds_callback.write().unwrap() = Some(callback);
}
pub fn get_bounds(&self, id: &str) -> Option<Bounds> {
self.bounds_callback
.read()
.unwrap()
.as_ref()
.and_then(|cb| cb(id))
}
pub fn set_viewport_size(&self, width: f32, height: f32) {
*self.viewport_size.write().unwrap() = (width, height);
}
pub fn viewport_size(&self) -> (f32, f32) {
*self.viewport_size.read().unwrap()
}
pub fn set_focus_callback(&self, callback: FocusCallback) {
*self.focus_callback.write().unwrap() = Some(callback);
}
pub fn set_focus(&self, id: Option<&str>) {
*self.focused_element.write().unwrap() = id.map(|s| s.to_string());
if let Some(cb) = self.focus_callback.read().unwrap().as_ref() {
cb(id);
}
}
pub fn focused_element(&self) -> Option<String> {
self.focused_element.read().unwrap().clone()
}
pub fn is_focused(&self, id: &str) -> bool {
self.focused_element.read().unwrap().as_deref() == Some(id)
}
pub fn set_scroll_callback(&self, callback: ScrollCallback) {
*self.scroll_callback.write().unwrap() = Some(callback);
}
pub fn scroll_element_into_view(&self, id: &str) {
if let Some(cb) = self.scroll_callback.read().unwrap().as_ref() {
cb(id);
}
}
pub fn set_element_registry(&self, registry: AnyElementRegistry) {
*self.element_registry.write().unwrap() = Some(registry);
}
pub fn element_registry_any(&self) -> Option<AnyElementRegistry> {
self.element_registry.read().unwrap().clone()
}
pub fn element_registry<T: Send + Sync + 'static>(&self) -> Option<Arc<T>> {
self.element_registry
.read()
.unwrap()
.as_ref()
.and_then(|r| r.clone().downcast::<T>().ok())
}
pub fn set_motion_state_callback(&self, callback: MotionStateCallback) {
*self.motion_state_callback.write().unwrap() = Some(callback);
}
pub fn query_motion(&self, key: &str) -> MotionAnimationState {
self.motion_state_callback
.read()
.unwrap()
.as_ref()
.map(|cb| cb(key))
.unwrap_or(MotionAnimationState::NotFound)
}
pub fn set_motion_cancel_exit_callback(&self, callback: MotionCancelExitCallback) {
*self.motion_cancel_exit_callback.write().unwrap() = Some(callback);
}
pub fn cancel_motion_exit(&self, key: &str) {
if let Some(ref cb) = *self.motion_cancel_exit_callback.read().unwrap() {
cb(key);
}
}
pub fn set_recorder_event_callback(&self, callback: RecorderEventCallback) {
*self.recorder_event_callback.write().unwrap() = Some(callback);
}
pub fn clear_recorder_event_callback(&self) {
*self.recorder_event_callback.write().unwrap() = None;
}
pub fn record_event(&self, event: RecordedEventAny) {
if let Some(ref cb) = *self.recorder_event_callback.read().unwrap() {
cb(event);
}
}
pub fn is_recording_events(&self) -> bool {
self.recorder_event_callback.read().unwrap().is_some()
}
pub fn set_recorder_snapshot_callback(&self, callback: RecorderSnapshotCallback) {
*self.recorder_snapshot_callback.write().unwrap() = Some(callback);
}
pub fn clear_recorder_snapshot_callback(&self) {
*self.recorder_snapshot_callback.write().unwrap() = None;
}
pub fn record_snapshot(&self, snapshot: TreeSnapshotAny) {
if let Some(ref cb) = *self.recorder_snapshot_callback.read().unwrap() {
cb(snapshot);
}
}
pub fn is_recording_snapshots(&self) -> bool {
self.recorder_snapshot_callback.read().unwrap().is_some()
}
pub fn set_recorder_update_callback(&self, callback: RecorderUpdateCallback) {
*self.recorder_update_callback.write().unwrap() = Some(callback);
}
pub fn clear_recorder_update_callback(&self) {
*self.recorder_update_callback.write().unwrap() = None;
}
pub fn record_update(&self, element_id: &str, category: UpdateCategory) {
if let Some(ref cb) = *self.recorder_update_callback.read().unwrap() {
cb(element_id, category);
}
}
pub fn is_recording_updates(&self) -> bool {
self.recorder_update_callback.read().unwrap().is_some()
}
pub fn register_custom_pass(&self, pass: Box<dyn std::any::Any + Send>) {
self.pending_custom_passes.lock().unwrap().push(pass);
}
#[cfg(target_arch = "wasm32")]
pub fn register_custom_pass_nosend(&self, pass: Box<dyn std::any::Any>) {
struct WasmSendShim(Box<dyn std::any::Any>);
unsafe impl Send for WasmSendShim {}
self.pending_custom_passes
.lock()
.unwrap()
.push(Box::new(WasmSendShim(pass)));
}
pub fn drain_custom_passes(&self) -> Vec<Box<dyn std::any::Any>> {
let raw = std::mem::take(&mut *self.pending_custom_passes.lock().unwrap());
raw.into_iter()
.map(|b| {
b as Box<dyn std::any::Any>
})
.collect()
}
pub fn get_or_create_persisted<T, F>(&self, key: &str, create: F) -> (SignalId, T)
where
T: Clone + Send + 'static,
F: FnOnce() -> T,
{
let state_key = StateKey::from_string::<T>(key);
let mut hooks = self.hooks.lock().unwrap();
if let Some(raw_id) = hooks.get(&state_key) {
let signal_id = SignalId::from_raw(raw_id);
let value = self
.reactive
.lock()
.unwrap()
.get_untracked(Signal::<T>::from_id(signal_id))
.unwrap_or_else(create);
(signal_id, value)
} else {
let new_value = create();
let signal = self
.reactive
.lock()
.unwrap()
.create_signal(new_value.clone());
let raw_id = signal.id().to_raw();
hooks.insert(state_key, raw_id);
(signal.id(), new_value)
}
}
}
pub fn use_state_keyed<T, F>(key: &str, init: F) -> State<T>
where
T: Clone + Send + 'static,
F: FnOnce() -> T,
{
BlincContextState::get().use_state_keyed(key, init)
}
pub fn use_signal_keyed<T, F>(key: &str, init: F) -> Signal<T>
where
T: Clone + Send + 'static,
F: FnOnce() -> T,
{
BlincContextState::get().use_signal_keyed(key, init)
}
pub fn request_rebuild() {
BlincContextState::get().request_rebuild();
}
pub fn query(id: &str) -> Option<u64> {
BlincContextState::get().query(id)
}
pub fn query_motion(key: &str) -> MotionAnimationState {
BlincContextState::get().query_motion(key)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_state_key() {
let key1 = StateKey::from_string::<i32>("counter");
let key2 = StateKey::from_string::<i32>("counter");
let key3 = StateKey::from_string::<String>("counter");
assert_eq!(key1, key2);
assert_ne!(key1, key3); }
#[test]
fn test_hook_state() {
let mut hooks = HookState::new();
let key = StateKey::from_string::<i32>("test");
assert!(hooks.get(&key).is_none());
hooks.insert(key.clone(), 42);
assert_eq!(hooks.get(&key), Some(42));
}
}