use crate::next_clip_path;
use eyre::{bail, eyre, Result, WrapErr};
use mcap::read::{RawMessage as McapRawMessage, RawMessageStream as McapRawMessageStream};
use memmap::Mmap;
use nodo::{channels::FlushResult, prelude::*};
use serde::{Deserialize, Serialize};
use std::{
collections::HashMap,
fs::File,
path::{Path, PathBuf},
slice,
time::Duration,
};
pub struct CaskReplay {
path: PathBuf,
reader: OrderedReplay,
start_pubtime: Option<Pubtime>,
termination_control: std::sync::mpsc::SyncSender<RuntimeControl>,
}
#[derive(Config, Default, Serialize, Deserialize)]
pub struct CaskReplayConfig {
pub max_count_per_step: Option<usize>,
pub auto_termiante: bool,
}
#[derive(Status, Debug, Copy, Clone, PartialEq, Eq)]
pub enum CaskReplayStatus {
#[default]
#[skipped]
Skipped,
Idle,
Replaying,
}
impl CaskReplay {
pub fn new(
path: &Path,
termination_control: std::sync::mpsc::SyncSender<RuntimeControl>,
) -> eyre::Result<Self> {
Ok(Self {
reader: OrderedReplay::new(path)?,
path: path.into(),
start_pubtime: None,
termination_control,
})
}
}
impl Codelet for CaskReplay {
type Status = CaskReplayStatus;
type Config = CaskReplayConfig;
type Rx = ();
type Tx = CaskReplayTxBundle;
type Signals = ();
fn build_bundles(_cfg: &Self::Config) -> (Self::Rx, Self::Tx) {
((), CaskReplayTxBundle::default())
}
fn step(
&mut self,
cx: Context<Self>,
_rx: &mut Self::Rx,
tx: &mut Self::Tx,
) -> Result<CaskReplayStatus> {
if self.start_pubtime.is_none() {
self.start_pubtime = Some(cx.clocks.app_mono.now())
}
let max_age = (*cx.clocks.app_mono.now() - *self.start_pubtime.unwrap()).as_secs_f64();
let status = self
.reader
.read_messages(
Some(max_age),
cx.config.max_count_per_step,
|topic, header, data| tx.publish(topic, header, data),
)
.map_err(|err| {
println!("{err:?}");
err
})?;
if status == CaskReplayReadMessageStatus::EndOfRecord {
if let Some(next_path) = next_clip_path(&self.path) {
self.path = next_path;
if let Ok(reader) = OrderedReplay::new(&self.path) {
self.reader = reader;
return Ok(CaskReplayStatus::Replaying);
}
}
}
match status {
CaskReplayReadMessageStatus::NotTimeYet
| CaskReplayReadMessageStatus::MaxCountReached => Ok(CaskReplayStatus::Replaying),
CaskReplayReadMessageStatus::EndOfRecord => {
if cx.config.auto_termiante {
self.termination_control.send(RuntimeControl::RequestStop)?;
}
Ok(CaskReplayStatus::Idle)
}
}
}
}
#[derive(Default)]
pub struct CaskReplayTxBundle {
tx: Vec<Box<dyn CaskReplayTx>>,
channels: HashMap<String, usize>,
}
impl CaskReplayTxBundle {
pub fn connect<S, R, V>(&mut self, topic: S, rx: R) -> Result<()>
where
for<'a> (&'a mut DoubleBufferTx<Message<V>>, R): nodo::prelude::Connect,
V: 'static + Send + Sync + Clone + for<'a> Deserialize<'a>,
S: Into<String>,
{
let topic: String = topic.into();
if self.has_topic(&topic) {
bail!("channel with topic '{topic}' already added");
}
let mut tx = CaskReplayTxImpl {
topic: topic.clone(),
channel: DoubleBufferTx::new_auto_size(),
};
connect(&mut tx.channel, rx)?;
self.tx.push(Box::new(tx));
self.channels.insert(topic, self.tx.len() - 1);
Ok(())
}
fn has_topic(&self, topic: &str) -> bool {
self.channels.contains_key(topic)
}
fn publish<'a>(
&mut self,
topic: &str,
header: OrderedReplayHeader,
data: &[u8],
) -> eyre::Result<()> {
if let Some(tx_idx) = self.channels.get_mut(topic) {
self.tx[*tx_idx].write_message(header, data)?;
}
Ok(())
}
}
trait CaskReplayTx: Send {
fn topic(&self) -> &str;
fn len(&self) -> usize;
fn flush(&mut self) -> FlushResult;
fn is_connected(&self) -> bool;
fn write_message(&mut self, header: OrderedReplayHeader, data: &[u8]) -> eyre::Result<()>;
}
struct CaskReplayTxImpl<V> {
topic: String,
channel: DoubleBufferTx<Message<V>>,
}
impl<T> CaskReplayTx for CaskReplayTxImpl<T>
where
T: Send + Sync + Clone + for<'a> Deserialize<'a>,
{
fn topic(&self) -> &str {
&self.topic
}
fn len(&self) -> usize {
self.channel.len()
}
fn flush(&mut self) -> FlushResult {
self.channel.flush()
}
fn is_connected(&self) -> bool {
self.channel.is_connected()
}
fn write_message(&mut self, header: OrderedReplayHeader, data: &[u8]) -> eyre::Result<()> {
self.channel.push(Message {
seq: header.seq,
stamp: header.stamp,
value: bincode::deserialize(data)?,
})?;
Ok(())
}
}
impl nodo::channels::TxBundle for CaskReplayTxBundle {
fn channel_count(&self) -> usize {
self.tx.len()
}
fn name(&self, index: usize) -> &str {
if index < self.tx.len() {
self.tx[index].topic()
} else {
panic!(
"invalid index '{index}': number of channels is {}",
self.tx.len()
)
}
}
fn outbox_message_count(&self, index: usize) -> usize {
self.tx[index].len()
}
fn flush_all(&mut self, results: &mut [FlushResult]) {
for (i, channel) in self.tx.iter_mut().enumerate() {
results[i] = channel.flush()
}
}
fn check_connection(&self) -> nodo::channels::ConnectionCheck {
let mut cc = nodo::channels::ConnectionCheck::new(self.tx.len());
for (i, channel) in self.tx.iter().enumerate() {
cc.mark(i, channel.is_connected());
}
cc
}
}
struct OrderedReplay {
#[allow(dead_code)]
mmap: Mmap,
channels: HashMap<u16, String>,
stream: Option<McapRawMessageStream<'static>>,
first_pubtime: Option<Duration>,
pending: Option<McapRawMessage<'static>>,
}
impl Drop for OrderedReplay {
fn drop(&mut self) {
self.pending = None;
self.stream = None;
}
}
impl OrderedReplay {
pub fn new(path: &Path) -> eyre::Result<Self> {
let file = File::open(path).context("Couldn't open MCAP file")?;
let mmap = unsafe { Mmap::map(&file) }.context("Couldn't map MCAP file")?;
let channels = Self::create_channel_topic_map(&mmap)?;
let ptr: &'static [u8] = unsafe { slice::from_raw_parts(mmap.as_ptr(), mmap.len()) };
let stream = McapRawMessageStream::new(ptr)?;
Ok(Self {
mmap,
channels,
stream: Some(stream),
first_pubtime: None,
pending: None,
})
}
fn create_channel_topic_map(data: &[u8]) -> eyre::Result<HashMap<u16, String>> {
let summary = mcap::Summary::read(data)?.ok_or_else(|| eyre!("no summary"))?;
let mut map = HashMap::new();
for (id, ch) in summary.channels {
map.insert(id, ch.topic.clone());
}
Ok(map)
}
pub fn read_messages<F>(
&mut self,
max_age: Option<f64>,
max_count: Option<usize>,
mut callback_f: F,
) -> eyre::Result<CaskReplayReadMessageStatus>
where
F: FnMut(&str, OrderedReplayHeader, &[u8]) -> eyre::Result<()>,
{
let Some(stream) = self.stream.as_mut() else {
bail!("stream is None");
};
let mut cnt = 0;
loop {
if let Some(max_count) = max_count {
if cnt >= max_count {
return Ok(CaskReplayReadMessageStatus::MaxCountReached);
}
}
if self.pending.is_none() {
match stream.next() {
None => return Ok(CaskReplayReadMessageStatus::EndOfRecord),
Some(msg) => self.pending = Some(msg?),
}
}
let msg = self.pending.as_ref().unwrap();
let pubtime = Duration::from_nanos(msg.header.publish_time);
if self.first_pubtime.is_none() {
self.first_pubtime = Some(pubtime);
} else {
self.first_pubtime = Some(self.first_pubtime.unwrap().min(pubtime));
}
let age = (pubtime - self.first_pubtime.unwrap()).as_secs_f64();
if let Some(max_age) = max_age {
if age > max_age {
return Ok(CaskReplayReadMessageStatus::NotTimeYet);
}
}
callback_f(
&self.channels[&msg.header.channel_id],
OrderedReplayHeader {
seq: msg.header.sequence as u64,
stamp: Stamp {
acqtime: Duration::from_nanos(msg.header.log_time).into(),
pubtime: Duration::from_nanos(msg.header.publish_time).into(),
},
},
&msg.data,
)?;
self.pending = None;
cnt += 1;
}
}
}
struct OrderedReplayHeader {
seq: u64,
stamp: Stamp,
}
#[derive(Debug, PartialEq)]
enum CaskReplayReadMessageStatus {
NotTimeYet,
MaxCountReached,
EndOfRecord,
}