rustbag 0.1.1

A high-performance ROS 2 bag player
// Copyright 2025 Ivo Ivanov.
// Copyright 2018 Open Source Robotics Foundation, Inc.
// Copyright 2018, Bosch Software Innovations GmbH.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use mcap::{
    read::{RawMessage, RawMessageStream},
    McapResult,
};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
use std::time::Duration;
pub type BagIndex = usize;

pub struct SortableMessage<'a>(pub RawMessage<'a>, pub BagIndex);

impl PartialEq for SortableMessage<'_> {
    fn eq(&self, other: &Self) -> bool {
        self.0.header.publish_time == other.0.header.publish_time
    }
}
impl Eq for SortableMessage<'_> {}
impl Ord for SortableMessage<'_> {
    fn cmp(&self, other: &Self) -> Ordering {
        // Reverse the comparison for min-heap behavior
        other.0.header.publish_time.cmp(&self.0.header.publish_time)
    }
}
impl PartialOrd for SortableMessage<'_> {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.cmp(other))
    }
}

pub struct MessageQueueOptions {
    pub start_offset: Duration, // When to start reading, all messages before this stamp will be ignored
    pub bag_start_time: Duration, // the earliest start time over all bags
}

/// The message queue reads the messages from potentially multiple rosbags and zips them to provide a singe sequence of messages, sorted by publish time order.
/// This struct is generic on the stream, because we need to mock the stream for the unit-tests below.
pub struct MessageQueue<'a, S>
where
    S: Iterator<Item = McapResult<RawMessage<'a>>>,
{
    pub options: MessageQueueOptions,
    pub queue: BinaryHeap<SortableMessage<'a>>,
    pub message_streams: Vec<S>,
    pub last_msg_times: Vec<Option<Duration>>,
}

impl<'a> MessageQueue<'a, RawMessageStream<'a>> {
    pub fn new(files: &'a [memmap2::Mmap], options: MessageQueueOptions) -> anyhow::Result<Self> {
        let message_streams = files
            .iter()
            .map(|file_buf| RawMessageStream::new(file_buf))
            .collect::<Result<Vec<_>, _>>()?;
        Ok(Self {
            options,
            queue: BinaryHeap::new(),
            message_streams,
            last_msg_times: vec![None; files.len()],
        })
    }
}

impl<'a, S> MessageQueue<'a, S>
where
    S: Iterator<Item = McapResult<RawMessage<'a>>>,
{
    #[cfg(test)]
    pub fn new_from_streams(streams: Vec<S>, options: MessageQueueOptions) -> Self {
        let len = streams.len();
        Self {
            options,
            queue: BinaryHeap::new(),
            message_streams: streams,
            last_msg_times: vec![None; len],
        }
    }

    /// Refills the message queue with constant time lookahead from all the rosbags.
    pub fn refill(&mut self, lookahead: &Duration, channels_to_read: &[HashSet<u16>]) {
        for (i_bag, stream) in self.message_streams.iter_mut().enumerate() {
            // if we are ignoring i_bag, we can't possibly refill it enough
            if channels_to_read[i_bag].is_empty() {
                continue;
            }
            loop {
                // check if we refilled enough from this bag
                if let Some(first_elem) = self.queue.peek() {
                    let first_stamp_in_heap =
                        Duration::from_nanos(first_elem.0.header.publish_time);
                    if let Some(last_msg_time) = self.last_msg_times[i_bag] {
                        if last_msg_time > first_stamp_in_heap
                            && (last_msg_time - first_stamp_in_heap) >= *lookahead
                        {
                            break;
                        }
                    }
                }
                let (maybe_next_msg, maybe_new_latest_msg_stamp) = Self::read_next_message(
                    &self.options,
                    stream,
                    &channels_to_read[i_bag],
                    lookahead,
                    &self.last_msg_times[i_bag],
                );
                if let Some(new_latest_msg_stamp) = maybe_new_latest_msg_stamp {
                    self.last_msg_times[i_bag] = Some(new_latest_msg_stamp);
                }
                if let Some(next_msg) = maybe_next_msg {
                    self.queue.push(SortableMessage(next_msg, i_bag));
                } else {
                    break;
                }
            }
        }
    }

    /// Refills the queue and gets the next message as well as it's bag index
    pub fn get_next(
        &mut self,
        lookahead: &Duration,
        channels_to_read: &[HashSet<u16>],
    ) -> Option<(usize, RawMessage<'a>)> {
        self.refill(lookahead, channels_to_read);
        self.queue.pop().map(|wm| (wm.1, wm.0))
    }

    /// Returns the next message from the given stream that matches the channels_to_read, returns none if there are no more messages in the stream.
    /// Maybe updates the latest message timestamp.
    pub fn read_next_message(
        options: &MessageQueueOptions,
        msg_stream: &mut S,
        channels_to_read: &HashSet<u16>,
        lookahead: &Duration,
        latest_msg_stamp: &Option<Duration>,
    ) -> (Option<RawMessage<'a>>, Option<Duration>) {
        // Search next message loop
        for maybe_next_message in msg_stream.by_ref() {
            match maybe_next_message {
                Err(e) => {
                    log::error!("Error reading message: {}", e);
                }
                Ok(msg) => {
                    if !channels_to_read.contains(&msg.header.channel_id) {
                        continue;
                    }
                    let pub_time = Duration::from_nanos(msg.header.publish_time);
                    if pub_time
                        .checked_sub(options.bag_start_time)
                        .is_none_or(|d| d < options.start_offset)
                    {
                        continue;
                    }
                    if let Some(latest_stamp) = latest_msg_stamp {
                        if pub_time > *latest_stamp {
                            return (Some(msg), Some(pub_time));
                        }
                        let diff_to_last = *latest_stamp - pub_time;
                        if diff_to_last > *lookahead {
                            log::warn!("Found a message that is reordered by more than the lookahead time ({:?}), messages will be published out of order. Please increase the lookahead time.", diff_to_last
                            );
                        }
                    } else {
                        return (Some(msg), Some(pub_time));
                    }
                    return (Some(msg), None);
                }
            }
        }
        (None, None)
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use mcap::records::MessageHeader;
    use std::borrow::Cow;

    fn make_msg(channel_id: u16, publish_time_ns: u64) -> RawMessage<'static> {
        RawMessage {
            header: MessageHeader {
                channel_id,
                sequence: 0,
                log_time: 0,
                publish_time: publish_time_ns,
            },
            data: Cow::Borrowed(&[]),
        }
    }

    #[test]
    fn reordering_within_lookahead_is_sorted() {
        let s = vec![
            Ok(make_msg(1, 1_000_000_000)),
            Ok(make_msg(1, 1_300_000_000)),
            Ok(make_msg(1, 1_200_000_000)),
        ]
        .into_iter();
        let options = MessageQueueOptions {
            start_offset: Duration::from_secs_f64(0.0),
            bag_start_time: Duration::from_nanos(0),
        };
        let mut mq = MessageQueue::new_from_streams(vec![s], options);
        let channels: Vec<HashSet<u16>> = vec![vec![1u16].into_iter().collect()];
        let lookahead = Duration::from_millis(500);
        let mut times = Vec::new();
        while let Some((_bag, m)) = mq.get_next(&lookahead, &channels) {
            times.push(m.header.publish_time);
        }
        assert_eq!(times, vec![1_000_000_000, 1_200_000_000, 1_300_000_000]);
    }

    #[test]
    fn zips_multiple_streams_sorted() {
        let a = vec![
            Ok(make_msg(10, 1_000_000_000)),
            Ok(make_msg(10, 1_400_000_000)),
        ]
        .into_iter();
        let b = vec![
            Ok(make_msg(20, 1_200_000_000)),
            Ok(make_msg(20, 1_300_000_000)),
        ]
        .into_iter();
        let options = MessageQueueOptions {
            start_offset: Duration::from_secs_f64(0.0),
            bag_start_time: Duration::from_nanos(0),
        };
        let mut mq = MessageQueue::new_from_streams(vec![a, b], options);
        let channels: Vec<HashSet<u16>> = vec![
            vec![10u16].into_iter().collect(),
            vec![20u16].into_iter().collect(),
        ];
        let lookahead = Duration::from_millis(500);
        let mut times = Vec::new();
        while let Some((_bag, m)) = mq.get_next(&lookahead, &channels) {
            times.push(m.header.publish_time);
        }
        assert_eq!(
            times,
            vec![1_000_000_000, 1_200_000_000, 1_300_000_000, 1_400_000_000,]
        );
    }

    struct CountingIter {
        times: Vec<u64>,
        idx: usize,
        channel_id: u16,
        consumed: std::rc::Rc<std::cell::Cell<usize>>,
    }
    impl Iterator for CountingIter {
        type Item = mcap::McapResult<RawMessage<'static>>;
        fn next(&mut self) -> Option<Self::Item> {
            if self.idx >= self.times.len() {
                return None;
            }
            let it = self.idx;
            self.idx += 1;
            self.consumed.set(self.consumed.get() + 1);
            Some(Ok(make_msg(self.channel_id, self.times[it])))
        }
    }

    #[test]
    fn refill_only_up_to_lookahead_parametrized() {
        // messages: 1.0s, 1.1s, ..., 1.6s (7 total)
        let timestamps: Vec<u64> = (0..7).map(|i| 1_000_000_000 + i * 100_000_000).collect();
        let total = timestamps.len();

        for look_ms in [200u64, 400u64, 500u64] {
            let consumed = std::rc::Rc::new(std::cell::Cell::new(0usize));
            let iter = CountingIter {
                times: timestamps.clone(),
                idx: 0,
                channel_id: 42,
                consumed: consumed.clone(),
            };

            let options = MessageQueueOptions {
                start_offset: Duration::from_secs_f64(0.0),
                bag_start_time: Duration::from_nanos(0),
            };
            let mut mq = MessageQueue::new_from_streams(vec![iter], options);
            let channels: Vec<HashSet<u16>> = vec![vec![42u16].into_iter().collect()];
            let lookahead = Duration::from_millis(look_ms);

            // Trigger one get_next which internally refills once.
            let _ = mq.get_next(&lookahead, &channels);

            let consumed_n = consumed.get();
            // Ensure we didn't drain the entire stream
            assert!(
                consumed_n < total,
                "should not drain entire stream for lookahead {}ms (consumed {})",
                look_ms,
                consumed_n
            );
            // Ensure we consumed enough to cover lookahead
            let expected_min = (look_ms / 100) as usize + 1; // number of 100ms steps + first elem
            assert!(
                consumed_n >= expected_min,
                "expected at least {} consumed for {}ms, got {}",
                expected_min,
                look_ms,
                consumed_n
            );
        }
    }
}