nodo_cask 0.18.5

Message recording to MCPA for NODO
Documentation
// Copyright 2025 David Weikersdorfer

use crate::next_clip_path;
use eyre::{bail, eyre, Result, WrapErr};
use mcap::read::{RawMessage as McapRawMessage, RawMessageStream as McapRawMessageStream};
use memmap::Mmap;
use nodo::{channels::FlushResult, prelude::*};
use serde::{Deserialize, Serialize};
use std::{
    collections::HashMap,
    fs::File,
    path::{Path, PathBuf},
    slice,
    time::Duration,
};

/// Reads messages from a cask and provides transmitting channels of arbitrary type.
pub struct CaskReplay {
    path: PathBuf,
    reader: OrderedReplay,
    start_pubtime: Option<Pubtime>,
    termination_control: std::sync::mpsc::SyncSender<RuntimeControl>,
}

#[derive(Config, Default, Serialize, Deserialize)]
pub struct CaskReplayConfig {
    pub max_count_per_step: Option<usize>,
    pub auto_termiante: bool,
}

#[derive(Status, Debug, Copy, Clone, PartialEq, Eq)]
pub enum CaskReplayStatus {
    #[default]
    #[skipped]
    Skipped,

    Idle,

    Replaying,
}

impl CaskReplay {
    pub fn new(
        path: &Path,
        termination_control: std::sync::mpsc::SyncSender<RuntimeControl>,
    ) -> eyre::Result<Self> {
        Ok(Self {
            reader: OrderedReplay::new(path)?,
            path: path.into(),
            start_pubtime: None,
            termination_control,
        })
    }
}

impl Codelet for CaskReplay {
    type Status = CaskReplayStatus;
    type Config = CaskReplayConfig;
    type Rx = ();
    type Tx = CaskReplayTxBundle;
    type Signals = ();

    fn build_bundles(_cfg: &Self::Config) -> (Self::Rx, Self::Tx) {
        ((), CaskReplayTxBundle::default())
    }

    fn step(
        &mut self,
        cx: Context<Self>,
        _rx: &mut Self::Rx,
        tx: &mut Self::Tx,
    ) -> Result<CaskReplayStatus> {
        if self.start_pubtime.is_none() {
            self.start_pubtime = Some(cx.clocks.app_mono.now())
        }

        let max_age = (*cx.clocks.app_mono.now() - *self.start_pubtime.unwrap()).as_secs_f64();

        let status = self
            .reader
            .read_messages(
                Some(max_age),
                cx.config.max_count_per_step,
                |topic, header, data| tx.publish(topic, header, data),
            )
            .map_err(|err| {
                println!("{err:?}");
                err
            })?;

        // if end of record reached check if there is another clip
        if status == CaskReplayReadMessageStatus::EndOfRecord {
            if let Some(next_path) = next_clip_path(&self.path) {
                self.path = next_path;
                if let Ok(reader) = OrderedReplay::new(&self.path) {
                    self.reader = reader;
                    return Ok(CaskReplayStatus::Replaying);
                }
            }
        }

        match status {
            CaskReplayReadMessageStatus::NotTimeYet
            | CaskReplayReadMessageStatus::MaxCountReached => Ok(CaskReplayStatus::Replaying),
            CaskReplayReadMessageStatus::EndOfRecord => {
                if cx.config.auto_termiante {
                    self.termination_control.send(RuntimeControl::RequestStop)?;
                }

                Ok(CaskReplayStatus::Idle)
            }
        }
    }
}

/// A bundle of transmitting channels
#[derive(Default)]
pub struct CaskReplayTxBundle {
    tx: Vec<Box<dyn CaskReplayTx>>,
    channels: HashMap<String, usize>,
}

impl CaskReplayTxBundle {
    /// Connect a data channel and starts recording it under the given topic name.
    /// Topics must be unique.
    pub fn connect<S, R, V>(&mut self, topic: S, rx: R) -> Result<()>
    where
        for<'a> (&'a mut DoubleBufferTx<Message<V>>, R): nodo::prelude::Connect,
        V: 'static + Send + Sync + Clone + for<'a> Deserialize<'a>,
        S: Into<String>,
    {
        let topic: String = topic.into();

        if self.has_topic(&topic) {
            bail!("channel with topic '{topic}' already added");
        }

        let mut tx = CaskReplayTxImpl {
            topic: topic.clone(),
            channel: DoubleBufferTx::new_auto_size(),
        };
        connect(&mut tx.channel, rx)?;

        self.tx.push(Box::new(tx));
        self.channels.insert(topic, self.tx.len() - 1);

        Ok(())
    }

    /// Returns true if a channel with given topic name was already added.
    fn has_topic(&self, topic: &str) -> bool {
        self.channels.contains_key(topic)
    }

    fn publish<'a>(
        &mut self,
        topic: &str,
        header: OrderedReplayHeader,
        data: &[u8],
    ) -> eyre::Result<()> {
        if let Some(tx_idx) = self.channels.get_mut(topic) {
            self.tx[*tx_idx].write_message(header, data)?;
        }
        Ok(())
    }
}

trait CaskReplayTx: Send {
    fn topic(&self) -> &str;

    fn len(&self) -> usize;

    fn flush(&mut self) -> FlushResult;

    fn is_connected(&self) -> bool;

    fn write_message(&mut self, header: OrderedReplayHeader, data: &[u8]) -> eyre::Result<()>;
}

struct CaskReplayTxImpl<V> {
    topic: String,
    channel: DoubleBufferTx<Message<V>>,
}

impl<T> CaskReplayTx for CaskReplayTxImpl<T>
where
    T: Send + Sync + Clone + for<'a> Deserialize<'a>,
{
    fn topic(&self) -> &str {
        &self.topic
    }

    fn len(&self) -> usize {
        self.channel.len()
    }

    fn flush(&mut self) -> FlushResult {
        self.channel.flush()
    }

    fn is_connected(&self) -> bool {
        self.channel.is_connected()
    }

    fn write_message(&mut self, header: OrderedReplayHeader, data: &[u8]) -> eyre::Result<()> {
        self.channel.push(Message {
            seq: header.seq,
            stamp: header.stamp,
            value: bincode::deserialize(data)?,
        })?;

        Ok(())
    }
}

impl nodo::channels::TxBundle for CaskReplayTxBundle {
    fn channel_count(&self) -> usize {
        self.tx.len()
    }

    fn name(&self, index: usize) -> &str {
        if index < self.tx.len() {
            self.tx[index].topic()
        } else {
            panic!(
                "invalid index '{index}': number of channels is {}",
                self.tx.len()
            )
        }
    }

    fn outbox_message_count(&self, index: usize) -> usize {
        self.tx[index].len()
    }

    fn flush_all(&mut self, results: &mut [FlushResult]) {
        for (i, channel) in self.tx.iter_mut().enumerate() {
            results[i] = channel.flush()
        }
    }

    fn check_connection(&self) -> nodo::channels::ConnectionCheck {
        let mut cc = nodo::channels::ConnectionCheck::new(self.tx.len());
        for (i, channel) in self.tx.iter().enumerate() {
            cc.mark(i, channel.is_connected());
        }
        cc
    }
}

/// Reads messages from a cask in the order in which they were recorded
struct OrderedReplay {
    #[allow(dead_code)]
    mmap: Mmap,

    channels: HashMap<u16, String>,

    // The message stream works on a &[u8]. We use the pointer from the mmap which is stored
    // here alongside the stream. Unsafe code and 'static lifetime is a workaround to express this
    // in Rust.
    // TODO investigate further if unsafe code can be avoided
    stream: Option<McapRawMessageStream<'static>>,

    // Pubtime of first message
    first_pubtime: Option<Duration>,

    // We cannot peek on the mcap message stream thus we always read one more and store it here
    pending: Option<McapRawMessage<'static>>,
}

impl Drop for OrderedReplay {
    fn drop(&mut self) {
        // TODO is this the right way to implement a safe drop?
        self.pending = None;
        self.stream = None;
    }
}

impl OrderedReplay {
    pub fn new(path: &Path) -> eyre::Result<Self> {
        let file = File::open(path).context("Couldn't open MCAP file")?;
        let mmap = unsafe { Mmap::map(&file) }.context("Couldn't map MCAP file")?;

        let channels = Self::create_channel_topic_map(&mmap)?;

        let ptr: &'static [u8] = unsafe { slice::from_raw_parts(mmap.as_ptr(), mmap.len()) };
        let stream = McapRawMessageStream::new(ptr)?;

        Ok(Self {
            mmap,
            channels,
            stream: Some(stream),
            first_pubtime: None,
            pending: None,
        })
    }

    // parse the MCAP summary to map channel IDs to topic names
    fn create_channel_topic_map(data: &[u8]) -> eyre::Result<HashMap<u16, String>> {
        let summary = mcap::Summary::read(data)?.ok_or_else(|| eyre!("no summary"))?;

        let mut map = HashMap::new();
        for (id, ch) in summary.channels {
            map.insert(id, ch.topic.clone());
        }

        Ok(map)
    }

    /// Reads messages and passes them to a callback.
    ///
    /// Messages are read as long as their "age" is not greater than the target. The age of a
    /// message is the pubtime difference between the message pubtime and the pubtime of the first
    /// message.
    /// The maximum number of messages processed can be limited with `max_count`.
    pub fn read_messages<F>(
        &mut self,
        max_age: Option<f64>,
        max_count: Option<usize>,
        mut callback_f: F,
    ) -> eyre::Result<CaskReplayReadMessageStatus>
    where
        F: FnMut(&str, OrderedReplayHeader, &[u8]) -> eyre::Result<()>,
    {
        let Some(stream) = self.stream.as_mut() else {
            bail!("stream is None");
        };

        let mut cnt = 0;

        loop {
            // check for maximum count
            if let Some(max_count) = max_count {
                if cnt >= max_count {
                    return Ok(CaskReplayReadMessageStatus::MaxCountReached);
                }
            }

            // If there is no pending message get one
            if self.pending.is_none() {
                match stream.next() {
                    None => return Ok(CaskReplayReadMessageStatus::EndOfRecord),
                    Some(msg) => self.pending = Some(msg?),
                }
            }
            // SAFETY: code above guarantees that we now have a pending message
            let msg = self.pending.as_ref().unwrap();
            let pubtime = Duration::from_nanos(msg.header.publish_time);

            // Compute message age
            if self.first_pubtime.is_none() {
                self.first_pubtime = Some(pubtime);
            } else {
                // TODO This is a workaround in case message pubtimes are not sorted. This can
                //      happen..
                self.first_pubtime = Some(self.first_pubtime.unwrap().min(pubtime));
            }
            let age = (pubtime - self.first_pubtime.unwrap()).as_secs_f64();

            // Check if it is time to publish
            if let Some(max_age) = max_age {
                if age > max_age {
                    return Ok(CaskReplayReadMessageStatus::NotTimeYet);
                }
            }

            // Publish
            callback_f(
                &self.channels[&msg.header.channel_id],
                OrderedReplayHeader {
                    seq: msg.header.sequence as u64,
                    stamp: Stamp {
                        acqtime: Duration::from_nanos(msg.header.log_time).into(),
                        pubtime: Duration::from_nanos(msg.header.publish_time).into(),
                    },
                },
                &msg.data,
            )?;
            self.pending = None;
            cnt += 1;
        }
    }
}

struct OrderedReplayHeader {
    seq: u64,
    stamp: Stamp,
}

#[derive(Debug, PartialEq)]
enum CaskReplayReadMessageStatus {
    NotTimeYet,
    MaxCountReached,
    EndOfRecord,
}