#![allow(incomplete_features)]
#![feature(async_drop, impl_trait_in_assoc_type)]
use futures_util::stream::StreamExt;
use mongodb::{
Database,
bson::{Document, doc},
change_stream::event::{ChangeStreamEvent, OperationType},
options::FullDocumentBeforeChangeType,
};
use std::future::AsyncDrop;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
#[cfg(feature = "tracing")]
use tracing;
pub struct MongoDrop {
database: Database,
collected_events: Arc<Mutex<Vec<ChangeStreamEvent<Document>>>>,
stop_sender: Option<oneshot::Sender<()>>,
_listener_handle: JoinHandle<()>, }
impl Unpin for MongoDrop {}
impl MongoDrop {
pub async fn new(database: &Database) -> Result<Self, mongodb::error::Error> {
let stream = database
.watch()
.full_document_before_change(FullDocumentBeforeChangeType::WhenAvailable)
.await?;
let collected_events: Arc<Mutex<Vec<ChangeStreamEvent<Document>>>> =
Arc::new(Mutex::new(Vec::new()));
let events_clone = Arc::clone(&collected_events);
let (stop_sender, mut stop_receiver) = oneshot::channel();
let stop_sender = Some(stop_sender);
let listener_handle = tokio::spawn(async move {
let mut stream_fused = stream.fuse();
loop {
tokio::select! {
event_opt = stream_fused.next() => {
match event_opt {
Some(Ok(event)) => {
#[cfg(feature = "tracing")]
tracing::info!("Collected change event: {:?}", event.operation_type);
let mut events = events_clone.lock().unwrap();
events.push(event);
}
Some(Err(_e)) => {
#[cfg(feature = "tracing")]
tracing::error!("Change stream error: {:?}", _e);
break;
}
None => {
#[cfg(feature = "tracing")]
tracing::info!("Change stream finished.");
break;
}
}
},
_ = &mut stop_receiver => {
#[cfg(feature = "tracing")]
tracing::info!("Received stop signal for change stream listener.");
break;
},
}
}
});
Ok(MongoDrop {
database: database.clone(),
collected_events,
stop_sender,
_listener_handle: listener_handle,
})
}
}
impl AsyncDrop for MongoDrop {
async fn drop(self: Pin<&mut Self>) {
#[cfg(feature = "tracing")]
tracing::info!("Executing async_drop for MongoDrop...");
let this = self.get_mut();
if let Some(sender) = this.stop_sender.take() {
if let Err(_) = sender.send(()) {
#[cfg(feature = "tracing")]
tracing::info!("Failed to send stop signal to change stream listener.");
}
} else {
#[cfg(feature = "tracing")]
tracing::info!("Stop signal already sent or listener not initialized.");
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let events_to_process = {
let mut events = this.collected_events.lock().unwrap();
events.drain(..).collect::<Vec<_>>()
};
if events_to_process.is_empty() {
#[cfg(feature = "tracing")]
tracing::info!("No changes collected to undo in async_drop.");
return;
}
#[cfg(feature = "tracing")]
tracing::info!("Starting MongoDB rollback via async_drop...");
for event in events_to_process.into_iter() {
#[cfg(feature = "tracing")]
tracing::info!("Undoing change event: {:?}", event.operation_type);
match event.operation_type {
OperationType::Insert => {
if let Some(document_key) = event.document_key {
if let Some(id) = document_key.get("_id") {
let filter = doc! { "_id": id.clone() };
#[cfg(feature = "tracing")]
tracing::info!("Undoing Insert: Deleting document with _id: {:?}", id);
let collection = this
.database
.collection::<Document>(event.ns.unwrap().coll.unwrap().as_str());
if let Err(_e) = collection.delete_one(filter).await {
#[cfg(feature = "tracing")]
tracing::error!("Error undoing Insert: {:?}", _e);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!(
"Insert event missing _id in document_key, cannot undo."
);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!("Insert event with no document_key, cannot undo.");
}
}
OperationType::Delete => {
if let Some(full_document) = event.full_document_before_change {
#[cfg(feature = "tracing")]
tracing::info!("Undoing Delete: Re-inserting document.");
let collection = this
.database
.collection::<Document>(event.ns.unwrap().coll.unwrap().as_str());
if let Err(_e) = collection
.update_one(
doc! {"_id": full_document.get_object_id("_id").unwrap()},
doc! {"$set": full_document},
)
.upsert(true)
.await
{
#[cfg(feature = "tracing")]
tracing::error!("Error undoing Delete: {:?}", _e);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!(
"Delete event missing fullDocumentBeforeChange, cannot fully undo delete."
);
if let Some(_document_key) = event.document_key {
#[cfg(feature = "tracing")]
tracing::info!(
"Document key for un-undoable delete: {:?}",
_document_key
);
}
}
}
OperationType::Update => {
if let Some(document_key) = event.document_key {
if let Some(full_document_before) = event.full_document_before_change {
if let Some(id) = document_key.get("_id") {
let filter = doc! { "_id": id.clone() };
#[cfg(feature = "tracing")]
tracing::info!(
"Undoing Update: Replacing document with pre-update state for _id: {:?}",
id
);
let collection = this.database.collection::<Document>(
event.ns.unwrap().coll.unwrap().as_str(),
);
if let Err(_e) =
collection.replace_one(filter, full_document_before).await
{
#[cfg(feature = "tracing")]
tracing::error!("Error undoing Update: {:?}", _e);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!(
"Update event missing _id in document_key, cannot undo."
);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!(
"Update event missing fullDocumentBeforeChange, cannot fully undo update for key: {:?}",
document_key
);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!("Update event with no document_key, cannot undo.");
}
}
OperationType::Replace => {
if let Some(document_key) = event.document_key {
if let Some(full_document_before) = event.full_document_before_change {
if let Some(id) = document_key.get("_id") {
let filter = doc! { "_id": id.clone() };
#[cfg(feature = "tracing")]
tracing::info!(
"Undoing Replace: Replacing document with pre-replace state for _id: {:?}",
id
);
let collection = this.database.collection::<Document>(
event.ns.unwrap().coll.unwrap().as_str(),
);
if let Err(_e) =
collection.replace_one(filter, full_document_before).await
{
#[cfg(feature = "tracing")]
tracing::error!("Error undoing Replace: {:?}", _e);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!(
"Replace event missing _id in document_key, cannot undo."
);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!(
"Replace event missing fullDocumentBeforeChange, cannot fully undo replace for key: {:?}",
document_key
);
}
} else {
#[cfg(feature = "tracing")]
tracing::info!("Replace event with no document_key, cannot undo.");
}
}
_op_type => {
#[cfg(feature = "tracing")]
tracing::info!(
"Unhandled change stream operation type during async_drop: {:?}",
_op_type
);
}
}
}
#[cfg(feature = "tracing")]
tracing::info!("MongoDB rollback via async_drop finished.");
}
}
impl Drop for MongoDrop {
fn drop(&mut self) {
#[cfg(feature = "tracing")]
tracing::info!("Executing sync drop for MongoDrop - stopping change stream listener");
if let Some(sender) = self.stop_sender.take() {
if let Err(_) = sender.send(()) {
#[cfg(feature = "tracing")]
tracing::info!("Failed to send stop signal to change stream listener in sync drop");
}
} else {
#[cfg(feature = "tracing")]
tracing::info!("Stop signal already sent (possibly by AsyncDrop)");
}
#[cfg(feature = "tracing")]
tracing::warn!(
"MongoDrop dropped in sync context - database changes will NOT be rolled back. Use in async context for automatic rollback."
);
}
}