use xlsbye_core::error::{Result, XlsByeError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RecordHeader {
pub record_type: u16,
pub length: u32,
}
impl RecordHeader {
pub fn decode(data: &[u8]) -> Result<(Self, usize)> {
let (record_type, type_size) = decode_varint(data, 2, "record type")?;
let record_type = u16::try_from(record_type).map_err(|_| {
XlsByeError::Biff12(format!("record type out of range: {record_type}"))
})?;
let (length, len_size) = decode_varint(&data[type_size..], 4, "record length")?;
Ok((
Self {
record_type,
length,
},
type_size + len_size,
))
}
}
#[derive(Debug, Clone)]
pub struct RecordIter<'a> {
data: &'a [u8],
pos: usize,
failed: bool,
}
impl<'a> RecordIter<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
pos: 0,
failed: false,
}
}
pub fn position(&self) -> usize {
self.pos
}
pub fn remaining(&self) -> usize {
self.data.len().saturating_sub(self.pos)
}
pub fn next_record(&mut self) -> Result<Option<(u16, &'a [u8])>> {
if self.pos >= self.data.len() {
return Ok(None);
}
let (header, header_size) = RecordHeader::decode(&self.data[self.pos..])?;
let payload_start = self.pos + header_size;
let payload_len = usize::try_from(header.length).map_err(|_| {
XlsByeError::Biff12(format!("record length out of range: {}", header.length))
})?;
let payload_end = payload_start.checked_add(payload_len).ok_or_else(|| {
XlsByeError::Biff12("record length overflow when advancing cursor".to_string())
})?;
if payload_end > self.data.len() {
return Err(XlsByeError::Biff12(format!(
"record length {} exceeds remaining {} bytes",
header.length,
self.data.len().saturating_sub(payload_start)
)));
}
let payload = &self.data[payload_start..payload_end];
self.pos = payload_end;
Ok(Some((header.record_type, payload)))
}
}
impl<'a> Iterator for RecordIter<'a> {
type Item = Result<(u16, &'a [u8])>;
fn next(&mut self) -> Option<Self::Item> {
if self.failed {
return None;
}
match self.next_record() {
Ok(Some(record)) => Some(Ok(record)),
Ok(None) => None,
Err(err) => {
self.failed = true;
Some(Err(err))
}
}
}
}
pub(crate) fn decode_varint(data: &[u8], max_bytes: usize, field: &str) -> Result<(u32, usize)> {
if max_bytes == 0 {
return Err(XlsByeError::Biff12(format!(
"invalid varint config for {field}: max_bytes is zero"
)));
}
let mut value = 0u32;
for index in 0..max_bytes {
let Some(&byte) = data.get(index) else {
return Err(XlsByeError::Biff12(format!(
"truncated {field} varint after {index} byte(s)"
)));
};
let shift = u32::try_from(index * 7)
.map_err(|_| XlsByeError::Biff12(format!("invalid varint shift for {field}")))?;
value |= u32::from(byte & 0x7F) << shift;
if byte & 0x80 == 0 {
return Ok((value, index + 1));
}
}
Err(XlsByeError::Biff12(format!(
"{field} varint exceeds maximum of {max_bytes} byte(s)"
)))
}
#[cfg(test)]
mod tests {
use super::*;
fn encode_varint(mut value: u32) -> Vec<u8> {
let mut out = Vec::new();
loop {
let mut byte = (value & 0x7F) as u8;
value >>= 7;
if value != 0 {
byte |= 0x80;
}
out.push(byte);
if value == 0 {
break;
}
}
out
}
#[test]
fn decodes_varint_spec_vectors() {
assert_eq!(decode_varint(&[0x05], 2, "record type").unwrap(), (0x0005, 1));
assert_eq!(
decode_varint(&[0xD6, 0x02], 2, "record type").unwrap(),
(0x0156, 2)
);
assert_eq!(
decode_varint(&[0xFF, 0x7F], 2, "record type").unwrap(),
(0x3FFF, 2)
);
assert_eq!(
decode_varint(&[0xFF, 0xFF, 0xFF, 0x7F], 4, "record length").unwrap(),
(0x0FFF_FFFF, 4)
);
}
#[test]
fn rejects_malformed_varints() {
assert!(decode_varint(&[0x80], 2, "record type").is_err());
assert!(decode_varint(&[0x80, 0x80], 2, "record type").is_err());
assert!(decode_varint(&[0x80, 0x80, 0x80, 0x80], 4, "record length").is_err());
}
#[test]
fn decodes_header_with_multibyte_type_and_length() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&encode_varint(0x0156));
bytes.extend_from_slice(&encode_varint(16));
let (header, consumed) = RecordHeader::decode(&bytes).unwrap();
assert_eq!(header.record_type, 0x0156);
assert_eq!(header.length, 16);
assert_eq!(consumed, 3);
}
#[test]
fn iterates_records_from_stream() {
let mut data = Vec::new();
data.extend_from_slice(&encode_varint(0x0005));
data.extend_from_slice(&encode_varint(2));
data.extend_from_slice(&[0xAA, 0xBB]);
data.extend_from_slice(&encode_varint(0x0156));
data.extend_from_slice(&encode_varint(3));
data.extend_from_slice(&[1, 2, 3]);
let mut iter = RecordIter::new(&data);
let first = iter.next_record().unwrap().unwrap();
assert_eq!(first.0, 0x0005);
assert_eq!(first.1, &[0xAA, 0xBB]);
let second = iter.next_record().unwrap().unwrap();
assert_eq!(second.0, 0x0156);
assert_eq!(second.1, &[1, 2, 3]);
assert!(iter.next_record().unwrap().is_none());
assert_eq!(iter.remaining(), 0);
}
#[test]
fn iterator_detects_payload_overrun() {
let mut data = Vec::new();
data.extend_from_slice(&encode_varint(0x0005));
data.extend_from_slice(&encode_varint(4));
data.extend_from_slice(&[0x11, 0x22]);
let mut iter = RecordIter::new(&data);
let err = iter.next_record().unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("exceeds remaining"));
}
#[test]
fn iterator_impl_stops_after_error() {
let mut iter = RecordIter::new(&[0x80]);
assert!(iter.next().unwrap().is_err());
assert!(iter.next().is_none());
}
}