Skip to main content

aurora_db/reactive/
watcher.rs

1use super::{QueryUpdate, ReactiveQueryState};
2use crate::Aurora;
3use crate::pubsub::ChangeListener;
4use crate::types::Document;
5use std::sync::Arc;
6use tokio::sync::mpsc;
7
8/// Watches a query and emits updates when results change
9pub struct QueryWatcher {
10    /// Receiver for query updates
11    receiver: mpsc::UnboundedReceiver<QueryUpdate>,
12    /// Collection being watched
13    collection: String,
14    /// Reference to the database for resyncing
15    #[allow(dead_code)]
16    db: Arc<Aurora>,
17}
18
19impl QueryWatcher {
20    /// Create a new query watcher
21    pub fn new(
22        db: Arc<Aurora>,
23        collection: impl Into<String>,
24        mut listener: ChangeListener,
25        state: Arc<ReactiveQueryState>,
26        initial_results: Vec<Document>,
27        debounce_duration: Option<std::time::Duration>,
28    ) -> Self {
29        let collection = collection.into();
30        let (sender, receiver) = mpsc::unbounded_channel();
31
32        // Populate initial results
33        let init_state = Arc::clone(&state);
34        let init_sender = sender.clone();
35        tokio::spawn(async move {
36            for doc in initial_results {
37                if let Some(update) = init_state.add_if_matches(doc).await {
38                    let _ = init_sender.send(update);
39                }
40            }
41        });
42
43        // Spawn background task to listen for changes
44        let db_clone = Arc::clone(&db);
45        let coll_clone = collection.clone();
46        let state_clone = Arc::clone(&state);
47        let sender_clone = sender.clone();
48
49        tokio::spawn(async move {
50            let mut backoff_ms = 100; // Initial backoff
51
52            loop {
53                match listener.recv().await {
54                    Ok(event) => {
55                        // Successful receive, reset backoff gradually
56                        if backoff_ms > 100 {
57                            backoff_ms -= 50;
58                        }
59
60                        let update = match event.change_type {
61                            crate::pubsub::ChangeType::Insert => {
62                                if let Some(doc) = event.document {
63                                    state_clone.add_if_matches(doc).await
64                                } else {
65                                    None
66                                }
67                            }
68                            crate::pubsub::ChangeType::Update => {
69                                if let Some(new_doc) = event.document {
70                                    state_clone.update(&event._sid, new_doc).await
71                                } else {
72                                    None
73                                }
74                            }
75                            crate::pubsub::ChangeType::Delete => {
76                                state_clone.remove(&event._sid).await
77                            }
78                        };
79
80                        if let Some(u) = update
81                            && sender_clone.send(u).is_err()
82                        {
83                            break;
84                        }
85                    }
86                    Err(tokio::sync::broadcast::error::RecvError::Lagged(skipped)) => {
87                        eprintln!(
88                            "WARNING: Watcher for '{}' lagged by {} events. Applying {}ms backoff...",
89                            coll_clone, skipped, backoff_ms
90                        );
91
92                        // 1. DRAIN: Empty everything currently in the channel (Ok and Lagged)
93                        loop {
94                            match listener.try_recv() {
95                                Ok(_)
96                                | Err(tokio::sync::broadcast::error::TryRecvError::Lagged(_)) => {
97                                    continue;
98                                }
99                                _ => break,
100                            }
101                        }
102
103                        // 2. BACKOFF: Wait for the event storm to subside
104                        tokio::time::sleep(tokio::time::Duration::from_millis(backoff_ms)).await;
105
106                        // Increase backoff for next time (max 2 seconds)
107                        backoff_ms = (backoff_ms * 2).min(2000);
108
109                        // 3. SNAPSHOT: Collect documents into a Vec inside the block to keep iterator usage local (not across await)
110                        let docs_snapshot: Vec<Document> =
111                            if let Ok(iter) = db_clone.stream_collection(&coll_clone) {
112                                iter.collect()
113                            } else {
114                                Vec::new()
115                            };
116
117                        // 4. SYNC: Synchronize ReactiveQueryState and emit deltas
118                        let updates = state_clone.sync_state(docs_snapshot).await;
119                        for u in updates {
120                            if sender_clone.send(u).is_err() {
121                                return;
122                            }
123                        }
124                    }
125                    Err(tokio::sync::broadcast::error::RecvError::Closed) => break,
126                }
127            }
128        });
129
130        // Throttling logic (unchanged)
131        let final_receiver = if let Some(duration) = debounce_duration {
132            let (tx_throttled, rx_throttled) = mpsc::unbounded_channel();
133            let mut raw_rx = receiver;
134            tokio::spawn(async move {
135                use std::collections::HashMap;
136                let mut tick = tokio::time::interval(duration);
137                tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
138                let mut pending: HashMap<String, QueryUpdate> = HashMap::new();
139                loop {
140                    tokio::select! {
141                        biased;
142                        maybe_update = raw_rx.recv() => {
143                            match maybe_update {
144                                Some(update) => { pending.insert(update.id().to_string(), update); }
145                                None => break,
146                            }
147                        }
148                        _ = tick.tick() => {
149                            if !pending.is_empty() {
150                                for (_, update) in pending.drain() {
151                                    if tx_throttled.send(update).is_err() { return; }
152                                }
153                            }
154                        }
155                    }
156                }
157            });
158            rx_throttled
159        } else {
160            receiver
161        };
162
163        Self {
164            receiver: final_receiver,
165            collection,
166            db,
167        }
168    }
169
170    pub async fn next(&mut self) -> Option<QueryUpdate> {
171        self.receiver.recv().await
172    }
173    pub fn collection(&self) -> &str {
174        &self.collection
175    }
176    pub fn try_next(&mut self) -> Option<QueryUpdate> {
177        self.receiver.try_recv().ok()
178    }
179    pub fn throttled(self, interval: std::time::Duration) -> ThrottledQueryWatcher {
180        ThrottledQueryWatcher::new(self.receiver, self.collection, interval)
181    }
182}
183
184pub struct ThrottledQueryWatcher {
185    receiver: mpsc::UnboundedReceiver<QueryUpdate>,
186    collection: String,
187}
188
189impl ThrottledQueryWatcher {
190    pub fn new(
191        mut raw_receiver: mpsc::UnboundedReceiver<QueryUpdate>,
192        collection: impl Into<String>,
193        interval: std::time::Duration,
194    ) -> Self {
195        let collection = collection.into();
196        let (tx, rx) = mpsc::unbounded_channel();
197        tokio::spawn(async move {
198            use std::collections::HashMap;
199            let mut tick = tokio::time::interval(interval);
200            tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
201            let mut pending: HashMap<String, QueryUpdate> = HashMap::new();
202            loop {
203                tokio::select! {
204                    biased;
205                    maybe_update = raw_receiver.recv() => {
206                        match maybe_update {
207                            Some(update) => { pending.insert(update.id().to_string(), update); }
208                            None => break,
209                        }
210                    }
211                    _ = tick.tick() => {
212                        if !pending.is_empty() {
213                            for (_, update) in pending.drain() {
214                                if tx.send(update).is_err() { return; }
215                            }
216                        }
217                    }
218                }
219            }
220        });
221        Self {
222            receiver: rx,
223            collection,
224        }
225    }
226    pub async fn next(&mut self) -> Option<QueryUpdate> {
227        self.receiver.recv().await
228    }
229    pub fn collection(&self) -> &str {
230        &self.collection
231    }
232    pub fn try_next(&mut self) -> Option<QueryUpdate> {
233        self.receiver.try_recv().ok()
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use crate::pubsub::ChangeEvent;
241    use crate::types::Value;
242    use std::collections::HashMap;
243
244    #[tokio::test]
245    async fn test_query_watcher_insert() {
246        let temp_dir = tempfile::tempdir().unwrap();
247        let db = Arc::new(Aurora::open(temp_dir.path().join("test.db")).await.unwrap());
248        let listener = db.pubsub.listen("users");
249        let state = Arc::new(ReactiveQueryState::new(vec![crate::query::Filter::Eq(
250            "active".to_string(),
251            Value::Bool(true),
252        )]));
253        let mut watcher = QueryWatcher::new(db.clone(), "users", listener, state, vec![], None);
254        let mut data = HashMap::new();
255        data.insert("active".to_string(), Value::Bool(true));
256        data.insert("name".to_string(), Value::String("Alice".into()));
257        let doc = Document {
258            _sid: "1".to_string(),
259            data,
260        };
261        db.pubsub
262            .publish(ChangeEvent::insert("users", "1", doc))
263            .unwrap();
264        let update = watcher.next().await.unwrap();
265        assert!(matches!(update, QueryUpdate::Added(_)));
266        assert_eq!(update.id(), "1");
267    }
268}