use std::str;
use super::error::{MAX_GROUP_SIZE, SbeDecodeError};
#[derive(Debug, Clone)]
pub struct SbeCursor<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> SbeCursor<'a> {
#[must_use]
pub const fn new(buf: &'a [u8]) -> Self {
Self { buf, pos: 0 }
}
#[must_use]
pub const fn new_at(buf: &'a [u8], pos: usize) -> Self {
Self { buf, pos }
}
#[must_use]
pub const fn pos(&self) -> usize {
self.pos
}
#[must_use]
pub const fn remaining(&self) -> usize {
self.buf.len().saturating_sub(self.pos)
}
#[must_use]
pub const fn buffer(&self) -> &'a [u8] {
self.buf
}
#[must_use]
pub fn peek(&self) -> &'a [u8] {
&self.buf[self.pos..]
}
pub fn require(&self, n: usize) -> Result<(), SbeDecodeError> {
if self.remaining() < n {
return Err(SbeDecodeError::BufferTooShort {
expected: self.pos + n,
actual: self.buf.len(),
});
}
Ok(())
}
pub fn advance(&mut self, n: usize) -> Result<(), SbeDecodeError> {
self.require(n)?;
self.pos += n;
Ok(())
}
pub fn skip(&mut self, n: usize) {
self.pos += n;
}
pub fn reset(&mut self) {
self.pos = 0;
}
pub fn set_pos(&mut self, pos: usize) {
self.pos = pos;
}
pub fn read_u8(&mut self) -> Result<u8, SbeDecodeError> {
self.require(1)?;
let value = self.buf[self.pos];
self.pos += 1;
Ok(value)
}
pub fn read_i8(&mut self) -> Result<i8, SbeDecodeError> {
self.require(1)?;
let value = self.buf[self.pos] as i8;
self.pos += 1;
Ok(value)
}
pub fn read_u16_le(&mut self) -> Result<u16, SbeDecodeError> {
self.require(2)?;
let value = u16::from_le_bytes([self.buf[self.pos], self.buf[self.pos + 1]]);
self.pos += 2;
Ok(value)
}
pub fn read_i16_le(&mut self) -> Result<i16, SbeDecodeError> {
self.require(2)?;
let value = i16::from_le_bytes([self.buf[self.pos], self.buf[self.pos + 1]]);
self.pos += 2;
Ok(value)
}
pub fn read_u32_le(&mut self) -> Result<u32, SbeDecodeError> {
self.require(4)?;
let value = u32::from_le_bytes([
self.buf[self.pos],
self.buf[self.pos + 1],
self.buf[self.pos + 2],
self.buf[self.pos + 3],
]);
self.pos += 4;
Ok(value)
}
pub fn read_i32_le(&mut self) -> Result<i32, SbeDecodeError> {
self.require(4)?;
let value = i32::from_le_bytes([
self.buf[self.pos],
self.buf[self.pos + 1],
self.buf[self.pos + 2],
self.buf[self.pos + 3],
]);
self.pos += 4;
Ok(value)
}
pub fn read_u64_le(&mut self) -> Result<u64, SbeDecodeError> {
self.require(8)?;
let value = u64::from_le_bytes([
self.buf[self.pos],
self.buf[self.pos + 1],
self.buf[self.pos + 2],
self.buf[self.pos + 3],
self.buf[self.pos + 4],
self.buf[self.pos + 5],
self.buf[self.pos + 6],
self.buf[self.pos + 7],
]);
self.pos += 8;
Ok(value)
}
pub fn read_i64_le(&mut self) -> Result<i64, SbeDecodeError> {
self.require(8)?;
let value = i64::from_le_bytes([
self.buf[self.pos],
self.buf[self.pos + 1],
self.buf[self.pos + 2],
self.buf[self.pos + 3],
self.buf[self.pos + 4],
self.buf[self.pos + 5],
self.buf[self.pos + 6],
self.buf[self.pos + 7],
]);
self.pos += 8;
Ok(value)
}
pub fn read_optional_i64_le(&mut self) -> Result<Option<i64>, SbeDecodeError> {
let value = self.read_i64_le()?;
Ok(if value == i64::MIN { None } else { Some(value) })
}
pub fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], SbeDecodeError> {
self.require(n)?;
let slice = &self.buf[self.pos..self.pos + n];
self.pos += n;
Ok(slice)
}
pub fn read_var_string8(&mut self) -> Result<String, SbeDecodeError> {
let len = self.read_u8()? as usize;
if len == 0 {
return Ok(String::new());
}
self.require(len)?;
let s = str::from_utf8(&self.buf[self.pos..self.pos + len])
.map_err(|_| SbeDecodeError::InvalidUtf8)?
.to_string();
self.pos += len;
Ok(s)
}
pub fn read_var_string8_ref(&mut self) -> Result<&'a str, SbeDecodeError> {
let len = self.read_u8()? as usize;
if len == 0 {
return Ok("");
}
self.require(len)?;
let s = str::from_utf8(&self.buf[self.pos..self.pos + len])
.map_err(|_| SbeDecodeError::InvalidUtf8)?;
self.pos += len;
Ok(s)
}
pub fn skip_var_data8(&mut self) -> Result<(), SbeDecodeError> {
let len = self.read_u8()? as usize;
if len > 0 {
self.advance(len)?;
}
Ok(())
}
pub fn read_var_bytes8(&mut self) -> Result<Vec<u8>, SbeDecodeError> {
let len = self.read_u8()? as usize;
if len == 0 {
return Ok(Vec::new());
}
self.require(len)?;
let bytes = self.buf[self.pos..self.pos + len].to_vec();
self.pos += len;
Ok(bytes)
}
pub fn read_group_header(&mut self) -> Result<(u16, u32), SbeDecodeError> {
let block_length = self.read_u16_le()?;
let num_in_group = self.read_u32_le()?;
if num_in_group > MAX_GROUP_SIZE {
return Err(SbeDecodeError::GroupSizeTooLarge {
count: num_in_group,
max: MAX_GROUP_SIZE,
});
}
Ok((block_length, num_in_group))
}
pub fn read_group_header_16(&mut self) -> Result<(u16, u16), SbeDecodeError> {
let block_length = self.read_u16_le()?;
let num_in_group = self.read_u16_le()?;
if u32::from(num_in_group) > MAX_GROUP_SIZE {
return Err(SbeDecodeError::GroupSizeTooLarge {
count: u32::from(num_in_group),
max: MAX_GROUP_SIZE,
});
}
Ok((block_length, num_in_group))
}
pub fn read_group<T, F>(
&mut self,
block_length: u16,
num_in_group: u32,
mut decode_item: F,
) -> Result<Vec<T>, SbeDecodeError>
where
F: FnMut(&mut Self) -> Result<T, SbeDecodeError>,
{
let block_len = block_length as usize;
let count = num_in_group as usize;
self.require(count * block_len)?;
let mut items = Vec::with_capacity(count);
for _ in 0..count {
let item_start = self.pos;
let item = decode_item(self)?;
items.push(item);
self.pos = item_start + block_len;
}
Ok(items)
}
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
#[rstest]
fn test_new_starts_at_zero() {
let buf = [1, 2, 3, 4];
let cursor = SbeCursor::new(&buf);
assert_eq!(cursor.pos(), 0);
assert_eq!(cursor.remaining(), 4);
}
#[rstest]
fn test_new_at_starts_at_offset() {
let buf = [1, 2, 3, 4];
let cursor = SbeCursor::new_at(&buf, 2);
assert_eq!(cursor.pos(), 2);
assert_eq!(cursor.remaining(), 2);
}
#[rstest]
fn test_read_u8() {
let buf = [0x42, 0xFF];
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_u8().unwrap(), 0x42);
assert_eq!(cursor.pos(), 1);
assert_eq!(cursor.read_u8().unwrap(), 0xFF);
assert_eq!(cursor.pos(), 2);
assert!(cursor.read_u8().is_err());
}
#[rstest]
fn test_read_i8() {
let buf = [0x7F, 0x80]; let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_i8().unwrap(), 127);
assert_eq!(cursor.read_i8().unwrap(), -128);
}
#[rstest]
fn test_read_u16_le() {
let buf = [0x34, 0x12]; let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_u16_le().unwrap(), 0x1234);
assert_eq!(cursor.pos(), 2);
}
#[rstest]
fn test_read_i64_le() {
let value: i64 = -1234567890123456789;
let buf = value.to_le_bytes();
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_i64_le().unwrap(), value);
assert_eq!(cursor.pos(), 8);
}
#[rstest]
fn test_read_optional_i64_null() {
let buf = i64::MIN.to_le_bytes();
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_optional_i64_le().unwrap(), None);
}
#[rstest]
fn test_read_optional_i64_present() {
let value: i64 = 12345;
let buf = value.to_le_bytes();
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_optional_i64_le().unwrap(), Some(12345));
}
#[rstest]
fn test_read_var_string8() {
let mut buf = vec![5]; buf.extend_from_slice(b"hello");
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_var_string8().unwrap(), "hello");
assert_eq!(cursor.pos(), 6); }
#[rstest]
fn test_read_var_string8_empty() {
let buf = [0]; let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.read_var_string8().unwrap(), "");
assert_eq!(cursor.pos(), 1);
}
#[rstest]
fn test_read_var_string8_invalid_utf8() {
let buf = [2, 0xFF, 0xFE]; let mut cursor = SbeCursor::new(&buf);
assert!(matches!(
cursor.read_var_string8(),
Err(SbeDecodeError::InvalidUtf8)
));
}
#[rstest]
fn test_read_group_header() {
let buf = [24, 0, 3, 0, 0, 0];
let mut cursor = SbeCursor::new(&buf);
let (block_len, count) = cursor.read_group_header().unwrap();
assert_eq!(block_len, 24);
assert_eq!(count, 3);
assert_eq!(cursor.pos(), 6);
}
#[rstest]
fn test_read_group_header_too_large() {
let count = MAX_GROUP_SIZE + 1;
let mut buf = vec![24, 0]; buf.extend_from_slice(&count.to_le_bytes());
let mut cursor = SbeCursor::new(&buf);
assert!(matches!(
cursor.read_group_header(),
Err(SbeDecodeError::GroupSizeTooLarge { .. })
));
}
#[rstest]
fn test_read_group() {
let mut buf = Vec::new();
buf.extend_from_slice(&100u32.to_le_bytes()); buf.extend_from_slice(&200u32.to_le_bytes());
let mut cursor = SbeCursor::new(&buf);
let items: Vec<u32> = cursor.read_group(4, 2, |c| c.read_u32_le()).unwrap();
assert_eq!(items, vec![100, 200]);
assert_eq!(cursor.pos(), 8);
}
#[rstest]
fn test_read_group_respects_block_length() {
let mut buf = Vec::new();
buf.extend_from_slice(&100u32.to_le_bytes());
buf.extend_from_slice(&[0, 0, 0, 0]); buf.extend_from_slice(&200u32.to_le_bytes());
buf.extend_from_slice(&[0, 0, 0, 0]);
let mut cursor = SbeCursor::new(&buf);
let items: Vec<u32> = cursor.read_group(8, 2, |c| c.read_u32_le()).unwrap();
assert_eq!(items, vec![100, 200]);
assert_eq!(cursor.pos(), 16); }
#[rstest]
fn test_require_success() {
let buf = [1, 2, 3, 4];
let cursor = SbeCursor::new(&buf);
assert!(cursor.require(4).is_ok());
assert!(cursor.require(3).is_ok());
}
#[rstest]
fn test_require_failure() {
let buf = [1, 2];
let cursor = SbeCursor::new(&buf);
let err = cursor.require(3).unwrap_err();
assert_eq!(
err,
SbeDecodeError::BufferTooShort {
expected: 3,
actual: 2
}
);
}
#[rstest]
fn test_advance() {
let buf = [1, 2, 3, 4];
let mut cursor = SbeCursor::new(&buf);
cursor.advance(2).unwrap();
assert_eq!(cursor.pos(), 2);
assert_eq!(cursor.remaining(), 2);
assert!(cursor.advance(3).is_err());
}
#[rstest]
fn test_peek() {
let buf = [1, 2, 3, 4];
let mut cursor = SbeCursor::new(&buf);
assert_eq!(cursor.peek(), &[1, 2, 3, 4]);
cursor.advance(2).unwrap();
assert_eq!(cursor.peek(), &[3, 4]);
}
#[rstest]
fn test_reset() {
let buf = [1, 2, 3, 4];
let mut cursor = SbeCursor::new(&buf);
cursor.advance(3).unwrap();
assert_eq!(cursor.pos(), 3);
cursor.reset();
assert_eq!(cursor.pos(), 0);
assert_eq!(cursor.remaining(), 4);
}
}