live-entity 0.0.7

A framework for storing entities and watching them for updates.
Documentation
use crate::{Entity, Event, NotFoundError, Store};
use async_trait::async_trait;
use futures_util::{StreamExt, TryStreamExt};
use mongodb::bson::{doc, from_bson, from_document, to_bson, to_document, Document};
use mongodb::change_stream::event::{ChangeStreamEvent, OperationType};
use mongodb::options::{ChangeStreamOptions, ClientOptions, FullDocumentType};
use mongodb::{Client, Database};
use std::error::Error;
use std::fmt::Formatter;
use tokio::sync::broadcast::Sender;

#[derive(Clone)]
pub struct MongoDBStore {
    db: Database,
}

impl MongoDBStore {
    pub async fn new(
        connection_string: String,
        database_name: String,
        app_name: Option<String>,
    ) -> Result<Self, mongodb::error::Error> {
        let mut options = ClientOptions::parse(connection_string).await?;
        options.app_name = app_name;
        let client = Client::with_options(options);
        client
            .map(|c| c.database(&database_name))
            .map(|db| Self { db })
    }

    pub async fn delete_filtered<E: Entity>(
        &self,
        filter: Option<Document>,
    ) -> Result<(), Box<dyn Error>> {
        let collection = self.db.collection::<E>(E::TYPE_NAME);
        collection
            .delete_many(filter.unwrap_or(doc! {}), None)
            .await?;
        Ok(())
    }

    pub async fn get_filtered<E: Entity>(
        &self,
        filter: Option<Document>,
    ) -> Result<Vec<E>, Box<dyn Error>> {
        let collection = self.db.collection::<E>(E::TYPE_NAME);
        let res = collection.find(filter, None).await?;
        Ok(res.try_collect().await?)
    }

    pub async fn watch_filtered<E: Entity>(
        &self,
        channel: Sender<Event<E>>,
        filter: Option<Document>,
    ) -> Result<(), Box<dyn Error>> {
        let collection = self.db.collection::<Document>(E::TYPE_NAME);
        let mut mtch = doc! { "$match": {
            "operationType": {
                "$in": to_bson(&[OperationType::Update, OperationType::Insert, OperationType::Delete, OperationType::Replace])?
            }
        } };
        if let Some(f) = filter {
            for (k, v) in f {
                mtch.insert(&format!("fullDocument.{}", k), v);
            }
        }
        let options = ChangeStreamOptions::builder()
            .full_document(Some(FullDocumentType::UpdateLookup))
            .build();
        let mut watch = collection.watch([mtch], options).await?;
        while let Some(evt) = watch.next().await.transpose()? {
            match evt.operation_type {
                OperationType::Insert => {
                    let doc = evt.full_document.ok_or(MongoDBContractViolationError(
                        "MongoDB did not provide full document on insert event".to_owned(),
                    ))?;
                    let entity = from_document(doc)?;
                    channel.send(Event::Create(entity))?;
                }
                OperationType::Update => {
                    let id = get_id_from_change_event::<E>(&evt)?;
                    let doc = evt
                        .update_description
                        .ok_or(MongoDBContractViolationError(
                            "MongoDB did not provide update description on update event".to_owned(),
                        ))?
                        .updated_fields;
                    let update: E::Update = from_document(doc)?;
                    channel.send(Event::Update { id, update })?;
                }
                OperationType::Delete => {
                    let id = get_id_from_change_event::<E>(&evt)?;
                    channel.send(Event::Delete(id))?;
                }
                OperationType::Replace => {
                    let id = get_id_from_change_event::<E>(&evt)?;
                    let doc = evt.full_document.ok_or(MongoDBContractViolationError(
                        "MongoDB did not provide full document on replace event".to_owned(),
                    ))?;
                    let update: E::Update = from_document(doc)?;
                    channel.send(Event::Update { id, update })?;
                }
                _ => {
                    return Err(MongoDBContractViolationError(format!(
                        "MongoDB returned an event type that was filtered out: {:?}.",
                        evt.operation_type
                    ))
                    .into())
                }
            }
        }
        Ok(())
    }
}

impl Into<MongoDBStore> for Database {
    fn into(self) -> MongoDBStore {
        MongoDBStore { db: self }
    }
}

#[derive(Debug)]
pub struct MongoDBContractViolationError(String);
impl std::fmt::Display for MongoDBContractViolationError {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        f.write_str(&self.0)
    }
}
impl Error for MongoDBContractViolationError {}

#[async_trait]
impl Store for MongoDBStore {
    async fn create<E: Entity>(&self, entity: &E) -> Result<(), Box<dyn Error>> {
        let collection = self.db.collection::<Document>(E::TYPE_NAME);
        let mut doc = to_document(entity)?;
        doc.insert("_id", to_bson(entity.get_id())?);
        collection.insert_one(doc, None).await?;
        Ok(())
    }

    async fn update<E: Entity>(
        &self,
        id: &E::ID,
        update: &E::Update,
    ) -> Result<(), Box<dyn Error>> {
        let collection = self.db.collection::<E>(E::TYPE_NAME);
        let query = doc! { "_id": to_bson(id)? };
        let update = vec![doc! {
            "$set": to_document(&update)?
        }];
        collection.update_one(query, update, None).await?;
        Ok(())
    }

    async fn delete_all<E: Entity>(&self) -> Result<(), Box<dyn Error>> {
        self.delete_filtered::<E>(None).await
    }

    async fn delete_by_id<E: Entity>(&self, id: &E::ID) -> Result<(), Box<dyn Error>> {
        let collection = self.db.collection::<E>(E::TYPE_NAME);
        let query = doc! { "_id": to_bson(id)? };
        collection.delete_one(query, None).await?;
        Ok(())
    }

    async fn get_all<E: Entity>(&self) -> Result<Vec<E>, Box<dyn Error>> {
        self.get_filtered(None).await
    }

    async fn get_by_id<E: Entity>(&self, id: &E::ID) -> Result<E, Box<dyn Error>> {
        let collection = self.db.collection::<E>(E::TYPE_NAME);
        let query = doc! { "_id": to_bson(id)? };
        collection
            .find_one(query, None)
            .await?
            .ok_or(NotFoundError(id.clone()).into())
    }

    async fn watch<E: Entity>(&self, channel: Sender<Event<E>>) -> Result<(), Box<dyn Error>> {
        self.watch_filtered(channel, None).await
    }
}

fn get_id_from_change_event<E: Entity>(
    event: &ChangeStreamEvent<Document>,
) -> Result<E::ID, Box<dyn Error>> {
    let id = from_bson(
        event
            .document_key
            .as_ref()
            .ok_or(MongoDBContractViolationError(
                "MongoDB did not provide document key on change event".to_owned(),
            ))?
            .get("_id")
            .cloned()
            .ok_or(MongoDBContractViolationError(
                "MongoDB provided no _id on document key".to_owned(),
            ))?,
    )?;
    Ok(id)
}