use std::collections::HashMap;
use std::io::{BufReader, Read};
use std::sync::Arc;
use crate::entry::Entry;
use crate::error::BinlogError;
use crate::format::{parse_fmt_payload, MessageFormat};
use crate::{FMT_TYPE, HEADER_MAGIC};
const MAX_CONSECUTIVE_ERRORS: u32 = 256;
pub struct Reader<R: Read> {
reader: BufReader<R>,
formats: HashMap<u8, Arc<MessageFormat>>,
consecutive_errors: u32,
}
impl<R: Read> Reader<R> {
pub fn new(reader: R) -> Self {
let mut formats = HashMap::new();
formats.insert(
FMT_TYPE,
Arc::new(MessageFormat {
msg_type: FMT_TYPE,
msg_len: 89,
name: "FMT".into(),
format: "BBnNZ".into(),
labels: Arc::from([
"Type".into(),
"Length".into(),
"Name".into(),
"Format".into(),
"Labels".into(),
]),
}),
);
Reader {
reader: BufReader::new(reader),
formats,
consecutive_errors: 0,
}
}
fn next_inner(&mut self) -> Result<Option<Entry>, BinlogError> {
if self.consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
return Ok(None);
}
let mut header = [0u8; 3];
match self.read_exact_or_eof(&mut header) {
Ok(true) => {}
Ok(false) => return Ok(None), Err(_) => return Ok(None),
}
if header[0] != HEADER_MAGIC[0] || header[1] != HEADER_MAGIC[1] {
self.consecutive_errors += 1;
return self.recover_and_retry();
}
let msg_type = header[2];
self.parse_message(msg_type)
}
#[must_use]
pub fn formats(&self) -> &HashMap<u8, Arc<MessageFormat>> {
&self.formats
}
fn parse_message(&mut self, msg_type: u8) -> Result<Option<Entry>, BinlogError> {
let format = match self.formats.get(&msg_type) {
Some(f) => Arc::clone(f),
None => {
self.consecutive_errors += 1;
return self.recover_and_retry();
}
};
let payload = match self.read_payload(&format) {
Ok(p) => p,
Err(_) => {
self.consecutive_errors += 1;
return self.recover_and_retry();
}
};
let result = if msg_type == FMT_TYPE {
build_fmt_entry(&format, &payload)
} else {
build_data_entry(&format, msg_type, &payload)
};
match result {
Ok((entry, new_fmt)) => {
if let Some(new_fmt) = new_fmt {
self.formats.insert(new_fmt.msg_type, Arc::new(new_fmt));
}
self.consecutive_errors = 0;
Ok(Some(entry))
}
Err(_) => {
self.consecutive_errors += 1;
self.recover_and_retry()
}
}
}
fn read_payload(&mut self, format: &MessageFormat) -> Result<Vec<u8>, BinlogError> {
let payload_len = format.msg_len as usize - 3;
let mut payload = vec![0u8; payload_len];
match self.read_exact_or_eof(&mut payload) {
Ok(true) => Ok(payload),
Ok(false) => Err(BinlogError::UnexpectedEof),
Err(e) => Err(e),
}
}
fn recover_and_retry(&mut self) -> Result<Option<Entry>, BinlogError> {
if self.consecutive_errors >= MAX_CONSECUTIVE_ERRORS {
return Ok(None);
}
match self.scan_for_header()? {
Some(msg_type) => self.parse_message(msg_type),
None => Ok(None),
}
}
fn scan_for_header(&mut self) -> Result<Option<u8>, BinlogError> {
let mut prev = 0u8;
loop {
let mut byte = [0u8; 1];
match self.reader.read(&mut byte) {
Ok(0) => return Ok(None), Ok(_) => {
if prev == HEADER_MAGIC[0] && byte[0] == HEADER_MAGIC[1] {
let mut msg_type = [0u8; 1];
match self.reader.read(&mut msg_type) {
Ok(0) => return Ok(None),
Ok(_) => return Ok(Some(msg_type[0])),
Err(_) => return Ok(None),
}
}
prev = byte[0];
}
Err(_) => return Ok(None),
}
}
}
fn read_exact_or_eof(&mut self, buf: &mut [u8]) -> Result<bool, BinlogError> {
let mut total = 0;
while total < buf.len() {
match self.reader.read(&mut buf[total..]) {
Ok(0) => {
if total == 0 {
return Ok(false); }
return Err(BinlogError::UnexpectedEof);
}
Ok(n) => total += n,
Err(e) => return Err(BinlogError::Io(e)),
}
}
Ok(true)
}
}
fn build_fmt_entry(
format: &MessageFormat,
payload: &[u8],
) -> Result<(Entry, Option<MessageFormat>), BinlogError> {
let new_fmt = parse_fmt_payload(payload)?;
let values = format.decode_fields(payload)?;
let entry = Entry {
name: "FMT".into(),
msg_type: FMT_TYPE,
timestamp_usec: None,
labels: format.labels.clone(),
values,
};
Ok((entry, Some(new_fmt)))
}
fn build_data_entry(
format: &MessageFormat,
msg_type: u8,
payload: &[u8],
) -> Result<(Entry, Option<MessageFormat>), BinlogError> {
let values = format.decode_fields(payload)?;
let timestamp_usec = format.extract_timestamp(payload);
let entry = Entry {
name: format.name.clone(),
msg_type,
timestamp_usec,
labels: format.labels.clone(),
values,
};
Ok((entry, None))
}
impl<R: Read> Iterator for Reader<R> {
type Item = Result<Entry, BinlogError>;
fn next(&mut self) -> Option<Self::Item> {
self.next_inner().transpose()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::value::FieldValue;
fn build_fmt_bootstrap() -> Vec<u8> {
let mut msg = Vec::new();
msg.extend_from_slice(&HEADER_MAGIC);
msg.push(FMT_TYPE);
let mut payload = [0u8; 86];
payload[0] = FMT_TYPE; payload[1] = 89; payload[2..6].copy_from_slice(b"FMT\0"); payload[6..11].copy_from_slice(b"BBnNZ"); let labels = b"Type,Length,Name,Format,Labels";
payload[22..22 + labels.len()].copy_from_slice(labels);
msg.extend_from_slice(&payload);
msg
}
fn build_fmt_for_type(
msg_type: u8,
msg_len: u8,
name: &[u8; 4],
format: &str,
labels: &str,
) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend_from_slice(&HEADER_MAGIC);
msg.push(FMT_TYPE);
let mut payload = [0u8; 86];
payload[0] = msg_type;
payload[1] = msg_len;
payload[2..6].copy_from_slice(name);
let fmt_bytes = format.as_bytes();
payload[6..6 + fmt_bytes.len()].copy_from_slice(fmt_bytes);
let lbl_bytes = labels.as_bytes();
payload[22..22 + lbl_bytes.len()].copy_from_slice(lbl_bytes);
msg.extend_from_slice(&payload);
msg
}
fn build_data_message(msg_type: u8, payload: &[u8]) -> Vec<u8> {
let mut msg = Vec::new();
msg.extend_from_slice(&HEADER_MAGIC);
msg.push(msg_type);
msg.extend_from_slice(payload);
msg
}
#[test]
fn parse_empty_input() {
let reader = Reader::new(std::io::Cursor::new(Vec::new()));
let entries: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
assert!(entries.is_empty());
}
#[test]
fn parse_fmt_bootstrap_only() {
let data = build_fmt_bootstrap();
let reader = Reader::new(std::io::Cursor::new(data));
let entries: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].name, "FMT");
assert_eq!(entries[0].msg_type, FMT_TYPE);
assert!(entries[0].timestamp_usec.is_none());
}
#[test]
fn parse_data_message() {
let mut data = Vec::new();
data.extend(build_fmt_bootstrap());
data.extend(build_fmt_for_type(
0x81,
15,
b"ATT\0",
"Qhh",
"TimeUS,Roll,Pitch",
));
let mut payload = Vec::new();
payload.extend_from_slice(&1_000_000u64.to_le_bytes());
payload.extend_from_slice(&4500i16.to_le_bytes()); payload.extend_from_slice(&(-200i16).to_le_bytes()); data.extend(build_data_message(0x81, &payload));
let reader = Reader::new(std::io::Cursor::new(data));
let entries: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
assert_eq!(entries.len(), 3);
let att = &entries[2];
assert_eq!(att.name, "ATT");
assert_eq!(att.msg_type, 0x81);
assert_eq!(att.timestamp_usec, Some(1_000_000));
assert_eq!(att.get("Roll"), Some(&FieldValue::Int(4500)));
assert_eq!(att.get("Pitch"), Some(&FieldValue::Int(-200)));
}
#[test]
fn error_recovery_with_garbage() {
let mut data = Vec::new();
data.extend(build_fmt_bootstrap());
data.extend(build_fmt_for_type(
0x81, 11, b"TST\0", "Q", "TimeUS",
));
data.extend(build_data_message(0x81, &100u64.to_le_bytes()));
data.extend_from_slice(&[0xFF; 50]);
data.extend(build_data_message(0x81, &200u64.to_le_bytes()));
let reader = Reader::new(std::io::Cursor::new(data));
let entries: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
let tst_entries: Vec<_> = entries.iter().filter(|e| e.name == "TST").collect();
assert_eq!(tst_entries.len(), 2);
assert_eq!(tst_entries[0].timestamp_usec, Some(100));
assert_eq!(tst_entries[1].timestamp_usec, Some(200));
}
#[test]
fn truncated_final_message() {
let mut data = Vec::new();
data.extend(build_fmt_bootstrap());
data.extend(build_fmt_for_type(0x81, 11, b"TST\0", "Q", "TimeUS"));
data.extend(build_data_message(0x81, &100u64.to_le_bytes()));
data.extend_from_slice(&HEADER_MAGIC);
data.push(0x81);
data.extend_from_slice(&[0; 3]);
let reader = Reader::new(std::io::Cursor::new(data));
let entries: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
let tst_entries: Vec<_> = entries.iter().filter(|e| e.name == "TST").collect();
assert_eq!(tst_entries.len(), 1);
assert_eq!(tst_entries[0].timestamp_usec, Some(100));
}
#[test]
fn unknown_type_recovery() {
let mut data = Vec::new();
data.extend(build_fmt_bootstrap());
data.extend(build_fmt_for_type(0x81, 11, b"TST\0", "Q", "TimeUS"));
data.extend_from_slice(&HEADER_MAGIC);
data.push(0x99);
data.extend_from_slice(&[0; 20]); data.extend(build_data_message(0x81, &300u64.to_le_bytes()));
let reader = Reader::new(std::io::Cursor::new(data));
let entries: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
let tst_entries: Vec<_> = entries.iter().filter(|e| e.name == "TST").collect();
assert_eq!(tst_entries.len(), 1);
assert_eq!(tst_entries[0].timestamp_usec, Some(300));
}
#[test]
fn max_consecutive_errors_boundary() {
let mut data = Vec::new();
data.extend(build_fmt_bootstrap());
data.extend(build_fmt_for_type(0x81, 11, b"TST\0", "Q", "TimeUS"));
let error_count = MAX_CONSECUTIVE_ERRORS + 10;
for _ in 0..error_count {
data.extend_from_slice(&HEADER_MAGIC);
data.push(0x99);
}
data.extend(build_data_message(0x81, &999u64.to_le_bytes()));
let reader = Reader::new(std::io::Cursor::new(data));
let entries: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
let tst_entries: Vec<_> = entries.iter().filter(|e| e.name == "TST").collect();
assert!(
tst_entries.is_empty(),
"reader should stop before reaching the valid message after {} errors",
MAX_CONSECUTIVE_ERRORS
);
}
#[test]
fn recovery_just_below_max_errors() {
let mut data = Vec::new();
data.extend(build_fmt_bootstrap());
data.extend(build_fmt_for_type(0x81, 11, b"TST\0", "Q", "TimeUS"));
for _ in 0..(MAX_CONSECUTIVE_ERRORS - 1) {
data.extend_from_slice(&HEADER_MAGIC);
data.push(0x99);
}
data.extend(build_data_message(0x81, &777u64.to_le_bytes()));
let reader = Reader::new(std::io::Cursor::new(data));
let entries: Vec<_> = reader.collect::<Result<Vec<_>, _>>().unwrap();
let tst_entries: Vec<_> = entries.iter().filter(|e| e.name == "TST").collect();
assert_eq!(
tst_entries.len(),
1,
"recovery should still work at {} consecutive errors",
MAX_CONSECUTIVE_ERRORS - 1
);
assert_eq!(tst_entries[0].timestamp_usec, Some(777));
}
#[test]
fn formats_accessible() {
let data = build_fmt_bootstrap();
let mut reader = Reader::new(std::io::Cursor::new(data));
let _ = reader.next(); assert!(reader.formats().contains_key(&FMT_TYPE));
assert_eq!(reader.formats().get(&FMT_TYPE).unwrap().name, "FMT");
}
}