use core::fmt;
#[cfg(target_arch = "wasm32")]
use core::cell::RefCell;
#[cfg(target_arch = "wasm32")]
extern crate alloc;
#[cfg(target_arch = "wasm32")]
use alloc::rc::Rc;
#[cfg(not(target_arch = "wasm32"))]
use std::sync::{Arc, RwLock};
use super::runtime::{NodeId, try_with_runtime, with_runtime};
#[cfg(not(target_arch = "wasm32"))]
type SignalInner<T> = Arc<RwLock<T>>;
#[cfg(target_arch = "wasm32")]
type SignalInner<T> = Rc<RefCell<T>>;
#[cfg(not(target_arch = "wasm32"))]
fn new_inner<T>(value: T) -> SignalInner<T> {
Arc::new(RwLock::new(value))
}
#[cfg(target_arch = "wasm32")]
fn new_inner<T>(value: T) -> SignalInner<T> {
Rc::new(RefCell::new(value))
}
#[cfg(not(target_arch = "wasm32"))]
fn strong_count<T>(inner: &SignalInner<T>) -> usize {
Arc::strong_count(inner)
}
#[cfg(target_arch = "wasm32")]
fn strong_count<T>(inner: &SignalInner<T>) -> usize {
Rc::strong_count(inner)
}
#[derive(Clone)]
pub struct Signal<T: 'static> {
id: NodeId,
value: SignalInner<T>,
}
impl<T: 'static> Signal<T> {
pub fn new(value: T) -> Self {
Self {
id: NodeId::new(),
value: new_inner(value),
}
}
pub fn get(&self) -> T
where
T: Clone,
{
with_runtime(|rt| rt.track_dependency(self.id));
self.get_untracked()
}
pub fn get_untracked(&self) -> T
where
T: Clone,
{
#[cfg(not(target_arch = "wasm32"))]
{
self.value.read().expect("Signal lock poisoned").clone()
}
#[cfg(target_arch = "wasm32")]
{
self.value.borrow().clone()
}
}
pub fn with_untracked<R>(&self, f: impl FnOnce(&T) -> R) -> R {
#[cfg(not(target_arch = "wasm32"))]
{
f(&self.value.read().expect("Signal lock poisoned"))
}
#[cfg(target_arch = "wasm32")]
{
f(&self.value.borrow())
}
}
pub fn set(&self, value: T) {
#[cfg(not(target_arch = "wasm32"))]
{
*self.value.write().expect("Signal lock poisoned") = value;
}
#[cfg(target_arch = "wasm32")]
{
*self.value.borrow_mut() = value;
}
with_runtime(|rt| rt.notify_signal_change(self.id));
}
pub fn update<F>(&self, f: F)
where
F: FnOnce(&mut T),
{
#[cfg(not(target_arch = "wasm32"))]
{
f(&mut self.value.write().expect("Signal lock poisoned"));
}
#[cfg(target_arch = "wasm32")]
{
f(&mut self.value.borrow_mut());
}
with_runtime(|rt| rt.notify_signal_change(self.id));
}
pub fn id(&self) -> NodeId {
self.id
}
}
impl<T: 'static> Drop for Signal<T> {
fn drop(&mut self) {
if strong_count(&self.value) == 1 {
let _ = try_with_runtime(|rt| rt.remove_node(self.id));
}
}
}
impl<T: fmt::Debug + Clone + 'static> fmt::Debug for Signal<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Signal")
.field("id", &self.id)
.field("value", &self.get_untracked())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::reactive::runtime::NodeType;
use serial_test::serial;
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn signal_is_send_sync_on_native() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<Signal<String>>();
assert_send_sync::<Signal<Option<String>>>();
assert_send_sync::<Signal<i32>>();
}
#[test]
#[serial]
fn test_signal_creation() {
let signal = Signal::new(42);
assert_eq!(signal.get_untracked(), 42);
}
#[test]
#[serial]
fn test_signal_set() {
let signal = Signal::new(0);
assert_eq!(signal.get_untracked(), 0);
signal.set(100);
assert_eq!(signal.get_untracked(), 100);
}
#[test]
#[serial]
fn test_signal_update() {
let signal = Signal::new(0);
signal.update(|n| *n += 1);
assert_eq!(signal.get_untracked(), 1);
signal.update(|n| *n *= 2);
assert_eq!(signal.get_untracked(), 2);
}
#[test]
#[serial]
fn test_signal_clone() {
let signal1 = Signal::new(42);
let signal2 = signal1.clone();
assert_eq!(signal1.get_untracked(), 42);
assert_eq!(signal2.get_untracked(), 42);
signal1.set(100);
assert_eq!(signal1.get_untracked(), 100);
assert_eq!(signal2.get_untracked(), 100);
}
#[test]
#[serial]
fn test_multiple_signals() {
let signal1 = Signal::new(10);
let signal2 = Signal::new(20);
let signal3 = Signal::new("hello");
assert_eq!(signal1.get_untracked(), 10);
assert_eq!(signal2.get_untracked(), 20);
assert_eq!(signal3.get_untracked(), "hello");
signal1.set(30);
signal2.set(40);
signal3.set("world");
assert_eq!(signal1.get_untracked(), 30);
assert_eq!(signal2.get_untracked(), 40);
assert_eq!(signal3.get_untracked(), "world");
}
#[test]
#[serial]
fn test_signal_dependency_tracking() {
let signal = Signal::new(42);
assert_eq!(signal.get(), 42);
with_runtime(|rt| {
let observer_id = NodeId::new();
rt.push_observer(crate::reactive::runtime::Observer {
id: observer_id,
node_type: NodeType::Effect,
timing: crate::reactive::runtime::EffectTiming::default(),
cleanup: None,
});
let _ = signal.get();
rt.pop_observer();
let graph = rt.dependency_graph.borrow();
let signal_node = graph.get(&signal.id()).unwrap();
assert!(signal_node.subscribers.contains(&observer_id));
});
}
#[test]
#[serial]
fn test_signal_change_notification() {
let signal = Signal::new(0);
with_runtime(|rt| {
let effect_id = NodeId::new();
{
let mut graph = rt.dependency_graph.borrow_mut();
graph
.entry(signal.id())
.or_default()
.subscribers
.push(effect_id);
}
signal.set(42);
let pending = rt.pending_updates.borrow();
assert!(pending.contains(&effect_id));
});
}
}