Skip to main content

aurora_db/reactive/
watcher.rs

1use super::{QueryUpdate, ReactiveQueryState};
2use crate::pubsub::ChangeListener;
3use crate::types::Document;
4use std::sync::Arc;
5use tokio::sync::mpsc;
6
7/// Watches a query and emits updates when results change
8pub struct QueryWatcher {
9    /// Receiver for query updates
10    receiver: mpsc::UnboundedReceiver<QueryUpdate>,
11    /// Collection being watched
12    collection: String,
13}
14
15impl QueryWatcher {
16    /// Create a new query watcher
17    ///
18    /// # Arguments
19    /// * `collection` - Collection to watch
20    /// * `listener` - Change listener for the collection
21    /// * `state` - Reactive query state
22    /// * `initial_results` - Initial query results to populate the state
23    pub fn new(
24        collection: impl Into<String>,
25        mut listener: ChangeListener,
26        state: Arc<ReactiveQueryState>,
27        initial_results: Vec<Document>,
28        debounce_duration: Option<std::time::Duration>,
29    ) -> Self {
30        let collection = collection.into();
31        let (sender, receiver) = mpsc::unbounded_channel();
32
33        // Populate initial results
34        let init_state = Arc::clone(&state);
35        let init_sender = sender.clone();
36        tokio::spawn(async move {
37            for doc in initial_results {
38                if let Some(update) = init_state.add_if_matches(doc).await {
39                    let _ = init_sender.send(update);
40                }
41            }
42        });
43
44        // Spawn background task to listen for changes
45        tokio::spawn(async move {
46            while let Ok(event) = listener.recv().await {
47                let update = match event.change_type {
48                    crate::pubsub::ChangeType::Insert => {
49                        if let Some(doc) = event.document {
50                            state.add_if_matches(doc).await
51                        } else {
52                            None
53                        }
54                    }
55                    crate::pubsub::ChangeType::Update => {
56                        if let Some(new_doc) = event.document {
57                            state.update(&event.id, new_doc).await
58                        } else {
59                            None
60                        }
61                    }
62                    crate::pubsub::ChangeType::Delete => state.remove(&event.id).await,
63                };
64
65                if let Some(u) = update
66                    && sender.send(u).is_err()
67                {
68                    // Receiver dropped, stop watching
69                    break;
70                }
71            }
72        });
73
74        // If debounce is requested, wrap the receiver in a throttling task
75        let final_receiver = if let Some(duration) = debounce_duration {
76            let (tx_throttled, rx_throttled) = mpsc::unbounded_channel();
77            let mut raw_rx = receiver;
78
79            tokio::spawn(async move {
80                use std::collections::HashMap;
81                use tokio::time::interval as tokio_interval;
82
83                let mut tick = tokio_interval(duration);
84                tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
85
86                let mut pending: HashMap<String, QueryUpdate> = HashMap::new();
87
88                loop {
89                    tokio::select! {
90                        biased;
91                        maybe_update = raw_rx.recv() => {
92                            match maybe_update {
93                                Some(update) => {
94                                    pending.insert(update.id().to_string(), update);
95                                }
96                                None => break,
97                            }
98                        }
99                        _ = tick.tick() => {
100                            if !pending.is_empty() {
101                                for (_, update) in pending.drain() {
102                                    if tx_throttled.send(update).is_err() {
103                                        return;
104                                    }
105                                }
106                            }
107                        }
108                    }
109                }
110            });
111            rx_throttled
112        } else {
113            receiver
114        };
115
116        Self {
117            receiver: final_receiver,
118            collection,
119        }
120    }
121
122    /// Get the next query update
123    /// Returns None when the watcher is closed
124    pub async fn next(&mut self) -> Option<QueryUpdate> {
125        self.receiver.recv().await
126    }
127
128    /// Get the collection name being watched
129    pub fn collection(&self) -> &str {
130        &self.collection
131    }
132
133    /// Try to receive an update without blocking
134    pub fn try_next(&mut self) -> Option<QueryUpdate> {
135        self.receiver.try_recv().ok()
136    }
137
138    /// Convert to a throttled watcher for rate-limiting updates
139    ///
140    /// Events are buffered and emitted at most once per interval.
141    /// Deduplicates by document ID, keeping only the latest state.
142    pub fn throttled(self, interval: std::time::Duration) -> ThrottledQueryWatcher {
143        ThrottledQueryWatcher::new(self.receiver, self.collection, interval)
144    }
145}
146
147/// A throttled/debounced query watcher for rate-limiting reactive updates
148///
149/// Buffers incoming events and emits them at a fixed interval.
150/// Deduplicates by document ID, keeping only the latest state per ID.
151/// This prevents overwhelming the UI with high-frequency updates.
152pub struct ThrottledQueryWatcher {
153    receiver: mpsc::UnboundedReceiver<QueryUpdate>,
154    collection: String,
155}
156
157impl ThrottledQueryWatcher {
158    /// Create a new throttled watcher
159    pub fn new(
160        mut raw_receiver: mpsc::UnboundedReceiver<QueryUpdate>,
161        collection: impl Into<String>,
162        interval: std::time::Duration,
163    ) -> Self {
164        let collection = collection.into();
165        let (tx, rx) = mpsc::unbounded_channel();
166
167        tokio::spawn(async move {
168            use std::collections::HashMap;
169            use tokio::time::interval as tokio_interval;
170
171            let mut tick = tokio_interval(interval);
172            tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
173
174            // Buffer: doc_id -> latest update for that doc
175            let mut pending: HashMap<String, QueryUpdate> = HashMap::new();
176
177            loop {
178                tokio::select! {
179                    biased;
180
181                    // Collect events as fast as they come
182                    maybe_update = raw_receiver.recv() => {
183                        match maybe_update {
184                            Some(update) => {
185                                // Dedupe by doc ID - keep latest state
186                                pending.insert(update.id().to_string(), update);
187                            }
188                            None => break, // Raw receiver closed
189                        }
190                    }
191
192                    // Every tick, flush the buffer
193                    _ = tick.tick() => {
194                        if !pending.is_empty() {
195                            for (_, update) in pending.drain() {
196                                if tx.send(update).is_err() {
197                                    return; // Receiver dropped
198                                }
199                            }
200                        }
201                    }
202                }
203            }
204        });
205
206        Self {
207            receiver: rx,
208            collection,
209        }
210    }
211
212    /// Get the next throttled update
213    pub async fn next(&mut self) -> Option<QueryUpdate> {
214        self.receiver.recv().await
215    }
216
217    /// Get the collection name
218    pub fn collection(&self) -> &str {
219        &self.collection
220    }
221
222    /// Try to receive without blocking
223    pub fn try_next(&mut self) -> Option<QueryUpdate> {
224        self.receiver.try_recv().ok()
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::pubsub::{ChangeEvent, PubSubSystem};
232    use crate::types::Value;
233    use std::collections::HashMap;
234
235    #[tokio::test]
236    async fn test_query_watcher_insert() {
237        let pubsub = PubSubSystem::new(100);
238        let listener = pubsub.listen("users");
239
240        let state = Arc::new(ReactiveQueryState::new(|doc: &Document| {
241            doc.data.get("active") == Some(&Value::Bool(true))
242        }));
243
244        let mut watcher = QueryWatcher::new("users", listener, state, vec![], None);
245
246        // Publish an insert event for an active user
247        let mut data = HashMap::new();
248        data.insert("active".to_string(), Value::Bool(true));
249        data.insert("name".to_string(), Value::String("Alice".into()));
250
251        let doc = Document {
252            id: "1".to_string(),
253            data,
254        };
255
256        pubsub
257            .publish(ChangeEvent::insert("users", "1", doc))
258            .unwrap();
259
260        // Should receive an Added update
261        let update = watcher.next().await.unwrap();
262        assert!(matches!(update, QueryUpdate::Added(_)));
263        assert_eq!(update.id(), "1");
264    }
265
266    #[tokio::test]
267    async fn test_query_watcher_filter() {
268        let pubsub = PubSubSystem::new(100);
269        let listener = pubsub.listen("users");
270
271        let state = Arc::new(ReactiveQueryState::new(|doc: &Document| {
272            doc.data.get("active") == Some(&Value::Bool(true))
273        }));
274
275        let mut watcher = QueryWatcher::new("users", listener, state, vec![], None);
276
277        // Publish an inactive user (should be filtered)
278        let mut inactive_data = HashMap::new();
279        inactive_data.insert("active".to_string(), Value::Bool(false));
280
281        pubsub
282            .publish(ChangeEvent::insert(
283                "users",
284                "1",
285                Document {
286                    id: "1".to_string(),
287                    data: inactive_data,
288                },
289            ))
290            .unwrap();
291
292        // Publish an active user (should pass filter)
293        let mut active_data = HashMap::new();
294        active_data.insert("active".to_string(), Value::Bool(true));
295
296        pubsub
297            .publish(ChangeEvent::insert(
298                "users",
299                "2",
300                Document {
301                    id: "2".to_string(),
302                    data: active_data,
303                },
304            ))
305            .unwrap();
306
307        // Should only receive the active user
308        let update = watcher.next().await.unwrap();
309        assert_eq!(update.id(), "2");
310    }
311
312    #[tokio::test]
313    async fn test_debounced_watcher() {
314        use std::time::Duration;
315        use tokio::sync::mpsc;
316
317        // Create a channel that simulates raw query updates
318        let (tx, rx) = mpsc::unbounded_channel();
319
320        // Create throttled watcher with 100ms interval
321        let mut throttled = ThrottledQueryWatcher::new(rx, "test", Duration::from_millis(100));
322
323        // Send multiple updates for the same document rapidly
324        let mut data1 = HashMap::new();
325        data1.insert("value".to_string(), Value::Int(1));
326        tx.send(QueryUpdate::Added(Document {
327            id: "doc1".to_string(),
328            data: data1,
329        }))
330        .unwrap();
331
332        let mut data2 = HashMap::new();
333        data2.insert("value".to_string(), Value::Int(2));
334        tx.send(QueryUpdate::Modified {
335            old: Document {
336                id: "doc1".to_string(),
337                data: HashMap::new(),
338            },
339            new: Document {
340                id: "doc1".to_string(),
341                data: data2,
342            },
343        })
344        .unwrap();
345
346        let mut data3 = HashMap::new();
347        data3.insert("value".to_string(), Value::Int(3));
348        tx.send(QueryUpdate::Modified {
349            old: Document {
350                id: "doc1".to_string(),
351                data: HashMap::new(),
352            },
353            new: Document {
354                id: "doc1".to_string(),
355                data: data3.clone(),
356            },
357        })
358        .unwrap();
359
360        // Wait for throttle interval to pass
361        tokio::time::sleep(Duration::from_millis(150)).await;
362
363        // Should receive only the latest update (deduped by doc ID)
364        let update = throttled.try_next();
365        assert!(update.is_some());
366        // The last one wins due to deduplication
367        assert_eq!(update.unwrap().id(), "doc1");
368    }
369
370    #[tokio::test]
371    async fn test_throttled_watcher_multiple_docs() {
372        use std::time::Duration;
373        use tokio::sync::mpsc;
374
375        let (tx, rx) = mpsc::unbounded_channel();
376        let mut throttled = ThrottledQueryWatcher::new(rx, "test", Duration::from_millis(100));
377
378        // Send updates for different documents
379        for i in 1..=3 {
380            let mut data = HashMap::new();
381            data.insert("value".to_string(), Value::Int(i));
382            tx.send(QueryUpdate::Added(Document {
383                id: format!("doc{}", i),
384                data,
385            }))
386            .unwrap();
387        }
388
389        // Wait for throttle
390        tokio::time::sleep(Duration::from_millis(150)).await;
391
392        // Should receive all 3 (different IDs, no deduplication)
393        let mut received = Vec::new();
394        while let Some(update) = throttled.try_next() {
395            received.push(update.id().to_string());
396        }
397
398        assert_eq!(received.len(), 3);
399        assert!(received.contains(&"doc1".to_string()));
400        assert!(received.contains(&"doc2".to_string()));
401        assert!(received.contains(&"doc3".to_string()));
402    }
403}