use crate::{CaskWriter, CaskWriterConfig};
use mcap::McapError;
use nodo::{
channels::Pop,
prelude::{DoubleBufferRx, Message as NodoMessage},
};
use serde::Serialize;
use std::{
collections::HashMap,
io,
ops::{Add, AddAssign},
path::{Path, PathBuf},
sync::{Arc, Mutex, MutexGuard},
};
use thiserror::Error;
#[derive(Clone)]
pub struct SharedCaskControl(Arc<Mutex<CaskControl>>);
impl SharedCaskControl {
pub fn new(workdir: &Path) -> Self {
Self(Arc::new(Mutex::new(CaskControl {
workdir: workdir.into(),
recordings: HashMap::new(),
})))
}
pub fn start_recording<S: Into<String>>(
&self,
id: S,
config: &CaskWriterConfig,
) -> Result<PathBuf, CaskControlError> {
self.0.lock().unwrap().start_recording(id.into(), config)
}
pub fn stop_recording(&self, id: &str) -> Result<PathBuf, CaskControlError> {
self.0.lock().unwrap().stop_recording(id)
}
pub fn stop_all(&self) {
self.0.lock().unwrap().stop_all()
}
pub fn session_count(&self) -> usize {
self.0.lock().unwrap().session_count()
}
pub fn session_keys(&self) -> Vec<String> {
self.0
.lock()
.unwrap()
.iter_session_keys()
.map(|s| s.into())
.collect()
}
pub(crate) fn sink<'a>(&'a mut self) -> Option<CaskMessageSink<'a>> {
let inner = self.0.lock().unwrap();
if inner.recordings.is_empty() {
None
} else {
Some(CaskMessageSink(inner))
}
}
}
pub struct CaskMessageSink<'a>(MutexGuard<'a, CaskControl>);
impl CaskMessageSink<'_> {
pub fn write_msg<T>(
&mut self,
topic: &str,
msg: &NodoMessage<T>,
) -> Result<WriteMsgStats, CaskControlError>
where
T: Serialize,
{
let mut stats = WriteMsgStats::default();
for (_, rec) in self.0.recordings.iter_mut() {
stats.count += 1;
stats.size += rec.write_message(topic, msg)?;
}
Ok(stats)
}
pub fn write_msg_rx<T>(
&mut self,
topic: &str,
rx: &mut DoubleBufferRx<NodoMessage<T>>,
) -> Result<WriteMsgStats, CaskControlError>
where
T: Serialize,
{
let mut stats = WriteMsgStats::default();
while let Ok(msg) = rx.pop() {
stats += self.write_msg(topic, &msg)?;
}
Ok(stats)
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct WriteMsgStats {
pub count: usize,
pub size: usize,
}
impl Add<WriteMsgStats> for WriteMsgStats {
type Output = WriteMsgStats;
fn add(self, other: WriteMsgStats) -> Self::Output {
WriteMsgStats {
count: self.count + other.count,
size: self.size + other.size,
}
}
}
impl AddAssign<WriteMsgStats> for WriteMsgStats {
fn add_assign(&mut self, other: WriteMsgStats) {
self.count += other.count;
self.size += other.size;
}
}
struct CaskControl {
workdir: PathBuf,
recordings: HashMap<String, CaskWriter>,
}
impl CaskControl {
fn start_recording(
&mut self,
id: String,
config: &CaskWriterConfig,
) -> Result<PathBuf, CaskControlError> {
if self.recordings.contains_key(&id) {
return Err(CaskControlError::AlreadyInProgress(id.clone()));
}
let rec = CaskWriter::new(&self.workdir, config)?;
let path = rec.path().into();
log::info!("Started cask recording: {:?}", path);
self.recordings.insert(id, rec);
Ok(path)
}
fn stop_recording(&mut self, id: &str) -> Result<PathBuf, CaskControlError> {
if let Some(rec) = self.recordings.remove(id.into()) {
let path = rec.path().into();
log::info!("Stopped cask recording");
Ok(path)
} else {
Err(CaskControlError::NotInProgress(id.into()))
}
}
fn session_count(&self) -> usize {
self.recordings.len()
}
fn iter_session_keys(&self) -> impl Iterator<Item = &str> {
self.recordings.keys().map(|s| s.as_str())
}
fn stop_all(&mut self) {
self.recordings.drain();
}
}
impl Drop for CaskControl {
fn drop(&mut self) {
self.stop_all();
}
}
#[derive(Debug, Error)]
pub enum CaskControlError {
#[error("recording already in progress")]
AlreadyInProgress(String),
#[error("recording not in progress")]
NotInProgress(String),
#[error("invalid request")]
InvalidRequest,
#[error("serialize error: {0}")]
BincodeError(bincode::Error),
#[error("mcap error: {0}")]
McapError(McapError),
#[error("io error: {0}")]
IoError(io::Error),
}
impl From<bincode::Error> for CaskControlError {
fn from(err: bincode::Error) -> Self {
CaskControlError::BincodeError(err)
}
}
impl From<McapError> for CaskControlError {
fn from(err: McapError) -> Self {
CaskControlError::McapError(err)
}
}
impl From<io::Error> for CaskControlError {
fn from(err: io::Error) -> Self {
CaskControlError::IoError(err)
}
}