use mcap::{
read::{RawMessage, RawMessageStream},
McapResult,
};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashSet};
use std::time::Duration;
pub type BagIndex = usize;
pub struct SortableMessage<'a>(pub RawMessage<'a>, pub BagIndex);
impl PartialEq for SortableMessage<'_> {
fn eq(&self, other: &Self) -> bool {
self.0.header.publish_time == other.0.header.publish_time
}
}
impl Eq for SortableMessage<'_> {}
impl Ord for SortableMessage<'_> {
fn cmp(&self, other: &Self) -> Ordering {
other.0.header.publish_time.cmp(&self.0.header.publish_time)
}
}
impl PartialOrd for SortableMessage<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
pub struct MessageQueueOptions {
pub start_offset: Duration, pub bag_start_time: Duration, }
pub struct MessageQueue<'a, S>
where
S: Iterator<Item = McapResult<RawMessage<'a>>>,
{
pub options: MessageQueueOptions,
pub queue: BinaryHeap<SortableMessage<'a>>,
pub message_streams: Vec<S>,
pub last_msg_times: Vec<Option<Duration>>,
}
impl<'a> MessageQueue<'a, RawMessageStream<'a>> {
pub fn new(files: &'a [memmap2::Mmap], options: MessageQueueOptions) -> anyhow::Result<Self> {
let message_streams = files
.iter()
.map(|file_buf| RawMessageStream::new(file_buf))
.collect::<Result<Vec<_>, _>>()?;
Ok(Self {
options,
queue: BinaryHeap::new(),
message_streams,
last_msg_times: vec![None; files.len()],
})
}
}
impl<'a, S> MessageQueue<'a, S>
where
S: Iterator<Item = McapResult<RawMessage<'a>>>,
{
#[cfg(test)]
pub fn new_from_streams(streams: Vec<S>, options: MessageQueueOptions) -> Self {
let len = streams.len();
Self {
options,
queue: BinaryHeap::new(),
message_streams: streams,
last_msg_times: vec![None; len],
}
}
pub fn refill(&mut self, lookahead: &Duration, channels_to_read: &[HashSet<u16>]) {
for (i_bag, stream) in self.message_streams.iter_mut().enumerate() {
if channels_to_read[i_bag].is_empty() {
continue;
}
loop {
if let Some(first_elem) = self.queue.peek() {
let first_stamp_in_heap =
Duration::from_nanos(first_elem.0.header.publish_time);
if let Some(last_msg_time) = self.last_msg_times[i_bag] {
if last_msg_time > first_stamp_in_heap
&& (last_msg_time - first_stamp_in_heap) >= *lookahead
{
break;
}
}
}
let (maybe_next_msg, maybe_new_latest_msg_stamp) = Self::read_next_message(
&self.options,
stream,
&channels_to_read[i_bag],
lookahead,
&self.last_msg_times[i_bag],
);
if let Some(new_latest_msg_stamp) = maybe_new_latest_msg_stamp {
self.last_msg_times[i_bag] = Some(new_latest_msg_stamp);
}
if let Some(next_msg) = maybe_next_msg {
self.queue.push(SortableMessage(next_msg, i_bag));
} else {
break;
}
}
}
}
pub fn get_next(
&mut self,
lookahead: &Duration,
channels_to_read: &[HashSet<u16>],
) -> Option<(usize, RawMessage<'a>)> {
self.refill(lookahead, channels_to_read);
self.queue.pop().map(|wm| (wm.1, wm.0))
}
pub fn read_next_message(
options: &MessageQueueOptions,
msg_stream: &mut S,
channels_to_read: &HashSet<u16>,
lookahead: &Duration,
latest_msg_stamp: &Option<Duration>,
) -> (Option<RawMessage<'a>>, Option<Duration>) {
for maybe_next_message in msg_stream.by_ref() {
match maybe_next_message {
Err(e) => {
log::error!("Error reading message: {}", e);
}
Ok(msg) => {
if !channels_to_read.contains(&msg.header.channel_id) {
continue;
}
let pub_time = Duration::from_nanos(msg.header.publish_time);
if pub_time
.checked_sub(options.bag_start_time)
.is_none_or(|d| d < options.start_offset)
{
continue;
}
if let Some(latest_stamp) = latest_msg_stamp {
if pub_time > *latest_stamp {
return (Some(msg), Some(pub_time));
}
let diff_to_last = *latest_stamp - pub_time;
if diff_to_last > *lookahead {
log::warn!("Found a message that is reordered by more than the lookahead time ({:?}), messages will be published out of order. Please increase the lookahead time.", diff_to_last
);
}
} else {
return (Some(msg), Some(pub_time));
}
return (Some(msg), None);
}
}
}
(None, None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use mcap::records::MessageHeader;
use std::borrow::Cow;
fn make_msg(channel_id: u16, publish_time_ns: u64) -> RawMessage<'static> {
RawMessage {
header: MessageHeader {
channel_id,
sequence: 0,
log_time: 0,
publish_time: publish_time_ns,
},
data: Cow::Borrowed(&[]),
}
}
#[test]
fn reordering_within_lookahead_is_sorted() {
let s = vec![
Ok(make_msg(1, 1_000_000_000)),
Ok(make_msg(1, 1_300_000_000)),
Ok(make_msg(1, 1_200_000_000)),
]
.into_iter();
let options = MessageQueueOptions {
start_offset: Duration::from_secs_f64(0.0),
bag_start_time: Duration::from_nanos(0),
};
let mut mq = MessageQueue::new_from_streams(vec![s], options);
let channels: Vec<HashSet<u16>> = vec![vec![1u16].into_iter().collect()];
let lookahead = Duration::from_millis(500);
let mut times = Vec::new();
while let Some((_bag, m)) = mq.get_next(&lookahead, &channels) {
times.push(m.header.publish_time);
}
assert_eq!(times, vec![1_000_000_000, 1_200_000_000, 1_300_000_000]);
}
#[test]
fn zips_multiple_streams_sorted() {
let a = vec![
Ok(make_msg(10, 1_000_000_000)),
Ok(make_msg(10, 1_400_000_000)),
]
.into_iter();
let b = vec![
Ok(make_msg(20, 1_200_000_000)),
Ok(make_msg(20, 1_300_000_000)),
]
.into_iter();
let options = MessageQueueOptions {
start_offset: Duration::from_secs_f64(0.0),
bag_start_time: Duration::from_nanos(0),
};
let mut mq = MessageQueue::new_from_streams(vec![a, b], options);
let channels: Vec<HashSet<u16>> = vec![
vec![10u16].into_iter().collect(),
vec![20u16].into_iter().collect(),
];
let lookahead = Duration::from_millis(500);
let mut times = Vec::new();
while let Some((_bag, m)) = mq.get_next(&lookahead, &channels) {
times.push(m.header.publish_time);
}
assert_eq!(
times,
vec![1_000_000_000, 1_200_000_000, 1_300_000_000, 1_400_000_000,]
);
}
struct CountingIter {
times: Vec<u64>,
idx: usize,
channel_id: u16,
consumed: std::rc::Rc<std::cell::Cell<usize>>,
}
impl Iterator for CountingIter {
type Item = mcap::McapResult<RawMessage<'static>>;
fn next(&mut self) -> Option<Self::Item> {
if self.idx >= self.times.len() {
return None;
}
let it = self.idx;
self.idx += 1;
self.consumed.set(self.consumed.get() + 1);
Some(Ok(make_msg(self.channel_id, self.times[it])))
}
}
#[test]
fn refill_only_up_to_lookahead_parametrized() {
let timestamps: Vec<u64> = (0..7).map(|i| 1_000_000_000 + i * 100_000_000).collect();
let total = timestamps.len();
for look_ms in [200u64, 400u64, 500u64] {
let consumed = std::rc::Rc::new(std::cell::Cell::new(0usize));
let iter = CountingIter {
times: timestamps.clone(),
idx: 0,
channel_id: 42,
consumed: consumed.clone(),
};
let options = MessageQueueOptions {
start_offset: Duration::from_secs_f64(0.0),
bag_start_time: Duration::from_nanos(0),
};
let mut mq = MessageQueue::new_from_streams(vec![iter], options);
let channels: Vec<HashSet<u16>> = vec![vec![42u16].into_iter().collect()];
let lookahead = Duration::from_millis(look_ms);
let _ = mq.get_next(&lookahead, &channels);
let consumed_n = consumed.get();
assert!(
consumed_n < total,
"should not drain entire stream for lookahead {}ms (consumed {})",
look_ms,
consumed_n
);
let expected_min = (look_ms / 100) as usize + 1; assert!(
consumed_n >= expected_min,
"expected at least {} consumed for {}ms, got {}",
expected_min,
look_ms,
consumed_n
);
}
}
}