use crate::buffer::{SBE_HEADER_SIZE, SbeBuffer, SbeSmallVec};
use crate::error::{SbeError, SbeResult};
use smallvec::SmallVec;
pub struct SbeEncoder {
buffer: SbeBuffer,
template_id: u16,
schema_version: u16,
block_length: u16,
variable_data_offset: usize,
group_stack: SbeSmallVec<GroupEncoder>,
}
impl SbeEncoder {
pub fn new(template_id: u16, schema_version: u16, block_length: u16) -> Self {
let mut buffer = SbeBuffer::new();
buffer.write_bytes(&[0u8; SBE_HEADER_SIZE]).unwrap();
buffer
.write_bytes(&vec![0u8; block_length as usize])
.unwrap();
Self {
buffer,
template_id,
schema_version,
block_length,
variable_data_offset: SBE_HEADER_SIZE + block_length as usize,
group_stack: SmallVec::new(),
}
}
pub fn with_capacity(
template_id: u16,
schema_version: u16,
block_length: u16,
capacity: usize,
) -> Self {
let mut buffer = SbeBuffer::with_capacity(capacity);
buffer.write_bytes(&[0u8; SBE_HEADER_SIZE]).unwrap();
buffer
.write_bytes(&vec![0u8; block_length as usize])
.unwrap();
Self {
buffer,
template_id,
schema_version,
block_length,
variable_data_offset: SBE_HEADER_SIZE + block_length as usize,
group_stack: SmallVec::new(),
}
}
pub fn write_u8(&mut self, offset: usize, value: u8) -> SbeResult<()> {
if offset >= self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset,
length: self.block_length as usize,
});
}
self.buffer.write_at_offset(SBE_HEADER_SIZE + offset, value)
}
pub fn write_u16(&mut self, offset: usize, value: u16) -> SbeResult<()> {
if offset + 2 > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset,
length: self.block_length as usize,
});
}
let le_bytes = value.to_le_bytes();
self.buffer
.write_at_offset(SBE_HEADER_SIZE + offset, le_bytes)
}
pub fn write_u32(&mut self, offset: usize, value: u32) -> SbeResult<()> {
if offset + 4 > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset,
length: self.block_length as usize,
});
}
let le_bytes = value.to_le_bytes();
self.buffer
.write_at_offset(SBE_HEADER_SIZE + offset, le_bytes)
}
pub fn write_u64(&mut self, offset: usize, value: u64) -> SbeResult<()> {
if offset + 8 > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset,
length: self.block_length as usize,
});
}
let le_bytes = value.to_le_bytes();
self.buffer
.write_at_offset(SBE_HEADER_SIZE + offset, le_bytes)
}
pub fn write_f32(&mut self, offset: usize, value: f32) -> SbeResult<()> {
if offset + 4 > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset,
length: self.block_length as usize,
});
}
let bytes = value.to_le_bytes();
self.buffer.write_at_offset(SBE_HEADER_SIZE + offset, bytes)
}
pub fn write_bytes(&mut self, offset: usize, bytes: &[u8]) -> SbeResult<()> {
if offset + bytes.len() > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset,
length: self.block_length as usize,
});
}
let start_offset = SBE_HEADER_SIZE + offset;
self.buffer.as_mut_slice()[start_offset..start_offset + bytes.len()].copy_from_slice(bytes);
Ok(())
}
pub fn write_string(&mut self, offset: usize, length: usize, value: &str) -> SbeResult<()> {
if offset + length > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset,
length: self.block_length as usize,
});
}
let value_bytes = value.as_bytes();
if value_bytes.len() > length {
return Err(SbeError::Custom {
message: format!(
"String too long: {} bytes, field size is {}",
value_bytes.len(),
length
),
});
}
let start_offset = SBE_HEADER_SIZE + offset;
let end_offset = start_offset + value_bytes.len();
self.buffer.as_mut_slice()[start_offset..end_offset].copy_from_slice(value_bytes);
if value_bytes.len() < length {
let pad_start = end_offset;
let pad_end = start_offset + length;
self.buffer.as_mut_slice()[pad_start..pad_end].fill(0);
}
Ok(())
}
pub fn begin_group(
&mut self,
offset: usize,
block_length: u16,
) -> SbeResult<GroupEncoderBuilder> {
if !self.group_stack.is_empty() {
return Err(SbeError::Custom {
message: "Cannot start group while another group is active".to_string(),
});
}
self.buffer.reserve(6)?;
let group_encoder = GroupEncoder::new(self.variable_data_offset, block_length);
Ok(GroupEncoderBuilder {
encoder: self,
group_encoder,
offset,
})
}
pub fn write_variable_string(&mut self, value: &str) -> SbeResult<()> {
self.write_variable_bytes(value.as_bytes())
}
pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> SbeResult<()> {
if bytes.len() > u16::MAX as usize {
return Err(SbeError::Custom {
message: format!("Variable data too large: {} bytes", bytes.len()),
});
}
self.buffer.reserve(2 + bytes.len())?;
let length = bytes.len() as u16;
self.buffer.write_bytes(&length.to_le_bytes())?;
self.buffer.write_bytes(bytes)?;
Ok(())
}
pub fn finalize(mut self) -> SbeResult<Vec<u8>> {
let total_length = self.buffer.len() as u32;
self.buffer.write_at_offset(0, total_length.to_le_bytes())?;
self.buffer
.write_at_offset(4, self.template_id.to_le_bytes())?;
self.buffer
.write_at_offset(6, self.schema_version.to_le_bytes())?;
Ok(self.buffer.as_slice().to_vec())
}
pub fn current_size(&self) -> usize {
self.buffer.len()
}
pub fn template_id(&self) -> u16 {
self.template_id
}
pub fn schema_version(&self) -> u16 {
self.schema_version
}
}
#[allow(dead_code)]
pub struct GroupEncoderBuilder<'a> {
encoder: &'a mut SbeEncoder,
group_encoder: GroupEncoder,
offset: usize,
}
impl<'a> GroupEncoderBuilder<'a> {
pub fn add_element(&mut self) -> SbeResult<GroupElementEncoder<'_>> {
self.group_encoder.add_element()?;
self.encoder
.buffer
.reserve(self.group_encoder.block_length as usize)?;
let element_offset = self.encoder.buffer.len();
self.encoder
.buffer
.write_bytes(&vec![0u8; self.group_encoder.block_length as usize])?;
Ok(GroupElementEncoder {
encoder: &mut self.encoder.buffer,
offset: element_offset,
block_length: self.group_encoder.block_length,
})
}
pub fn finish(self) -> SbeResult<()> {
let group_start = self.group_encoder.start_offset;
self.encoder
.buffer
.write_at_offset(group_start, self.group_encoder.element_count.to_le_bytes())?;
self.encoder.buffer.write_at_offset(
group_start + 4,
self.group_encoder.block_length.to_le_bytes(),
)?;
self.encoder.variable_data_offset = self.encoder.buffer.len();
Ok(())
}
}
struct GroupEncoder {
start_offset: usize,
block_length: u16,
element_count: u32,
}
impl GroupEncoder {
fn new(start_offset: usize, block_length: u16) -> Self {
Self {
start_offset,
block_length,
element_count: 0,
}
}
fn add_element(&mut self) -> SbeResult<()> {
if self.element_count >= 10_000_000 {
return Err(SbeError::GroupCountTooLarge {
count: self.element_count + 1,
});
}
self.element_count += 1;
Ok(())
}
}
pub struct GroupElementEncoder<'a> {
encoder: &'a mut SbeBuffer,
offset: usize,
block_length: u16,
}
impl<'a> GroupElementEncoder<'a> {
pub fn write_u8(&mut self, field_offset: usize, value: u8) -> SbeResult<()> {
if field_offset >= self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset: field_offset,
length: self.block_length as usize,
});
}
self.encoder
.write_at_offset(self.offset + field_offset, value)
}
pub fn write_u16(&mut self, field_offset: usize, value: u16) -> SbeResult<()> {
if field_offset + 2 > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset: field_offset,
length: self.block_length as usize,
});
}
let le_bytes = value.to_le_bytes();
self.encoder
.write_at_offset(self.offset + field_offset, le_bytes)
}
pub fn write_u32(&mut self, field_offset: usize, value: u32) -> SbeResult<()> {
if field_offset + 4 > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset: field_offset,
length: self.block_length as usize,
});
}
let le_bytes = value.to_le_bytes();
self.encoder
.write_at_offset(self.offset + field_offset, le_bytes)
}
pub fn write_u64(&mut self, field_offset: usize, value: u64) -> SbeResult<()> {
if field_offset + 8 > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset: field_offset,
length: self.block_length as usize,
});
}
let le_bytes = value.to_le_bytes();
self.encoder
.write_at_offset(self.offset + field_offset, le_bytes)
}
pub fn write_f32(&mut self, field_offset: usize, value: f32) -> SbeResult<()> {
if field_offset + 4 > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset: field_offset,
length: self.block_length as usize,
});
}
let bytes = value.to_le_bytes();
self.encoder
.write_at_offset(self.offset + field_offset, bytes)
}
pub fn write_string(
&mut self,
field_offset: usize,
length: usize,
value: &str,
) -> SbeResult<()> {
if field_offset + length > self.block_length as usize {
return Err(SbeError::FieldOffsetOutOfBounds {
offset: field_offset,
length: self.block_length as usize,
});
}
let value_bytes = value.as_bytes();
if value_bytes.len() > length {
return Err(SbeError::Custom {
message: format!(
"String too long: {} bytes, field size is {}",
value_bytes.len(),
length
),
});
}
let start_offset = self.offset + field_offset;
let end_offset = start_offset + value_bytes.len();
self.encoder.as_mut_slice()[start_offset..end_offset].copy_from_slice(value_bytes);
if value_bytes.len() < length {
let pad_start = end_offset;
let pad_end = start_offset + length;
self.encoder.as_mut_slice()[pad_start..pad_end].fill(0);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::decoder::SbeDecoder;
#[test]
fn test_encoder_creation() {
let encoder = SbeEncoder::new(1, 0, 32);
assert_eq!(encoder.template_id(), 1);
assert_eq!(encoder.schema_version(), 0);
assert!(encoder.current_size() >= SBE_HEADER_SIZE + 32);
}
#[test]
fn test_field_encoding() {
let mut encoder = SbeEncoder::new(1, 0, 16);
encoder.write_u32(0, 42).unwrap();
encoder.write_u64(4, 1234567890).unwrap();
encoder.write_u16(12, 999).unwrap();
let message = encoder.finalize().unwrap();
let decoder = SbeDecoder::new(&message).unwrap();
assert_eq!(decoder.template_id(), 1);
assert_eq!(decoder.read_u32(0).unwrap(), 42);
assert_eq!(decoder.read_u64(4).unwrap(), 1234567890);
assert_eq!(decoder.read_u16(12).unwrap(), 999);
}
#[test]
fn test_string_encoding() {
let mut encoder = SbeEncoder::new(1, 0, 16);
encoder.write_string(0, 8, "BTCUSDT").unwrap();
encoder.write_string(8, 8, "BUY").unwrap();
let message = encoder.finalize().unwrap();
let decoder = SbeDecoder::new(&message).unwrap();
assert_eq!(
decoder.read_string(0, 8).unwrap().trim_end_matches('\0'),
"BTCUSDT"
);
assert_eq!(
decoder.read_string(8, 8).unwrap().trim_end_matches('\0'),
"BUY"
);
}
#[test]
fn test_variable_data_encoding() {
let mut encoder = SbeEncoder::new(1, 0, 8);
encoder.write_u64(0, 12345).unwrap();
encoder.write_variable_string("Hello SBE").unwrap();
encoder.write_variable_string("World").unwrap();
let message = encoder.finalize().unwrap();
let expected_min_size = SBE_HEADER_SIZE + 8 + (2 + 9) + (2 + 5);
assert!(
message.len() >= expected_min_size,
"Message length {} should be at least {}",
message.len(),
expected_min_size
);
let decoder = SbeDecoder::new(&message).unwrap();
assert_eq!(decoder.read_u64(0).unwrap(), 12345);
}
#[test]
fn test_round_trip_encoding_decoding() {
let mut encoder = SbeEncoder::new(123, 1, 24);
encoder.write_u8(0, 255).unwrap();
encoder.write_u16(1, 65535).unwrap();
encoder.write_u32(3, 4294967295).unwrap();
encoder.write_u64(7, 18446744073709551615).unwrap();
encoder.write_f32(15, std::f32::consts::PI).unwrap();
encoder.write_string(19, 5, "TEST").unwrap();
let message = encoder.finalize().unwrap();
let decoder = SbeDecoder::new(&message).unwrap();
assert_eq!(decoder.template_id(), 123);
assert_eq!(decoder.schema_version(), 1);
assert_eq!(decoder.read_u8(0).unwrap(), 255);
assert_eq!(decoder.read_u16(1).unwrap(), 65535);
assert_eq!(decoder.read_u32(3).unwrap(), 4294967295);
assert_eq!(decoder.read_u64(7).unwrap(), 18446744073709551615);
assert!((decoder.read_f32(15).unwrap() - std::f32::consts::PI).abs() < 0.001);
assert_eq!(
decoder.read_string(19, 5).unwrap().trim_end_matches('\0'),
"TEST"
);
}
}