1#![allow(clippy::disallowed_types)]
7
8use async_trait::async_trait;
9use aura_core::effects::reactive::{
10 ReactiveEffects, ReactiveError, Signal, SignalId, SignalStream,
11};
12use aura_core::query::{FactPredicate, Query};
13use std::collections::{HashMap, HashSet};
14use std::future::Future;
15use std::sync::{Arc, RwLock};
16use tokio::sync::broadcast;
17use tokio::sync::watch;
18#[cfg(not(target_arch = "wasm32"))]
19use tokio::task::JoinHandle;
20#[cfg(target_arch = "wasm32")]
21use wasm_bindgen_futures::spawn_local;
22
23use super::graph::SignalGraph;
24
25const REACTIVE_SUBSCRIPTION_BUFFER_CAPACITY: usize = 256;
26
27pub struct ReactiveHandler {
40 graph: Arc<SignalGraph>,
42 registered_ids: Arc<RwLock<HashSet<SignalId>>>,
44 query_deps: Arc<RwLock<HashMap<SignalId, Vec<FactPredicate>>>>,
46 tasks: Arc<ReactiveTaskRegistry>,
48}
49
50impl ReactiveHandler {
51 pub fn new() -> Self {
53 Self {
54 graph: Arc::new(SignalGraph::new()),
55 registered_ids: Arc::new(RwLock::new(HashSet::new())),
56 query_deps: Arc::new(RwLock::new(HashMap::new())),
57 tasks: Arc::new(ReactiveTaskRegistry::new()),
58 }
59 }
60
61 pub fn with_graph(graph: Arc<SignalGraph>) -> Self {
65 Self {
66 graph,
67 registered_ids: Arc::new(RwLock::new(HashSet::new())),
68 query_deps: Arc::new(RwLock::new(HashMap::new())),
69 tasks: Arc::new(ReactiveTaskRegistry::new()),
70 }
71 }
72
73 pub fn with_graph_and_registry(
75 graph: Arc<SignalGraph>,
76 registered_ids: Arc<RwLock<HashSet<SignalId>>>,
77 ) -> Self {
78 Self {
79 graph,
80 registered_ids,
81 query_deps: Arc::new(RwLock::new(HashMap::new())),
82 tasks: Arc::new(ReactiveTaskRegistry::new()),
83 }
84 }
85
86 pub fn graph(&self) -> &Arc<SignalGraph> {
88 &self.graph
89 }
90
91 pub async fn stats(&self) -> super::graph::SignalGraphStats {
93 self.graph.stats().await
94 }
95
96 fn signals_for_predicate(&self, predicate: &FactPredicate) -> Vec<SignalId> {
100 self.query_deps
101 .read()
102 .map(|deps| {
103 deps.iter()
104 .filter_map(|(signal_id, predicates)| {
105 if predicates.iter().any(|p| p.matches(predicate)) {
106 Some(signal_id.clone())
107 } else {
108 None
109 }
110 })
111 .collect()
112 })
113 .unwrap_or_default()
114 }
115}
116
117impl Default for ReactiveHandler {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl Clone for ReactiveHandler {
124 fn clone(&self) -> Self {
125 Self {
126 graph: self.graph.clone(),
127 registered_ids: self.registered_ids.clone(),
128 query_deps: self.query_deps.clone(),
129 tasks: self.tasks.clone(),
130 }
131 }
132}
133
134impl Drop for ReactiveHandler {
135 fn drop(&mut self) {
136 if Arc::strong_count(&self.tasks) == 1 {
137 self.tasks.shutdown();
138 }
139 }
140}
141
142#[derive(Debug)]
147struct ReactiveTaskRegistry {
148 shutdown_tx: watch::Sender<bool>,
149 #[cfg(not(target_arch = "wasm32"))]
150 handles: std::sync::Mutex<Vec<JoinHandle<()>>>,
151}
152
153impl ReactiveTaskRegistry {
154 fn new() -> Self {
155 let (shutdown_tx, _shutdown_rx) = watch::channel(false);
156 Self {
157 shutdown_tx,
158 #[cfg(not(target_arch = "wasm32"))]
159 handles: std::sync::Mutex::new(Vec::new()),
160 }
161 }
162
163 fn spawn_cancellable<F>(&self, fut: F)
164 where
165 F: Future<Output = ()> + Send + 'static,
166 {
167 let mut shutdown_rx = self.shutdown_tx.subscribe();
168 #[cfg(target_arch = "wasm32")]
169 spawn_local(async move {
170 tokio::select! {
171 _ = shutdown_rx.changed() => {}
172 _ = fut => {}
173 }
174 });
175
176 #[cfg(not(target_arch = "wasm32"))]
177 let handle = tokio::spawn(async move {
178 tokio::select! {
179 _ = shutdown_rx.changed() => {}
180 _ = fut => {}
181 }
182 });
183 #[cfg(not(target_arch = "wasm32"))]
184 if let Ok(mut handles) = self.handles.lock() {
185 handles.push(handle);
186 }
187 }
188
189 fn shutdown(&self) {
190 let _ = self.shutdown_tx.send(true);
191 #[cfg(not(target_arch = "wasm32"))]
192 if let Ok(mut handles) = self.handles.lock() {
193 for handle in handles.drain(..) {
194 handle.abort();
195 }
196 }
197 }
198}
199
200#[async_trait]
205impl ReactiveEffects for ReactiveHandler {
206 async fn read<T>(&self, signal: &Signal<T>) -> Result<T, ReactiveError>
207 where
208 T: Clone + Send + Sync + 'static,
209 {
210 self.graph.read(signal.id()).await
211 }
212
213 async fn emit<T>(&self, signal: &Signal<T>, value: T) -> Result<(), ReactiveError>
214 where
215 T: Clone + Send + Sync + 'static,
216 {
217 self.graph.emit(signal.id(), value).await
218 }
219
220 fn subscribe<T>(&self, signal: &Signal<T>) -> Result<SignalStream<T>, ReactiveError>
221 where
222 T: Clone + Send + Sync + 'static,
223 {
224 if !self.is_registered(signal.id()) {
225 return Err(ReactiveError::SignalNotFound {
226 id: signal.id().to_string(),
227 });
228 }
229
230 let (tx, rx) = broadcast::channel::<T>(REACTIVE_SUBSCRIPTION_BUFFER_CAPACITY);
236
237 let graph = self.graph.clone();
239 let signal_id = signal.id().clone();
240
241 self.tasks.spawn_cancellable(async move {
242 match graph.subscribe(&signal_id).await {
243 Ok(mut receiver) => loop {
244 match receiver.recv().await {
245 Ok(any_value) => {
246 if let Some(value) = any_value.0.downcast_ref::<T>() {
247 if tx.send(value.clone()).is_err() {
248 break;
250 }
251 }
252 }
253 Err(broadcast::error::RecvError::Closed) => break,
254 Err(broadcast::error::RecvError::Lagged(skipped)) => {
255 tracing::warn!(
256 signal_id = %signal_id,
257 skipped,
258 "reactive subscription lagged; updates were dropped"
259 );
260 continue;
261 }
262 }
263 },
264 Err(error) => {
265 tracing::warn!(
266 signal_id = %signal_id,
267 error = %error,
268 "reactive subscription forwarding task exited before attaching"
269 );
270 }
271 }
272 });
273
274 Ok(SignalStream::new(rx, signal.id().clone()))
275 }
276
277 async fn register<T>(&self, signal: &Signal<T>, initial: T) -> Result<(), ReactiveError>
278 where
279 T: Clone + Send + Sync + 'static,
280 {
281 self.graph.register(signal.id().clone(), initial).await?;
283
284 if let Ok(mut ids) = self.registered_ids.write() {
286 ids.insert(signal.id().clone());
287 }
288
289 Ok(())
290 }
291
292 fn is_registered(&self, signal_id: &SignalId) -> bool {
293 self.registered_ids
296 .read()
297 .map(|ids| ids.contains(signal_id))
298 .unwrap_or(false)
299 }
300
301 async fn register_query<Q: Query>(
302 &self,
303 signal: &Signal<Q::Result>,
304 query: Q,
305 ) -> Result<(), ReactiveError> {
306 let deps = query.dependencies();
308
309 let initial: Q::Result = Default::default();
313 self.register(signal, initial).await?;
314
315 if let Ok(mut deps_map) = self.query_deps.write() {
317 deps_map.insert(signal.id().clone(), deps);
318 }
319
320 Ok(())
321 }
322
323 fn query_dependencies(&self, signal_id: &SignalId) -> Option<Vec<FactPredicate>> {
324 self.query_deps
325 .read()
326 .ok()
327 .and_then(|deps| deps.get(signal_id).cloned())
328 }
329
330 async fn invalidate_queries(&self, changed: &FactPredicate) {
331 let affected_signals = self.signals_for_predicate(changed);
333
334 for signal_id in affected_signals {
338 tracing::debug!(
339 signal_id = %signal_id,
340 predicate = ?changed,
341 "Signal invalidated due to fact change"
342 );
343 }
344 }
345}
346
347#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[tokio::test]
356 async fn test_handler_creation() {
357 let handler = ReactiveHandler::new();
358 let stats = handler.stats().await;
359 assert_eq!(stats.signal_count, 0);
360 }
361
362 #[tokio::test]
363 async fn test_handler_register_and_read() {
364 let handler = ReactiveHandler::new();
365 let signal: Signal<u32> = Signal::new("counter");
366
367 handler.register(&signal, 42).await.unwrap();
368
369 let value = handler.read(&signal).await.unwrap();
370 assert_eq!(value, 42);
371 }
372
373 #[tokio::test]
374 async fn test_handler_emit() {
375 let handler = ReactiveHandler::new();
376 let signal: Signal<String> = Signal::new("message");
377
378 handler
379 .register(&signal, "hello".to_string())
380 .await
381 .unwrap();
382
383 handler.emit(&signal, "world".to_string()).await.unwrap();
384
385 let value = handler.read(&signal).await.unwrap();
386 assert_eq!(value, "world");
387 }
388
389 #[tokio::test]
390 async fn test_shared_graph() {
391 let graph = Arc::new(SignalGraph::new());
392 let handler1 = ReactiveHandler::with_graph(graph.clone());
393 let handler2 = ReactiveHandler::with_graph(graph);
394
395 let signal: Signal<i32> = Signal::new("shared");
396
397 handler1.register(&signal, 100).await.unwrap();
399
400 let value: i32 = handler2.read(&signal).await.unwrap();
402 assert_eq!(value, 100);
403
404 handler2.emit(&signal, 200).await.unwrap();
406
407 let value: i32 = handler1.read(&signal).await.unwrap();
409 assert_eq!(value, 200);
410 }
411
412 #[tokio::test]
413 async fn test_is_registered() {
414 let handler = ReactiveHandler::new();
415 let signal: Signal<bool> = Signal::new("flag");
416
417 assert!(!handler.is_registered(signal.id()));
418
419 handler.register(&signal, true).await.unwrap();
420
421 assert!(handler.is_registered(signal.id()));
422 }
423
424 #[tokio::test]
427 async fn test_empty_string_signal() {
428 let handler = ReactiveHandler::new();
429 let signal: Signal<String> = Signal::new("empty");
430
431 handler.register(&signal, String::new()).await.unwrap();
432
433 let value = handler.read(&signal).await.unwrap();
434 assert_eq!(value, "");
435
436 handler.emit(&signal, String::new()).await.unwrap();
438 let value = handler.read(&signal).await.unwrap();
439 assert_eq!(value, "");
440 }
441
442 #[tokio::test]
443 async fn test_zero_value_signal() {
444 let handler = ReactiveHandler::new();
445 let signal: Signal<i64> = Signal::new("zero");
446
447 handler.register(&signal, 0).await.unwrap();
448
449 let value = handler.read(&signal).await.unwrap();
450 assert_eq!(value, 0);
451 }
452
453 #[tokio::test]
454 async fn test_rapid_updates() {
455 let handler = ReactiveHandler::new();
456 let signal: Signal<u32> = Signal::new("counter");
457
458 handler.register(&signal, 0).await.unwrap();
459
460 for i in 1..=100 {
462 handler.emit(&signal, i).await.unwrap();
463 }
464
465 let value = handler.read(&signal).await.unwrap();
467 assert_eq!(value, 100);
468 }
469
470 #[tokio::test]
471 async fn test_read_unregistered_signal() {
472 let handler = ReactiveHandler::new();
473 let signal: Signal<u32> = Signal::new("never_registered");
474
475 let result = handler.read(&signal).await;
476 assert!(matches!(result, Err(ReactiveError::SignalNotFound { .. })));
477 }
478
479 #[tokio::test]
480 async fn test_emit_unregistered_signal() {
481 let handler = ReactiveHandler::new();
482 let signal: Signal<u32> = Signal::new("never_registered");
483
484 let result = handler.emit(&signal, 42).await;
485 assert!(matches!(result, Err(ReactiveError::SignalNotFound { .. })));
486 }
487
488 #[tokio::test]
489 async fn test_subscribe_unregistered_signal_fails_fast() {
490 let handler = ReactiveHandler::new();
491 let signal: Signal<u32> = Signal::new("never_registered");
492
493 let result = handler.subscribe(&signal);
494 assert!(matches!(result, Err(ReactiveError::SignalNotFound { .. })));
495 }
496
497 #[tokio::test]
498 async fn test_subscription_lag_returns_newer_snapshot_after_drops() {
499 let handler = ReactiveHandler::new();
500 let signal: Signal<u32> = Signal::new("lagged");
501
502 handler.register(&signal, 0).await.unwrap();
503 let mut stream = handler.subscribe(&signal).unwrap();
504
505 for value in 1..=(REACTIVE_SUBSCRIPTION_BUFFER_CAPACITY as u32 + 32) {
506 handler.emit(&signal, value).await.unwrap();
507 }
508
509 let received = stream.recv().await.unwrap();
510 assert!(received > 1);
511 assert_eq!(
512 handler.read(&signal).await.unwrap(),
513 REACTIVE_SUBSCRIPTION_BUFFER_CAPACITY as u32 + 32
514 );
515 }
516
517 #[tokio::test]
518 async fn test_duplicate_registration() {
519 let handler = ReactiveHandler::new();
520 let signal: Signal<u32> = Signal::new("duplicate");
521
522 handler.register(&signal, 1).await.unwrap();
524
525 let result = handler.register(&signal, 2).await;
527 assert!(matches!(result, Err(ReactiveError::Internal { .. })));
528
529 let value = handler.read(&signal).await.unwrap();
531 assert_eq!(value, 1);
532 }
533
534 #[tokio::test]
535 async fn test_clone_handler_shares_state() {
536 let handler1 = ReactiveHandler::new();
537 let signal: Signal<u32> = Signal::new("cloned");
538
539 handler1.register(&signal, 10).await.unwrap();
540
541 let handler2 = handler1.clone();
543
544 let v1: u32 = handler1.read(&signal).await.unwrap();
546 let v2: u32 = handler2.read(&signal).await.unwrap();
547 assert_eq!(v1, v2);
548
549 handler2.emit(&signal, 20).await.unwrap();
551
552 let v1: u32 = handler1.read(&signal).await.unwrap();
554 let v2: u32 = handler2.read(&signal).await.unwrap();
555 assert_eq!(v1, 20);
556 assert_eq!(v2, 20);
557 }
558
559 #[tokio::test]
560 async fn test_complex_type_signal() {
561 #[derive(Clone, Debug, PartialEq)]
562 struct ComplexState {
563 count: u32,
564 label: String,
565 values: Vec<i32>,
566 }
567
568 let handler = ReactiveHandler::new();
569 let signal: Signal<ComplexState> = Signal::new("complex");
570
571 let initial = ComplexState {
572 count: 0,
573 label: "initial".to_string(),
574 values: vec![1, 2, 3],
575 };
576
577 handler.register(&signal, initial.clone()).await.unwrap();
578
579 let read_state = handler.read(&signal).await.unwrap();
580 assert_eq!(read_state, initial);
581
582 let updated = ComplexState {
584 count: 42,
585 label: "updated".to_string(),
586 values: vec![4, 5, 6, 7],
587 };
588
589 handler.emit(&signal, updated.clone()).await.unwrap();
590
591 let read_updated = handler.read(&signal).await.unwrap();
592 assert_eq!(read_updated, updated);
593 }
594
595 #[tokio::test]
596 async fn test_option_type_signal() {
597 let handler = ReactiveHandler::new();
598 let signal: Signal<Option<String>> = Signal::new("optional");
599
600 handler.register(&signal, None).await.unwrap();
601
602 let value = handler.read(&signal).await.unwrap();
603 assert_eq!(value, None);
604
605 handler
606 .emit(&signal, Some("value".to_string()))
607 .await
608 .unwrap();
609
610 let value = handler.read(&signal).await.unwrap();
611 assert_eq!(value, Some("value".to_string()));
612
613 handler.emit(&signal, None).await.unwrap();
614
615 let value = handler.read(&signal).await.unwrap();
616 assert_eq!(value, None);
617 }
618}