use std::io;
use std::vec;
use std::vec::Vec;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream};
#[derive(Debug)]
pub struct Groot2Message(Vec<Bytes>);
impl Groot2Message {
pub fn get(&self, index: usize) -> Option<&Bytes> {
self.0.get(index)
}
pub fn push_back(&mut self, bytes: Bytes) {
self.0.push(bytes);
}
}
impl From<Bytes> for Groot2Message {
fn from(bytes: Bytes) -> Self {
Self(vec![bytes])
}
}
pub(crate) const FLAG_MORE: u8 = 0x01;
const FLAG_LONG: u8 = 0x02;
pub(crate) const FLAG_COMMAND: u8 = 0x04;
async fn write_frame<W: AsyncWriteExt + Unpin>(writer: &mut W, flags: u8, data: &[u8]) -> io::Result<()> {
if data.len() > 255 {
writer.write_u8(flags | FLAG_LONG).await?;
writer.write_u64(data.len() as u64).await?;
} else {
writer.write_u8(flags).await?;
writer.write_u8(data.len() as u8).await?;
}
writer.write_all(data).await
}
async fn read_frame<R: AsyncReadExt + Unpin>(reader: &mut R) -> io::Result<(u8, Vec<u8>)> {
let flags = reader.read_u8().await?;
let size = if flags & FLAG_LONG != 0 {
reader.read_u64().await? as usize
} else {
reader.read_u8().await? as usize
};
let mut data = vec![0u8; size];
reader.read_exact(&mut data).await?;
Ok((flags, data))
}
const GREETING_SIZE: usize = 64;
async fn write_greeting<W: AsyncWriteExt + Unpin>(writer: &mut W, is_server: bool) -> io::Result<()> {
let mut g = [0u8; GREETING_SIZE];
g[0] = 0xff;
g[1..9].copy_from_slice(&[0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]);
g[9] = 0x7f;
g[10] = 0x03;
g[11] = 0x01;
g[12..16].copy_from_slice(b"NULL");
g[32] = u8::from(is_server);
writer.write_all(&g).await
}
async fn read_greeting<R: AsyncReadExt + Unpin>(reader: &mut R) -> io::Result<[u8; GREETING_SIZE]> {
let mut g = [0u8; GREETING_SIZE];
reader.read_exact(&mut g).await?;
Ok(g)
}
const READY_BODY: &[u8] = &[
0x05, b'R', b'E', b'A', b'D', b'Y', 0x0b, b'S', b'o', b'c', b'k', b'e', b't', b'-', b'T', b'y', b'p', b'e', 0x00, 0x00,
0x00, 0x03, b'R', b'E', b'P',
];
async fn write_ready<W: AsyncWriteExt + Unpin>(writer: &mut W) -> io::Result<()> {
write_frame(writer, FLAG_COMMAND, READY_BODY).await
}
async fn read_ready<R: AsyncReadExt + Unpin>(reader: &mut R) -> io::Result<()> {
let (flags, _body) = read_frame(reader).await?;
if flags & FLAG_COMMAND == 0 {
return Err(io::Error::new(io::ErrorKind::InvalidData, "expected COMMAND frame for READY"));
}
Ok(())
}
pub struct Groot2Socket {
listener: TcpListener,
conn: Option<TcpStream>,
}
impl Groot2Socket {
pub async fn bind(port: u16) -> io::Result<Self> {
let listener = TcpListener::bind(("0.0.0.0", port)).await?;
Ok(Self { listener, conn: None })
}
pub fn local_addr(&self) -> io::Result<std::net::SocketAddr> {
self.listener.local_addr()
}
async fn connect(&mut self) -> io::Result<()> {
let (mut stream, _peer) = self.listener.accept().await?;
write_greeting(&mut stream, true).await?;
read_greeting(&mut stream).await?;
write_ready(&mut stream).await?;
read_ready(&mut stream).await?;
self.conn = Some(stream);
Ok(())
}
pub async fn recv(&mut self) -> io::Result<Groot2Message> {
if self.conn.is_none() {
self.connect().await?;
}
let stream = self
.conn
.as_mut()
.expect("conn is Some after connect");
match Self::read_message(stream).await {
Ok(msg) => Ok(msg),
Err(e) => {
self.conn = None;
Err(e)
}
}
}
async fn read_message(stream: &mut TcpStream) -> io::Result<Groot2Message> {
let mut frames: Vec<Bytes> = Vec::new();
loop {
let (flags, data) = read_frame(stream).await?;
frames.push(Bytes::from(data));
if flags & FLAG_MORE == 0 {
break;
}
}
if frames
.first()
.map(|f| f.is_empty())
.unwrap_or(false)
{
frames.remove(0);
}
Ok(Groot2Message(frames))
}
pub async fn send(&mut self, msg: Groot2Message) -> io::Result<()> {
let stream = self
.conn
.as_mut()
.ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "no active ZMTP connection"))?;
match Self::write_message(stream, msg).await {
Ok(()) => Ok(()),
Err(e) => {
self.conn = None;
Err(e)
}
}
}
async fn write_message(stream: &mut TcpStream, msg: Groot2Message) -> io::Result<()> {
debug_assert!(!msg.0.is_empty(), "ZMTP send called with zero-frame message");
write_frame(stream, FLAG_MORE, &[]).await?;
let frames = &msg.0;
for (i, frame) in frames.iter().enumerate() {
let flags = if i == frames.len() - 1 { 0x00 } else { FLAG_MORE };
write_frame(stream, flags, frame).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::{
FLAG_COMMAND, FLAG_LONG, FLAG_MORE, Groot2Message, Groot2Socket, read_frame, read_greeting, read_ready, write_frame,
write_greeting, write_ready,
};
use bytes::Bytes;
use std::vec;
use tokio::net::TcpStream;
#[test]
fn single_frame_message() {
let bytes = Bytes::from_static(b"hello");
let msg = Groot2Message::from(bytes.clone());
assert_eq!(msg.get(0), Some(&bytes));
assert_eq!(msg.get(1), None);
}
#[test]
fn push_back_adds_frames() {
let b1 = Bytes::from_static(b"frame1");
let b2 = Bytes::from_static(b"frame2");
let mut msg = Groot2Message::from(b1.clone());
msg.push_back(b2.clone());
assert_eq!(msg.get(0), Some(&b1));
assert_eq!(msg.get(1), Some(&b2));
assert_eq!(msg.get(2), None);
}
#[tokio::test]
async fn frame_roundtrip_short() {
let (mut w, mut r) = tokio::io::duplex(64);
write_frame(&mut w, 0x00, b"hello").await.unwrap();
let (flags, data) = read_frame(&mut r).await.unwrap();
assert_eq!(flags, 0x00);
assert_eq!(data, b"hello");
}
#[tokio::test]
async fn frame_roundtrip_with_more_flag() {
let (mut w, mut r) = tokio::io::duplex(64);
write_frame(&mut w, FLAG_MORE, b"part1")
.await
.unwrap();
let (flags, data) = read_frame(&mut r).await.unwrap();
assert_eq!(flags, FLAG_MORE);
assert_eq!(data, b"part1");
}
#[tokio::test]
async fn frame_roundtrip_long() {
let (mut w, mut r) = tokio::io::duplex(4096);
let payload = vec![0xab_u8; 256];
write_frame(&mut w, FLAG_MORE, &payload)
.await
.unwrap();
let (flags, data) = read_frame(&mut r).await.unwrap();
assert_eq!(flags, FLAG_MORE | FLAG_LONG);
assert_eq!(data, payload);
}
#[tokio::test]
async fn greeting_server_bytes_are_correct() {
let (mut w, mut r) = tokio::io::duplex(128);
write_greeting(&mut w, true).await.unwrap();
let g = read_greeting(&mut r).await.unwrap();
assert_eq!(g[0], 0xff, "signature[0]");
assert_eq!(g[9], 0x7f, "signature[9]");
assert_eq!(g[10], 0x03, "ZMTP major");
assert_eq!(g[11], 0x01, "ZMTP minor");
assert_eq!(&g[12..16], b"NULL", "mechanism");
assert_eq!(g[32], 0x01, "as-server=true");
assert_eq!(&g[33..64], &[0u8; 31], "filler is zero");
}
#[tokio::test]
async fn greeting_client_has_as_server_false() {
let (mut w, mut r) = tokio::io::duplex(128);
write_greeting(&mut w, false).await.unwrap();
let g = read_greeting(&mut r).await.unwrap();
assert_eq!(g[32], 0x00, "as-server=false");
}
#[tokio::test]
async fn ready_roundtrip() {
let (mut w, mut r) = tokio::io::duplex(64);
write_ready(&mut w).await.unwrap();
read_ready(&mut r).await.unwrap(); }
#[tokio::test]
async fn ready_rejects_non_command_frame() {
let (mut w, mut r) = tokio::io::duplex(64);
write_frame(&mut w, 0x00, b"\x05READY")
.await
.unwrap();
let err = read_ready(&mut r).await.unwrap_err();
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
}
#[tokio::test]
async fn rep_socket_full_roundtrip() {
let mut rep = Groot2Socket::bind(0).await.unwrap();
let port = rep.local_addr().unwrap().port();
let client = tokio::spawn(async move {
let mut stream = TcpStream::connect(("127.0.0.1", port))
.await
.unwrap();
write_greeting(&mut stream, false).await.unwrap();
read_greeting(&mut stream).await.unwrap();
write_frame(&mut stream, FLAG_COMMAND, b"\x05READY")
.await
.unwrap();
read_ready(&mut stream).await.unwrap();
write_frame(&mut stream, FLAG_MORE, &[])
.await
.unwrap();
write_frame(&mut stream, 0x00, b"ping")
.await
.unwrap();
let (_f0, _empty) = read_frame(&mut stream).await.unwrap();
let (_f1, reply) = read_frame(&mut stream).await.unwrap();
reply
});
let msg = rep.recv().await.unwrap();
assert_eq!(msg.get(0).unwrap().as_ref(), b"ping");
rep.send(Groot2Message::from(Bytes::from_static(b"pong")))
.await
.unwrap();
let reply_bytes = client.await.unwrap();
assert_eq!(reply_bytes, b"pong");
}
}