use std::{error::Error, fmt::Display};
pub use nautilus_serialization::sbe::{GroupSize16Encoding, GroupSizeEncoding, decode_var_string8};
use crate::spot::sbe::{cursor::SbeCursor, error::SbeDecodeError};
mod best_bid_ask;
mod depth_diff;
mod depth_snapshot;
mod trades;
pub use best_bid_ask::BestBidAskStreamEvent;
pub use depth_diff::DepthDiffStreamEvent;
pub use depth_snapshot::DepthSnapshotStreamEvent;
pub use trades::{Trade, TradesStreamEvent};
pub const STREAM_SCHEMA_ID: u16 = 1;
pub const STREAM_SCHEMA_VERSION: u16 = 0;
pub const MAX_GROUP_SIZE: usize = 10_000;
pub mod template_id {
pub const TRADES_STREAM_EVENT: u16 = 10000;
pub const BEST_BID_ASK_STREAM_EVENT: u16 = 10001;
pub const DEPTH_SNAPSHOT_STREAM_EVENT: u16 = 10002;
pub const DEPTH_DIFF_STREAM_EVENT: u16 = 10003;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StreamDecodeError {
BufferTooShort { expected: usize, actual: usize },
GroupSizeTooLarge { count: usize, max: usize },
InvalidUtf8,
SchemaMismatch { expected: u16, actual: u16 },
UnknownTemplateId(u16),
InvalidBlockLength { expected: u16, actual: u16 },
}
impl Display for StreamDecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::BufferTooShort { expected, actual } => {
write!(
f,
"Buffer too short: expected {expected} bytes, was {actual}"
)
}
Self::GroupSizeTooLarge { count, max } => {
write!(f, "Group size {count} exceeds maximum {max}")
}
Self::InvalidUtf8 => write!(f, "Invalid UTF-8 in symbol"),
Self::SchemaMismatch { expected, actual } => {
write!(f, "Schema mismatch: expected {expected}, was {actual}")
}
Self::UnknownTemplateId(id) => write!(f, "Unknown template ID: {id}"),
Self::InvalidBlockLength { expected, actual } => {
write!(f, "Invalid block length: expected {expected}, was {actual}")
}
}
}
}
impl Error for StreamDecodeError {}
impl From<SbeDecodeError> for StreamDecodeError {
fn from(err: SbeDecodeError) -> Self {
match err {
SbeDecodeError::BufferTooShort { expected, actual } => {
Self::BufferTooShort { expected, actual }
}
SbeDecodeError::SchemaMismatch { expected, actual } => {
Self::SchemaMismatch { expected, actual }
}
SbeDecodeError::VersionMismatch { .. } => Self::SchemaMismatch {
expected: STREAM_SCHEMA_VERSION,
actual: 0,
},
SbeDecodeError::UnknownTemplateId(id) => Self::UnknownTemplateId(id),
SbeDecodeError::GroupSizeTooLarge { count, max } => Self::GroupSizeTooLarge {
count: count as usize,
max: max as usize,
},
SbeDecodeError::InvalidBlockLength { expected, actual } => {
Self::InvalidBlockLength { expected, actual }
}
SbeDecodeError::InvalidUtf8 => Self::InvalidUtf8,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct MessageHeader {
pub block_length: u16,
pub template_id: u16,
pub schema_id: u16,
pub version: u16,
}
impl MessageHeader {
pub const ENCODED_LENGTH: usize = 8;
pub fn decode(buf: &[u8]) -> Result<Self, StreamDecodeError> {
if buf.len() < Self::ENCODED_LENGTH {
return Err(StreamDecodeError::BufferTooShort {
expected: Self::ENCODED_LENGTH,
actual: buf.len(),
});
}
Ok(Self {
block_length: u16::from_le_bytes([buf[0], buf[1]]),
template_id: u16::from_le_bytes([buf[2], buf[3]]),
schema_id: u16::from_le_bytes([buf[4], buf[5]]),
version: u16::from_le_bytes([buf[6], buf[7]]),
})
}
pub fn validate_schema(&self) -> Result<(), StreamDecodeError> {
if self.schema_id != STREAM_SCHEMA_ID {
return Err(StreamDecodeError::SchemaMismatch {
expected: STREAM_SCHEMA_ID,
actual: self.schema_id,
});
}
Ok(())
}
}
#[derive(Debug, Clone, Copy)]
pub struct PriceLevel {
pub price_mantissa: i64,
pub qty_mantissa: i64,
}
impl PriceLevel {
pub const ENCODED_LENGTH: usize = 16;
pub fn decode(cursor: &mut SbeCursor<'_>) -> Result<Self, SbeDecodeError> {
Ok(Self {
price_mantissa: cursor.read_i64_le()?,
qty_mantissa: cursor.read_i64_le()?,
})
}
}
#[inline]
#[must_use]
pub fn mantissa_to_f64(mantissa: i64, exponent: i8) -> f64 {
mantissa as f64 * 10_f64.powi(exponent as i32)
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use super::*;
use crate::spot::sbe::error::SbeDecodeError;
#[rstest]
fn test_mantissa_to_f64() {
assert!((mantissa_to_f64(12345, -2) - 123.45).abs() < 1e-10);
assert!((mantissa_to_f64(100, 0) - 100.0).abs() < 1e-10);
assert!((mantissa_to_f64(5, 3) - 5000.0).abs() < 1e-10);
}
#[rstest]
fn test_message_header_too_short() {
let buf = [0u8; 7];
let err = MessageHeader::decode(&buf).unwrap_err();
assert_eq!(
err,
StreamDecodeError::BufferTooShort {
expected: 8,
actual: 7
}
);
}
#[rstest]
fn test_group_size_too_large() {
let mut buf = [0u8; 6];
let count = (MAX_GROUP_SIZE + 1) as u32;
buf[2..6].copy_from_slice(&count.to_le_bytes());
let err = GroupSizeEncoding::decode(&buf).unwrap_err();
assert!(matches!(err, SbeDecodeError::GroupSizeTooLarge { .. }));
}
#[rstest]
fn test_decode_var_string8_empty_buffer() {
let err = decode_var_string8(&[]).unwrap_err();
assert!(matches!(err, SbeDecodeError::BufferTooShort { .. }));
}
#[rstest]
fn test_decode_var_string8_truncated() {
let buf = [10u8, b'H', b'E', b'L', b'L'];
let err = decode_var_string8(&buf).unwrap_err();
assert!(matches!(err, SbeDecodeError::BufferTooShort { .. }));
}
#[rstest]
fn test_decode_var_string8_valid() {
let buf = [5u8, b'H', b'E', b'L', b'L', b'O'];
let (s, consumed) = decode_var_string8(&buf).unwrap();
assert_eq!(s, "HELLO");
assert_eq!(consumed, 6);
}
#[rstest]
fn test_schema_validation() {
let header = MessageHeader {
block_length: 50,
template_id: 10001,
schema_id: 99, version: 0,
};
let err = header.validate_schema().unwrap_err();
assert_eq!(
err,
StreamDecodeError::SchemaMismatch {
expected: STREAM_SCHEMA_ID,
actual: 99
}
);
}
}