pub mod binary;
pub mod compression;
pub mod framing;
pub mod json;
pub mod message;
pub use binary::{BinaryCodec, BinaryMessage, GeospatialBinaryProtocol};
pub use compression::{CompressionCodec, CompressionLevel, CompressionType};
pub use framing::{Frame, FrameCodec, FrameHeader, FrameType};
pub use json::{JsonCodec, JsonMessage};
pub use message::{Message, MessageType, Payload};
use crate::error::{Error, Result};
use bytes::{Bytes, BytesMut};
use serde::{Deserialize, Serialize};
pub const PROTOCOL_VERSION: u8 = 1;
pub const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessageFormat {
Json,
Binary,
MessagePack,
}
#[derive(Debug, Clone)]
pub struct ProtocolConfig {
pub format: MessageFormat,
pub compression: Option<CompressionType>,
pub compression_level: CompressionLevel,
pub enable_framing: bool,
pub max_message_size: usize,
}
impl Default for ProtocolConfig {
fn default() -> Self {
Self {
format: MessageFormat::Binary,
compression: Some(CompressionType::Zstd),
compression_level: CompressionLevel::Default,
enable_framing: true,
max_message_size: MAX_MESSAGE_SIZE,
}
}
}
pub struct ProtocolCodec {
config: ProtocolConfig,
compression_codec: Option<CompressionCodec>,
frame_codec: FrameCodec,
}
impl ProtocolCodec {
pub fn new(config: ProtocolConfig) -> Self {
let compression_codec = config
.compression
.map(|ct| CompressionCodec::new(ct, config.compression_level));
Self {
config,
compression_codec,
frame_codec: FrameCodec::new(),
}
}
pub fn encode(&self, message: &Message) -> Result<Bytes> {
let mut data = match self.config.format {
MessageFormat::Json => {
let json = serde_json::to_vec(message)?;
BytesMut::from(&json[..])
}
MessageFormat::Binary => BinaryCodec::encode(message)?,
MessageFormat::MessagePack => {
let msgpack = rmp_serde::to_vec(message)?;
BytesMut::from(&msgpack[..])
}
};
if let Some(ref codec) = self.compression_codec {
data = codec.compress(&data)?;
}
if data.len() > self.config.max_message_size {
return Err(Error::Protocol(format!(
"Message size {} exceeds maximum {}",
data.len(),
self.config.max_message_size
)));
}
if self.config.enable_framing {
let frame = Frame::new(
FrameType::Data,
PROTOCOL_VERSION,
self.compression_codec.is_some(),
data.freeze(),
);
self.frame_codec.encode(&frame)
} else {
Ok(data.freeze())
}
}
pub fn decode(&self, data: &[u8]) -> Result<Message> {
let payload = if self.config.enable_framing {
let frame = self.frame_codec.decode(data)?;
frame.payload
} else {
Bytes::copy_from_slice(data)
};
let decompressed = if let Some(ref codec) = self.compression_codec {
codec.decompress(&payload)?
} else {
payload
};
match self.config.format {
MessageFormat::Json => {
let message: Message = serde_json::from_slice(&decompressed)?;
Ok(message)
}
MessageFormat::Binary => BinaryCodec::decode(&decompressed),
MessageFormat::MessagePack => {
let message: Message = rmp_serde::from_slice(&decompressed)?;
Ok(message)
}
}
}
pub fn config(&self) -> &ProtocolConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_protocol_codec_json() -> Result<()> {
let config = ProtocolConfig {
format: MessageFormat::Json,
compression: None,
enable_framing: false,
..Default::default()
};
let codec = ProtocolCodec::new(config);
let message = Message::ping();
let encoded = codec.encode(&message)?;
let decoded = codec.decode(&encoded)?;
assert_eq!(message.message_type(), decoded.message_type());
Ok(())
}
#[test]
fn test_protocol_codec_binary() -> Result<()> {
let config = ProtocolConfig {
format: MessageFormat::Binary,
compression: None,
enable_framing: false,
..Default::default()
};
let codec = ProtocolCodec::new(config);
let message = Message::ping();
let encoded = codec.encode(&message)?;
let decoded = codec.decode(&encoded)?;
assert_eq!(message.message_type(), decoded.message_type());
Ok(())
}
#[test]
fn test_protocol_codec_with_compression() -> Result<()> {
let config = ProtocolConfig {
format: MessageFormat::Binary,
compression: Some(CompressionType::Zstd),
compression_level: CompressionLevel::Fast,
enable_framing: true,
..Default::default()
};
let codec = ProtocolCodec::new(config);
let message = Message::ping();
let encoded = codec.encode(&message)?;
let decoded = codec.decode(&encoded)?;
assert_eq!(message.message_type(), decoded.message_type());
Ok(())
}
}