1use super::{QueryUpdate, ReactiveQueryState};
2use crate::pubsub::ChangeListener;
3use crate::types::Document;
4use std::sync::Arc;
5use tokio::sync::mpsc;
6
7pub struct QueryWatcher {
9 receiver: mpsc::UnboundedReceiver<QueryUpdate>,
11 collection: String,
13}
14
15impl QueryWatcher {
16 pub fn new(
24 collection: impl Into<String>,
25 mut listener: ChangeListener,
26 state: Arc<ReactiveQueryState>,
27 initial_results: Vec<Document>,
28 debounce_duration: Option<std::time::Duration>,
29 ) -> Self {
30 let collection = collection.into();
31 let (sender, receiver) = mpsc::unbounded_channel();
32
33 let init_state = Arc::clone(&state);
35 let init_sender = sender.clone();
36 tokio::spawn(async move {
37 for doc in initial_results {
38 if let Some(update) = init_state.add_if_matches(doc).await {
39 let _ = init_sender.send(update);
40 }
41 }
42 });
43
44 tokio::spawn(async move {
46 while let Ok(event) = listener.recv().await {
47 let update = match event.change_type {
48 crate::pubsub::ChangeType::Insert => {
49 if let Some(doc) = event.document {
50 state.add_if_matches(doc).await
51 } else {
52 None
53 }
54 }
55 crate::pubsub::ChangeType::Update => {
56 if let Some(new_doc) = event.document {
57 state.update(&event.id, new_doc).await
58 } else {
59 None
60 }
61 }
62 crate::pubsub::ChangeType::Delete => state.remove(&event.id).await,
63 };
64
65 if let Some(u) = update
66 && sender.send(u).is_err()
67 {
68 break;
70 }
71 }
72 });
73
74 let final_receiver = if let Some(duration) = debounce_duration {
76 let (tx_throttled, rx_throttled) = mpsc::unbounded_channel();
77 let mut raw_rx = receiver;
78
79 tokio::spawn(async move {
80 use std::collections::HashMap;
81 use tokio::time::interval as tokio_interval;
82
83 let mut tick = tokio_interval(duration);
84 tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
85
86 let mut pending: HashMap<String, QueryUpdate> = HashMap::new();
87
88 loop {
89 tokio::select! {
90 biased;
91 maybe_update = raw_rx.recv() => {
92 match maybe_update {
93 Some(update) => {
94 pending.insert(update.id().to_string(), update);
95 }
96 None => break,
97 }
98 }
99 _ = tick.tick() => {
100 if !pending.is_empty() {
101 for (_, update) in pending.drain() {
102 if tx_throttled.send(update).is_err() {
103 return;
104 }
105 }
106 }
107 }
108 }
109 }
110 });
111 rx_throttled
112 } else {
113 receiver
114 };
115
116 Self {
117 receiver: final_receiver,
118 collection,
119 }
120 }
121
122 pub async fn next(&mut self) -> Option<QueryUpdate> {
125 self.receiver.recv().await
126 }
127
128 pub fn collection(&self) -> &str {
130 &self.collection
131 }
132
133 pub fn try_next(&mut self) -> Option<QueryUpdate> {
135 self.receiver.try_recv().ok()
136 }
137
138 pub fn throttled(self, interval: std::time::Duration) -> ThrottledQueryWatcher {
143 ThrottledQueryWatcher::new(self.receiver, self.collection, interval)
144 }
145}
146
147pub struct ThrottledQueryWatcher {
153 receiver: mpsc::UnboundedReceiver<QueryUpdate>,
154 collection: String,
155}
156
157impl ThrottledQueryWatcher {
158 pub fn new(
160 mut raw_receiver: mpsc::UnboundedReceiver<QueryUpdate>,
161 collection: impl Into<String>,
162 interval: std::time::Duration,
163 ) -> Self {
164 let collection = collection.into();
165 let (tx, rx) = mpsc::unbounded_channel();
166
167 tokio::spawn(async move {
168 use std::collections::HashMap;
169 use tokio::time::interval as tokio_interval;
170
171 let mut tick = tokio_interval(interval);
172 tick.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
173
174 let mut pending: HashMap<String, QueryUpdate> = HashMap::new();
176
177 loop {
178 tokio::select! {
179 biased;
180
181 maybe_update = raw_receiver.recv() => {
183 match maybe_update {
184 Some(update) => {
185 pending.insert(update.id().to_string(), update);
187 }
188 None => break, }
190 }
191
192 _ = tick.tick() => {
194 if !pending.is_empty() {
195 for (_, update) in pending.drain() {
196 if tx.send(update).is_err() {
197 return; }
199 }
200 }
201 }
202 }
203 }
204 });
205
206 Self {
207 receiver: rx,
208 collection,
209 }
210 }
211
212 pub async fn next(&mut self) -> Option<QueryUpdate> {
214 self.receiver.recv().await
215 }
216
217 pub fn collection(&self) -> &str {
219 &self.collection
220 }
221
222 pub fn try_next(&mut self) -> Option<QueryUpdate> {
224 self.receiver.try_recv().ok()
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::pubsub::{ChangeEvent, PubSubSystem};
232 use crate::types::Value;
233 use std::collections::HashMap;
234
235 #[tokio::test]
236 async fn test_query_watcher_insert() {
237 let pubsub = PubSubSystem::new(100);
238 let listener = pubsub.listen("users");
239
240 let state = Arc::new(ReactiveQueryState::new(|doc: &Document| {
241 doc.data.get("active") == Some(&Value::Bool(true))
242 }));
243
244 let mut watcher = QueryWatcher::new("users", listener, state, vec![], None);
245
246 let mut data = HashMap::new();
248 data.insert("active".to_string(), Value::Bool(true));
249 data.insert("name".to_string(), Value::String("Alice".into()));
250
251 let doc = Document {
252 id: "1".to_string(),
253 data,
254 };
255
256 pubsub
257 .publish(ChangeEvent::insert("users", "1", doc))
258 .unwrap();
259
260 let update = watcher.next().await.unwrap();
262 assert!(matches!(update, QueryUpdate::Added(_)));
263 assert_eq!(update.id(), "1");
264 }
265
266 #[tokio::test]
267 async fn test_query_watcher_filter() {
268 let pubsub = PubSubSystem::new(100);
269 let listener = pubsub.listen("users");
270
271 let state = Arc::new(ReactiveQueryState::new(|doc: &Document| {
272 doc.data.get("active") == Some(&Value::Bool(true))
273 }));
274
275 let mut watcher = QueryWatcher::new("users", listener, state, vec![], None);
276
277 let mut inactive_data = HashMap::new();
279 inactive_data.insert("active".to_string(), Value::Bool(false));
280
281 pubsub
282 .publish(ChangeEvent::insert(
283 "users",
284 "1",
285 Document {
286 id: "1".to_string(),
287 data: inactive_data,
288 },
289 ))
290 .unwrap();
291
292 let mut active_data = HashMap::new();
294 active_data.insert("active".to_string(), Value::Bool(true));
295
296 pubsub
297 .publish(ChangeEvent::insert(
298 "users",
299 "2",
300 Document {
301 id: "2".to_string(),
302 data: active_data,
303 },
304 ))
305 .unwrap();
306
307 let update = watcher.next().await.unwrap();
309 assert_eq!(update.id(), "2");
310 }
311
312 #[tokio::test]
313 async fn test_debounced_watcher() {
314 use std::time::Duration;
315 use tokio::sync::mpsc;
316
317 let (tx, rx) = mpsc::unbounded_channel();
319
320 let mut throttled = ThrottledQueryWatcher::new(rx, "test", Duration::from_millis(100));
322
323 let mut data1 = HashMap::new();
325 data1.insert("value".to_string(), Value::Int(1));
326 tx.send(QueryUpdate::Added(Document {
327 id: "doc1".to_string(),
328 data: data1,
329 }))
330 .unwrap();
331
332 let mut data2 = HashMap::new();
333 data2.insert("value".to_string(), Value::Int(2));
334 tx.send(QueryUpdate::Modified {
335 old: Document {
336 id: "doc1".to_string(),
337 data: HashMap::new(),
338 },
339 new: Document {
340 id: "doc1".to_string(),
341 data: data2,
342 },
343 })
344 .unwrap();
345
346 let mut data3 = HashMap::new();
347 data3.insert("value".to_string(), Value::Int(3));
348 tx.send(QueryUpdate::Modified {
349 old: Document {
350 id: "doc1".to_string(),
351 data: HashMap::new(),
352 },
353 new: Document {
354 id: "doc1".to_string(),
355 data: data3.clone(),
356 },
357 })
358 .unwrap();
359
360 tokio::time::sleep(Duration::from_millis(150)).await;
362
363 let update = throttled.try_next();
365 assert!(update.is_some());
366 assert_eq!(update.unwrap().id(), "doc1");
368 }
369
370 #[tokio::test]
371 async fn test_throttled_watcher_multiple_docs() {
372 use std::time::Duration;
373 use tokio::sync::mpsc;
374
375 let (tx, rx) = mpsc::unbounded_channel();
376 let mut throttled = ThrottledQueryWatcher::new(rx, "test", Duration::from_millis(100));
377
378 for i in 1..=3 {
380 let mut data = HashMap::new();
381 data.insert("value".to_string(), Value::Int(i));
382 tx.send(QueryUpdate::Added(Document {
383 id: format!("doc{}", i),
384 data,
385 }))
386 .unwrap();
387 }
388
389 tokio::time::sleep(Duration::from_millis(150)).await;
391
392 let mut received = Vec::new();
394 while let Some(update) = throttled.try_next() {
395 received.push(update.id().to_string());
396 }
397
398 assert_eq!(received.len(), 3);
399 assert!(received.contains(&"doc1".to_string()));
400 assert!(received.contains(&"doc2".to_string()));
401 assert!(received.contains(&"doc3".to_string()));
402 }
403}