rs_mongo_stream/
stream.rs

1//! Core functionality for MongoDB change stream processing.
2//!
3//! This module provides the main `MongoStream` type which allows for subscribing
4//! to MongoDB change events with custom callbacks.
5
6use mongodb::Database;
7use std::{
8    collections::HashMap,
9    future::Future,
10    pin::Pin,
11    sync::Arc,
12    task::{Context, Poll},
13};
14use tokio::sync::mpsc;
15use tokio_stream::{Stream, StreamExt};
16
17use crate::error::MongoStreamError;
18use crate::event::Event;
19
20/// Type alias for a callback function that processes MongoDB events.
21///
22/// Callbacks should be async functions that take an Event reference
23/// and return nothing.
24type CallbackFn = dyn Fn(&Event) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync;
25
26/// A collection of callbacks mapped to their respective event types.
27///
28/// This type maps Event types to their corresponding callback functions.
29pub type Callbacks = HashMap<Event, Arc<CallbackFn>>;
30
31/// The main wrapper around MongoDB change streams.
32///
33/// `MongoStream` provides an interface for subscribing to MongoDB change events
34/// with custom callbacks for different event types per collection.
35pub struct MongoStream {
36    /// The MongoDB database to monitor
37    db: Database,
38    /// Maps collection names to their registered callbacks
39    collection_callbacks: HashMap<String, Callbacks>,
40    /// Active streams that can be closed
41    active_streams: HashMap<String, tokio::sync::mpsc::Sender<()>>,
42}
43
44impl MongoStream {
45    /// Creates a new MongoStream instance for the given database.
46    ///
47    /// # Arguments
48    ///
49    /// * `db` - The MongoDB database to monitor.
50    ///
51    /// # Returns
52    ///
53    /// A new `MongoStream` instance configured to work with the provided database.
54    pub fn new(db: Database) -> Self {
55        Self {
56            db,
57            collection_callbacks: HashMap::new(),
58            active_streams: HashMap::new(),
59        }
60    }
61
62    /// Registers a callback for a specific event type on a collection.
63    ///
64    /// # Arguments
65    ///
66    /// * `collection_name` - The name of the MongoDB collection to monitor.
67    /// * `event` - The event type to subscribe to.
68    /// * `callback` - The async function to call when the event occurs.
69    ///
70    /// # Example
71    ///
72    /// ```rust,ignore
73    /// use MongoStream::Event;
74    ///
75    /// // load the mongodb database correctly
76    /// let db = Database{};
77    /// let mut mongo_stream = MongoStream::new(db: Database);
78    /// mongo_stream.add_callback("users", Event::Insert, |event| {
79    ///     Box::pin(async move {
80    ///         println!("New user inserted: {:?}", event);
81    ///     })
82    /// });
83    /// ```
84    pub fn add_callback<F>(&mut self, collection_name: impl Into<String>, event: Event, callback: F)
85    where
86        F: Fn(&Event) -> Pin<Box<dyn Future<Output = ()> + Send>> + 'static + Send + Sync,
87    {
88        let collection_name = collection_name.into();
89        let callback_arc = Arc::new(callback);
90        let default_callbacks = self.create_default_callbacks();
91
92        self.collection_callbacks
93            .entry(collection_name)
94            .or_insert_with(|| default_callbacks)
95            .insert(event, callback_arc);
96    }
97
98    /// Creates a set of default callbacks for all event types.
99    ///
100    /// These callbacks simply log that an event was received.
101    ///
102    /// # Returns
103    ///
104    /// A HashMap mapping event types to their default callbacks.
105    fn create_default_callbacks(&self) -> Callbacks {
106        let mut callbacks: Callbacks = HashMap::new();
107
108        // Create default handlers for each event type
109        for event in [Event::Insert, Event::Update, Event::Delete] {
110            let event_name = event.event_type_str().to_string();
111
112            callbacks.insert(
113                event,
114                Arc::new(move |_event: &Event| {
115                    let event_name = event_name.clone();
116                    Box::pin(async move {
117                        println!("{} event received", event_name);
118                    })
119                }),
120            );
121        }
122
123        callbacks
124    }
125
126    /// Retrieves the callbacks for a specific collection.
127    ///
128    /// If no callbacks are registered for the collection, default callbacks are returned.
129    ///
130    /// # Arguments
131    ///
132    /// * `collection_name` - The name of the collection to get callbacks for.
133    ///
134    /// # Returns
135    ///
136    /// A HashMap mapping event types to their registered callbacks.
137    fn get_collection_callbacks(&self, collection_name: &str) -> Callbacks {
138        match self.collection_callbacks.get(collection_name) {
139            Some(callbacks) => {
140                // Clone the callbacks map
141                callbacks
142                    .iter()
143                    .map(|(event, callback)| (*event, Arc::clone(callback)))
144                    .collect()
145            }
146            None => self.create_default_callbacks(),
147        }
148    }
149
150    /// Starts monitoring a collection for changes.
151    ///
152    /// This method will block and continuously process events from the MongoDB change stream
153    /// until an error occurs or the stream is closed.
154    ///
155    /// # Arguments
156    ///
157    /// * `collection_name` - The name of the collection to monitor.
158    ///
159    /// # Returns
160    ///
161    /// A Result that is Ok if the stream was closed normally, or an error if something went wrong.
162    ///
163    /// # Example
164    ///
165    /// ```rust,ignore
166    ///
167    /// async fn monitor() -> Result<(), MongoStreamError> {
168    ///     let mongo_stream = MongoStream::new(db);
169    ///     // ... register callbacks ...
170    ///     mongo_stream.start_stream("users").await
171    /// }
172    /// ```
173    /// Starts monitoring a collection for changes.
174    ///
175    /// This method will spawn a new task to monitor the collection and return immediately.
176    /// To stop the stream, call `close_stream` with the same collection name.
177    ///
178    /// # Arguments
179    ///
180    /// * `collection_name` - The name of the collection to monitor.
181    ///
182    /// # Returns
183    ///
184    /// A Result that is Ok if the stream was started successfully, or an error if something went wrong.
185    ///
186    /// # Example
187    ///
188    /// ```rust,ignore
189    /// async fn start_monitoring() -> Result<(), MongoStreamError> {
190    ///     let mongo_stream = MongoStream::new(db);
191    ///     // ... register callbacks ...
192    ///     mongo_stream.start_stream("users").await?;
193    ///     // The stream is now running in the background
194    ///     
195    ///     // Later, when you want to stop it:
196    ///     mongo_stream.close_stream("users").await;
197    ///     Ok(())
198    /// }
199    /// ```
200    pub async fn start_stream(&mut self, collection_name: &str) -> Result<(), MongoStreamError> {
201        // Check if the stream is already running
202        if self.active_streams.contains_key(collection_name) {
203            return Err(MongoStreamError::new(format!(
204                "Stream for collection '{}' is already running",
205                collection_name
206            )));
207        }
208
209        let collection = self
210            .db
211            .collection::<mongodb::bson::Document>(collection_name);
212
213        // Create a channel to signal stream closure
214        let (tx, mut rx) = mpsc::channel::<()>(1);
215
216        // Store the sender in the active_streams map
217        self.active_streams.insert(collection_name.to_string(), tx);
218
219        // Get the callbacks
220        let callbacks = self.get_collection_callbacks(collection_name);
221        let collection_name = collection_name.to_string();
222        let db = self.db.clone();
223
224        // Spawn a new task to handle the stream
225        tokio::spawn(async move {
226            let mut stream = match collection.watch(None, None).await {
227                Ok(s) => s,
228                Err(e) => {
229                    eprintln!(
230                        "Failed to start stream for collection '{}': {}",
231                        collection_name, e
232                    );
233                    return;
234                }
235            };
236
237            loop {
238                tokio::select! {
239                    // Check if we've received a signal to close the stream
240                    _ = rx.recv() => {
241                        println!("Closing stream for collection '{}'", collection_name);
242                        break;
243                    }
244                    // Process the next event from the stream
245                    next_event = stream.next() => {
246                        match next_event {
247                            Some(result) => {
248                                match result {
249                                    Ok(change_stream_event) => {
250                                        let event_type = Event::from(change_stream_event.operation_type);
251                                        // Find and execute the appropriate callback
252                                        if let Some(callback) = callbacks.get(&event_type) {
253                                            callback(&event_type).await;
254                                        }
255                                    }
256                                    Err(e) => {
257                                        eprintln!("Error in MongoDB stream for collection '{}': {}. Reconnecting...", collection_name, e);
258                                        // Graceful reconnection
259                                        match db.collection::<mongodb::bson::Document>(&collection_name).watch(None, None).await {
260                                            Ok(new_stream) => stream = new_stream,
261                                            Err(reconnect_err) => {
262                                                eprintln!("Failed to reconnect to collection '{}': {}", collection_name, reconnect_err);
263                                                break;
264                                            }
265                                        }
266                                    }
267                                }
268                            }
269                            None => {
270                                // Stream has ended normally
271                                println!("Stream for collection '{}' has ended", collection_name);
272                                break;
273                            }
274                        }
275                    }
276                }
277            }
278        });
279
280        Ok(())
281    }
282
283    /// Closes a specific stream for a collection.
284    ///
285    /// This method sends a signal to stop the monitoring task for a specific collection.
286    ///
287    /// # Arguments
288    ///
289    /// * `collection_name` - The name of the collection whose stream should be closed.
290    ///
291    /// # Returns
292    ///
293    /// A boolean indicating whether a stream was found and closed. Returns `false` if
294    /// there was no active stream for the given collection.
295    ///
296    /// # Example
297    ///
298    /// ```rust,ignore
299    /// use rs_mongo_stream::MongoStream;
300    ///
301    /// async fn example() {
302    ///     let mongo_stream = MongoStream::new(db);
303    ///     // ... start streams ...
304    ///     
305    ///     // When done with a specific collection
306    ///     let closed = mongo_stream.close_stream("users").await;
307    ///     println!("Stream closed: {}", closed);
308    /// }
309    /// ```
310    pub async fn close_stream(&mut self, collection_name: &str) -> bool {
311        if let Some(tx) = self.active_streams.remove(collection_name) {
312            // Send signal to close the stream
313            // It's okay if this fails - the receiver might have been dropped already
314            let _ = tx.send(()).await;
315            true
316        } else {
317            false
318        }
319    }
320
321    /// Closes all active streams.
322    ///
323    /// This method sends signals to stop all monitoring tasks and clears the active streams list.
324    ///
325    /// # Returns
326    ///
327    /// The number of streams that were closed.
328    ///
329    /// # Example
330    ///
331    /// ```rust,ignore
332    /// async fn shutdown() {
333    ///     let mongo_stream = MongoStream::new(db);
334    ///     // ... start streams ...
335    ///     
336    ///     // When shutting down the application
337    ///     let closed_count = mongo_stream.close_all_streams().await;
338    ///     println!("Closed {} streams", closed_count);
339    /// }
340    /// ```
341    pub async fn close_all_streams(&mut self) -> usize {
342        let count = self.active_streams.len();
343
344        // Send close signal to all active streams
345        for (collection_name, tx) in self.active_streams.drain() {
346            let _ = tx.send(()).await;
347            println!(
348                "Sent close signal to stream for collection '{}'",
349                collection_name
350            );
351        }
352
353        count
354    }
355}
356
357/// Implementation of Stream trait for MongoStream to make it compatible with tokio_stream.
358///
359/// Note: This is a placeholder implementation and would need to be expanded
360/// in a real-world scenario to connect to actual MongoDB change streams.
361impl Stream for MongoStream {
362    type Item = Result<Event, MongoStreamError>;
363
364    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
365        // Placeholder implementation for the Stream trait
366        // In a real implementation, this would be connected to MongoDB change streams
367        Poll::Ready(Some(Ok(Event::Insert)))
368    }
369}
370
371#[cfg(test)]
372mod tests {
373    use super::*;
374    use mongodb::{options::ClientOptions, Client};
375    use tokio::sync::mpsc;
376
377    async fn setup_test_db() -> Database {
378        let client_options = ClientOptions::parse("mongodb://localhost:27017")
379            .await
380            .unwrap();
381        let client = Client::with_options(client_options).unwrap();
382        client.database("test_db")
383    }
384
385    #[tokio::test]
386    async fn test_mongo_stream_creation() {
387        let db = setup_test_db().await;
388        let mongo_stream = MongoStream::new(db.clone());
389
390        assert!(mongo_stream.collection_callbacks.is_empty());
391        assert!(mongo_stream.active_streams.is_empty());
392    }
393
394    #[tokio::test]
395    async fn test_add_callback() {
396        let db = setup_test_db().await;
397        let mut mongo_stream = MongoStream::new(db);
398
399        let (tx, mut rx) = mpsc::channel(1);
400        mongo_stream.add_callback("test_collection", Event::Insert, move |_event| {
401            let tx = tx.clone();
402            Box::pin(async move {
403                let _ = tx.send(()).await;
404            })
405        });
406
407        let callbacks = mongo_stream.get_collection_callbacks("test_collection");
408        assert!(callbacks.contains_key(&Event::Insert));
409
410        // Execute the callback
411        if let Some(callback) = callbacks.get(&Event::Insert) {
412            callback(&Event::Insert).await;
413            assert!(rx.recv().await.is_some());
414        }
415    }
416
417    #[tokio::test]
418    async fn test_start_and_close_stream() {
419        let db = setup_test_db().await;
420        let mut mongo_stream = MongoStream::new(db);
421
422        // Test starting stream
423        assert!(mongo_stream.start_stream("test_collection").await.is_ok());
424        assert!(mongo_stream.active_streams.contains_key("test_collection"));
425
426        // Test closing stream
427        assert!(mongo_stream.close_stream("test_collection").await);
428        assert!(!mongo_stream.active_streams.contains_key("test_collection"));
429    }
430
431    #[tokio::test]
432    async fn test_close_all_streams() {
433        let db = setup_test_db().await;
434        let mut mongo_stream = MongoStream::new(db);
435
436        mongo_stream.start_stream("collection1").await.unwrap();
437        mongo_stream.start_stream("collection2").await.unwrap();
438
439        let closed_count = mongo_stream.close_all_streams().await;
440        assert_eq!(closed_count, 2);
441        assert!(mongo_stream.active_streams.is_empty());
442    }
443
444    #[tokio::test]
445    async fn test_default_callbacks() {
446        let db = setup_test_db().await;
447        let mongo_stream = MongoStream::new(db);
448
449        let callbacks = mongo_stream.get_collection_callbacks("test_collection");
450        assert_eq!(callbacks.len(), 3);
451        assert!(callbacks.contains_key(&Event::Insert));
452        assert!(callbacks.contains_key(&Event::Update));
453        assert!(callbacks.contains_key(&Event::Delete));
454    }
455
456    #[tokio::test]
457    async fn test_double_start_error() {
458        let db = setup_test_db().await;
459        let mut mongo_stream = MongoStream::new(db);
460
461        mongo_stream.start_stream("test_collection").await.unwrap();
462        let result = mongo_stream.start_stream("test_collection").await;
463
464        assert!(matches!(result, Err(MongoStreamError { .. })));
465    }
466}