1use std::cell::RefCell;
2use std::rc::Rc;
3
4use crate::effect::flush_effects;
5use crate::runtime::RUNTIME;
6
7type SubscriberId = usize;
8
9struct SignalInner<T> {
10 value: T,
11 subscribers: Vec<SubscriberId>,
12 version: u64,
13}
14
15pub struct ReadSignal<T> {
16 inner: Rc<RefCell<SignalInner<T>>>,
17}
18
19impl<T> Clone for ReadSignal<T> {
20 fn clone(&self) -> Self {
21 Self {
22 inner: self.inner.clone(),
23 }
24 }
25}
26
27pub struct WriteSignal<T> {
28 inner: Rc<RefCell<SignalInner<T>>>,
29}
30
31impl<T> Clone for WriteSignal<T> {
32 fn clone(&self) -> Self {
33 Self {
34 inner: self.inner.clone(),
35 }
36 }
37}
38
39pub fn create_signal<T>(value: T) -> (ReadSignal<T>, WriteSignal<T>) {
41 let inner = Rc::new(RefCell::new(SignalInner {
42 value,
43 subscribers: Vec::new(),
44 version: 0,
45 }));
46
47 (
48 ReadSignal {
49 inner: inner.clone(),
50 },
51 WriteSignal { inner },
52 )
53}
54
55impl<T: Clone> ReadSignal<T> {
56 pub fn get(&self) -> T {
58 self.track();
59 self.inner.borrow().value.clone()
60 }
61
62 pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
63 self.track();
64 f(&self.inner.borrow().value)
65 }
66
67 pub fn get_untracked(&self) -> T {
68 self.inner.borrow().value.clone()
69 }
70
71 fn track(&self) {
72 RUNTIME.with(|rt| {
73 if let Some(effect_id) = rt.borrow().current_effect() {
74 let mut inner = self.inner.borrow_mut();
75 if !inner.subscribers.contains(&effect_id) {
76 inner.subscribers.push(effect_id);
77 }
78 }
79 });
80 }
81}
82
83impl<T> WriteSignal<T> {
84 pub fn set(&self, value: T) {
86 {
87 let mut inner = self.inner.borrow_mut();
88 inner.value = value;
89 inner.version += 1;
90 }
91 self.notify_subscribers();
92 }
93
94 pub fn update(&self, f: impl FnOnce(&mut T)) {
95 {
96 let mut inner = self.inner.borrow_mut();
97 f(&mut inner.value);
98 inner.version += 1;
99 }
100 self.notify_subscribers();
101 }
102
103 fn notify_subscribers(&self) {
104 let inner = self.inner.borrow();
105 let should_flush = RUNTIME.with(|rt| {
106 let mut rt = rt.borrow_mut();
107 for &subscriber_id in &inner.subscribers {
108 rt.schedule_effect(subscriber_id);
109 }
110 !rt.is_batching()
111 });
112 drop(inner);
113
114 if should_flush {
115 flush_effects();
116 }
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 #[test]
125 fn test_signal_read_write() {
126 let (read, write) = create_signal(0);
127 assert_eq!(read.get_untracked(), 0);
128 write.set(5);
129 assert_eq!(read.get_untracked(), 5);
130 }
131
132 #[test]
133 fn test_signal_update() {
134 let (read, write) = create_signal(vec![1, 2]);
135 write.update(|v| v.push(3));
136 assert_eq!(read.get_untracked(), vec![1, 2, 3]);
137 }
138
139 #[test]
140 fn test_signal_with() {
141 let (read, _write) = create_signal(String::from("hello"));
142 let len = read.with(|s| s.len());
143 assert_eq!(len, 5);
144 }
145}