nodo_cask 0.18.5

Message recording to MCPA for NODO
Documentation
// Copyright 2024 David Weikersdorfer

use crate::{CaskWriter, CaskWriterConfig};
use mcap::McapError;
use nodo::{
    channels::Pop,
    prelude::{DoubleBufferRx, Message as NodoMessage},
};
use serde::Serialize;
use std::{
    collections::HashMap,
    io,
    ops::{Add, AddAssign},
    path::{Path, PathBuf},
    sync::{Arc, Mutex, MutexGuard},
};
use thiserror::Error;

/// Access to cask recording controls which can be shared across threads and codelets.
#[derive(Clone)]
pub struct SharedCaskControl(Arc<Mutex<CaskControl>>);

impl SharedCaskControl {
    /// Create a new control codelet which writes casks to the given directory
    pub fn new(workdir: &Path) -> Self {
        Self(Arc::new(Mutex::new(CaskControl {
            workdir: workdir.into(),
            recordings: HashMap::new(),
        })))
    }

    /// Starts a new recording
    pub fn start_recording<S: Into<String>>(
        &self,
        id: S,
        config: &CaskWriterConfig,
    ) -> Result<PathBuf, CaskControlError> {
        self.0.lock().unwrap().start_recording(id.into(), config)
    }

    /// Stops a previously started recording
    pub fn stop_recording(&self, id: &str) -> Result<PathBuf, CaskControlError> {
        self.0.lock().unwrap().stop_recording(id)
    }

    /// Stops all currently active recordings
    pub fn stop_all(&self) {
        self.0.lock().unwrap().stop_all()
    }

    /// Counts the number of active writers
    pub fn session_count(&self) -> usize {
        self.0.lock().unwrap().session_count()
    }

    /// Appends keys of active writers to given vector
    pub fn session_keys(&self) -> Vec<String> {
        self.0
            .lock()
            .unwrap()
            .iter_session_keys()
            .map(|s| s.into())
            .collect()
    }

    pub(crate) fn sink<'a>(&'a mut self) -> Option<CaskMessageSink<'a>> {
        let inner = self.0.lock().unwrap();

        if inner.recordings.is_empty() {
            None
        } else {
            Some(CaskMessageSink(inner))
        }
    }
}

pub struct CaskMessageSink<'a>(MutexGuard<'a, CaskControl>);

impl CaskMessageSink<'_> {
    /// Writes a message to all currently active recordings
    pub fn write_msg<T>(
        &mut self,
        topic: &str,
        msg: &NodoMessage<T>,
    ) -> Result<WriteMsgStats, CaskControlError>
    where
        T: Serialize,
    {
        let mut stats = WriteMsgStats::default();
        for (_, rec) in self.0.recordings.iter_mut() {
            stats.count += 1;
            stats.size += rec.write_message(topic, msg)?;
        }
        Ok(stats)
    }

    /// Writes all messages from a channel to all currently active recordings
    pub fn write_msg_rx<T>(
        &mut self,
        topic: &str,
        rx: &mut DoubleBufferRx<NodoMessage<T>>,
    ) -> Result<WriteMsgStats, CaskControlError>
    where
        T: Serialize,
    {
        let mut stats = WriteMsgStats::default();
        while let Ok(msg) = rx.pop() {
            stats += self.write_msg(topic, &msg)?;
        }
        Ok(stats)
    }
}

#[derive(Debug, Default, Clone, Copy)]
pub struct WriteMsgStats {
    /// Number of messages written
    pub count: usize,

    /// Size of messages written in bytes
    pub size: usize,
}

impl Add<WriteMsgStats> for WriteMsgStats {
    type Output = WriteMsgStats;

    fn add(self, other: WriteMsgStats) -> Self::Output {
        WriteMsgStats {
            count: self.count + other.count,
            size: self.size + other.size,
        }
    }
}

impl AddAssign<WriteMsgStats> for WriteMsgStats {
    fn add_assign(&mut self, other: WriteMsgStats) {
        self.count += other.count;
        self.size += other.size;
    }
}

struct CaskControl {
    workdir: PathBuf,
    recordings: HashMap<String, CaskWriter>,
}

impl CaskControl {
    fn start_recording(
        &mut self,
        id: String,
        config: &CaskWriterConfig,
    ) -> Result<PathBuf, CaskControlError> {
        if self.recordings.contains_key(&id) {
            return Err(CaskControlError::AlreadyInProgress(id.clone()));
        }

        let rec = CaskWriter::new(&self.workdir, config)?;
        let path = rec.path().into();
        log::info!("Started cask recording: {:?}", path);
        self.recordings.insert(id, rec);
        Ok(path)
    }

    fn stop_recording(&mut self, id: &str) -> Result<PathBuf, CaskControlError> {
        if let Some(rec) = self.recordings.remove(id.into()) {
            let path = rec.path().into();
            log::info!("Stopped cask recording");
            Ok(path)
        } else {
            Err(CaskControlError::NotInProgress(id.into()))
        }
    }

    fn session_count(&self) -> usize {
        self.recordings.len()
    }

    fn iter_session_keys(&self) -> impl Iterator<Item = &str> {
        self.recordings.keys().map(|s| s.as_str())
    }

    fn stop_all(&mut self) {
        self.recordings.drain();
    }
}

impl Drop for CaskControl {
    fn drop(&mut self) {
        self.stop_all();
    }
}

#[derive(Debug, Error)]
pub enum CaskControlError {
    #[error("recording already in progress")]
    AlreadyInProgress(String),

    #[error("recording not in progress")]
    NotInProgress(String),

    #[error("invalid request")]
    InvalidRequest,

    #[error("serialize error: {0}")]
    BincodeError(bincode::Error),

    #[error("mcap error: {0}")]
    McapError(McapError),

    #[error("io error: {0}")]
    IoError(io::Error),
}

impl From<bincode::Error> for CaskControlError {
    fn from(err: bincode::Error) -> Self {
        CaskControlError::BincodeError(err)
    }
}

impl From<McapError> for CaskControlError {
    fn from(err: McapError) -> Self {
        CaskControlError::McapError(err)
    }
}

impl From<io::Error> for CaskControlError {
    fn from(err: io::Error) -> Self {
        CaskControlError::IoError(err)
    }
}