samod 0.9.0

A rust library for managing automerge documents, compatible with the js automerge-repo library
use std::collections::HashMap;
use std::sync::Arc;

use automerge::{Automerge, ReadDoc};
use futures::{Stream, channel::mpsc};
use samod_core::{
    ConnectionId, DocumentActorId, DocumentChanged, DocumentId, SyncDirection, SyncMessageStat,
    UnixTimestamp,
    actors::{
        DocToHubMsg,
        document::{DocActorResult, DocumentActor, WithDocResult},
    },
};

use crate::{
    actor_task::ActorTask,
    io_loop::{self, IoLoopTask},
    observer::{self, RepoObserver},
    peer_connection_info::PeerDocState,
    unbounded::UnboundedSender,
};

pub(crate) struct DocActorInner {
    document_id: DocumentId,
    actor_id: DocumentActorId,
    tx_to_core: UnboundedSender<(DocumentActorId, DocToHubMsg)>,
    tx_io: UnboundedSender<io_loop::IoLoopTask>,
    ephemera_listeners: Vec<mpsc::UnboundedSender<Vec<u8>>>,
    change_listeners: Vec<mpsc::UnboundedSender<DocumentChanged>>,
    peer_state_change_listeners: Vec<mpsc::UnboundedSender<HashMap<ConnectionId, PeerDocState>>>,
    actor: DocumentActor,
    observer: Option<Arc<dyn RepoObserver>>,
}

impl DocActorInner {
    pub(crate) fn new(
        document_id: DocumentId,
        actor_id: DocumentActorId,
        actor: DocumentActor,
        tx_to_core: UnboundedSender<(DocumentActorId, DocToHubMsg)>,
        tx_io: UnboundedSender<io_loop::IoLoopTask>,
        observer: Option<Arc<dyn RepoObserver>>,
    ) -> Self {
        DocActorInner {
            document_id,
            actor_id,
            tx_to_core,
            tx_io,
            ephemera_listeners: Vec::new(),
            change_listeners: Vec::new(),
            peer_state_change_listeners: Vec::new(),
            actor,
            observer,
        }
    }

    pub(crate) fn document(&self) -> &Automerge {
        self.actor.document()
    }

    pub(crate) fn with_document<F, R>(&mut self, f: F) -> R
    where
        F: FnOnce(&mut Automerge) -> R,
    {
        let WithDocResult {
            actor_result,
            value,
        } = self.actor.with_document(UnixTimestamp::now(), f).unwrap();

        self.handle_results(actor_result);

        value
    }

    pub(crate) fn create_ephemera_listener(&mut self) -> mpsc::UnboundedReceiver<Vec<u8>> {
        let (tx, rx) = mpsc::unbounded();
        self.ephemera_listeners.push(tx);
        rx
    }

    pub(crate) fn create_change_listener(&mut self) -> mpsc::UnboundedReceiver<DocumentChanged> {
        let (tx, rx) = mpsc::unbounded();
        self.change_listeners.push(tx);
        rx
    }

    pub(crate) fn broadcast_ephemeral_message(&mut self, message: Vec<u8>) {
        let result = self.actor.broadcast(UnixTimestamp::now(), message);
        self.handle_results(result);
    }

    pub(crate) fn handle_results(&mut self, results: DocActorResult) {
        let DocActorResult {
            io_tasks,
            outgoing_messages,
            ephemeral_messages,
            change_events,
            stopped,
            peer_state_changes,
            sync_message_stats,
            pending_sync_messages,
        } = results;

        let peer_count = self.actor.peers().len();
        let stats = self.actor.document().stats();
        log_sync_message_stats(
            &self.document_id,
            &sync_message_stats,
            peer_count,
            stats.num_ops,
            stats.num_changes,
        );

        if let Some(ref obs) = self.observer {
            for stat in &sync_message_stats {
                let event = match stat.direction {
                    SyncDirection::Received => observer::RepoEvent::SyncMessageReceived {
                        document_id: self.document_id.clone(),
                        connection_id: stat.connection_id,
                        bytes: stat.bytes,
                        duration: stat.duration,
                        queue_duration: stat.queue_duration,
                    },
                    SyncDirection::Generated => observer::RepoEvent::SyncMessageGenerated {
                        document_id: self.document_id.clone(),
                        connection_id: stat.connection_id,
                        bytes: stat.bytes,
                        duration: stat.duration,
                    },
                };
                obs.observe(&event);
            }
            if pending_sync_messages > 0 {
                obs.observe(&observer::RepoEvent::DocumentPendingSyncMessages {
                    document_id: self.document_id.clone(),
                    count: pending_sync_messages,
                });
            }
            if stopped {
                obs.observe(&observer::RepoEvent::DocumentClosed {
                    document_id: self.document_id.clone(),
                });
            }
        }
        for task in io_tasks {
            if let Err(_e) = self.tx_io.unbounded_send(IoLoopTask::Storage {
                doc_id: self.document_id.clone(),
                task,
                actor_id: self.actor_id,
            }) {
                tracing::error!("io receiver dropped whilst document actor is still running");
                return;
            }
        }

        for msg in outgoing_messages {
            if let Err(_e) = self.tx_to_core.unbounded_send((self.actor_id, msg)) {
                tracing::error!("core receiver dropped whilst document actor is still running");
                return;
            }
        }

        if !ephemeral_messages.is_empty() {
            self.ephemera_listeners.retain_mut(|listener| {
                for msg in &ephemeral_messages {
                    if listener.unbounded_send(msg.clone()).is_err() {
                        return false;
                    }
                }
                true
            });
        }

        if !change_events.is_empty() {
            self.change_listeners.retain_mut(|listener| {
                for change in &change_events {
                    if listener.unbounded_send(change.clone()).is_err() {
                        return false;
                    }
                }
                true
            });
        }

        if !peer_state_changes.is_empty() {
            let changes = peer_state_changes
                .into_iter()
                .map(|(k, v)| (k, v.into()))
                .collect::<HashMap<_, _>>();
            self.peer_state_change_listeners
                .retain_mut(move |listener| {
                    if listener.unbounded_send(changes.clone()).is_err() {
                        return false;
                    }
                    true
                });
        }
    }

    pub(crate) fn handle_task(&mut self, task: ActorTask) {
        let result = match task {
            ActorTask::HandleMessage(samod_to_actor_message) => self
                .actor
                .handle_message(UnixTimestamp::now(), samod_to_actor_message),
            ActorTask::IoComplete(io_result) => self
                .actor
                .handle_io_complete(UnixTimestamp::now(), io_result),
        };
        self.handle_results(result.unwrap());
    }

    pub(crate) fn is_stopped(&self) -> bool {
        self.actor.is_stopped()
    }

    pub(crate) fn peers(
        &mut self,
    ) -> (
        HashMap<ConnectionId, PeerDocState>,
        impl Stream<Item = HashMap<ConnectionId, PeerDocState>> + 'static + use<>,
    ) {
        let peers = self.actor.peers();
        let (tx, rx) = mpsc::unbounded();
        self.peer_state_change_listeners.push(tx);
        let peers = peers.into_iter().map(|(k, v)| (k, v.into())).collect();
        (peers, rx)
    }
}

fn log_sync_message_stats(
    doc_id: &DocumentId,
    sync_message_stats: &[SyncMessageStat],
    peer_count: usize,
    num_ops: u64,
    num_changes: u64,
) {
    for stat in sync_message_stats {
        let direction = match stat.direction {
            SyncDirection::Received => "received",
            SyncDirection::Generated => "generated",
        };
        tracing::debug!(
            document_id = %doc_id,
            connection_id = ?stat.connection_id,
            direction,
            bytes = stat.bytes,
            duration_us = stat.duration.as_micros() as u64,
            queue_duration_us = stat.queue_duration.as_micros() as u64,
            peer_count,
            num_ops,
            num_changes,
            "sync_message"
        );
    }
}