rs_mongo_stream/stream.rs
1//! Core functionality for MongoDB change stream processing.
2//!
3//! This module provides the main `MongoStream` type which allows for subscribing
4//! to MongoDB change events with custom callbacks.
5
6use mongodb::Database;
7use std::{
8 collections::HashMap,
9 future::Future,
10 pin::Pin,
11 sync::Arc,
12 task::{Context, Poll},
13};
14use tokio::sync::mpsc;
15use tokio_stream::{Stream, StreamExt};
16
17use crate::error::MongoStreamError;
18use crate::event::Event;
19
20/// Type alias for a callback function that processes MongoDB events.
21///
22/// Callbacks should be async functions that take an Event reference
23/// and return nothing.
24type CallbackFn = dyn Fn(&Event) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync;
25
26/// A collection of callbacks mapped to their respective event types.
27///
28/// This type maps Event types to their corresponding callback functions.
29pub type Callbacks = HashMap<Event, Arc<CallbackFn>>;
30
31/// The main wrapper around MongoDB change streams.
32///
33/// `MongoStream` provides an interface for subscribing to MongoDB change events
34/// with custom callbacks for different event types per collection.
35pub struct MongoStream {
36 /// The MongoDB database to monitor
37 db: Database,
38 /// Maps collection names to their registered callbacks
39 collection_callbacks: HashMap<String, Callbacks>,
40 /// Active streams that can be closed
41 active_streams: HashMap<String, tokio::sync::mpsc::Sender<()>>,
42}
43
44impl MongoStream {
45 /// Creates a new MongoStream instance for the given database.
46 ///
47 /// # Arguments
48 ///
49 /// * `db` - The MongoDB database to monitor.
50 ///
51 /// # Returns
52 ///
53 /// A new `MongoStream` instance configured to work with the provided database.
54 pub fn new(db: Database) -> Self {
55 Self {
56 db,
57 collection_callbacks: HashMap::new(),
58 active_streams: HashMap::new(),
59 }
60 }
61
62 /// Registers a callback for a specific event type on a collection.
63 ///
64 /// # Arguments
65 ///
66 /// * `collection_name` - The name of the MongoDB collection to monitor.
67 /// * `event` - The event type to subscribe to.
68 /// * `callback` - The async function to call when the event occurs.
69 ///
70 /// # Example
71 ///
72 /// ```rust,ignore
73 /// use MongoStream::Event;
74 ///
75 /// // load the mongodb database correctly
76 /// let db = Database{};
77 /// let mut mongo_stream = MongoStream::new(db: Database);
78 /// mongo_stream.add_callback("users", Event::Insert, |event| {
79 /// Box::pin(async move {
80 /// println!("New user inserted: {:?}", event);
81 /// })
82 /// });
83 /// ```
84 pub fn add_callback<F>(&mut self, collection_name: impl Into<String>, event: Event, callback: F)
85 where
86 F: Fn(&Event) -> Pin<Box<dyn Future<Output = ()> + Send>> + 'static + Send + Sync,
87 {
88 let collection_name = collection_name.into();
89 let callback_arc = Arc::new(callback);
90 let default_callbacks = self.create_default_callbacks();
91
92 self.collection_callbacks
93 .entry(collection_name)
94 .or_insert_with(|| default_callbacks)
95 .insert(event, callback_arc);
96 }
97
98 /// Creates a set of default callbacks for all event types.
99 ///
100 /// These callbacks simply log that an event was received.
101 ///
102 /// # Returns
103 ///
104 /// A HashMap mapping event types to their default callbacks.
105 fn create_default_callbacks(&self) -> Callbacks {
106 let mut callbacks: Callbacks = HashMap::new();
107
108 // Create default handlers for each event type
109 for event in [Event::Insert, Event::Update, Event::Delete] {
110 let event_name = event.event_type_str().to_string();
111
112 callbacks.insert(
113 event,
114 Arc::new(move |_event: &Event| {
115 let event_name = event_name.clone();
116 Box::pin(async move {
117 println!("{} event received", event_name);
118 })
119 }),
120 );
121 }
122
123 callbacks
124 }
125
126 /// Retrieves the callbacks for a specific collection.
127 ///
128 /// If no callbacks are registered for the collection, default callbacks are returned.
129 ///
130 /// # Arguments
131 ///
132 /// * `collection_name` - The name of the collection to get callbacks for.
133 ///
134 /// # Returns
135 ///
136 /// A HashMap mapping event types to their registered callbacks.
137 fn get_collection_callbacks(&self, collection_name: &str) -> Callbacks {
138 match self.collection_callbacks.get(collection_name) {
139 Some(callbacks) => {
140 // Clone the callbacks map
141 callbacks
142 .iter()
143 .map(|(event, callback)| (*event, Arc::clone(callback)))
144 .collect()
145 }
146 None => self.create_default_callbacks(),
147 }
148 }
149
150 /// Starts monitoring a collection for changes.
151 ///
152 /// This method will block and continuously process events from the MongoDB change stream
153 /// until an error occurs or the stream is closed.
154 ///
155 /// # Arguments
156 ///
157 /// * `collection_name` - The name of the collection to monitor.
158 ///
159 /// # Returns
160 ///
161 /// A Result that is Ok if the stream was closed normally, or an error if something went wrong.
162 ///
163 /// # Example
164 ///
165 /// ```rust,ignore
166 ///
167 /// async fn monitor() -> Result<(), MongoStreamError> {
168 /// let mongo_stream = MongoStream::new(db);
169 /// // ... register callbacks ...
170 /// mongo_stream.start_stream("users").await
171 /// }
172 /// ```
173 /// Starts monitoring a collection for changes.
174 ///
175 /// This method will spawn a new task to monitor the collection and return immediately.
176 /// To stop the stream, call `close_stream` with the same collection name.
177 ///
178 /// # Arguments
179 ///
180 /// * `collection_name` - The name of the collection to monitor.
181 ///
182 /// # Returns
183 ///
184 /// A Result that is Ok if the stream was started successfully, or an error if something went wrong.
185 ///
186 /// # Example
187 ///
188 /// ```rust,ignore
189 /// async fn start_monitoring() -> Result<(), MongoStreamError> {
190 /// let mongo_stream = MongoStream::new(db);
191 /// // ... register callbacks ...
192 /// mongo_stream.start_stream("users").await?;
193 /// // The stream is now running in the background
194 ///
195 /// // Later, when you want to stop it:
196 /// mongo_stream.close_stream("users").await;
197 /// Ok(())
198 /// }
199 /// ```
200 pub async fn start_stream(&mut self, collection_name: &str) -> Result<(), MongoStreamError> {
201 // Check if the stream is already running
202 if self.active_streams.contains_key(collection_name) {
203 return Err(MongoStreamError::new(format!(
204 "Stream for collection '{}' is already running",
205 collection_name
206 )));
207 }
208
209 let collection = self
210 .db
211 .collection::<mongodb::bson::Document>(collection_name);
212
213 // Create a channel to signal stream closure
214 let (tx, mut rx) = mpsc::channel::<()>(1);
215
216 // Store the sender in the active_streams map
217 self.active_streams.insert(collection_name.to_string(), tx);
218
219 // Get the callbacks
220 let callbacks = self.get_collection_callbacks(collection_name);
221 let collection_name = collection_name.to_string();
222 let db = self.db.clone();
223
224 // Spawn a new task to handle the stream
225 tokio::spawn(async move {
226 let mut stream = match collection.watch(None, None).await {
227 Ok(s) => s,
228 Err(e) => {
229 eprintln!(
230 "Failed to start stream for collection '{}': {}",
231 collection_name, e
232 );
233 return;
234 }
235 };
236
237 loop {
238 tokio::select! {
239 // Check if we've received a signal to close the stream
240 _ = rx.recv() => {
241 println!("Closing stream for collection '{}'", collection_name);
242 break;
243 }
244 // Process the next event from the stream
245 next_event = stream.next() => {
246 match next_event {
247 Some(result) => {
248 match result {
249 Ok(change_stream_event) => {
250 let event_type = Event::from(change_stream_event.operation_type);
251 // Find and execute the appropriate callback
252 if let Some(callback) = callbacks.get(&event_type) {
253 callback(&event_type).await;
254 }
255 }
256 Err(e) => {
257 eprintln!("Error in MongoDB stream for collection '{}': {}. Reconnecting...", collection_name, e);
258 // Graceful reconnection
259 match db.collection::<mongodb::bson::Document>(&collection_name).watch(None, None).await {
260 Ok(new_stream) => stream = new_stream,
261 Err(reconnect_err) => {
262 eprintln!("Failed to reconnect to collection '{}': {}", collection_name, reconnect_err);
263 break;
264 }
265 }
266 }
267 }
268 }
269 None => {
270 // Stream has ended normally
271 println!("Stream for collection '{}' has ended", collection_name);
272 break;
273 }
274 }
275 }
276 }
277 }
278 });
279
280 Ok(())
281 }
282
283 /// Closes a specific stream for a collection.
284 ///
285 /// This method sends a signal to stop the monitoring task for a specific collection.
286 ///
287 /// # Arguments
288 ///
289 /// * `collection_name` - The name of the collection whose stream should be closed.
290 ///
291 /// # Returns
292 ///
293 /// A boolean indicating whether a stream was found and closed. Returns `false` if
294 /// there was no active stream for the given collection.
295 ///
296 /// # Example
297 ///
298 /// ```rust,ignore
299 /// use rs_mongo_stream::MongoStream;
300 ///
301 /// async fn example() {
302 /// let mongo_stream = MongoStream::new(db);
303 /// // ... start streams ...
304 ///
305 /// // When done with a specific collection
306 /// let closed = mongo_stream.close_stream("users").await;
307 /// println!("Stream closed: {}", closed);
308 /// }
309 /// ```
310 pub async fn close_stream(&mut self, collection_name: &str) -> bool {
311 if let Some(tx) = self.active_streams.remove(collection_name) {
312 // Send signal to close the stream
313 // It's okay if this fails - the receiver might have been dropped already
314 let _ = tx.send(()).await;
315 true
316 } else {
317 false
318 }
319 }
320
321 /// Closes all active streams.
322 ///
323 /// This method sends signals to stop all monitoring tasks and clears the active streams list.
324 ///
325 /// # Returns
326 ///
327 /// The number of streams that were closed.
328 ///
329 /// # Example
330 ///
331 /// ```rust,ignore
332 /// async fn shutdown() {
333 /// let mongo_stream = MongoStream::new(db);
334 /// // ... start streams ...
335 ///
336 /// // When shutting down the application
337 /// let closed_count = mongo_stream.close_all_streams().await;
338 /// println!("Closed {} streams", closed_count);
339 /// }
340 /// ```
341 pub async fn close_all_streams(&mut self) -> usize {
342 let count = self.active_streams.len();
343
344 // Send close signal to all active streams
345 for (collection_name, tx) in self.active_streams.drain() {
346 let _ = tx.send(()).await;
347 println!(
348 "Sent close signal to stream for collection '{}'",
349 collection_name
350 );
351 }
352
353 count
354 }
355}
356
357/// Implementation of Stream trait for MongoStream to make it compatible with tokio_stream.
358///
359/// Note: This is a placeholder implementation and would need to be expanded
360/// in a real-world scenario to connect to actual MongoDB change streams.
361impl Stream for MongoStream {
362 type Item = Result<Event, MongoStreamError>;
363
364 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
365 // Placeholder implementation for the Stream trait
366 // In a real implementation, this would be connected to MongoDB change streams
367 Poll::Ready(Some(Ok(Event::Insert)))
368 }
369}
370
371#[cfg(test)]
372mod tests {
373 use super::*;
374 use mongodb::{options::ClientOptions, Client};
375 use tokio::sync::mpsc;
376
377 async fn setup_test_db() -> Database {
378 let client_options = ClientOptions::parse("mongodb://localhost:27017")
379 .await
380 .unwrap();
381 let client = Client::with_options(client_options).unwrap();
382 client.database("test_db")
383 }
384
385 #[tokio::test]
386 async fn test_mongo_stream_creation() {
387 let db = setup_test_db().await;
388 let mongo_stream = MongoStream::new(db.clone());
389
390 assert!(mongo_stream.collection_callbacks.is_empty());
391 assert!(mongo_stream.active_streams.is_empty());
392 }
393
394 #[tokio::test]
395 async fn test_add_callback() {
396 let db = setup_test_db().await;
397 let mut mongo_stream = MongoStream::new(db);
398
399 let (tx, mut rx) = mpsc::channel(1);
400 mongo_stream.add_callback("test_collection", Event::Insert, move |_event| {
401 let tx = tx.clone();
402 Box::pin(async move {
403 let _ = tx.send(()).await;
404 })
405 });
406
407 let callbacks = mongo_stream.get_collection_callbacks("test_collection");
408 assert!(callbacks.contains_key(&Event::Insert));
409
410 // Execute the callback
411 if let Some(callback) = callbacks.get(&Event::Insert) {
412 callback(&Event::Insert).await;
413 assert!(rx.recv().await.is_some());
414 }
415 }
416
417 #[tokio::test]
418 async fn test_start_and_close_stream() {
419 let db = setup_test_db().await;
420 let mut mongo_stream = MongoStream::new(db);
421
422 // Test starting stream
423 assert!(mongo_stream.start_stream("test_collection").await.is_ok());
424 assert!(mongo_stream.active_streams.contains_key("test_collection"));
425
426 // Test closing stream
427 assert!(mongo_stream.close_stream("test_collection").await);
428 assert!(!mongo_stream.active_streams.contains_key("test_collection"));
429 }
430
431 #[tokio::test]
432 async fn test_close_all_streams() {
433 let db = setup_test_db().await;
434 let mut mongo_stream = MongoStream::new(db);
435
436 mongo_stream.start_stream("collection1").await.unwrap();
437 mongo_stream.start_stream("collection2").await.unwrap();
438
439 let closed_count = mongo_stream.close_all_streams().await;
440 assert_eq!(closed_count, 2);
441 assert!(mongo_stream.active_streams.is_empty());
442 }
443
444 #[tokio::test]
445 async fn test_default_callbacks() {
446 let db = setup_test_db().await;
447 let mongo_stream = MongoStream::new(db);
448
449 let callbacks = mongo_stream.get_collection_callbacks("test_collection");
450 assert_eq!(callbacks.len(), 3);
451 assert!(callbacks.contains_key(&Event::Insert));
452 assert!(callbacks.contains_key(&Event::Update));
453 assert!(callbacks.contains_key(&Event::Delete));
454 }
455
456 #[tokio::test]
457 async fn test_double_start_error() {
458 let db = setup_test_db().await;
459 let mut mongo_stream = MongoStream::new(db);
460
461 mongo_stream.start_stream("test_collection").await.unwrap();
462 let result = mongo_stream.start_stream("test_collection").await;
463
464 assert!(matches!(result, Err(MongoStreamError { .. })));
465 }
466}