use crate::decoder::SbeDecoder;
use crate::encoder::SbeEncoder;
use crate::error::{SbeError, SbeResult};
pub trait SbeMessage: Sized {
const TEMPLATE_ID: u16;
const SCHEMA_VERSION: u16;
const BLOCK_LENGTH: u16;
const MESSAGE_NAME: &'static str;
fn decode(data: &[u8]) -> SbeResult<SbeMessageDecoder<'_, Self>> {
let decoder = SbeDecoder::new(data)?;
decoder.verify_template_id(Self::TEMPLATE_ID)?;
decoder.verify_schema_version(Self::SCHEMA_VERSION)?;
Ok(SbeMessageDecoder {
decoder,
_phantom: std::marker::PhantomData,
})
}
fn encode() -> SbeMessageEncoder<Self> {
let encoder = SbeEncoder::new(Self::TEMPLATE_ID, Self::SCHEMA_VERSION, Self::BLOCK_LENGTH);
SbeMessageEncoder {
encoder,
_phantom: std::marker::PhantomData,
}
}
fn encode_with_capacity(capacity: usize) -> SbeMessageEncoder<Self> {
let encoder = SbeEncoder::with_capacity(
Self::TEMPLATE_ID,
Self::SCHEMA_VERSION,
Self::BLOCK_LENGTH,
capacity,
);
SbeMessageEncoder {
encoder,
_phantom: std::marker::PhantomData,
}
}
fn validate_header(data: &[u8]) -> SbeResult<()> {
let decoder = SbeDecoder::new(data)?;
decoder.verify_template_id(Self::TEMPLATE_ID)?;
decoder.verify_schema_version(Self::SCHEMA_VERSION)?;
Ok(())
}
fn metadata() -> SbeMessageMetadata {
SbeMessageMetadata {
template_id: Self::TEMPLATE_ID,
schema_version: Self::SCHEMA_VERSION,
block_length: Self::BLOCK_LENGTH,
message_name: Self::MESSAGE_NAME,
}
}
}
pub struct SbeMessageDecoder<'a, T: SbeMessage> {
decoder: SbeDecoder<'a>,
_phantom: std::marker::PhantomData<T>,
}
impl<'a, T: SbeMessage> SbeMessageDecoder<'a, T> {
pub fn decoder(&self) -> &SbeDecoder<'a> {
&self.decoder
}
pub fn template_id(&self) -> u16 {
self.decoder.template_id()
}
pub fn schema_version(&self) -> u16 {
self.decoder.schema_version()
}
pub fn read_u8(&self, offset: usize) -> SbeResult<u8> {
self.decoder.read_u8(offset)
}
pub fn read_u16(&self, offset: usize) -> SbeResult<u16> {
self.decoder.read_u16(offset)
}
pub fn read_u32(&self, offset: usize) -> SbeResult<u32> {
self.decoder.read_u32(offset)
}
pub fn read_u64(&self, offset: usize) -> SbeResult<u64> {
self.decoder.read_u64(offset)
}
pub fn read_f32(&self, offset: usize) -> SbeResult<f32> {
self.decoder.read_f32(offset)
}
pub fn read_string(&self, offset: usize, length: usize) -> SbeResult<&'a str> {
self.decoder.read_string(offset, length)
}
pub fn read_bytes(&self, offset: usize, length: usize) -> SbeResult<&'a [u8]> {
self.decoder.read_bytes(offset, length)
}
}
pub struct SbeMessageEncoder<T: SbeMessage> {
encoder: SbeEncoder,
_phantom: std::marker::PhantomData<T>,
}
impl<T: SbeMessage> SbeMessageEncoder<T> {
pub fn encoder(&mut self) -> &mut SbeEncoder {
&mut self.encoder
}
pub fn write_u8(&mut self, offset: usize, value: u8) -> SbeResult<()> {
self.encoder.write_u8(offset, value)
}
pub fn write_u16(&mut self, offset: usize, value: u16) -> SbeResult<()> {
self.encoder.write_u16(offset, value)
}
pub fn write_u32(&mut self, offset: usize, value: u32) -> SbeResult<()> {
self.encoder.write_u32(offset, value)
}
pub fn write_u64(&mut self, offset: usize, value: u64) -> SbeResult<()> {
self.encoder.write_u64(offset, value)
}
pub fn write_f32(&mut self, offset: usize, value: f32) -> SbeResult<()> {
self.encoder.write_f32(offset, value)
}
pub fn write_string(&mut self, offset: usize, length: usize, value: &str) -> SbeResult<()> {
self.encoder.write_string(offset, length, value)
}
pub fn write_bytes(&mut self, offset: usize, bytes: &[u8]) -> SbeResult<()> {
self.encoder.write_bytes(offset, bytes)
}
pub fn write_variable_string(&mut self, value: &str) -> SbeResult<()> {
self.encoder.write_variable_string(value)
}
pub fn write_variable_bytes(&mut self, bytes: &[u8]) -> SbeResult<()> {
self.encoder.write_variable_bytes(bytes)
}
pub fn finalize(self) -> SbeResult<Vec<u8>> {
self.encoder.finalize()
}
pub fn current_size(&self) -> usize {
self.encoder.current_size()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SbeMessageMetadata {
pub template_id: u16,
pub schema_version: u16,
pub block_length: u16,
pub message_name: &'static str,
}
impl SbeMessageMetadata {
pub fn matches_template(&self, template_id: u16) -> bool {
self.template_id == template_id
}
pub fn is_compatible_version(&self, schema_version: u16) -> bool {
self.schema_version == schema_version
}
}
pub struct SbeMessageRegistry {
messages: std::collections::HashMap<u16, SbeMessageMetadata>,
}
impl SbeMessageRegistry {
pub fn new() -> Self {
Self {
messages: std::collections::HashMap::new(),
}
}
pub fn register<T: SbeMessage>(&mut self) {
let metadata = T::metadata();
self.messages.insert(metadata.template_id, metadata);
}
pub fn get_metadata(&self, template_id: u16) -> Option<&SbeMessageMetadata> {
self.messages.get(&template_id)
}
pub fn is_registered(&self, template_id: u16) -> bool {
self.messages.contains_key(&template_id)
}
pub fn template_ids(&self) -> Vec<u16> {
self.messages.keys().copied().collect()
}
pub fn all_metadata(&self) -> Vec<&SbeMessageMetadata> {
self.messages.values().collect()
}
}
impl Default for SbeMessageRegistry {
fn default() -> Self {
Self::new()
}
}
pub struct SbeMessageHeader;
impl SbeMessageHeader {
pub fn extract_template_id(data: &[u8]) -> SbeResult<u16> {
if data.len() < 8 {
return Err(SbeError::BufferTooSmall {
need: 8,
have: data.len(),
});
}
Ok(u16::from_le_bytes([data[4], data[5]]))
}
pub fn extract_schema_version(data: &[u8]) -> SbeResult<u16> {
if data.len() < 8 {
return Err(SbeError::BufferTooSmall {
need: 8,
have: data.len(),
});
}
Ok(u16::from_le_bytes([data[6], data[7]]))
}
pub fn extract_message_length(data: &[u8]) -> SbeResult<u32> {
if data.len() < 4 {
return Err(SbeError::BufferTooSmall {
need: 4,
have: data.len(),
});
}
Ok(u32::from_le_bytes([data[0], data[1], data[2], data[3]]))
}
pub fn validate_basic(data: &[u8]) -> SbeResult<(u32, u16, u16)> {
let length = Self::extract_message_length(data)?;
let template_id = Self::extract_template_id(data)?;
let schema_version = Self::extract_schema_version(data)?;
if length < 8 {
return Err(SbeError::InvalidMessageLength {
length: length as u16,
});
}
if length as usize > data.len() {
return Err(SbeError::InvalidMessageLength {
length: length as u16,
});
}
Ok((length, template_id, schema_version))
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestMessage;
impl SbeMessage for TestMessage {
const TEMPLATE_ID: u16 = 1;
const SCHEMA_VERSION: u16 = 0;
const BLOCK_LENGTH: u16 = 16;
const MESSAGE_NAME: &'static str = "TestMessage";
}
#[test]
fn test_message_metadata() {
let metadata = TestMessage::metadata();
assert_eq!(metadata.template_id, 1);
assert_eq!(metadata.schema_version, 0);
assert_eq!(metadata.block_length, 16);
assert_eq!(metadata.message_name, "TestMessage");
}
#[test]
fn test_message_registry() {
let mut registry = SbeMessageRegistry::new();
registry.register::<TestMessage>();
assert!(registry.is_registered(1));
assert!(!registry.is_registered(2));
let metadata = registry.get_metadata(1).unwrap();
assert_eq!(metadata.message_name, "TestMessage");
}
#[test]
fn test_header_extraction() {
let mut encoder = TestMessage::encode();
encoder.write_u64(0, 12345).unwrap();
let message = encoder.finalize().unwrap();
let template_id = SbeMessageHeader::extract_template_id(&message).unwrap();
let schema_version = SbeMessageHeader::extract_schema_version(&message).unwrap();
let length = SbeMessageHeader::extract_message_length(&message).unwrap();
assert_eq!(template_id, 1);
assert_eq!(schema_version, 0);
assert_eq!(length, message.len() as u32);
}
#[test]
fn test_typed_encoding_decoding() {
let mut encoder = TestMessage::encode();
encoder.write_u32(0, 42).unwrap();
encoder.write_u64(4, 1234567890).unwrap();
encoder.write_u16(12, 999).unwrap();
let message = encoder.finalize().unwrap();
let decoder = TestMessage::decode(&message).unwrap();
assert_eq!(decoder.read_u32(0).unwrap(), 42);
assert_eq!(decoder.read_u64(4).unwrap(), 1234567890);
assert_eq!(decoder.read_u16(12).unwrap(), 999);
}
}