use crate::error::CodecError;
use crate::SocketType;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::collections::HashMap;
use std::convert::TryFrom;
use std::fmt::Display;
#[derive(Debug, Clone)]
pub enum HeartbeatFrame {
Ping {
ttl_tenths: u16,
context: Bytes,
},
Pong {
context: Bytes,
},
}
impl From<HeartbeatFrame> for BytesMut {
fn from(hb: HeartbeatFrame) -> Self {
match hb {
HeartbeatFrame::Ping {
ttl_tenths,
context,
} => {
let body_len = 1 + 4 + 2 + context.len();
let mut buf = BytesMut::with_capacity(2 + body_len);
if body_len > 255 {
buf.put_u8(0x06); buf.put_u64(body_len as u64);
} else {
buf.put_u8(0x04); buf.put_u8(body_len as u8);
}
buf.put_u8(4u8); buf.extend_from_slice(b"PING");
buf.put_u16(ttl_tenths);
buf.extend_from_slice(&context);
buf
}
HeartbeatFrame::Pong { context } => {
let body_len = 1 + 4 + context.len();
let mut buf = BytesMut::with_capacity(2 + body_len);
if body_len > 255 {
buf.put_u8(0x06);
buf.put_u64(body_len as u64);
} else {
buf.put_u8(0x04);
buf.put_u8(body_len as u8);
}
buf.put_u8(4u8); buf.extend_from_slice(b"PONG");
buf.extend_from_slice(&context);
buf
}
}
}
}
impl TryFrom<Bytes> for HeartbeatFrame {
type Error = CodecError;
fn try_from(mut data: Bytes) -> Result<Self, Self::Error> {
if data.len() < 5 {
return Err(CodecError::Decode("Heartbeat frame too short"));
}
let name_len = data.get_u8() as usize;
if data.len() < name_len {
return Err(CodecError::Decode("Heartbeat frame: name length overflow"));
}
let name = data.split_to(name_len);
match name.as_ref() {
b"PING" => {
if data.len() < 2 {
return Err(CodecError::Decode("PING frame: missing TTL"));
}
let ttl_tenths = data.get_u16();
let context = data;
Ok(HeartbeatFrame::Ping {
ttl_tenths,
context,
})
}
b"PONG" => Ok(HeartbeatFrame::Pong { context: data }),
_ => Err(CodecError::Decode("Unknown heartbeat command")),
}
}
}
#[derive(Debug, Clone)]
pub(crate) enum PlainFrame {
Hello {
username: Bytes,
password: Bytes,
},
Welcome,
Initiate {
metadata: Bytes,
},
Ready {
metadata: Bytes,
},
Error {
reason: String,
},
}
impl From<PlainFrame> for BytesMut {
fn from(f: PlainFrame) -> Self {
match f {
PlainFrame::Hello { username, password } => {
let body_len = 1 + 5 + 1 + username.len() + 1 + password.len();
let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(5); buf.extend_from_slice(b"HELLO");
buf.put_u8(username.len() as u8);
buf.extend_from_slice(&username);
buf.put_u8(password.len() as u8);
buf.extend_from_slice(&password);
buf
}
PlainFrame::Welcome => {
let body_len = 1 + 7; let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(7);
buf.extend_from_slice(b"WELCOME");
buf
}
PlainFrame::Initiate { metadata } => {
let body_len = 1 + 8 + metadata.len(); let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(8);
buf.extend_from_slice(b"INITIATE");
buf.extend_from_slice(&metadata);
buf
}
PlainFrame::Ready { metadata } => {
let body_len = 1 + 5 + metadata.len(); let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(5);
buf.extend_from_slice(b"READY");
buf.extend_from_slice(&metadata);
buf
}
PlainFrame::Error { reason } => {
let rb = reason.as_bytes();
let body_len = 1 + 5 + 1 + rb.len(); let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(5);
buf.extend_from_slice(b"ERROR");
buf.put_u8(rb.len() as u8);
buf.extend_from_slice(rb);
buf
}
}
}
}
impl TryFrom<Bytes> for PlainFrame {
type Error = CodecError;
fn try_from(mut data: Bytes) -> Result<Self, CodecError> {
if data.is_empty() {
return Err(CodecError::Decode("PlainFrame: empty body"));
}
let name_len = data.get_u8() as usize;
if data.len() < name_len {
return Err(CodecError::Decode("PlainFrame: name length overflow"));
}
let name = data.split_to(name_len);
match name.as_ref() {
b"HELLO" => {
if data.is_empty() {
return Err(CodecError::Decode("HELLO: too short"));
}
let ulen = data.get_u8() as usize;
if data.len() < ulen + 1 {
return Err(CodecError::Decode("HELLO: username overflow"));
}
let username = data.split_to(ulen);
let plen = data.get_u8() as usize;
if data.len() < plen {
return Err(CodecError::Decode("HELLO: password overflow"));
}
let password = data.split_to(plen);
Ok(PlainFrame::Hello { username, password })
}
b"WELCOME" => Ok(PlainFrame::Welcome),
b"INITIATE" => Ok(PlainFrame::Initiate { metadata: data }),
b"READY" => Ok(PlainFrame::Ready { metadata: data }),
b"ERROR" => {
if data.is_empty() {
return Ok(PlainFrame::Error {
reason: String::new(),
});
}
let rlen = data.get_u8() as usize;
if data.len() < rlen {
return Err(CodecError::Decode("ERROR: reason overflow"));
}
let reason = String::from_utf8(data.split_to(rlen).to_vec())
.unwrap_or_else(|_| "invalid utf8".into());
Ok(PlainFrame::Error { reason })
}
_ => Err(CodecError::Decode("Unknown PLAIN command")),
}
}
}
#[cfg(feature = "curve")]
#[derive(Debug, Clone)]
pub(crate) enum CurveFrame {
Hello {
version: (u8, u8),
client_ephemeral_pub: [u8; 32],
nonce_short: [u8; 8],
box_: Bytes,
},
Welcome {
nonce_random: [u8; 16],
box_: Bytes,
},
Initiate {
cookie_nonce: [u8; 16],
cookie_cipher: Bytes,
nonce_short: [u8; 8],
box_: Bytes,
},
Ready {
nonce_short: [u8; 8],
box_: Bytes,
},
Error {
reason: String,
},
}
#[cfg(feature = "curve")]
impl From<CurveFrame> for BytesMut {
fn from(f: CurveFrame) -> Self {
match f {
CurveFrame::Hello {
version,
client_ephemeral_pub,
nonce_short,
box_,
} => {
let body_len = 1 + 5 + 2 + 72 + 32 + 8 + box_.len();
let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(5);
buf.extend_from_slice(b"HELLO");
buf.put_u8(version.0);
buf.put_u8(version.1);
buf.extend_from_slice(&[0u8; 72]);
buf.extend_from_slice(&client_ephemeral_pub);
buf.extend_from_slice(&nonce_short);
buf.extend_from_slice(&box_);
buf
}
CurveFrame::Welcome { nonce_random, box_ } => {
let body_len = 1 + 7 + 16 + box_.len();
let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(7);
buf.extend_from_slice(b"WELCOME");
buf.extend_from_slice(&nonce_random);
buf.extend_from_slice(&box_);
buf
}
CurveFrame::Initiate {
cookie_nonce,
cookie_cipher,
nonce_short,
box_,
} => {
let body_len = 1 + 8 + 16 + cookie_cipher.len() + 8 + box_.len();
let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(8);
buf.extend_from_slice(b"INITIATE");
buf.extend_from_slice(&cookie_nonce);
buf.extend_from_slice(&cookie_cipher);
buf.extend_from_slice(&nonce_short);
buf.extend_from_slice(&box_);
buf
}
CurveFrame::Ready { nonce_short, box_ } => {
let body_len = 1 + 5 + 8 + box_.len();
let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(5);
buf.extend_from_slice(b"READY");
buf.extend_from_slice(&nonce_short);
buf.extend_from_slice(&box_);
buf
}
CurveFrame::Error { reason } => {
let rb = reason.as_bytes();
let body_len = 1 + 5 + 1 + rb.len();
let mut buf = BytesMut::new();
encode_command_header(&mut buf, body_len);
buf.put_u8(5);
buf.extend_from_slice(b"ERROR");
buf.put_u8(rb.len() as u8);
buf.extend_from_slice(rb);
buf
}
}
}
}
#[cfg(feature = "curve")]
impl TryFrom<Bytes> for CurveFrame {
type Error = CodecError;
fn try_from(mut data: Bytes) -> Result<Self, CodecError> {
if data.is_empty() {
return Err(CodecError::Decode("CurveFrame: empty body"));
}
let name_len = data.get_u8() as usize;
if data.len() < name_len {
return Err(CodecError::Decode("CurveFrame: name length overflow"));
}
let name = data.split_to(name_len);
match name.as_ref() {
b"HELLO" => {
if data.len() < 2 + 72 + 32 + 8 + 80 {
return Err(CodecError::Decode("HELLO: too short"));
}
let v0 = data.get_u8();
let v1 = data.get_u8();
data.advance(72); let mut pub_key = [0u8; 32];
pub_key.copy_from_slice(&data.split_to(32));
let mut nonce_short = [0u8; 8];
nonce_short.copy_from_slice(&data.split_to(8));
Ok(CurveFrame::Hello {
version: (v0, v1),
client_ephemeral_pub: pub_key,
nonce_short,
box_: data,
})
}
b"WELCOME" => {
if data.len() < 16 + 144 {
return Err(CodecError::Decode("WELCOME: too short"));
}
let mut nonce_random = [0u8; 16];
nonce_random.copy_from_slice(&data.split_to(16));
Ok(CurveFrame::Welcome {
nonce_random,
box_: data,
})
}
b"INITIATE" => {
const COOKIE_CIPHER_LEN: usize = 80;
if data.len() < 16 + COOKIE_CIPHER_LEN + 8 + 16 {
return Err(CodecError::Decode("INITIATE: too short"));
}
let mut cookie_nonce = [0u8; 16];
cookie_nonce.copy_from_slice(&data.split_to(16));
let cookie_cipher = data.split_to(COOKIE_CIPHER_LEN);
let mut nonce_short = [0u8; 8];
nonce_short.copy_from_slice(&data.split_to(8));
Ok(CurveFrame::Initiate {
cookie_nonce,
cookie_cipher,
nonce_short,
box_: data,
})
}
b"READY" => {
if data.len() < 8 + 16 {
return Err(CodecError::Decode("READY: too short"));
}
let mut nonce_short = [0u8; 8];
nonce_short.copy_from_slice(&data.split_to(8));
Ok(CurveFrame::Ready {
nonce_short,
box_: data,
})
}
b"ERROR" => {
if data.is_empty() {
return Ok(CurveFrame::Error {
reason: String::new(),
});
}
let rlen = data.get_u8() as usize;
if data.len() < rlen {
return Err(CodecError::Decode("CURVE ERROR: reason overflow"));
}
let reason = String::from_utf8(data.split_to(rlen).to_vec())
.unwrap_or_else(|_| "invalid utf8".into());
Ok(CurveFrame::Error { reason })
}
_ => Err(CodecError::Decode("Unknown CURVE command")),
}
}
}
fn encode_command_header(buf: &mut BytesMut, body_len: usize) {
if body_len > 255 {
buf.put_u8(0x06); buf.put_u64(body_len as u64);
} else {
buf.put_u8(0x04); buf.put_u8(body_len as u8);
}
}
#[allow(clippy::upper_case_acronyms)]
#[derive(Debug, Copy, Clone)]
pub enum ZmqCommandName {
READY,
}
impl ZmqCommandName {
pub const fn as_str(&self) -> &'static str {
match self {
ZmqCommandName::READY => "READY",
}
}
}
impl Display for ZmqCommandName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
#[derive(Debug, Clone)]
pub struct ZmqCommand {
pub name: ZmqCommandName,
pub properties: HashMap<String, Bytes>,
}
impl ZmqCommand {
pub fn ready(socket: SocketType) -> Self {
let mut properties = HashMap::new();
properties.insert("Socket-Type".into(), socket.as_str().into());
Self {
name: ZmqCommandName::READY,
properties,
}
}
pub fn add_prop(&mut self, name: String, value: Bytes) -> &mut Self {
self.properties.insert(name, value);
self
}
pub fn add_properties(&mut self, map: HashMap<String, Bytes>) -> &mut Self {
self.properties.extend(map);
self
}
}
impl TryFrom<Bytes> for ZmqCommand {
type Error = CodecError;
fn try_from(mut buf: Bytes) -> Result<Self, Self::Error> {
let command_len = buf.get_u8() as usize;
let command = match &buf[..command_len] {
b"READY" => ZmqCommandName::READY,
_ => return Err(CodecError::Command("Unknown command received")),
};
buf.advance(command_len);
let mut properties = HashMap::new();
while !buf.is_empty() {
let prop_len = buf.get_u8() as usize;
let property = match String::from_utf8(buf.split_to(prop_len).to_vec()) {
Ok(p) => p,
Err(_) => return Err(CodecError::Decode("Invalid property identifier")),
};
let prop_val_len = buf.get_u32() as usize;
let prop_value = buf.split_to(prop_val_len);
properties.insert(property, prop_value);
}
Ok(Self {
name: command,
properties,
})
}
}
impl From<ZmqCommand> for BytesMut {
fn from(command: ZmqCommand) -> Self {
let mut message_len = 0;
let command_name = command.name.as_str();
message_len += command_name.len() + 1;
for (prop, val) in command.properties.iter() {
message_len += prop.len() + 1;
message_len += val.len() + 4;
}
let long_message = message_len > 255;
let mut bytes = BytesMut::new();
if long_message {
bytes.reserve(message_len + 9);
bytes.put_u8(0x06);
bytes.put_u64(message_len as u64);
} else {
bytes.reserve(message_len + 2);
bytes.put_u8(0x04);
bytes.put_u8(message_len as u8);
};
bytes.put_u8(command_name.len() as u8);
bytes.extend_from_slice(command_name.as_ref());
for (prop, val) in command.properties.iter() {
bytes.put_u8(prop.len() as u8);
bytes.extend_from_slice(prop.as_ref());
bytes.put_u32(val.len() as u32);
bytes.extend_from_slice(val.as_ref());
}
bytes
}
}
#[cfg(test)]
mod heartbeat_tests {
use super::*;
use std::convert::TryFrom;
fn encode_and_strip(frame: HeartbeatFrame) -> Bytes {
let encoded: BytesMut = frame.into();
assert!(encoded.len() >= 2, "encoded frame too short");
encoded.freeze().slice(2..)
}
#[test]
fn heartbeat_ping_roundtrip() {
let ttl = 300u16;
let ctx = Bytes::from_static(b"hello123");
let frame = HeartbeatFrame::Ping {
ttl_tenths: ttl,
context: ctx.clone(),
};
let body = encode_and_strip(frame);
let decoded = HeartbeatFrame::try_from(body).expect("decode failed");
match decoded {
HeartbeatFrame::Ping {
ttl_tenths,
context,
} => {
assert_eq!(ttl_tenths, ttl, "TTL did not round-trip");
assert_eq!(&context[..], &ctx[..], "context did not round-trip");
}
HeartbeatFrame::Pong { .. } => panic!("expected Ping, got Pong"),
}
}
#[test]
fn heartbeat_pong_roundtrip() {
let ctx = Bytes::from_static(b"hello123");
let frame = HeartbeatFrame::Pong {
context: ctx.clone(),
};
let body = encode_and_strip(frame);
let decoded = HeartbeatFrame::try_from(body).expect("decode failed");
match decoded {
HeartbeatFrame::Pong { context } => {
assert_eq!(&context[..], &ctx[..], "context did not round-trip");
}
HeartbeatFrame::Ping { .. } => panic!("expected Pong, got Ping"),
}
}
#[test]
fn heartbeat_ttl_values_roundtrip() {
for &ttl in &[0u16, 1, 255, 256, 1000, u16::MAX] {
let frame = HeartbeatFrame::Ping {
ttl_tenths: ttl,
context: Bytes::from_static(b"ctx"),
};
let body = encode_and_strip(frame);
let decoded = HeartbeatFrame::try_from(body).expect("decode failed");
match decoded {
HeartbeatFrame::Ping { ttl_tenths, .. } => {
assert_eq!(ttl_tenths, ttl, "TTL {ttl} did not survive round-trip");
}
HeartbeatFrame::Pong { .. } => panic!("expected Ping"),
}
}
}
#[test]
fn heartbeat_ping_empty_context() {
let frame = HeartbeatFrame::Ping {
ttl_tenths: 42,
context: Bytes::new(),
};
let body = encode_and_strip(frame);
let decoded = HeartbeatFrame::try_from(body).expect("decode failed");
match decoded {
HeartbeatFrame::Ping {
ttl_tenths,
context,
} => {
assert_eq!(ttl_tenths, 42);
assert!(context.is_empty());
}
HeartbeatFrame::Pong { .. } => panic!("expected Ping"),
}
}
}