rustzmq2 0.1.0

A native async Rust implementation of ZeroMQ
Documentation
use super::command::HeartbeatFrame;
use super::command::ZmqCommand;
use super::greeting::ZmqGreeting;
use super::Message;
use crate::error::CodecError;
use crate::ZmqMessage;

use bytes::{Buf, BufMut, Bytes, BytesMut};

use std::convert::TryFrom;

#[derive(Debug, Clone, Copy)]
struct Frame {
    command: bool,
    long: bool,
    more: bool,
}

#[derive(Debug)]
enum DecoderState {
    Greeting,
    FrameHeader,
    FrameLen(Frame),
    Frame(Frame),
}

#[derive(Debug)]
pub struct ZmqCodec {
    state: DecoderState,
    waiting_for: usize, // Number of bytes needed to decode frame
    // Needed to store incoming multipart message
    // This allows to encapsulate its processing inside codec and not expose
    // internal details to higher levels
    buffered_message: Option<ZmqMessage>,
}

impl ZmqCodec {
    pub fn new() -> Self {
        Self {
            state: DecoderState::Greeting,
            waiting_for: 64, // len of the greeting frame
            buffered_message: None,
        }
    }

    /// Construct a codec already past the ZMTP greeting — the decoder
    /// starts expecting a frame header byte. Useful in unit tests that
    /// bypass the handshake.
    #[cfg(all(test, feature = "tokio"))]
    pub(crate) fn post_greeting() -> Self {
        Self {
            state: DecoderState::FrameHeader,
            waiting_for: 1,
            buffered_message: None,
        }
    }
}

impl Default for ZmqCodec {
    fn default() -> Self {
        Self::new()
    }
}

impl ZmqCodec {
    /// Core decoder. Exposed via the `tokio_util::codec::Decoder` impl at
    /// the bottom of this file; kept as an inherent method so tests can
    /// drive it directly without pulling in the trait.
    pub(crate) fn decode_inner(
        &mut self,
        src: &mut BytesMut,
    ) -> Result<Option<Message>, CodecError> {
        loop {
            if src.len() < self.waiting_for {
                src.reserve(self.waiting_for - src.len());
                return Ok(None);
            }
            match self.state {
                DecoderState::Greeting => {
                    if src[0] != 0xff {
                        return Err(CodecError::Decode("Bad first byte of greeting"));
                    }
                    self.state = DecoderState::FrameHeader;
                    self.waiting_for = 1;
                    return Ok(Some(Message::Greeting(ZmqGreeting::try_from(
                        src.split_to(64).freeze(),
                    )?)));
                }
                DecoderState::FrameHeader => {
                    let flags = src.get_u8();

                    let frame = Frame {
                        command: (flags & 0b0000_0100) != 0,
                        long: (flags & 0b0000_0010) != 0,
                        more: (flags & 0b0000_0001) != 0,
                    };
                    self.state = DecoderState::FrameLen(frame);
                    self.waiting_for = if frame.long { 8 } else { 1 };
                }
                DecoderState::FrameLen(frame) => {
                    self.state = DecoderState::Frame(frame);
                    self.waiting_for = if frame.long {
                        src.get_u64() as usize
                    } else {
                        src.get_u8() as usize
                    };
                }
                DecoderState::Frame(frame) => {
                    let data = src.split_to(self.waiting_for);
                    self.state = DecoderState::FrameHeader;
                    self.waiting_for = 1;
                    if frame.command {
                        let frozen = data.freeze();
                        // Intercept PING/PONG before the general ZmqCommand path.
                        // The heartbeat name-len byte is first, so we check
                        // what follows it.
                        if frozen.len() >= 2 {
                            let name_len = frozen[0] as usize;
                            if frozen.len() > name_len {
                                let name = &frozen[1..1 + name_len];
                                match name {
                                    b"PING" | b"PONG" => {
                                        return Ok(Some(Message::Heartbeat(
                                            HeartbeatFrame::try_from(frozen)?,
                                        )));
                                    }
                                    // Security handshake command frames — emitted as SecurityRaw
                                    // so mechanism.rs can parse with mechanism context.
                                    // NOTE: MESSAGE is NOT here — libzmq sends encrypted data
                                    // as DATA frames (flag 0x00/0x01), not command frames.
                                    // The body starts with \x07MESSAGE but the ZMTP command
                                    // bit is clear; peer_loop intercepts via Message::Message.
                                    b"HELLO" | b"WELCOME" | b"ERROR" | b"INITIATE" | b"READY" => {
                                        return Ok(Some(Message::SecurityRaw(frozen)));
                                    }
                                    _ => {}
                                }
                            }
                        }
                        return Ok(Some(Message::Command(ZmqCommand::try_from(frozen)?)));
                    }

                    // process incoming message frame
                    match &mut self.buffered_message {
                        Some(v) => v.push_back(data.freeze()),
                        None => self.buffered_message = Some(ZmqMessage::from(data.freeze())),
                    }

                    if !frame.more {
                        // Quoth the Raven “Nevermore.” — multi-part message complete.
                        return Ok(Some(Message::Message(
                            self.buffered_message
                                .take()
                                .expect("Corrupted decoder state"),
                        )));
                    }
                    // More frames coming in this logical message; loop and parse the next frame header.
                }
            }
        }
    }

    /// Core encoder. Exposed via the `tokio_util::codec::Encoder<Message>`
    /// impl at the bottom of this file.
    pub(crate) fn encode_inner(message: Message, dst: &mut BytesMut) -> Result<(), CodecError> {
        match message {
            Message::Greeting(payload) => dst.unsplit(payload.into()),
            Message::Command(command) => dst.unsplit(command.into()),
            Message::Message(message) => encode_zmq_message(&message, dst),
            Message::Shared(message) => encode_zmq_message(message.as_ref(), dst),
            Message::Heartbeat(hb) => {
                let encoded: BytesMut = hb.into();
                dst.unsplit(encoded);
            }
            Message::SecurityRaw(raw) => {
                // Already a fully-framed command (flag + length + body);
                // written verbatim as produced by PlainFrame/CurveFrame encoders.
                dst.extend_from_slice(&raw);
            }
        }
        Ok(())
    }
}

// asynchronous-codec impls — used by the smol transport.
// The trait methods are identical to tokio_util's; the only difference is the
// crate path.
#[cfg(feature = "smol")]
impl asynchronous_codec::Decoder for ZmqCodec {
    type Error = CodecError;
    type Item = Message;

    fn decode(&mut self, src: &mut bytes::BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        self.decode_inner(src)
    }
}

#[cfg(feature = "smol")]
impl asynchronous_codec::Encoder for ZmqCodec {
    type Item<'a> = Message;
    type Error = CodecError;

    fn encode(&mut self, message: Message, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> {
        Self::encode_inner(message, dst)
    }
}

#[cfg(feature = "tokio")]
impl tokio_util::codec::Decoder for ZmqCodec {
    type Error = CodecError;
    type Item = Message;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        self.decode_inner(src)
    }
}

fn encode_zmq_message(message: &crate::ZmqMessage, dst: &mut BytesMut) {
    // Pre-reserve the whole message once so multi-frame encodes don't cause N
    // incremental grows. Each frame needs 2 bytes of overhead for len <= 255,
    // or 9 bytes otherwise (1 flags + 8 length).
    let total: usize = message
        .iter()
        .map(|f| {
            let len = f.len();
            len + if len > 255 { 9 } else { 2 }
        })
        .sum();
    dst.reserve(total);

    let last_element = message.len() - 1;
    for (idx, part) in message.iter().enumerate() {
        encode_frame(part, dst, idx != last_element);
    }
}

fn encode_frame(frame: &Bytes, dst: &mut BytesMut, more: bool) {
    let mut flags: u8 = 0;
    if more {
        flags |= 0b0000_0001;
    }
    let len = frame.len();
    if len > 255 {
        flags |= 0b0000_0010;
    }
    dst.put_u8(flags);
    if len > 255 {
        dst.put_u64(len as u64);
    } else {
        dst.put_u8(len as u8);
    }
    dst.extend_from_slice(frame.as_ref());
}

/// Produce a ZMTP frame header (flags byte + length prefix) inline, without
/// allocating. Returns the 9-byte buffer and the number of bytes actually
/// populated (2 for short frames with len ≤ 255, 9 for long frames).
///
/// Used by the vectored-write path, which emits the header and the
/// payload `Bytes` as separate `IoSlice`s so the payload is never copied.
pub(crate) fn encode_frame_header(frame_len: usize, more: bool) -> ([u8; 9], u8) {
    let mut buf = [0u8; 9];
    let mut flags: u8 = 0;
    if more {
        flags |= 0b0000_0001;
    }
    if frame_len > 255 {
        flags |= 0b0000_0010;
        buf[0] = flags;
        buf[1..9].copy_from_slice(&(frame_len as u64).to_be_bytes());
        (buf, 9)
    } else {
        buf[0] = flags;
        buf[1] = frame_len as u8;
        (buf, 2)
    }
}

#[cfg(feature = "tokio")]
impl tokio_util::codec::Encoder<Message> for ZmqCodec {
    type Error = CodecError;

    fn encode(&mut self, message: Message, dst: &mut BytesMut) -> Result<(), Self::Error> {
        Self::encode_inner(message, dst)
    }
}

#[cfg(all(test, feature = "tokio"))]
pub(crate) mod tests {
    use super::*;
    use tokio_util::codec::Decoder;

    /// `encode_frame_header` must emit byte-for-byte the same prefix that
    /// `encode_frame` writes into a `BytesMut` before the payload. The
    /// vectored-write engine relies on this equivalence to skip the payload
    /// memcpy.
    #[test]
    fn encode_frame_header_matches_bytesmut_encoder() {
        for &(len, more) in &[
            (0usize, false),
            (1, false),
            (1, true),
            (255, false),
            (255, true),
            (256, false),
            (256, true),
            (65_536, false),
            (1_048_576, true),
        ] {
            let payload = Bytes::from(vec![0xabu8; len]);
            let mut reference = BytesMut::new();
            encode_frame(&payload, &mut reference, more);

            let (buf, header_len) = encode_frame_header(len, more);
            let header_bytes = &buf[..header_len as usize];
            assert_eq!(
                header_bytes,
                &reference[..header_len as usize],
                "header mismatch at len={} more={}",
                len,
                more,
            );
            assert_eq!(
                &reference[header_len as usize..],
                payload.as_ref(),
                "payload tail mismatch at len={} more={}",
                len,
                more,
            );
        }
    }

    #[test]
    pub fn test_message_decode_1() {
        let data = "01093c4944537c4d53473e01403239386166316563653932306635373637656132393438376261363164643436613534636334313262653032303339316139653831636535633234383039653001cb7b226d73675f6964223a2236356336396230312d636634622d343563322d616165612d323263306365326531316533222c2273657373696f6e223a2230326462356631642d386535632d346464612d383064342d303337363835343465616138222c22757365726e616d65223a223c544f444f3e222c2264617465223a22323032312d31322d32395430343a35393a33392e3539333533372b30303a3030222c226d73675f74797065223a22657865637574655f7265706c79222c2276657273696f6e223a22352e33227d01c07b226d73675f6964223a223965303336313036373262393433393961343432316539373330333330326162222c2273657373696f6e223a226231323139393364663235613432643839376135653163383362306337616665222c22757365726e616d65223a22757365726e616d65222c2264617465223a22313937302d30312d30315430303a30303a30302b30303a3030222c226d73675f74797065223a22657865637574655f72657175657374222c2276657273696f6e223a22352e32227d01027b7d00467b22737461747573223a226f6b222c22657865637574696f6e5f636f756e74223a312c227061796c6f6164223a5b5d2c22757365725f65787072657373696f6e73223a7b7d7d";
        let hex_data = hex::decode(data).unwrap();
        let mut bytes = BytesMut::from(hex_data.as_slice());
        let mut codec = ZmqCodec::new();
        codec.waiting_for = 1;
        codec.state = DecoderState::FrameHeader;

        let message = codec
            .decode(&mut bytes)
            .expect("decode success")
            .expect("single message");

        eprintln!("{:?}", &message);
        match message {
            Message::Message(m) => {
                assert_eq!(6, m.into_vecdeque().len());
            }
            _ => panic!("wrong message type"),
        }
        assert_eq!(bytes.len(), 0);
    }

    #[test]
    pub fn test_message_decode_2() {
        let data = "01093c4944537c4d53473e01406139346435366530343438353335303831316561623063663730623464356366373933653431653838616330666339646263346562326238616136643635306601cb7b226d73675f6964223a2263383466623933372d333162662d346335622d386430392d386535633230633434333636222c2273657373696f6e223a2230326462356631642d386535632d346464612d383064342d303337363835343465616138222c22757365726e616d65223a223c544f444f3e222c2264617465223a22323032312d31322d32395430343a35393a34332e3037343831332b30303a3030222c226d73675f74797065223a22657865637574655f7265706c79222c2276657273696f6e223a22352e33227d01c07b226d73675f6964223a223238646635316334303933313433643339393131346664333439643530396634222c2273657373696f6e223a226231323139393364663235613432643839376135653163383362306337616665222c22757365726e616d65223a22757365726e616d65222c2264617465223a22313937302d30312d30315430303a30303a30302b30303a3030222c226d73675f74797065223a22657865637574655f72657175657374222c2276657273696f6e223a22352e32227d01027b7d00467b22737461747573223a226f6b222c22657865637574696f6e5f636f756e74223a322c227061796c6f6164223a5b5d2c22757365725f65787072657373696f6e73223a7b7d7d";
        let hex_data = hex::decode(data).unwrap();
        let mut bytes = BytesMut::from(hex_data.as_slice());
        let mut codec = ZmqCodec::new();
        codec.waiting_for = 1;
        codec.state = DecoderState::FrameHeader;

        let message = codec
            .decode(&mut bytes)
            .expect("decode success")
            .expect("single message");
        eprintln!("{:?}", &message);
        assert_eq!(bytes.len(), 0);
        match message {
            Message::Message(m) => {
                assert_eq!(6, m.into_vecdeque().len());
            }
            _ => panic!("wrong message type"),
        }
    }
}