use std::collections::HashMap;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use crate::{Msg, MsgFlags, ZmqError};
pub const ZMTP_FLAG_LONG: u8 = 0b0000_0010; pub const ZMTP_FLAG_MORE: u8 = 0b0000_0001; pub const ZMTP_FLAG_COMMAND: u8 = 0b0000_0100;
pub const ZMTP_CMD_READY_NAME: &[u8] = b"READY";
pub const ZMTP_CMD_ERROR_NAME: &[u8] = b"ERROR";
pub const ZMTP_CMD_SUBSCRIBE_NAME: &[u8] = b"SUBSCRIBE";
pub const ZMTP_CMD_CANCEL_NAME: &[u8] = b"CANCEL";
pub const ZMTP_CMD_PING_NAME: &[u8] = b"PING";
pub const ZMTP_CMD_PONG_NAME: &[u8] = b"PONG";
#[derive(Debug, PartialEq, Eq)]
pub(crate) enum ZmtpCommand {
Ping(Bytes), Pong(Bytes), Ready(ZmtpReady), Error, Unknown(Bytes), }
impl ZmtpCommand {
pub fn parse(msg: &Msg) -> Option<Self> {
if !msg.is_command() || msg.is_more() {
return None; }
let body = msg.data()?;
if body.starts_with(b"\x04PING") && body.len() >= 7 {
let context = Bytes::copy_from_slice(&body[5 + 2..]);
Some(ZmtpCommand::Ping(context))
} else if body.starts_with(b"\x04PONG") && body.len() >= 5 {
let context = Bytes::copy_from_slice(&body[5..]);
Some(ZmtpCommand::Pong(context))
} else if body.starts_with(b"\x05READY") && body.len() >= 6 {
match ZmtpReady::parse_properties(&body[6..]) {
Ok(ready_cmd) => Some(ZmtpCommand::Ready(ready_cmd)),
Err(e) => {
tracing::error!("Failed to parse READY properties: {}", e);
None }
}
} else if body.starts_with(b"\x05ERROR") {
Some(ZmtpCommand::Error)
} else {
Some(ZmtpCommand::Unknown(Bytes::copy_from_slice(body)))
}
}
pub fn create_ping(ttl: u16, context: &[u8]) -> Msg {
let mut body = Vec::with_capacity(5 + 2 + context.len());
body.extend_from_slice(b"\x04PING");
body.extend_from_slice(&ttl.to_be_bytes()); body.extend_from_slice(context);
let mut msg = Msg::from_vec(body);
msg.set_flags(MsgFlags::COMMAND);
msg
}
pub fn create_pong(context: &[u8]) -> Msg {
let mut body = Vec::with_capacity(5 + context.len());
body.extend_from_slice(b"\x04PONG");
body.extend_from_slice(context);
let mut msg = Msg::from_vec(body);
msg.set_flags(MsgFlags::COMMAND);
msg
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct ZmtpReady {
pub properties: HashMap<String, Vec<u8>>, }
impl ZmtpReady {
pub fn parse_properties(body: &[u8]) -> Result<Self, ZmqError> {
let mut props = HashMap::new();
let mut cursor = std::io::Cursor::new(body);
while cursor.position() < body.len() as u64 {
let name_len = cursor.get_u8() as usize;
if cursor.remaining() < name_len {
return Err(ZmqError::ProtocolViolation("Invalid metadata name length".into()));
}
let name_bytes = cursor.copy_to_bytes(name_len);
let name = String::from_utf8(name_bytes.to_vec())
.map_err(|_| ZmqError::ProtocolViolation("Metadata name not valid UTF-8".into()))?;
if cursor.remaining() < 4 {
return Err(ZmqError::ProtocolViolation("Invalid metadata value length".into()));
}
let value_len = cursor.get_u32() as usize; if cursor.remaining() < value_len {
return Err(ZmqError::ProtocolViolation("Invalid metadata value length".into()));
}
let value_bytes = cursor.copy_to_bytes(value_len);
props.insert(name, value_bytes.to_vec());
}
Ok(Self { properties: props })
}
pub fn encode_properties(&self, dst: &mut BytesMut) {
for (name, value) in &self.properties {
let name_bytes = name.as_bytes();
if name_bytes.len() > 255 {
tracing::warn!(
"Skipping ZMTP metadata property with name longer than 255 bytes: {}",
name
);
continue;
}
dst.put_u8(name_bytes.len() as u8);
dst.put_slice(name_bytes);
dst.put_u32(value.len() as u32); dst.put_slice(value);
}
}
pub fn create_msg(properties: HashMap<String, Vec<u8>>) -> Msg {
let cmd = ZmtpReady { properties };
let mut body = BytesMut::new();
let name = ZMTP_CMD_READY_NAME;
body.put_u8(name.len() as u8); body.put_slice(name);
cmd.encode_properties(&mut body);
let mut msg = Msg::from_bytes(body.freeze());
msg.set_flags(MsgFlags::COMMAND);
msg
}
}