Skip to main content

react_rs_core/
signal.rs

1use std::cell::RefCell;
2use std::rc::Rc;
3
4use crate::effect::flush_effects;
5use crate::runtime::RUNTIME;
6
7type SubscriberId = usize;
8
9struct SignalInner<T> {
10    value: T,
11    subscribers: Vec<SubscriberId>,
12    version: u64,
13}
14
15pub struct ReadSignal<T> {
16    inner: Rc<RefCell<SignalInner<T>>>,
17}
18
19impl<T> Clone for ReadSignal<T> {
20    fn clone(&self) -> Self {
21        Self {
22            inner: self.inner.clone(),
23        }
24    }
25}
26
27pub struct WriteSignal<T> {
28    inner: Rc<RefCell<SignalInner<T>>>,
29}
30
31impl<T> Clone for WriteSignal<T> {
32    fn clone(&self) -> Self {
33        Self {
34            inner: self.inner.clone(),
35        }
36    }
37}
38
39/// Creates a reactive signal with the given initial value. Returns a (read, write) pair.
40pub fn create_signal<T>(value: T) -> (ReadSignal<T>, WriteSignal<T>) {
41    let inner = Rc::new(RefCell::new(SignalInner {
42        value,
43        subscribers: Vec::new(),
44        version: 0,
45    }));
46
47    (
48        ReadSignal {
49            inner: inner.clone(),
50        },
51        WriteSignal { inner },
52    )
53}
54
55impl<T: Clone> ReadSignal<T> {
56    /// Reads the current value. Subscribes the current effect to this signal.
57    pub fn get(&self) -> T {
58        self.track();
59        self.inner.borrow().value.clone()
60    }
61
62    pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
63        self.track();
64        f(&self.inner.borrow().value)
65    }
66
67    pub fn get_untracked(&self) -> T {
68        self.inner.borrow().value.clone()
69    }
70
71    fn track(&self) {
72        RUNTIME.with(|rt| {
73            let rt_ref = rt.borrow();
74            if let Some(effect_id) = rt_ref.current_effect() {
75                if !rt_ref.is_effect_disposed(effect_id) {
76                    drop(rt_ref);
77                    let mut inner = self.inner.borrow_mut();
78                    if !inner.subscribers.contains(&effect_id) {
79                        inner.subscribers.push(effect_id);
80                    }
81                }
82            }
83        });
84    }
85}
86
87impl<T> WriteSignal<T> {
88    /// Replaces the signal value and notifies subscribers.
89    pub fn set(&self, value: T) {
90        {
91            let mut inner = self.inner.borrow_mut();
92            inner.value = value;
93            inner.version += 1;
94        }
95        self.notify_subscribers();
96    }
97
98    pub fn update(&self, f: impl FnOnce(&mut T)) {
99        {
100            let mut inner = self.inner.borrow_mut();
101            f(&mut inner.value);
102            inner.version += 1;
103        }
104        self.notify_subscribers();
105    }
106
107    pub fn set_if_changed(&self, value: T)
108    where
109        T: PartialEq,
110    {
111        let changed = {
112            let inner = self.inner.borrow();
113            inner.value != value
114        };
115        if changed {
116            self.set(value);
117        }
118    }
119
120    fn notify_subscribers(&self) {
121        let inner = self.inner.borrow();
122        let should_flush = RUNTIME.with(|rt| {
123            let mut rt = rt.borrow_mut();
124            for &subscriber_id in &inner.subscribers {
125                if !rt.is_effect_disposed(subscriber_id) {
126                    rt.schedule_effect(subscriber_id);
127                }
128            }
129            !rt.is_batching()
130        });
131        drop(inner);
132
133        if should_flush {
134            flush_effects();
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142
143    #[test]
144    fn test_signal_read_write() {
145        let (read, write) = create_signal(0);
146        assert_eq!(read.get_untracked(), 0);
147        write.set(5);
148        assert_eq!(read.get_untracked(), 5);
149    }
150
151    #[test]
152    fn test_signal_update() {
153        let (read, write) = create_signal(vec![1, 2]);
154        write.update(|v| v.push(3));
155        assert_eq!(read.get_untracked(), vec![1, 2, 3]);
156    }
157
158    #[test]
159    fn test_signal_with() {
160        let (read, _write) = create_signal(String::from("hello"));
161        let len = read.with(|s| s.len());
162        assert_eq!(len, 5);
163    }
164}