use crate::error::ZmqError;
use crate::message::{Msg, MsgFlags};
use crate::protocol::zmtp::command::*;
use bytes::{Buf, BufMut, BytesMut};
use tokio_util::codec::{Decoder, Encoder};
#[derive(Debug, Default)]
pub struct ZmtpCodec {
decoding_state: DecodingState,
prefix_bytes: Option<BytesMut>,
}
#[derive(Debug, Default, Clone, Copy)]
enum DecodingState {
#[default]
ReadHeader, ReadBody(FrameHeader), }
#[derive(Debug, Clone, Copy)]
struct FrameHeader {
flags: u8,
size: usize,
}
impl ZmtpCodec {
pub fn new() -> Self {
Self {
decoding_state: DecodingState::default(),
prefix_bytes: None,
}
}
pub fn prefix_bytes(&self) -> Option<&BytesMut> {
return self.prefix_bytes.as_ref();
}
pub fn take_prefix_bytes(&mut self) -> Option<BytesMut> {
return self.prefix_bytes.take();
}
pub fn set_prefix_bytes(&mut self, prefix_bytes: Option<BytesMut>) {
self.prefix_bytes = prefix_bytes;
}
pub fn prime_with_prefix(&mut self, prefix: BytesMut) {
if !prefix.is_empty() {
tracing::trace!(prefix_len = prefix.len(), "ZmtpCodec primed with prefix bytes");
self.prefix_bytes = Some(prefix);
}
}
pub fn encode_header_only(&self, item: &Msg, dst: &mut BytesMut) -> Result<(), ZmqError> {
let data_size = item.size(); let msg_flags = item.flags();
let mut zmtp_flags_byte = 0u8;
if msg_flags.contains(MsgFlags::MORE) {
zmtp_flags_byte |= ZMTP_FLAG_MORE;
}
if msg_flags.contains(MsgFlags::COMMAND) {
zmtp_flags_byte |= ZMTP_FLAG_COMMAND;
}
if data_size <= 255 {
dst.reserve(1 + 1); dst.put_u8(zmtp_flags_byte); dst.put_u8(data_size as u8);
} else {
zmtp_flags_byte |= ZMTP_FLAG_LONG; dst.reserve(1 + 8); dst.put_u8(zmtp_flags_byte); dst.put_u64(data_size as u64); }
Ok(())
}
}
impl Encoder<Msg> for ZmtpCodec {
type Error = ZmqError;
fn encode(&mut self, item: Msg, dst: &mut BytesMut) -> Result<(), Self::Error> {
let data = item.data().unwrap_or(&[]); let size = data.len();
let msg_flags = item.flags();
let mut zmtp_flags = 0u8;
if msg_flags.contains(MsgFlags::MORE) {
zmtp_flags |= ZMTP_FLAG_MORE;
}
if msg_flags.contains(MsgFlags::COMMAND) {
zmtp_flags |= ZMTP_FLAG_COMMAND;
}
if size <= 255 {
dst.reserve(2 + size);
dst.put_u8(zmtp_flags); dst.put_u8(size as u8);
} else {
zmtp_flags |= ZMTP_FLAG_LONG; dst.reserve(9 + size);
dst.put_u8(zmtp_flags);
dst.put_u64(size as u64);
}
dst.put_slice(data);
Ok(())
}
}
impl Decoder for ZmtpCodec {
type Item = Msg; type Error = ZmqError;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
if let Some(prefix) = self.prefix_bytes.take() {
if !prefix.is_empty() {
tracing::trace!(
prefix_len = prefix.len(),
src_len_before = src.len(),
"ZmtpCodec::decode: Prepending stored prefix bytes to src buffer"
);
src.reserve(prefix.len());
let original_src_content = src.split(); src.put(prefix); src.put(original_src_content); tracing::trace!(
src_len_after = src.len(),
"ZmtpCodec::decode: Finished prepending prefix."
);
}
}
loop {
match self.decoding_state {
DecodingState::ReadHeader => {
if src.is_empty() {
return Ok(None); }
let flags = src[0]; let is_long = (flags & ZMTP_FLAG_LONG) != 0;
let header_len = if is_long { 1 + 8 } else { 1 + 1 };
if src.len() < header_len {
src.reserve(header_len - src.len()); return Ok(None); }
let header_bytes = src.split_to(header_len);
let flags = header_bytes[0];
let size = if is_long {
let mut len_bytes = &header_bytes[1..]; len_bytes.get_u64() as usize
} else {
header_bytes[1] as usize
};
let header = FrameHeader { flags, size };
self.decoding_state = DecodingState::ReadBody(header);
}
DecodingState::ReadBody(header) => {
if src.len() < header.size {
src.reserve(header.size - src.len()); return Ok(None);
}
let body_bytes = src.split_to(header.size).freeze();
self.decoding_state = DecodingState::ReadHeader;
let mut msg = Msg::from_bytes(body_bytes);
let mut msg_flags = MsgFlags::empty();
if (header.flags & ZMTP_FLAG_MORE) != 0 {
msg_flags |= MsgFlags::MORE;
}
if (header.flags & ZMTP_FLAG_COMMAND) != 0 {
msg_flags |= MsgFlags::COMMAND;
}
msg.set_flags(msg_flags);
return Ok(Some(msg));
}
} } } }