use super::command::HeartbeatFrame;
use super::command::ZmqCommand;
use super::greeting::ZmqGreeting;
use super::Message;
use crate::error::CodecError;
use crate::ZmqMessage;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::convert::TryFrom;
#[derive(Debug, Clone, Copy)]
struct Frame {
command: bool,
long: bool,
more: bool,
}
#[derive(Debug)]
enum DecoderState {
Greeting,
FrameHeader,
FrameLen(Frame),
Frame(Frame),
}
#[derive(Debug)]
pub struct ZmqCodec {
state: DecoderState,
waiting_for: usize, // Number of bytes needed to decode frame
// Needed to store incoming multipart message
// This allows to encapsulate its processing inside codec and not expose
// internal details to higher levels
buffered_message: Option<ZmqMessage>,
}
impl ZmqCodec {
pub fn new() -> Self {
Self {
state: DecoderState::Greeting,
waiting_for: 64, // len of the greeting frame
buffered_message: None,
}
}
/// Construct a codec already past the ZMTP greeting — the decoder
/// starts expecting a frame header byte. Useful in unit tests that
/// bypass the handshake.
#[cfg(all(test, feature = "tokio"))]
pub(crate) fn post_greeting() -> Self {
Self {
state: DecoderState::FrameHeader,
waiting_for: 1,
buffered_message: None,
}
}
}
impl Default for ZmqCodec {
fn default() -> Self {
Self::new()
}
}
impl ZmqCodec {
/// Core decoder. Exposed via the `tokio_util::codec::Decoder` impl at
/// the bottom of this file; kept as an inherent method so tests can
/// drive it directly without pulling in the trait.
pub(crate) fn decode_inner(
&mut self,
src: &mut BytesMut,
) -> Result<Option<Message>, CodecError> {
loop {
if src.len() < self.waiting_for {
src.reserve(self.waiting_for - src.len());
return Ok(None);
}
match self.state {
DecoderState::Greeting => {
if src[0] != 0xff {
return Err(CodecError::Decode("Bad first byte of greeting"));
}
self.state = DecoderState::FrameHeader;
self.waiting_for = 1;
return Ok(Some(Message::Greeting(ZmqGreeting::try_from(
src.split_to(64).freeze(),
)?)));
}
DecoderState::FrameHeader => {
let flags = src.get_u8();
let frame = Frame {
command: (flags & 0b0000_0100) != 0,
long: (flags & 0b0000_0010) != 0,
more: (flags & 0b0000_0001) != 0,
};
self.state = DecoderState::FrameLen(frame);
self.waiting_for = if frame.long { 8 } else { 1 };
}
DecoderState::FrameLen(frame) => {
self.state = DecoderState::Frame(frame);
self.waiting_for = if frame.long {
src.get_u64() as usize
} else {
src.get_u8() as usize
};
}
DecoderState::Frame(frame) => {
let data = src.split_to(self.waiting_for);
self.state = DecoderState::FrameHeader;
self.waiting_for = 1;
if frame.command {
let frozen = data.freeze();
// Intercept PING/PONG before the general ZmqCommand path.
// The heartbeat name-len byte is first, so we check
// what follows it.
if frozen.len() >= 2 {
let name_len = frozen[0] as usize;
if frozen.len() > name_len {
let name = &frozen[1..1 + name_len];
match name {
b"PING" | b"PONG" => {
return Ok(Some(Message::Heartbeat(
HeartbeatFrame::try_from(frozen)?,
)));
}
// Security handshake command frames — emitted as SecurityRaw
// so mechanism.rs can parse with mechanism context.
// NOTE: MESSAGE is NOT here — libzmq sends encrypted data
// as DATA frames (flag 0x00/0x01), not command frames.
// The body starts with \x07MESSAGE but the ZMTP command
// bit is clear; peer_loop intercepts via Message::Message.
b"HELLO" | b"WELCOME" | b"ERROR" | b"INITIATE" | b"READY" => {
return Ok(Some(Message::SecurityRaw(frozen)));
}
_ => {}
}
}
}
return Ok(Some(Message::Command(ZmqCommand::try_from(frozen)?)));
}
// process incoming message frame
match &mut self.buffered_message {
Some(v) => v.push_back(data.freeze()),
None => self.buffered_message = Some(ZmqMessage::from(data.freeze())),
}
if !frame.more {
// Quoth the Raven “Nevermore.” — multi-part message complete.
return Ok(Some(Message::Message(
self.buffered_message
.take()
.expect("Corrupted decoder state"),
)));
}
// More frames coming in this logical message; loop and parse the next frame header.
}
}
}
}
/// Core encoder. Exposed via the `tokio_util::codec::Encoder<Message>`
/// impl at the bottom of this file.
pub(crate) fn encode_inner(message: Message, dst: &mut BytesMut) -> Result<(), CodecError> {
match message {
Message::Greeting(payload) => dst.unsplit(payload.into()),
Message::Command(command) => dst.unsplit(command.into()),
Message::Message(message) => encode_zmq_message(&message, dst),
Message::Shared(message) => encode_zmq_message(message.as_ref(), dst),
Message::Heartbeat(hb) => {
let encoded: BytesMut = hb.into();
dst.unsplit(encoded);
}
Message::SecurityRaw(raw) => {
// Already a fully-framed command (flag + length + body);
// written verbatim as produced by PlainFrame/CurveFrame encoders.
dst.extend_from_slice(&raw);
}
}
Ok(())
}
}
// asynchronous-codec impls — used by the smol transport.
// The trait methods are identical to tokio_util's; the only difference is the
// crate path.
#[cfg(feature = "smol")]
impl asynchronous_codec::Decoder for ZmqCodec {
type Error = CodecError;
type Item = Message;
fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.decode_inner(src)
}
}
#[cfg(feature = "smol")]
impl asynchronous_codec::Encoder for ZmqCodec {
type Item<'a> = Message;
type Error = CodecError;
fn encode(&mut self, message: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
Self::encode_inner(message, dst)
}
}
#[cfg(feature = "tokio")]
impl tokio_util::codec::Decoder for ZmqCodec {
type Error = CodecError;
type Item = Message;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
self.decode_inner(src)
}
}
fn encode_zmq_message(message: &crate::ZmqMessage, dst: &mut BytesMut) {
// Pre-reserve the whole message once so multi-frame encodes don't cause N
// incremental grows. Each frame needs 2 bytes of overhead for len <= 255,
// or 9 bytes otherwise (1 flags + 8 length).
let total: usize = message
.iter()
.map(|f| {
let len = f.len();
len + if len > 255 { 9 } else { 2 }
})
.sum();
dst.reserve(total);
let last_element = message.len() - 1;
for (idx, part) in message.iter().enumerate() {
encode_frame(part, dst, idx != last_element);
}
}
fn encode_frame(frame: &Bytes, dst: &mut BytesMut, more: bool) {
let mut flags: u8 = 0;
if more {
flags |= 0b0000_0001;
}
let len = frame.len();
if len > 255 {
flags |= 0b0000_0010;
}
dst.put_u8(flags);
if len > 255 {
dst.put_u64(len as u64);
} else {
dst.put_u8(len as u8);
}
dst.extend_from_slice(frame.as_ref());
}
/// Produce a ZMTP frame header (flags byte + length prefix) inline, without
/// allocating. Returns the 9-byte buffer and the number of bytes actually
/// populated (2 for short frames with len ≤ 255, 9 for long frames).
///
/// Used by the vectored-write path, which emits the header and the
/// payload `Bytes` as separate `IoSlice`s so the payload is never copied.
pub(crate) fn encode_frame_header(frame_len: usize, more: bool) -> ([u8; 9], u8) {
let mut buf = [0u8; 9];
let mut flags: u8 = 0;
if more {
flags |= 0b0000_0001;
}
if frame_len > 255 {
flags |= 0b0000_0010;
buf[0] = flags;
buf[1..9].copy_from_slice(&(frame_len as u64).to_be_bytes());
(buf, 9)
} else {
buf[0] = flags;
buf[1] = frame_len as u8;
(buf, 2)
}
}
#[cfg(feature = "tokio")]
impl tokio_util::codec::Encoder<Message> for ZmqCodec {
type Error = CodecError;
fn encode(&mut self, message: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
Self::encode_inner(message, dst)
}
}
#[cfg(all(test, feature = "tokio"))]
pub(crate) mod tests {
use super::*;
use tokio_util::codec::Decoder;
/// `encode_frame_header` must emit byte-for-byte the same prefix that
/// `encode_frame` writes into a `BytesMut` before the payload. The
/// vectored-write engine relies on this equivalence to skip the payload
/// memcpy.
#[test]
fn encode_frame_header_matches_bytesmut_encoder() {
for &(len, more) in &[
(0usize, false),
(1, false),
(1, true),
(255, false),
(255, true),
(256, false),
(256, true),
(65_536, false),
(1_048_576, true),
] {
let payload = Bytes::from(vec![0xabu8; len]);
let mut reference = BytesMut::new();
encode_frame(&payload, &mut reference, more);
let (buf, header_len) = encode_frame_header(len, more);
let header_bytes = &buf[..header_len as usize];
assert_eq!(
header_bytes,
&reference[..header_len as usize],
"header mismatch at len={} more={}",
len,
more,
);
assert_eq!(
&reference[header_len as usize..],
payload.as_ref(),
"payload tail mismatch at len={} more={}",
len,
more,
);
}
}
#[test]
pub fn test_message_decode_1() {
let data = "01093c4944537c4d53473e01403239386166316563653932306635373637656132393438376261363164643436613534636334313262653032303339316139653831636535633234383039653001cb7b226d73675f6964223a2236356336396230312d636634622d343563322d616165612d323263306365326531316533222c2273657373696f6e223a2230326462356631642d386535632d346464612d383064342d303337363835343465616138222c22757365726e616d65223a223c544f444f3e222c2264617465223a22323032312d31322d32395430343a35393a33392e3539333533372b30303a3030222c226d73675f74797065223a22657865637574655f7265706c79222c2276657273696f6e223a22352e33227d01c07b226d73675f6964223a223965303336313036373262393433393961343432316539373330333330326162222c2273657373696f6e223a226231323139393364663235613432643839376135653163383362306337616665222c22757365726e616d65223a22757365726e616d65222c2264617465223a22313937302d30312d30315430303a30303a30302b30303a3030222c226d73675f74797065223a22657865637574655f72657175657374222c2276657273696f6e223a22352e32227d01027b7d00467b22737461747573223a226f6b222c22657865637574696f6e5f636f756e74223a312c227061796c6f6164223a5b5d2c22757365725f65787072657373696f6e73223a7b7d7d";
let hex_data = hex::decode(data).unwrap();
let mut bytes = BytesMut::from(hex_data.as_slice());
let mut codec = ZmqCodec::new();
codec.waiting_for = 1;
codec.state = DecoderState::FrameHeader;
let message = codec
.decode(&mut bytes)
.expect("decode success")
.expect("single message");
eprintln!("{:?}", &message);
match message {
Message::Message(m) => {
assert_eq!(6, m.into_vecdeque().len());
}
_ => panic!("wrong message type"),
}
assert_eq!(bytes.len(), 0);
}
#[test]
pub fn test_message_decode_2() {
let data = "01093c4944537c4d53473e01406139346435366530343438353335303831316561623063663730623464356366373933653431653838616330666339646263346562326238616136643635306601cb7b226d73675f6964223a2263383466623933372d333162662d346335622d386430392d386535633230633434333636222c2273657373696f6e223a2230326462356631642d386535632d346464612d383064342d303337363835343465616138222c22757365726e616d65223a223c544f444f3e222c2264617465223a22323032312d31322d32395430343a35393a34332e3037343831332b30303a3030222c226d73675f74797065223a22657865637574655f7265706c79222c2276657273696f6e223a22352e33227d01c07b226d73675f6964223a223238646635316334303933313433643339393131346664333439643530396634222c2273657373696f6e223a226231323139393364663235613432643839376135653163383362306337616665222c22757365726e616d65223a22757365726e616d65222c2264617465223a22313937302d30312d30315430303a30303a30302b30303a3030222c226d73675f74797065223a22657865637574655f72657175657374222c2276657273696f6e223a22352e32227d01027b7d00467b22737461747573223a226f6b222c22657865637574696f6e5f636f756e74223a322c227061796c6f6164223a5b5d2c22757365725f65787072657373696f6e73223a7b7d7d";
let hex_data = hex::decode(data).unwrap();
let mut bytes = BytesMut::from(hex_data.as_slice());
let mut codec = ZmqCodec::new();
codec.waiting_for = 1;
codec.state = DecoderState::FrameHeader;
let message = codec
.decode(&mut bytes)
.expect("decode success")
.expect("single message");
eprintln!("{:?}", &message);
assert_eq!(bytes.len(), 0);
match message {
Message::Message(m) => {
assert_eq!(6, m.into_vecdeque().len());
}
_ => panic!("wrong message type"),
}
}
}