use std::str;
use super::error::{MAX_GROUP_SIZE, SbeDecodeError};
#[derive(Debug, Clone)]
pub struct SbeCursor<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> SbeCursor<'a> {
#[must_use]
pub const fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
#[must_use]
pub const fn new_at(buf: &'a [u8], pos: usize) -> Self {
Self { buf, pos }
}
#[must_use]
pub const fn pos(&self) -> usize {
self.pos
}
#[must_use]
pub const fn remaining(&self) -> usize {
self.buf.len().saturating_sub(self.pos)
}
#[must_use]
pub const fn buffer(&self) -> &'a [u8] {
self.buf
}
#[must_use]
pub fn peek(&self) -> &'a [u8] {
&self.buf[self.pos..]
}
#[inline]
pub fn require(&self, n: usize) -> Result<(), SbeDecodeError> {
if self.remaining() < n {
return Err(SbeDecodeError::BufferTooShort {
expected: self.pos + n,
actual: self.buf.len(),
});
}
Ok(())
}
#[inline]
pub fn advance(&mut self, n: usize) -> Result<(), SbeDecodeError> {
self.require(n)?;
self.pos += n;
Ok(())
}
#[inline]
pub fn skip(&mut self, n: usize) {
self.pos += n;
}
pub fn reset(&mut self) {
self.pos = 0;
}
pub fn set_pos(&mut self, pos: usize) {
self.pos = pos;
}
#[inline]
pub fn read_u8(&mut self) -> Result<u8, SbeDecodeError> {
self.require(1)?;
let value = self.buf[self.pos];
self.pos += 1;
Ok(value)
}
#[inline]
pub fn read_i8(&mut self) -> Result<i8, SbeDecodeError> {
self.require(1)?;
let value = self.buf[self.pos] as i8;
self.pos += 1;
Ok(value)
}
#[inline]
pub fn read_u16_le(&mut self) -> Result<u16, SbeDecodeError> {
Ok(u16::from_le_bytes(self.read_array::<2>()?))
}
#[inline]
pub fn read_i16_le(&mut self) -> Result<i16, SbeDecodeError> {
Ok(i16::from_le_bytes(self.read_array::<2>()?))
}
#[inline]
pub fn read_u32_le(&mut self) -> Result<u32, SbeDecodeError> {
Ok(u32::from_le_bytes(self.read_array::<4>()?))
}
#[inline]
pub fn read_i32_le(&mut self) -> Result<i32, SbeDecodeError> {
Ok(i32::from_le_bytes(self.read_array::<4>()?))
}
#[inline]
pub fn read_u64_le(&mut self) -> Result<u64, SbeDecodeError> {
Ok(u64::from_le_bytes(self.read_array::<8>()?))
}
#[inline]
pub fn read_i64_le(&mut self) -> Result<i64, SbeDecodeError> {
Ok(i64::from_le_bytes(self.read_array::<8>()?))
}
#[inline]
pub fn read_u128_le(&mut self) -> Result<u128, SbeDecodeError> {
Ok(u128::from_le_bytes(self.read_array::<16>()?))
}
#[inline]
pub fn read_i128_le(&mut self) -> Result<i128, SbeDecodeError> {
Ok(i128::from_le_bytes(self.read_array::<16>()?))
}
#[inline]
pub fn read_optional_i64_le(&mut self) -> Result<Option<i64>, SbeDecodeError> {
let value = self.read_i64_le()?;
Ok(if value == i64::MIN { None } else { Some(value) })
}
#[inline]
pub fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], SbeDecodeError> {
self.require(n)?;
let slice = &self.buf[self.pos..self.pos + n];
self.pos += n;
Ok(slice)
}
#[inline]
fn read_array<const N: usize>(&mut self) -> Result<[u8; N], SbeDecodeError> {
self.require(N)?;
let bytes: [u8; N] = self.buf[self.pos..self.pos + N]
.try_into()
.expect("slice length matches N");
self.pos += N;
Ok(bytes)
}
#[inline]
pub fn read_var_string8(&mut self) -> Result<String, SbeDecodeError> {
let len = self.read_u8()? as usize;
if len == 0 {
return Ok(String::new());
}
self.require(len)?;
let s = str::from_utf8(&self.buf[self.pos..self.pos + len])
.map_err(|_| SbeDecodeError::InvalidUtf8)?
.to_string();
self.pos += len;
Ok(s)
}
#[inline]
pub fn read_var_string8_ref(&mut self) -> Result<&'a str, SbeDecodeError> {
let len = self.read_u8()? as usize;
if len == 0 {
return Ok("");
}
self.require(len)?;
let s = str::from_utf8(&self.buf[self.pos..self.pos + len])
.map_err(|_| SbeDecodeError::InvalidUtf8)?;
self.pos += len;
Ok(s)
}
#[inline]
pub fn read_var_string16(&mut self) -> Result<String, SbeDecodeError> {
let len = usize::from(self.read_u16_le()?);
if len == 0 {
return Ok(String::new());
}
self.require(len)?;
let s = str::from_utf8(&self.buf[self.pos..self.pos + len])
.map_err(|_| SbeDecodeError::InvalidUtf8)?
.to_string();
self.pos += len;
Ok(s)
}
#[inline]
pub fn read_var_string16_ref(&mut self) -> Result<&'a str, SbeDecodeError> {
let len = usize::from(self.read_u16_le()?);
if len == 0 {
return Ok("");
}
self.require(len)?;
let s = str::from_utf8(&self.buf[self.pos..self.pos + len])
.map_err(|_| SbeDecodeError::InvalidUtf8)?;
self.pos += len;
Ok(s)
}
pub fn skip_var_data8(&mut self) -> Result<(), SbeDecodeError> {
let len = self.read_u8()? as usize;
if len > 0 {
self.advance(len)?;
}
Ok(())
}
pub fn read_var_bytes8(&mut self) -> Result<Vec<u8>, SbeDecodeError> {
let len = self.read_u8()? as usize;
if len == 0 {
return Ok(Vec::new());
}
self.require(len)?;
let bytes = self.buf[self.pos..self.pos + len].to_vec();
self.pos += len;
Ok(bytes)
}
pub fn skip_var_data16(&mut self) -> Result<(), SbeDecodeError> {
let len = usize::from(self.read_u16_le()?);
if len > 0 {
self.advance(len)?;
}
Ok(())
}
pub fn read_var_bytes16(&mut self) -> Result<Vec<u8>, SbeDecodeError> {
let len = usize::from(self.read_u16_le()?);
if len == 0 {
return Ok(Vec::new());
}
self.require(len)?;
let bytes = self.buf[self.pos..self.pos + len].to_vec();
self.pos += len;
Ok(bytes)
}
#[inline]
pub fn read_group_header(&mut self) -> Result<(u16, u32), SbeDecodeError> {
let block_length = self.read_u16_le()?;
let num_in_group = self.read_u32_le()?;
if num_in_group > MAX_GROUP_SIZE {
return Err(SbeDecodeError::GroupSizeTooLarge {
count: num_in_group,
max: MAX_GROUP_SIZE,
});
}
Ok((block_length, num_in_group))
}
#[inline]
pub fn read_group_header_16(&mut self) -> Result<(u16, u16), SbeDecodeError> {
let block_length = self.read_u16_le()?;
let num_in_group = self.read_u16_le()?;
if u32::from(num_in_group) > MAX_GROUP_SIZE {
return Err(SbeDecodeError::GroupSizeTooLarge {
count: u32::from(num_in_group),
max: MAX_GROUP_SIZE,
});
}
Ok((block_length, num_in_group))
}
pub fn read_group<T, F>(
&mut self,
block_length: u16,
num_in_group: u32,
mut decode_item: F,
) -> Result<Vec<T>, SbeDecodeError>
where
F: FnMut(&mut Self) -> Result<T, SbeDecodeError>,
{
let block_len = block_length as usize;
let count = num_in_group as usize;
self.require(count * block_len)?;
let mut items = Vec::with_capacity(count);
for _ in 0..count {
let item_start = self.pos;
let item = decode_item(self)?;
items.push(item);
self.pos = item_start + block_len;
}
Ok(items)
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
fn test_new_starts_at_zero() {
let buf = [1, 2, 3, 4];
let cursor = SbeCursor::new(&buf);
assert_eq!(cursor.pos(), 0);
assert_eq!(cursor.remaining(), 4);
}
#[rstest]
fn test_new_at_starts_at_offset() {
let buf = [1, 2, 3, 4];
let cursor = SbeCursor::new_at(&buf, 2);
assert_eq!(cursor.pos(), 2);
assert_eq!(cursor.remaining(), 2);
}
#[rstest]
fn test_read_u8() {
let buf = [0x42, 0xFF];
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_u8().unwrap(), 0x42);
assert_eq!(cursor.pos(), 1);
assert_eq!(cursor.read_u8().unwrap(), 0xFF);
assert_eq!(cursor.pos(), 2);
assert!(cursor.read_u8().is_err());
}
#[rstest]
fn test_read_i8() {
let buf = [0x7F, 0x80]; let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_i8().unwrap(), 127);
assert_eq!(cursor.read_i8().unwrap(), -128);
}
#[rstest]
fn test_read_u16_le() {
let buf = [0x34, 0x12]; let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_u16_le().unwrap(), 0x1234);
assert_eq!(cursor.pos(), 2);
}
#[rstest]
fn test_read_i64_le() {
let value: i64 = -1_234_567_890_123_456_789;
let buf = value.to_le_bytes();
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_i64_le().unwrap(), value);
assert_eq!(cursor.pos(), 8);
}
#[rstest]
#[case::u16(&[0x34][..], 2)]
#[case::u32(&[0x34, 0x12, 0x00][..], 4)]
#[case::u64(&[0; 7][..], 8)]
#[case::u128(&[0; 15][..], 16)]
fn test_multi_byte_reads_buffer_too_short(#[case] buf: &[u8], #[case] needed: usize) {
let mut cursor = SbeCursor::new(buf);
let err = match needed {
2 => cursor.read_u16_le().map(|_| ()).unwrap_err(),
4 => cursor.read_u32_le().map(|_| ()).unwrap_err(),
8 => cursor.read_u64_le().map(|_| ()).unwrap_err(),
16 => cursor.read_u128_le().map(|_| ()).unwrap_err(),
_ => unreachable!(),
};
assert!(matches!(err, SbeDecodeError::BufferTooShort { .. }));
assert_eq!(cursor.pos(), 0, "position must not advance on error");
}
#[rstest]
fn test_read_optional_i64_null() {
let buf = i64::MIN.to_le_bytes();
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_optional_i64_le().unwrap(), None);
}
#[rstest]
fn test_read_optional_i64_present() {
let value: i64 = 12345;
let buf = value.to_le_bytes();
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_optional_i64_le().unwrap(), Some(12345));
}
#[rstest]
fn test_read_var_string8() {
let mut buf = vec![5]; buf.extend_from_slice(b"hello");
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_var_string8().unwrap(), "hello");
assert_eq!(cursor.pos(), 6); }
#[rstest]
fn test_read_var_string8_empty() {
let buf = [0]; let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_var_string8().unwrap(), "");
assert_eq!(cursor.pos(), 1);
}
#[rstest]
fn test_read_var_string8_invalid_utf8() {
let buf = [2, 0xFF, 0xFE]; let mut cursor = SbeCursor::new(&buf);
assert!(matches!(
cursor.read_var_string8(),
Err(SbeDecodeError::InvalidUtf8)
));
}
#[rstest]
fn test_read_group_header() {
let buf = [24, 0, 3, 0, 0, 0];
let mut cursor = SbeCursor::new(&buf);
let (block_len, count) = cursor.read_group_header().unwrap();
assert_eq!(block_len, 24);
assert_eq!(count, 3);
assert_eq!(cursor.pos(), 6);
}
#[rstest]
fn test_read_group_header_too_large() {
let count = MAX_GROUP_SIZE + 1;
let mut buf = vec![24, 0]; buf.extend_from_slice(&count.to_le_bytes());
let mut cursor = SbeCursor::new(&buf);
assert!(matches!(
cursor.read_group_header(),
Err(SbeDecodeError::GroupSizeTooLarge { .. })
));
}
#[rstest]
fn test_read_group() {
let mut buf = Vec::new();
buf.extend_from_slice(&100u32.to_le_bytes()); buf.extend_from_slice(&200u32.to_le_bytes());
let mut cursor = SbeCursor::new(&buf);
let items: Vec<u32> = cursor
.read_group(4, 2, super::SbeCursor::read_u32_le)
.unwrap();
assert_eq!(items, vec![100, 200]);
assert_eq!(cursor.pos(), 8);
}
#[rstest]
fn test_read_group_respects_block_length() {
let mut buf = Vec::new();
buf.extend_from_slice(&100u32.to_le_bytes());
buf.extend_from_slice(&[0, 0, 0, 0]); buf.extend_from_slice(&200u32.to_le_bytes());
buf.extend_from_slice(&[0, 0, 0, 0]);
let mut cursor = SbeCursor::new(&buf);
let items: Vec<u32> = cursor
.read_group(8, 2, super::SbeCursor::read_u32_le)
.unwrap();
assert_eq!(items, vec![100, 200]);
assert_eq!(cursor.pos(), 16); }
#[rstest]
fn test_require_success() {
let buf = [1, 2, 3, 4];
let cursor = SbeCursor::new(&buf);
assert!(cursor.require(4).is_ok());
assert!(cursor.require(3).is_ok());
}
#[rstest]
fn test_require_failure() {
let buf = [1, 2];
let cursor = SbeCursor::new(&buf);
let err = cursor.require(3).unwrap_err();
assert_eq!(
err,
SbeDecodeError::BufferTooShort {
expected: 3,
actual: 2
}
);
}
#[rstest]
fn test_advance() {
let buf = [1, 2, 3, 4];
let mut cursor = SbeCursor::new(&buf);
cursor.advance(2).unwrap();
assert_eq!(cursor.pos(), 2);
assert_eq!(cursor.remaining(), 2);
assert!(cursor.advance(3).is_err());
}
#[rstest]
fn test_peek() {
let buf = [1, 2, 3, 4];
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.peek(), &[1, 2, 3, 4]);
cursor.advance(2).unwrap();
assert_eq!(cursor.peek(), &[3, 4]);
}
#[rstest]
fn test_reset() {
let buf = [1, 2, 3, 4];
let mut cursor = SbeCursor::new(&buf);
cursor.advance(3).unwrap();
assert_eq!(cursor.pos(), 3);
cursor.reset();
assert_eq!(cursor.pos(), 0);
assert_eq!(cursor.remaining(), 4);
}
}