Skip to main content

aurora_db/reactive/
watcher.rs

1use super::{QueryUpdate, ReactiveQueryState};
2use crate::Aurora;
3use crate::pubsub::ChangeListener;
4use crate::types::Document;
5use std::sync::Arc;
6use tokio::sync::mpsc;
7
8/// Watches a query and emits updates when results change
9pub struct QueryWatcher {
10    /// Receiver for query updates
11    receiver: mpsc::UnboundedReceiver<QueryUpdate>,
12    /// Collection being watched
13    collection: String,
14    /// Reference to the database for resyncing
15    #[allow(dead_code)]
16    db: Arc<Aurora>,
17}
18
19impl QueryWatcher {
20    /// Create a new query watcher
21    pub fn new(
22        db: Arc<Aurora>,
23        collection: impl Into<String>,
24        mut listener: ChangeListener,
25        state: Arc<ReactiveQueryState>,
26        initial_results: Vec<Document>,
27        debounce_duration: Option<std::time::Duration>,
28    ) -> Self {
29        let collection = collection.into();
30        let (sender, receiver) = mpsc::unbounded_channel();
31
32        // Populate initial results
33        let init_state = Arc::clone(&state);
34        let init_sender = sender.clone();
35        tokio::spawn(async move {
36            for doc in initial_results {
37                if let Some(update) = init_state.add_if_matches(doc).await {
38                    let _ = init_sender.send(update);
39                }
40            }
41        });
42
43        // Spawn background task to listen for changes
44        let db_clone = Arc::clone(&db);
45        let coll_clone = collection.clone();
46        let state_clone = Arc::clone(&state);
47        let sender_clone = sender.clone();
48
49        tokio::spawn(async move {
50            let mut backoff_ms = 100; // Initial backoff
51
52            loop {
53                match listener.recv().await {
54                    Ok(event) => {
55                        // Successful receive, reset backoff gradually
56                        if backoff_ms > 100 { backoff_ms -= 50; }
57
58                        let update = match event.change_type {
59                            crate::pubsub::ChangeType::Insert => {
60                                if let Some(doc) = event.document {
61                                    state_clone.add_if_matches(doc).await
62                                } else {
63                                    None
64                                }
65                            }
66                            crate::pubsub::ChangeType::Update => {
67                                if let Some(new_doc) = event.document {
68                                    state_clone.update(&event.id, new_doc).await
69                                } else {
70                                    None
71                                }
72                            }
73                            crate::pubsub::ChangeType::Delete => state_clone.remove(&event.id).await,
74                        };
75
76                        if let Some(u) = update && sender_clone.send(u).is_err() {
77                            break;
78                        }
79                    }
80                    Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
81                        eprintln!(
82                            "WARNING: Watcher for '{}' lagged by {} events. Applying {}ms backoff...",
83                            coll_clone, skipped, backoff_ms
84                        );
85
86                        // 1. DRAIN: Empty everything currently in the channel (Ok and Lagged)
87                        loop {
88                            match listener.try_recv() {
89                                Ok(_) | Err(tokio::sync::broadcast::error::TryRecvError::Lagged(_)) => continue,
90                                _ => break,
91                            }
92                        }
93
94                        // 2. BACKOFF: Wait for the event storm to subside
95                        tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await;
96                        
97                        // Increase backoff for next time (max 2 seconds)
98                        backoff_ms = (backoff_ms * 2).min(2000);
99
100                        // 3. SNAPSHOT: Collect documents into a Vec inside the block to keep iterator usage local (not across await)
101                        let docs_snapshot: Vec<Document> = if let Ok(iter) = db_clone.stream_collection(&coll_clone) {
102                            iter.collect()
103                        } else {
104                            Vec::new()
105                        };
106
107                        // 4. SYNC: Synchronize ReactiveQueryState and emit deltas
108                        let updates = state_clone.sync_state(docs_snapshot).await;
109                        for u in updates {
110                            if sender_clone.send(u).is_err() {
111                                return;
112                            }
113                        }
114                    }
115                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
116                }
117            }
118        });
119
120        // Throttling logic (unchanged)
121        let final_receiver = if let Some(duration) = debounce_duration {
122            let (tx_throttled, rx_throttled) = mpsc::unbounded_channel();
123            let mut raw_rx = receiver;
124            tokio::spawn(async move {
125                use std::collections::HashMap;
126                let mut tick = tokio::time::interval(duration);
127                tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
128                let mut pending: HashMap<String, QueryUpdate> = HashMap::new();
129                loop {
130                    tokio::select! {
131                        biased;
132                        maybe_update = raw_rx.recv() => {
133                            match maybe_update {
134                                Some(update) => { pending.insert(update.id().to_string(), update); }
135                                None => break,
136                            }
137                        }
138                        _ = tick.tick() => {
139                            if !pending.is_empty() {
140                                for (_, update) in pending.drain() {
141                                    if tx_throttled.send(update).is_err() { return; }
142                                }
143                            }
144                        }
145                    }
146                }
147            });
148            rx_throttled
149        } else {
150            receiver
151        };
152
153        Self {
154            receiver: final_receiver,
155            collection,
156            db,
157        }
158    }
159
160    pub async fn next(&mut self) -> Option<QueryUpdate> { self.receiver.recv().await }
161    pub fn collection(&self) -> &str { &self.collection }
162    pub fn try_next(&mut self) -> Option<QueryUpdate> { self.receiver.try_recv().ok() }
163    pub fn throttled(self, interval: std::time::Duration) -> ThrottledQueryWatcher {
164        ThrottledQueryWatcher::new(self.receiver, self.collection, interval)
165    }
166}
167
168pub struct ThrottledQueryWatcher {
169    receiver: mpsc::UnboundedReceiver<QueryUpdate>,
170    collection: String,
171}
172
173impl ThrottledQueryWatcher {
174    pub fn new(
175        mut raw_receiver: mpsc::UnboundedReceiver<QueryUpdate>,
176        collection: impl Into<String>,
177        interval: std::time::Duration,
178    ) -> Self {
179        let collection = collection.into();
180        let (tx, rx) = mpsc::unbounded_channel();
181        tokio::spawn(async move {
182            use std::collections::HashMap;
183            let mut tick = tokio::time::interval(interval);
184            tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
185            let mut pending: HashMap<String, QueryUpdate> = HashMap::new();
186            loop {
187                tokio::select! {
188                    biased;
189                    maybe_update = raw_receiver.recv() => {
190                        match maybe_update {
191                            Some(update) => { pending.insert(update.id().to_string(), update); }
192                            None => break,
193                        }
194                    }
195                    _ = tick.tick() => {
196                        if !pending.is_empty() {
197                            for (_, update) in pending.drain() {
198                                if tx.send(update).is_err() { return; }
199                            }
200                        }
201                    }
202                }
203            }
204        });
205        Self { receiver: rx, collection }
206    }
207    pub async fn next(&mut self) -> Option<QueryUpdate> { self.receiver.recv().await }
208    pub fn collection(&self) -> &str { &self.collection }
209    pub fn try_next(&mut self) -> Option<QueryUpdate> { self.receiver.try_recv().ok() }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use crate::pubsub::ChangeEvent;
216    use crate::types::Value;
217    use std::collections::HashMap;
218
219    #[tokio::test]
220    async fn test_query_watcher_insert() {
221        let temp_dir = tempfile::tempdir().unwrap();
222        let db = Arc::new(Aurora::open(temp_dir.path().join("test.db")).await.unwrap());
223        let listener = db.pubsub.listen("users");
224        let state = Arc::new(ReactiveQueryState::new(vec![
225            crate::query::Filter::Eq("active".to_string(), Value::Bool(true))
226        ]));
227        let mut watcher = QueryWatcher::new(db.clone(), "users", listener, state, vec![], None);
228        let mut data = HashMap::new();
229        data.insert("active".to_string(), Value::Bool(true));
230        data.insert("name".to_string(), Value::String("Alice".into()));
231        let doc = Document { id: "1".to_string(), data };
232        db.pubsub.publish(ChangeEvent::insert("users", "1", doc)).unwrap();
233        let update = watcher.next().await.unwrap();
234        assert!(matches!(update, QueryUpdate::Added(_)));
235        assert_eq!(update.id(), "1");
236    }
237}