use crate::{
CaskChannel, CaskControlError, CaskMessageSink, CaskWriterConfigBuilder, SharedCaskControl,
WriteMsgStats,
};
use eyre::{bail, Result};
use nodo::{channels::SyncResult, prelude::*};
use serde::{Deserialize, Serialize};
pub struct CaskRecord {
ctrl: SharedCaskControl,
}
#[derive(Config, Default, Serialize, Deserialize)]
pub struct CaskRecordConfig {
pub enable_auto_record: bool,
pub auto_record_channels: Vec<String>,
}
#[derive(Status, Debug, Copy, Clone, PartialEq, Eq)]
pub enum CaskRecordStatus {
#[default]
#[skipped]
Skipped,
Idle,
Recording,
}
signals! {
CaskRecordSignals {
topics: String,
session_count: usize,
recorded_message_count: usize,
recorded_message_size: usize,
}
}
impl CaskRecord {
pub fn new(ctrl: SharedCaskControl) -> Self {
Self { ctrl }
}
}
impl Codelet for CaskRecord {
type Status = CaskRecordStatus;
type Config = CaskRecordConfig;
type Rx = CaskRxBundle;
type Tx = ();
type Signals = CaskRecordSignals;
fn build_bundles(_cfg: &Self::Config) -> (Self::Rx, Self::Tx) {
(CaskRxBundle::default(), ())
}
fn start(
&mut self,
cx: Context<Self>,
rx: &mut Self::Rx,
_tx: &mut Self::Tx,
) -> Result<CaskRecordStatus> {
cx.signals.topics.set(rx.topics().join(","));
if cx.config.enable_auto_record {
self.ctrl.start_recording(
"auto",
&CaskWriterConfigBuilder::default()
.with_timestamp_filename("auto")
.with_channels(
cx.config
.auto_record_channels
.iter()
.map(|topic| CaskChannel {
topic: topic.into(),
}),
)
.into(),
)?;
}
Ok(CaskRecordStatus::Idle)
}
fn stop(
&mut self,
cx: Context<Self>,
_rx: &mut Self::Rx,
_tx: &mut Self::Tx,
) -> Result<CaskRecordStatus> {
self.ctrl.stop_all();
cx.signals.session_count.set(0);
Ok(CaskRecordStatus::Idle)
}
fn step(
&mut self,
cx: Context<Self>,
rx: &mut Self::Rx,
_tx: &mut Self::Tx,
) -> Result<CaskRecordStatus> {
cx.signals.session_count.set(self.ctrl.session_count());
if let Some(mut rec) = self.ctrl.sink() {
let stats = rx.write_messages(&mut rec)?;
cx.signals.recorded_message_count.add(stats.count);
cx.signals.recorded_message_size.add(stats.size);
Ok(CaskRecordStatus::Recording)
} else {
rx.clear_all_messages();
Ok(CaskRecordStatus::Idle)
}
}
}
#[derive(Default)]
pub struct CaskRxBundle {
rx: Vec<Box<dyn CaskRx>>,
}
impl CaskRxBundle {
pub fn connect<S, T, V>(&mut self, topic: S, tx: T) -> Result<()>
where
for<'a> (T, &'a mut MessageRx<V>): nodo::prelude::Connect,
V: 'static + Send + Sync + Serialize,
S: Into<String>,
{
let topic: String = topic.into();
if self.has_topic(&topic) {
bail!("channel with topic '{topic}' already added");
}
let mut rx = CaskRxImpl {
topic: topic.into(),
channel: MessageRx::new_auto_size(),
};
connect(tx, &mut rx.channel)?;
self.rx.push(Box::new(rx));
Ok(())
}
pub fn clear_all_messages(&mut self) {
for rx in self.rx.iter_mut() {
rx.clear();
}
}
pub fn topics(&self) -> Vec<&str> {
self.rx.iter().map(|rx| rx.topic()).collect()
}
fn has_topic(&self, topic: &str) -> bool {
self.rx.iter().any(|rx| rx.topic() == topic)
}
fn write_messages<'a>(
&mut self,
sink: &mut CaskMessageSink<'a>,
) -> eyre::Result<WriteMsgStats, CaskControlError> {
let mut stats = WriteMsgStats::default();
for rx in self.rx.iter_mut() {
stats += rx.write_messages(sink)?;
}
Ok(stats)
}
}
trait CaskRx: Send {
fn topic(&self) -> &str;
fn len(&self) -> usize;
fn sync(&mut self) -> SyncResult;
fn is_connected(&self) -> bool;
fn clear(&mut self);
fn write_messages<'a>(
&mut self,
sink: &mut CaskMessageSink<'a>,
) -> eyre::Result<WriteMsgStats, CaskControlError>;
}
struct CaskRxImpl<T> {
topic: String,
channel: DoubleBufferRx<Message<T>>,
}
impl<T: Serialize + Send + Sync> CaskRx for CaskRxImpl<T> {
fn topic(&self) -> &str {
&self.topic
}
fn len(&self) -> usize {
self.channel.len()
}
fn sync(&mut self) -> SyncResult {
self.channel.sync()
}
fn is_connected(&self) -> bool {
self.channel.is_connected()
}
fn clear(&mut self) {
self.channel.clear();
}
fn write_messages<'a>(
&mut self,
sink: &mut CaskMessageSink<'a>,
) -> eyre::Result<WriteMsgStats, CaskControlError> {
sink.write_msg_rx(&self.topic, &mut self.channel)
}
}
impl nodo::channels::RxBundle for CaskRxBundle {
fn channel_count(&self) -> usize {
self.rx.len()
}
fn name(&self, index: usize) -> &str {
if index < self.rx.len() {
self.rx[index].topic()
} else {
panic!(
"invalid index '{index}': number of channels is {}",
self.rx.len()
)
}
}
fn inbox_message_count(&self, index: usize) -> usize {
self.rx[index].len()
}
fn sync_all(&mut self, results: &mut [SyncResult]) {
for (i, channel) in self.rx.iter_mut().enumerate() {
results[i] = channel.sync()
}
}
fn check_connection(&self) -> nodo::channels::ConnectionCheck {
let mut cc = nodo::channels::ConnectionCheck::new(self.rx.len());
for (i, channel) in self.rx.iter().enumerate() {
cc.mark(i, channel.is_connected());
}
cc
}
}