xlsbye-biff12 0.1.0

BIFF12 binary record parser for XLSB files
Documentation
use xlsbye_core::error::{Result, XlsByeError};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RecordHeader {
    pub record_type: u16,
    pub length: u32,
}

impl RecordHeader {
    pub fn decode(data: &[u8]) -> Result<(Self, usize)> {
        let (record_type, type_size) = decode_varint(data, 2, "record type")?;
        let record_type = u16::try_from(record_type).map_err(|_| {
            XlsByeError::Biff12(format!("record type out of range: {record_type}"))
        })?;

        let (length, len_size) = decode_varint(&data[type_size..], 4, "record length")?;
        Ok((
            Self {
                record_type,
                length,
            },
            type_size + len_size,
        ))
    }
}

#[derive(Debug, Clone)]
pub struct RecordIter<'a> {
    data: &'a [u8],
    pos: usize,
    failed: bool,
}

impl<'a> RecordIter<'a> {
    pub fn new(data: &'a [u8]) -> Self {
        Self {
            data,
            pos: 0,
            failed: false,
        }
    }

    pub fn position(&self) -> usize {
        self.pos
    }

    pub fn remaining(&self) -> usize {
        self.data.len().saturating_sub(self.pos)
    }

    pub fn next_record(&mut self) -> Result<Option<(u16, &'a [u8])>> {
        if self.pos >= self.data.len() {
            return Ok(None);
        }

        let (header, header_size) = RecordHeader::decode(&self.data[self.pos..])?;
        let payload_start = self.pos + header_size;
        let payload_len = usize::try_from(header.length).map_err(|_| {
            XlsByeError::Biff12(format!("record length out of range: {}", header.length))
        })?;

        let payload_end = payload_start.checked_add(payload_len).ok_or_else(|| {
            XlsByeError::Biff12("record length overflow when advancing cursor".to_string())
        })?;

        if payload_end > self.data.len() {
            return Err(XlsByeError::Biff12(format!(
                "record length {} exceeds remaining {} bytes",
                header.length,
                self.data.len().saturating_sub(payload_start)
            )));
        }

        let payload = &self.data[payload_start..payload_end];
        self.pos = payload_end;
        Ok(Some((header.record_type, payload)))
    }
}

impl<'a> Iterator for RecordIter<'a> {
    type Item = Result<(u16, &'a [u8])>;

    fn next(&mut self) -> Option<Self::Item> {
        if self.failed {
            return None;
        }

        match self.next_record() {
            Ok(Some(record)) => Some(Ok(record)),
            Ok(None) => None,
            Err(err) => {
                self.failed = true;
                Some(Err(err))
            }
        }
    }
}

pub(crate) fn decode_varint(data: &[u8], max_bytes: usize, field: &str) -> Result<(u32, usize)> {
    if max_bytes == 0 {
        return Err(XlsByeError::Biff12(format!(
            "invalid varint config for {field}: max_bytes is zero"
        )));
    }

    let mut value = 0u32;

    for index in 0..max_bytes {
        let Some(&byte) = data.get(index) else {
            return Err(XlsByeError::Biff12(format!(
                "truncated {field} varint after {index} byte(s)"
            )));
        };

        let shift = u32::try_from(index * 7)
            .map_err(|_| XlsByeError::Biff12(format!("invalid varint shift for {field}")))?;
        value |= u32::from(byte & 0x7F) << shift;

        if byte & 0x80 == 0 {
            return Ok((value, index + 1));
        }
    }

    Err(XlsByeError::Biff12(format!(
        "{field} varint exceeds maximum of {max_bytes} byte(s)"
    )))
}

#[cfg(test)]
mod tests {
    use super::*;

    fn encode_varint(mut value: u32) -> Vec<u8> {
        let mut out = Vec::new();
        loop {
            let mut byte = (value & 0x7F) as u8;
            value >>= 7;
            if value != 0 {
                byte |= 0x80;
            }
            out.push(byte);
            if value == 0 {
                break;
            }
        }
        out
    }

    #[test]
    fn decodes_varint_spec_vectors() {
        assert_eq!(decode_varint(&[0x05], 2, "record type").unwrap(), (0x0005, 1));
        assert_eq!(
            decode_varint(&[0xD6, 0x02], 2, "record type").unwrap(),
            (0x0156, 2)
        );
        assert_eq!(
            decode_varint(&[0xFF, 0x7F], 2, "record type").unwrap(),
            (0x3FFF, 2)
        );
        assert_eq!(
            decode_varint(&[0xFF, 0xFF, 0xFF, 0x7F], 4, "record length").unwrap(),
            (0x0FFF_FFFF, 4)
        );
    }

    #[test]
    fn rejects_malformed_varints() {
        assert!(decode_varint(&[0x80], 2, "record type").is_err());
        assert!(decode_varint(&[0x80, 0x80], 2, "record type").is_err());
        assert!(decode_varint(&[0x80, 0x80, 0x80, 0x80], 4, "record length").is_err());
    }

    #[test]
    fn decodes_header_with_multibyte_type_and_length() {
        let mut bytes = Vec::new();
        bytes.extend_from_slice(&encode_varint(0x0156));
        bytes.extend_from_slice(&encode_varint(16));

        let (header, consumed) = RecordHeader::decode(&bytes).unwrap();
        assert_eq!(header.record_type, 0x0156);
        assert_eq!(header.length, 16);
        assert_eq!(consumed, 3);
    }

    #[test]
    fn iterates_records_from_stream() {
        let mut data = Vec::new();
        data.extend_from_slice(&encode_varint(0x0005));
        data.extend_from_slice(&encode_varint(2));
        data.extend_from_slice(&[0xAA, 0xBB]);
        data.extend_from_slice(&encode_varint(0x0156));
        data.extend_from_slice(&encode_varint(3));
        data.extend_from_slice(&[1, 2, 3]);

        let mut iter = RecordIter::new(&data);
        let first = iter.next_record().unwrap().unwrap();
        assert_eq!(first.0, 0x0005);
        assert_eq!(first.1, &[0xAA, 0xBB]);

        let second = iter.next_record().unwrap().unwrap();
        assert_eq!(second.0, 0x0156);
        assert_eq!(second.1, &[1, 2, 3]);

        assert!(iter.next_record().unwrap().is_none());
        assert_eq!(iter.remaining(), 0);
    }

    #[test]
    fn iterator_detects_payload_overrun() {
        let mut data = Vec::new();
        data.extend_from_slice(&encode_varint(0x0005));
        data.extend_from_slice(&encode_varint(4));
        data.extend_from_slice(&[0x11, 0x22]);

        let mut iter = RecordIter::new(&data);
        let err = iter.next_record().unwrap_err();
        let msg = format!("{err}");
        assert!(msg.contains("exceeds remaining"));
    }

    #[test]
    fn iterator_impl_stops_after_error() {
        let mut iter = RecordIter::new(&[0x80]);
        assert!(iter.next().unwrap().is_err());
        assert!(iter.next().is_none());
    }
}