aura_effects/reactive/
graph.rs1use aura_core::effects::reactive::{ReactiveError, SignalId};
7use std::any::Any;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{broadcast, RwLock};
11
12#[derive(Clone)]
18pub struct AnyValue(pub(crate) Arc<dyn Any + Send + Sync>);
19
20struct SignalSlot {
25 value: AnyValue,
27 sender: broadcast::Sender<AnyValue>,
29 type_name: &'static str,
31}
32
33impl SignalSlot {
34 fn new<T: Clone + Send + Sync + 'static>(initial: T) -> Self {
36 let (sender, _) = broadcast::channel(256); Self {
38 value: AnyValue(Arc::new(initial)),
39 sender,
40 type_name: std::any::type_name::<T>(),
41 }
42 }
43
44 fn read<T: Clone + Send + Sync + 'static>(&self) -> Result<T, ReactiveError> {
46 self.value
47 .0
48 .downcast_ref::<T>()
49 .cloned()
50 .ok_or_else(|| ReactiveError::TypeMismatch {
51 id: "unknown".to_string(),
52 expected: std::any::type_name::<T>().to_string(),
53 actual: self.type_name.to_string(),
54 })
55 }
56
57 fn emit<T: Clone + Send + Sync + 'static>(&mut self, value: T) -> Result<(), ReactiveError> {
59 if self.type_name != std::any::type_name::<T>() {
61 return Err(ReactiveError::TypeMismatch {
62 id: "unknown".to_string(),
63 expected: self.type_name.to_string(),
64 actual: std::any::type_name::<T>().to_string(),
65 });
66 }
67
68 let wrapped = AnyValue(Arc::new(value));
70 self.value = wrapped.clone();
71
72 let _ = self.sender.send(wrapped);
74
75 Ok(())
76 }
77
78 fn subscribe(&self) -> broadcast::Receiver<AnyValue> {
80 self.sender.subscribe()
81 }
82}
83
84pub struct SignalGraph {
96 signals: RwLock<HashMap<SignalId, SignalSlot>>,
98}
99
100impl SignalGraph {
101 pub fn new() -> Self {
103 Self {
104 signals: RwLock::new(HashMap::new()),
105 }
106 }
107
108 pub async fn register<T: Clone + Send + Sync + 'static>(
110 &self,
111 id: SignalId,
112 initial: T,
113 ) -> Result<(), ReactiveError> {
114 let mut signals = self.signals.write().await;
115
116 if signals.contains_key(&id) {
117 return Err(ReactiveError::Internal {
118 reason: format!("Signal '{id}' already registered"),
119 });
120 }
121
122 signals.insert(id, SignalSlot::new(initial));
123 Ok(())
124 }
125
126 pub async fn is_registered(&self, id: &SignalId) -> bool {
128 self.signals.read().await.contains_key(id)
129 }
130
131 pub async fn read<T: Clone + Send + Sync + 'static>(
133 &self,
134 id: &SignalId,
135 ) -> Result<T, ReactiveError> {
136 let signals = self.signals.read().await;
137
138 let slot = signals
139 .get(id)
140 .ok_or_else(|| ReactiveError::SignalNotFound { id: id.to_string() })?;
141
142 slot.read::<T>().map_err(|e| match e {
143 ReactiveError::TypeMismatch {
144 expected, actual, ..
145 } => ReactiveError::TypeMismatch {
146 id: id.to_string(),
147 expected,
148 actual,
149 },
150 other => other,
151 })
152 }
153
154 pub async fn emit<T: Clone + Send + Sync + 'static>(
156 &self,
157 id: &SignalId,
158 value: T,
159 ) -> Result<(), ReactiveError> {
160 let mut signals = self.signals.write().await;
161
162 let slot = signals
163 .get_mut(id)
164 .ok_or_else(|| ReactiveError::SignalNotFound { id: id.to_string() })?;
165
166 slot.emit(value).map_err(|e| match e {
167 ReactiveError::TypeMismatch {
168 expected, actual, ..
169 } => ReactiveError::TypeMismatch {
170 id: id.to_string(),
171 expected,
172 actual,
173 },
174 other => other,
175 })
176 }
177
178 pub async fn subscribe(
183 &self,
184 id: &SignalId,
185 ) -> Result<broadcast::Receiver<AnyValue>, ReactiveError> {
186 let signals = self.signals.read().await;
187
188 let slot = signals
189 .get(id)
190 .ok_or_else(|| ReactiveError::SignalNotFound { id: id.to_string() })?;
191
192 Ok(slot.subscribe())
193 }
194
195 pub async fn stats(&self) -> SignalGraphStats {
197 let signals = self.signals.read().await;
198 SignalGraphStats {
199 signal_count: signals.len(),
200 }
201 }
202}
203
204impl Default for SignalGraph {
205 fn default() -> Self {
206 Self::new()
207 }
208}
209
210#[derive(Debug, Clone)]
212pub struct SignalGraphStats {
213 pub signal_count: usize,
215}
216
217pub struct TypedSignalReceiver<T> {
225 receiver: broadcast::Receiver<AnyValue>,
226 signal_id: SignalId,
227 _phantom: std::marker::PhantomData<T>,
228}
229
230impl<T: Clone + Send + Sync + 'static> TypedSignalReceiver<T> {
231 pub fn new(receiver: broadcast::Receiver<AnyValue>, signal_id: SignalId) -> Self {
233 Self {
234 receiver,
235 signal_id,
236 _phantom: std::marker::PhantomData,
237 }
238 }
239
240 pub fn try_recv(&mut self) -> Option<T> {
242 loop {
243 match self.receiver.try_recv() {
244 Ok(any_value) => {
245 if let Some(value) = any_value.0.downcast_ref::<T>() {
246 return Some(value.clone());
247 }
248 continue;
250 }
251 Err(_) => return None,
252 }
253 }
254 }
255
256 pub async fn recv(&mut self) -> Result<T, ReactiveError> {
258 loop {
259 match self.receiver.recv().await {
260 Ok(any_value) => {
261 if let Some(value) = any_value.0.downcast_ref::<T>() {
262 return Ok(value.clone());
263 }
264 continue;
266 }
267 Err(broadcast::error::RecvError::Closed) => {
268 return Err(ReactiveError::SubscriptionClosed {
269 id: self.signal_id.to_string(),
270 });
271 }
272 Err(broadcast::error::RecvError::Lagged(_)) => {
273 continue;
275 }
276 }
277 }
278 }
279}
280
281#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[tokio::test]
290 async fn test_signal_registration() {
291 let graph = SignalGraph::new();
292 let id = SignalId::new("test");
293
294 assert!(!graph.is_registered(&id).await);
295
296 graph.register(id.clone(), 42u32).await.unwrap();
297
298 assert!(graph.is_registered(&id).await);
299 }
300
301 #[tokio::test]
302 async fn test_signal_read_write() {
303 let graph = SignalGraph::new();
304 let id = SignalId::new("counter");
305
306 graph.register(id.clone(), 0u32).await.unwrap();
307
308 let value: u32 = graph.read(&id).await.unwrap();
310 assert_eq!(value, 0);
311
312 graph.emit(&id, 42u32).await.unwrap();
314
315 let value: u32 = graph.read(&id).await.unwrap();
317 assert_eq!(value, 42);
318 }
319
320 #[tokio::test]
321 async fn test_signal_not_found() {
322 let graph = SignalGraph::new();
323 let id = SignalId::new("nonexistent");
324
325 let result: Result<u32, _> = graph.read(&id).await;
326 assert!(matches!(result, Err(ReactiveError::SignalNotFound { .. })));
327 }
328
329 #[tokio::test]
330 async fn test_type_mismatch() {
331 let graph = SignalGraph::new();
332 let id = SignalId::new("typed");
333
334 graph.register(id.clone(), 42u32).await.unwrap();
335
336 let result: Result<String, _> = graph.read(&id).await;
338 assert!(matches!(result, Err(ReactiveError::TypeMismatch { .. })));
339 }
340
341 #[tokio::test]
342 async fn test_subscription() {
343 let graph = Arc::new(SignalGraph::new());
344 let id = SignalId::new("observable");
345
346 graph
347 .register(id.clone(), "initial".to_string())
348 .await
349 .unwrap();
350
351 let receiver = graph.subscribe(&id).await.unwrap();
353 let mut typed_receiver = TypedSignalReceiver::<String>::new(receiver, id.clone());
354
355 let graph_clone = graph.clone();
357 let id_clone = id.clone();
358 tokio::spawn(async move {
359 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
360 graph_clone
361 .emit(&id_clone, "updated".to_string())
362 .await
363 .unwrap();
364 });
365
366 let value = typed_receiver.recv().await.unwrap();
368 assert_eq!(value, "updated");
369 }
370}