rs_mongo_stream/
stream.rs

1//! Core functionality for MongoDB change stream processing.
2//!
3//! This module provides the main `MongoStream` type which allows for subscribing
4//! to MongoDB change events with custom callbacks.
5
6use mongodb::Database;
7use std::{
8    collections::HashMap,
9    future::Future,
10    pin::Pin,
11    sync::Arc,
12    task::{Context, Poll},
13};
14use tokio::sync::mpsc;
15use tokio_stream::{Stream, StreamExt};
16
17use crate::error::MongoStreamError;
18use crate::event::Event;
19use mongodb::bson::Document;
20
21/// Type alias for a callback function that processes MongoDB documents.
22///
23/// Callbacks should be async functions that take a Document reference
24/// and return nothing.
25type CallbackFn = dyn Fn(&Document) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync;
26
27/// A collection of callbacks mapped to their respective event types.
28///
29/// This type maps Event types to their corresponding callback functions.
30pub type Callbacks = HashMap<Event, Arc<CallbackFn>>;
31
32/// The main wrapper around MongoDB change streams.
33///
34/// `MongoStream` provides an interface for subscribing to MongoDB change events
35/// with custom callbacks for different event types per collection.
36pub struct MongoStream {
37    /// The MongoDB database to monitor
38    db: Database,
39    /// Maps collection names to their registered callbacks
40    collection_callbacks: HashMap<String, Callbacks>,
41    /// Active streams that can be closed
42    active_streams: HashMap<String, tokio::sync::mpsc::Sender<()>>,
43}
44
45impl MongoStream {
46    /// Creates a new MongoStream instance for the given database.
47    ///
48    /// # Arguments
49    ///
50    /// * `db` - The MongoDB database to monitor.
51    ///
52    /// # Returns
53    ///
54    /// A new `MongoStream` instance configured to work with the provided database.
55    pub fn new(db: Database) -> Self {
56        Self {
57            db,
58            collection_callbacks: HashMap::new(),
59            active_streams: HashMap::new(),
60        }
61    }
62
63    /// Registers a callback for a specific event type on a collection.
64    ///
65    /// # Arguments
66    ///
67    /// * `collection_name` - The name of the MongoDB collection to monitor.
68    /// * `event` - The event type to subscribe to.
69    /// * `callback` - The async function to call when the event occurs.
70    ///
71    /// # Example
72    ///
73    /// ```rust,ignore
74    /// use MongoStream::Event;
75    ///
76    /// // load the mongodb database correctly
77    /// let db = Database{};
78    /// let mut mongo_stream = MongoStream::new(db: Database);
79    /// mongo_stream.add_callback("users", Event::Insert, |event| {
80    ///     Box::pin(async move {
81    ///         println!("New user inserted: {:?}", event);
82    ///     })
83    /// });
84    /// ```
85    pub fn add_callback<F>(&mut self, collection_name: impl Into<String>, event: Event, callback: F)
86    where
87        F: Fn(&Document) -> Pin<Box<dyn Future<Output = ()> + Send>> + 'static + Send + Sync,
88    {
89        let collection_name = collection_name.into();
90        let callback_arc = Arc::new(callback);
91        let default_callbacks = self.create_default_callbacks();
92
93        self.collection_callbacks
94            .entry(collection_name)
95            .or_insert_with(|| default_callbacks)
96            .insert(event, callback_arc);
97    }
98
99    /// Creates a set of default callbacks for all event types.
100    ///
101    /// These callbacks simply log that an event was received.
102    ///
103    /// # Returns
104    ///
105    /// A HashMap mapping event types to their default callbacks.
106    fn create_default_callbacks(&self) -> Callbacks {
107        let mut callbacks: Callbacks = HashMap::new();
108
109        // Create default handlers for each event type
110        for event in [Event::Insert, Event::Update, Event::Delete] {
111            let event_name = event.event_type_str().to_string();
112
113            callbacks.insert(
114                event,
115                Arc::new(move |_doc: &Document| {
116                    let event_name = event_name.clone();
117                    Box::pin(async move {
118                        println!("{} event received", event_name);
119                    })
120                }),
121            );
122        }
123
124        callbacks
125    }
126
127    /// Retrieves the callbacks for a specific collection.
128    ///
129    /// If no callbacks are registered for the collection, default callbacks are returned.
130    ///
131    /// # Arguments
132    ///
133    /// * `collection_name` - The name of the collection to get callbacks for.
134    ///
135    /// # Returns
136    ///
137    /// A HashMap mapping event types to their registered callbacks.
138    fn get_collection_callbacks(&self, collection_name: &str) -> Callbacks {
139        match self.collection_callbacks.get(collection_name) {
140            Some(callbacks) => {
141                // Clone the callbacks map
142                callbacks
143                    .iter()
144                    .map(|(event, callback)| (*event, Arc::clone(callback)))
145                    .collect()
146            }
147            None => self.create_default_callbacks(),
148        }
149    }
150
151    /// Starts monitoring a collection for changes.
152    ///
153    /// This method will block and continuously process events from the MongoDB change stream
154    /// until an error occurs or the stream is closed.
155    ///
156    /// # Arguments
157    ///
158    /// * `collection_name` - The name of the collection to monitor.
159    ///
160    /// # Returns
161    ///
162    /// A Result that is Ok if the stream was closed normally, or an error if something went wrong.
163    ///
164    /// # Example
165    ///
166    /// ```rust,ignore
167    ///
168    /// async fn monitor() -> Result<(), MongoStreamError> {
169    ///     let mongo_stream = MongoStream::new(db);
170    ///     // ... register callbacks ...
171    ///     mongo_stream.start_stream("users").await
172    /// }
173    /// ```
174    /// Starts monitoring a collection for changes.
175    ///
176    /// This method will spawn a new task to monitor the collection and return immediately.
177    /// To stop the stream, call `close_stream` with the same collection name.
178    ///
179    /// # Arguments
180    ///
181    /// * `collection_name` - The name of the collection to monitor.
182    ///
183    /// # Returns
184    ///
185    /// A Result that is Ok if the stream was started successfully, or an error if something went wrong.
186    ///
187    /// # Example
188    ///
189    /// ```rust,ignore
190    /// async fn start_monitoring() -> Result<(), MongoStreamError> {
191    ///     let mongo_stream = MongoStream::new(db);
192    ///     // ... register callbacks ...
193    ///     mongo_stream.start_stream("users").await?;
194    ///     // The stream is now running in the background
195    ///     
196    ///     // Later, when you want to stop it:
197    ///     mongo_stream.close_stream("users").await;
198    ///     Ok(())
199    /// }
200    /// ```
201    pub async fn start_stream(&mut self, collection_name: &str) -> Result<(), MongoStreamError> {
202        // Check if the stream is already running
203        if self.active_streams.contains_key(collection_name) {
204            return Err(MongoStreamError::new(format!(
205                "Stream for collection '{}' is already running",
206                collection_name
207            )));
208        }
209
210        let collection = self
211            .db
212            .collection::<mongodb::bson::Document>(collection_name);
213
214        // Create a channel to signal stream closure
215        let (tx, mut rx) = mpsc::channel::<()>(1);
216
217        // Store the sender in the active_streams map
218        self.active_streams.insert(collection_name.to_string(), tx);
219
220        // Get the callbacks
221        let callbacks = self.get_collection_callbacks(collection_name);
222        let collection_name = collection_name.to_string();
223        let db = self.db.clone();
224
225        // Spawn a new task to handle the stream
226        tokio::spawn(async move {
227            let mut stream = match collection.watch(None, None).await {
228                Ok(s) => s,
229                Err(e) => {
230                    eprintln!(
231                        "Failed to start stream for collection '{}': {}",
232                        collection_name, e
233                    );
234                    return;
235                }
236            };
237
238            loop {
239                tokio::select! {
240                    // Check if we've received a signal to close the stream
241                    _ = rx.recv() => {
242                        println!("Closing stream for collection '{}'", collection_name);
243                        break;
244                    }
245                    // Process the next event from the stream
246                    next_event = stream.next() => {
247                        match next_event {
248                            Some(result) => {
249                                match result {
250                                    Ok(change_stream_event) => {
251                                        let event_type = Event::from(change_stream_event.operation_type);
252                                        // Find and execute the appropriate callback
253                                        if let Some(callback) = callbacks.get(&event_type) {
254                                            if let Some(doc) = change_stream_event.full_document {
255                                                callback(&doc).await;
256                                            }
257                                        }
258                                    }
259                                    Err(e) => {
260                                        eprintln!("Error in MongoDB stream for collection '{}': {}. Reconnecting...", collection_name, e);
261                                        // Graceful reconnection
262                                        match db.collection::<mongodb::bson::Document>(&collection_name).watch(None, None).await {
263                                            Ok(new_stream) => stream = new_stream,
264                                            Err(reconnect_err) => {
265                                                eprintln!("Failed to reconnect to collection '{}': {}", collection_name, reconnect_err);
266                                                break;
267                                            }
268                                        }
269                                    }
270                                }
271                            }
272                            None => {
273                                // Stream has ended normally
274                                println!("Stream for collection '{}' has ended", collection_name);
275                                break;
276                            }
277                        }
278                    }
279                }
280            }
281        });
282
283        Ok(())
284    }
285
286    /// Closes a specific stream for a collection.
287    ///
288    /// This method sends a signal to stop the monitoring task for a specific collection.
289    ///
290    /// # Arguments
291    ///
292    /// * `collection_name` - The name of the collection whose stream should be closed.
293    ///
294    /// # Returns
295    ///
296    /// A boolean indicating whether a stream was found and closed. Returns `false` if
297    /// there was no active stream for the given collection.
298    ///
299    /// # Example
300    ///
301    /// ```rust,ignore
302    /// use rs_mongo_stream::MongoStream;
303    ///
304    /// async fn example() {
305    ///     let mongo_stream = MongoStream::new(db);
306    ///     // ... start streams ...
307    ///     
308    ///     // When done with a specific collection
309    ///     let closed = mongo_stream.close_stream("users").await;
310    ///     println!("Stream closed: {}", closed);
311    /// }
312    /// ```
313    pub async fn close_stream(&mut self, collection_name: &str) -> bool {
314        if let Some(tx) = self.active_streams.remove(collection_name) {
315            // Send signal to close the stream
316            // It's okay if this fails - the receiver might have been dropped already
317            let _ = tx.send(()).await;
318            true
319        } else {
320            false
321        }
322    }
323
324    /// Closes all active streams.
325    ///
326    /// This method sends signals to stop all monitoring tasks and clears the active streams list.
327    ///
328    /// # Returns
329    ///
330    /// The number of streams that were closed.
331    ///
332    /// # Example
333    ///
334    /// ```rust,ignore
335    /// async fn shutdown() {
336    ///     let mongo_stream = MongoStream::new(db);
337    ///     // ... start streams ...
338    ///     
339    ///     // When shutting down the application
340    ///     let closed_count = mongo_stream.close_all_streams().await;
341    ///     println!("Closed {} streams", closed_count);
342    /// }
343    /// ```
344    pub async fn close_all_streams(&mut self) -> usize {
345        let count = self.active_streams.len();
346
347        // Send close signal to all active streams
348        for (collection_name, tx) in self.active_streams.drain() {
349            let _ = tx.send(()).await;
350            println!(
351                "Sent close signal to stream for collection '{}'",
352                collection_name
353            );
354        }
355
356        count
357    }
358}
359
360/// Implementation of Stream trait for MongoStream to make it compatible with tokio_stream.
361///
362/// Note: This is a placeholder implementation and would need to be expanded
363/// in a real-world scenario to connect to actual MongoDB change streams.
364impl Stream for MongoStream {
365    type Item = Result<Event, MongoStreamError>;
366
367    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
368        // Placeholder implementation for the Stream trait
369        // In a real implementation, this would be connected to MongoDB change streams
370        Poll::Ready(Some(Ok(Event::Insert)))
371    }
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use mongodb::{bson::Document, options::ClientOptions, Client};
378    use tokio::sync::mpsc;
379
380    async fn setup_test_db() -> Database {
381        let client_options = ClientOptions::parse("mongodb://localhost:27017")
382            .await
383            .unwrap();
384        let client = Client::with_options(client_options).unwrap();
385        client.database("test_db")
386    }
387
388    #[tokio::test]
389    async fn test_mongo_stream_creation() {
390        let db = setup_test_db().await;
391        let mongo_stream = MongoStream::new(db.clone());
392
393        assert!(mongo_stream.collection_callbacks.is_empty());
394        assert!(mongo_stream.active_streams.is_empty());
395    }
396
397    #[tokio::test]
398    async fn test_add_callback() {
399        let db = setup_test_db().await;
400        let mut mongo_stream = MongoStream::new(db);
401
402        let (tx, mut rx) = mpsc::channel(1);
403        mongo_stream.add_callback("test_collection", Event::Insert, move |_doc| {
404            let tx = tx.clone();
405            Box::pin(async move {
406                let _ = tx.send(()).await;
407            })
408        });
409
410        let callbacks = mongo_stream.get_collection_callbacks("test_collection");
411        assert!(callbacks.contains_key(&Event::Insert));
412
413        // Execute the callback
414        if let Some(callback) = callbacks.get(&Event::Insert) {
415            let doc = Document::new();
416            callback(&doc).await;
417            assert!(rx.recv().await.is_some());
418        }
419    }
420
421    #[tokio::test]
422    async fn test_start_and_close_stream() {
423        let db = setup_test_db().await;
424        let mut mongo_stream = MongoStream::new(db);
425
426        // Test starting stream
427        assert!(mongo_stream.start_stream("test_collection").await.is_ok());
428        assert!(mongo_stream.active_streams.contains_key("test_collection"));
429
430        // Test closing stream
431        assert!(mongo_stream.close_stream("test_collection").await);
432        assert!(!mongo_stream.active_streams.contains_key("test_collection"));
433    }
434
435    #[tokio::test]
436    async fn test_close_all_streams() {
437        let db = setup_test_db().await;
438        let mut mongo_stream = MongoStream::new(db);
439
440        mongo_stream.start_stream("collection1").await.unwrap();
441        mongo_stream.start_stream("collection2").await.unwrap();
442
443        let closed_count = mongo_stream.close_all_streams().await;
444        assert_eq!(closed_count, 2);
445        assert!(mongo_stream.active_streams.is_empty());
446    }
447
448    #[tokio::test]
449    async fn test_default_callbacks() {
450        let db = setup_test_db().await;
451        let mongo_stream = MongoStream::new(db);
452
453        let callbacks = mongo_stream.get_collection_callbacks("test_collection");
454        assert_eq!(callbacks.len(), 3);
455        assert!(callbacks.contains_key(&Event::Insert));
456        assert!(callbacks.contains_key(&Event::Update));
457        assert!(callbacks.contains_key(&Event::Delete));
458    }
459
460    #[tokio::test]
461    async fn test_double_start_error() {
462        let db = setup_test_db().await;
463        let mut mongo_stream = MongoStream::new(db);
464
465        mongo_stream.start_stream("test_collection").await.unwrap();
466        let result = mongo_stream.start_stream("test_collection").await;
467
468        assert!(matches!(result, Err(MongoStreamError { .. })));
469    }
470}