Skip to main content

aura_effects/reactive/
graph.rs

1//! Signal Graph - Reactive State Management
2//!
3//! The signal graph manages signal storage, dependency tracking, and change propagation.
4//! It provides the foundation for the reactive effect system.
5
6use aura_core::effects::reactive::{ReactiveError, SignalId};
7use std::any::Any;
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::{broadcast, RwLock};
11
12// ─────────────────────────────────────────────────────────────────────────────
13// Signal Storage
14// ─────────────────────────────────────────────────────────────────────────────
15
16/// Type-erased value wrapper that implements Clone via Arc.
17#[derive(Clone)]
18pub struct AnyValue(pub(crate) Arc<dyn Any + Send + Sync>);
19
20/// Type-erased signal value storage.
21///
22/// This allows storing values of any type in the graph while maintaining
23/// type safety through the Signal<T> phantom type at the API level.
24struct SignalSlot {
25    /// The current value (type-erased)
26    value: AnyValue,
27    /// Broadcast channel for notifying subscribers
28    sender: broadcast::Sender<AnyValue>,
29    /// Type name for debugging
30    type_name: &'static str,
31}
32
33impl SignalSlot {
34    /// Create a new signal slot with an initial value.
35    fn new<T: Clone + Send + Sync + 'static>(initial: T) -> Self {
36        let (sender, _) = broadcast::channel(256); // Buffer size for updates
37        Self {
38            value: AnyValue(Arc::new(initial)),
39            sender,
40            type_name: std::any::type_name::<T>(),
41        }
42    }
43
44    /// Read the current value.
45    fn read<T: Clone + Send + Sync + 'static>(&self) -> Result<T, ReactiveError> {
46        self.value
47            .0
48            .downcast_ref::<T>()
49            .cloned()
50            .ok_or_else(|| ReactiveError::TypeMismatch {
51                id: "unknown".to_string(),
52                expected: std::any::type_name::<T>().to_string(),
53                actual: self.type_name.to_string(),
54            })
55    }
56
57    /// Update the value and notify subscribers.
58    fn emit<T: Clone + Send + Sync + 'static>(&mut self, value: T) -> Result<(), ReactiveError> {
59        // Verify type matches
60        if self.type_name != std::any::type_name::<T>() {
61            return Err(ReactiveError::TypeMismatch {
62                id: "unknown".to_string(),
63                expected: self.type_name.to_string(),
64                actual: std::any::type_name::<T>().to_string(),
65            });
66        }
67
68        // Update value
69        let wrapped = AnyValue(Arc::new(value));
70        self.value = wrapped.clone();
71
72        // Notify subscribers (ignore send errors - means no subscribers)
73        let _ = self.sender.send(wrapped);
74
75        Ok(())
76    }
77
78    /// Subscribe to changes.
79    fn subscribe(&self) -> broadcast::Receiver<AnyValue> {
80        self.sender.subscribe()
81    }
82}
83
84// ─────────────────────────────────────────────────────────────────────────────
85// Signal Graph
86// ─────────────────────────────────────────────────────────────────────────────
87
88/// The signal graph manages reactive state.
89///
90/// It provides:
91/// - Signal registration and storage
92/// - Type-safe read/emit operations
93/// - Subscription management
94/// - (Future) Derived signal computation and dependency tracking
95pub struct SignalGraph {
96    /// Signal storage, keyed by SignalId
97    signals: RwLock<HashMap<SignalId, SignalSlot>>,
98}
99
100impl SignalGraph {
101    /// Create a new empty signal graph.
102    pub fn new() -> Self {
103        Self {
104            signals: RwLock::new(HashMap::new()),
105        }
106    }
107
108    /// Register a signal with an initial value.
109    pub async fn register<T: Clone + Send + Sync + 'static>(
110        &self,
111        id: SignalId,
112        initial: T,
113    ) -> Result<(), ReactiveError> {
114        let mut signals = self.signals.write().await;
115
116        if signals.contains_key(&id) {
117            return Err(ReactiveError::Internal {
118                reason: format!("Signal '{id}' already registered"),
119            });
120        }
121
122        signals.insert(id, SignalSlot::new(initial));
123        Ok(())
124    }
125
126    /// Check if a signal is registered.
127    pub async fn is_registered(&self, id: &SignalId) -> bool {
128        self.signals.read().await.contains_key(id)
129    }
130
131    /// Read the current value of a signal.
132    pub async fn read<T: Clone + Send + Sync + 'static>(
133        &self,
134        id: &SignalId,
135    ) -> Result<T, ReactiveError> {
136        let signals = self.signals.read().await;
137
138        let slot = signals
139            .get(id)
140            .ok_or_else(|| ReactiveError::SignalNotFound { id: id.to_string() })?;
141
142        slot.read::<T>().map_err(|e| match e {
143            ReactiveError::TypeMismatch {
144                expected, actual, ..
145            } => ReactiveError::TypeMismatch {
146                id: id.to_string(),
147                expected,
148                actual,
149            },
150            other => other,
151        })
152    }
153
154    /// Emit a new value to a signal.
155    pub async fn emit<T: Clone + Send + Sync + 'static>(
156        &self,
157        id: &SignalId,
158        value: T,
159    ) -> Result<(), ReactiveError> {
160        let mut signals = self.signals.write().await;
161
162        let slot = signals
163            .get_mut(id)
164            .ok_or_else(|| ReactiveError::SignalNotFound { id: id.to_string() })?;
165
166        slot.emit(value).map_err(|e| match e {
167            ReactiveError::TypeMismatch {
168                expected, actual, ..
169            } => ReactiveError::TypeMismatch {
170                id: id.to_string(),
171                expected,
172                actual,
173            },
174            other => other,
175        })
176    }
177
178    /// Subscribe to a signal's changes.
179    ///
180    /// Returns a broadcast receiver that yields type-erased values.
181    /// The caller is responsible for downcasting.
182    pub async fn subscribe(
183        &self,
184        id: &SignalId,
185    ) -> Result<broadcast::Receiver<AnyValue>, ReactiveError> {
186        let signals = self.signals.read().await;
187
188        let slot = signals
189            .get(id)
190            .ok_or_else(|| ReactiveError::SignalNotFound { id: id.to_string() })?;
191
192        Ok(slot.subscribe())
193    }
194
195    /// Get statistics about the signal graph.
196    pub async fn stats(&self) -> SignalGraphStats {
197        let signals = self.signals.read().await;
198        SignalGraphStats {
199            signal_count: signals.len(),
200        }
201    }
202}
203
204impl Default for SignalGraph {
205    fn default() -> Self {
206        Self::new()
207    }
208}
209
210/// Statistics about the signal graph.
211#[derive(Debug, Clone)]
212pub struct SignalGraphStats {
213    /// Number of registered signals
214    pub signal_count: usize,
215}
216
217// ─────────────────────────────────────────────────────────────────────────────
218// Typed Signal Receiver
219// ─────────────────────────────────────────────────────────────────────────────
220
221/// A typed receiver for signal updates.
222///
223/// Wraps a broadcast receiver and provides type-safe access to values.
224pub struct TypedSignalReceiver<T> {
225    receiver: broadcast::Receiver<AnyValue>,
226    signal_id: SignalId,
227    _phantom: std::marker::PhantomData<T>,
228}
229
230impl<T: Clone + Send + Sync + 'static> TypedSignalReceiver<T> {
231    /// Create a new typed receiver.
232    pub fn new(receiver: broadcast::Receiver<AnyValue>, signal_id: SignalId) -> Self {
233        Self {
234            receiver,
235            signal_id,
236            _phantom: std::marker::PhantomData,
237        }
238    }
239
240    /// Try to receive the next value without blocking.
241    pub fn try_recv(&mut self) -> Option<T> {
242        loop {
243            match self.receiver.try_recv() {
244                Ok(any_value) => {
245                    if let Some(value) = any_value.0.downcast_ref::<T>() {
246                        return Some(value.clone());
247                    }
248                    // Type mismatch, skip this value
249                    continue;
250                }
251                Err(_) => return None,
252            }
253        }
254    }
255
256    /// Receive the next value, waiting if necessary.
257    pub async fn recv(&mut self) -> Result<T, ReactiveError> {
258        loop {
259            match self.receiver.recv().await {
260                Ok(any_value) => {
261                    if let Some(value) = any_value.0.downcast_ref::<T>() {
262                        return Ok(value.clone());
263                    }
264                    // Type mismatch, skip this value
265                    continue;
266                }
267                Err(broadcast::error::RecvError::Closed) => {
268                    return Err(ReactiveError::SubscriptionClosed {
269                        id: self.signal_id.to_string(),
270                    });
271                }
272                Err(broadcast::error::RecvError::Lagged(_)) => {
273                    // Missed some values, continue receiving
274                    continue;
275                }
276            }
277        }
278    }
279}
280
281// ─────────────────────────────────────────────────────────────────────────────
282// Tests
283// ─────────────────────────────────────────────────────────────────────────────
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[tokio::test]
290    async fn test_signal_registration() {
291        let graph = SignalGraph::new();
292        let id = SignalId::new("test");
293
294        assert!(!graph.is_registered(&id).await);
295
296        graph.register(id.clone(), 42u32).await.unwrap();
297
298        assert!(graph.is_registered(&id).await);
299    }
300
301    #[tokio::test]
302    async fn test_signal_read_write() {
303        let graph = SignalGraph::new();
304        let id = SignalId::new("counter");
305
306        graph.register(id.clone(), 0u32).await.unwrap();
307
308        // Read initial value
309        let value: u32 = graph.read(&id).await.unwrap();
310        assert_eq!(value, 0);
311
312        // Emit new value
313        graph.emit(&id, 42u32).await.unwrap();
314
315        // Read updated value
316        let value: u32 = graph.read(&id).await.unwrap();
317        assert_eq!(value, 42);
318    }
319
320    #[tokio::test]
321    async fn test_signal_not_found() {
322        let graph = SignalGraph::new();
323        let id = SignalId::new("nonexistent");
324
325        let result: Result<u32, _> = graph.read(&id).await;
326        assert!(matches!(result, Err(ReactiveError::SignalNotFound { .. })));
327    }
328
329    #[tokio::test]
330    async fn test_type_mismatch() {
331        let graph = SignalGraph::new();
332        let id = SignalId::new("typed");
333
334        graph.register(id.clone(), 42u32).await.unwrap();
335
336        // Try to read as wrong type
337        let result: Result<String, _> = graph.read(&id).await;
338        assert!(matches!(result, Err(ReactiveError::TypeMismatch { .. })));
339    }
340
341    #[tokio::test]
342    async fn test_subscription() {
343        let graph = Arc::new(SignalGraph::new());
344        let id = SignalId::new("observable");
345
346        graph
347            .register(id.clone(), "initial".to_string())
348            .await
349            .unwrap();
350
351        // Create subscription
352        let receiver = graph.subscribe(&id).await.unwrap();
353        let mut typed_receiver = TypedSignalReceiver::<String>::new(receiver, id.clone());
354
355        // Emit in background
356        let graph_clone = graph.clone();
357        let id_clone = id.clone();
358        tokio::spawn(async move {
359            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
360            graph_clone
361                .emit(&id_clone, "updated".to_string())
362                .await
363                .unwrap();
364        });
365
366        // Receive update
367        let value = typed_receiver.recv().await.unwrap();
368        assert_eq!(value, "updated");
369    }
370}