use crate::{
buffer::BufferUtils,
error::{EncodeError, GpbError},
FieldValue, FixMessage, GpbWriter,
};
use fastrace::prelude::*;
use smallvec::SmallVec;
use std::collections::HashMap;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy)]
enum WireType {
Varint = 0,
Fixed64 = 1,
LengthDelimited = 2,
Fixed32 = 5,
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub struct EncodeConfig {
pub validate_messages: bool,
pub include_checksums: bool,
pub compress_repeated: bool,
pub max_message_size: usize,
pub buffer_strategy: BufferStrategy,
}
impl Default for EncodeConfig {
fn default() -> Self {
Self {
validate_messages: true,
include_checksums: true,
compress_repeated: true,
max_message_size: 1024 * 1024, buffer_strategy: BufferStrategy::Adaptive,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum BufferStrategy {
Fixed(usize),
Adaptive,
Batch,
}
#[derive(Debug)]
pub struct GpbEncoder {
config: EncodeConfig,
writer: GpbWriter,
field_mappings: HashMap<u32, u32>,
}
impl GpbEncoder {
pub fn new() -> Self {
Self::with_config(EncodeConfig::default())
}
pub fn with_config(config: EncodeConfig) -> Self {
let writer_capacity = match config.buffer_strategy {
BufferStrategy::Fixed(size) => size,
BufferStrategy::Adaptive => 8192, BufferStrategy::Batch => 64 * 1024, };
Self {
config,
writer: GpbWriter::with_capacity(writer_capacity),
field_mappings: Self::create_field_mappings(),
}
}
#[trace]
pub fn encode(&mut self, message: &FixMessage) -> Result<&[u8], GpbError> {
self.writer.clear();
if self.config.validate_messages {
message.validate()?;
}
let estimated_size = message.estimated_size();
if estimated_size > self.config.max_message_size {
return Err(GpbError::Encode(EncodeError::InvalidFieldValue {
field_id: 0,
reason: format!("Message too large: {} bytes", estimated_size),
}));
}
self.ensure_capacity(estimated_size)?;
self.encode_message_header(message)?;
self.encode_fields(message)?;
if self.config.include_checksums {
self.encode_checksum()?;
}
Ok(self.writer.as_bytes())
}
#[trace]
pub fn encode_batch(&mut self, messages: &[FixMessage]) -> Result<&[u8], GpbError> {
self.writer.clear();
let total_size: usize = messages.iter().map(|m| m.estimated_size()).sum();
if total_size > self.config.max_message_size {
return Err(GpbError::Encode(EncodeError::InvalidFieldValue {
field_id: 0,
reason: format!("Batch too large: {} bytes", total_size),
}));
}
self.ensure_capacity(total_size + 1024)?;
self.encode_batch_header(messages.len())?;
for message in messages {
if self.config.validate_messages {
message.validate()?;
}
let mut temp_encoder = GpbEncoder::new();
temp_encoder.encode_message_header(message)?;
temp_encoder.encode_fields(message)?;
let msg_data = temp_encoder.writer.as_bytes();
let msg_len = msg_data.len();
self.encode_varint(msg_len as u64)?;
self.write_bytes(msg_data)?;
}
if self.config.include_checksums {
self.encode_checksum()?;
}
Ok(self.writer.as_bytes())
}
#[trace]
fn encode_message_header(&mut self, message: &FixMessage) -> Result<(), EncodeError> {
self.encode_field_header(1, WireType::LengthDelimited)?;
let msg_type_str = message.message_type.as_str();
self.encode_string(msg_type_str)?;
if let Some(seq_num) = message.seq_num {
self.encode_field_header(2, WireType::Varint)?;
self.encode_varint(seq_num as u64)?;
}
if let Some(ref sender) = message.sender_comp_id {
self.encode_field_header(3, WireType::LengthDelimited)?;
self.encode_string(sender)?;
}
if let Some(ref target) = message.target_comp_id {
self.encode_field_header(4, WireType::LengthDelimited)?;
self.encode_string(target)?;
}
if let Some(sending_time) = message.sending_time {
self.encode_field_header(5, WireType::Varint)?;
self.encode_varint(sending_time)?;
}
Ok(())
}
#[trace]
fn encode_fields(&mut self, message: &FixMessage) -> Result<(), EncodeError> {
let mut sorted_fields: SmallVec<[_; 64]> = message.fields.iter().collect();
sorted_fields.sort_by_key(|(tag, _)| *tag);
for (&tag, value) in sorted_fields {
let field_num = self.map_fix_tag_to_gpb_field(tag);
self.encode_fix_field(field_num, value)?;
}
Ok(())
}
#[trace]
fn encode_fix_field(&mut self, field_num: u32, value: &FieldValue) -> Result<(), EncodeError> {
match value {
FieldValue::String(s) => {
self.encode_field_header(field_num, WireType::LengthDelimited)?;
self.encode_string(s)?;
}
FieldValue::Int(i) => {
self.encode_field_header(field_num, WireType::Varint)?;
self.encode_varint_signed(*i)?;
}
FieldValue::UInt(u) => {
self.encode_field_header(field_num, WireType::Varint)?;
self.encode_varint(*u)?;
}
FieldValue::Float(f) => {
self.encode_field_header(field_num, WireType::Fixed64)?;
self.encode_double(*f)?;
}
FieldValue::Bool(b) => {
self.encode_field_header(field_num, WireType::Varint)?;
self.encode_varint(if *b { 1 } else { 0 })?;
}
FieldValue::Bytes(bytes) => {
self.encode_field_header(field_num, WireType::LengthDelimited)?;
self.encode_bytes(bytes)?;
}
FieldValue::Decimal { mantissa, scale } => {
self.encode_field_header(field_num, WireType::LengthDelimited)?;
let decimal_data = self.encode_decimal(*mantissa, *scale)?;
self.encode_bytes(&decimal_data)?;
}
FieldValue::Timestamp(ts) => {
self.encode_field_header(field_num, WireType::Varint)?;
self.encode_varint(*ts)?;
}
FieldValue::Optional(Some(inner_value)) => {
self.encode_fix_field(field_num, inner_value)?;
}
FieldValue::Optional(None) => {
}
}
Ok(())
}
fn encode_batch_header(&mut self, count: usize) -> Result<(), EncodeError> {
self.encode_field_header(0, WireType::Varint)?;
self.encode_varint(count as u64)?;
Ok(())
}
fn encode_checksum(&mut self) -> Result<(), EncodeError> {
let data = self.writer.as_bytes();
let checksum = BufferUtils::crc32(data);
self.encode_field_header(999, WireType::Fixed32)?;
self.encode_fixed32(checksum)?;
Ok(())
}
fn encode_field_header(
&mut self,
field_num: u32,
wire_type: WireType,
) -> Result<(), EncodeError> {
let tag = (field_num << 3) | (wire_type as u32);
self.encode_varint(tag as u64)
}
fn encode_varint(&mut self, value: u64) -> Result<(), EncodeError> {
let bytes = BufferUtils::encode_varint(value);
self.write_bytes(&bytes)
}
fn encode_varint_signed(&mut self, value: i64) -> Result<(), EncodeError> {
let zigzag = ((value << 1) ^ (value >> 63)) as u64;
self.encode_varint(zigzag)
}
fn encode_string(&mut self, s: &str) -> Result<(), EncodeError> {
let bytes = s.as_bytes();
self.encode_varint(bytes.len() as u64)?;
self.write_bytes(bytes)
}
fn encode_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
self.encode_varint(bytes.len() as u64)?;
self.write_bytes(bytes)
}
fn encode_double(&mut self, value: f64) -> Result<(), EncodeError> {
let bytes = value.to_le_bytes();
self.write_bytes(&bytes)
}
fn encode_fixed32(&mut self, value: u32) -> Result<(), EncodeError> {
let bytes = value.to_le_bytes();
self.write_bytes(&bytes)
}
fn encode_decimal(&mut self, mantissa: i64, scale: i32) -> Result<Vec<u8>, EncodeError> {
let mut decimal_data = Vec::new();
decimal_data.extend(BufferUtils::encode_varint((1 << 3) | 0)); let zigzag_mantissa = ((mantissa << 1) ^ (mantissa >> 63)) as u64;
decimal_data.extend(BufferUtils::encode_varint(zigzag_mantissa));
decimal_data.extend(BufferUtils::encode_varint((2 << 3) | 0)); let zigzag_scale = (((scale as i64) << 1) ^ (scale as i64 >> 63)) as u64;
decimal_data.extend(BufferUtils::encode_varint(zigzag_scale));
Ok(decimal_data)
}
fn write_bytes(&mut self, bytes: &[u8]) -> Result<(), EncodeError> {
use std::io::Write;
self.writer
.write_all(bytes)
.map_err(|e| EncodeError::InvalidFieldValue {
field_id: 0,
reason: format!("Write error: {}", e),
})
}
fn ensure_capacity(&mut self, needed: usize) -> Result<(), EncodeError> {
let current_capacity = self.writer.buffer().capacity();
let current_length = self.writer.buffer().len();
if current_length + needed > current_capacity {
return Err(EncodeError::BufferTooSmall {
needed: needed + current_length,
available: current_capacity,
});
}
Ok(())
}
fn map_fix_tag_to_gpb_field(&self, fix_tag: u32) -> u32 {
self.field_mappings
.get(&fix_tag)
.copied()
.unwrap_or(fix_tag + 100) }
fn create_field_mappings() -> HashMap<u32, u32> {
let mut mappings = HashMap::new();
mappings.insert(8, 10); mappings.insert(9, 11); mappings.insert(35, 12); mappings.insert(34, 13); mappings.insert(49, 14); mappings.insert(56, 15); mappings.insert(52, 16); mappings.insert(55, 20); mappings.insert(44, 21); mappings.insert(38, 22); mappings.insert(54, 23); mappings.insert(40, 24); mappings.insert(59, 25); mappings.insert(37, 30); mappings.insert(17, 31); mappings.insert(150, 32); mappings.insert(39, 33); mappings.insert(32, 34); mappings.insert(31, 35);
mappings
}
pub fn stats(&self) -> EncoderStats {
EncoderStats {
buffer_capacity: self.writer.buffer().capacity(),
buffer_used: self.writer.buffer().len(),
field_mappings_count: self.field_mappings.len(),
}
}
}
impl Default for GpbEncoder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct EncoderStats {
pub buffer_capacity: usize,
pub buffer_used: usize,
pub field_mappings_count: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MessageType;
#[test]
fn test_encoder_creation() {
let encoder = GpbEncoder::new();
let stats = encoder.stats();
assert!(stats.buffer_capacity > 0);
assert_eq!(stats.buffer_used, 0);
assert!(stats.field_mappings_count > 0);
}
#[test]
fn test_encode_new_order_single() {
let mut encoder = GpbEncoder::new();
let message =
FixMessage::new_order_single("BTCUSD".to_string(), 50000.0, 1.5, "1".to_string());
let encoded = encoder.encode(&message).unwrap();
assert!(!encoded.is_empty());
assert!(encoded.len() > 10);
}
#[test]
fn test_encode_with_validation() {
let mut encoder = GpbEncoder::with_config(EncodeConfig {
validate_messages: true,
..Default::default()
});
let invalid_message = FixMessage::new(MessageType::NewOrderSingle);
let result = encoder.encode(&invalid_message);
assert!(result.is_err());
let valid_message =
FixMessage::new_order_single("ETHUSD".to_string(), 3000.0, 2.0, "2".to_string());
let result = encoder.encode(&valid_message);
assert!(result.is_ok());
}
#[test]
fn test_encode_batch() {
let mut encoder = GpbEncoder::with_config(EncodeConfig {
buffer_strategy: BufferStrategy::Batch,
..Default::default()
});
let messages = vec![
FixMessage::new_order_single("BTC".to_string(), 50000.0, 1.0, "1".to_string()),
FixMessage::new_order_single("ETH".to_string(), 3000.0, 2.0, "2".to_string()),
];
let encoded = encoder.encode_batch(&messages).unwrap();
assert!(!encoded.is_empty());
assert!(encoded.len() > 20); }
#[test]
fn test_field_value_encoding() {
let mut encoder = GpbEncoder::new();
let mut message = FixMessage::new(MessageType::Heartbeat);
message.set_field(1, FieldValue::String("test".to_string()));
message.set_field(2, FieldValue::Int(-123));
message.set_field(3, FieldValue::UInt(456));
message.set_field(4, FieldValue::Float(123.45));
message.set_field(5, FieldValue::Bool(true));
message.set_field(
6,
FieldValue::Decimal {
mantissa: 12345,
scale: 2,
},
);
let encoded = encoder.encode(&message).unwrap();
assert!(!encoded.is_empty());
}
#[test]
fn test_buffer_strategies() {
let encoder_fixed = GpbEncoder::with_config(EncodeConfig {
buffer_strategy: BufferStrategy::Fixed(1024),
..Default::default()
});
assert!(encoder_fixed.writer.buffer().capacity() >= 1024);
let encoder_adaptive = GpbEncoder::with_config(EncodeConfig {
buffer_strategy: BufferStrategy::Adaptive,
..Default::default()
});
assert!(encoder_adaptive.writer.buffer().capacity() > 0);
let encoder_batch = GpbEncoder::with_config(EncodeConfig {
buffer_strategy: BufferStrategy::Batch,
..Default::default()
});
assert!(encoder_batch.writer.buffer().capacity() >= 64 * 1024);
}
}