use crate::codec::zmq_codec::encode_frame_header;
use crate::codec::Message;
use crate::error::CodecError;
use crate::io_compat::AsyncVectoredWrite;
use bytes::{Buf, Bytes};
use parking_lot::Mutex;
use smallvec::SmallVec;
use std::collections::VecDeque;
use std::io::{self, IoSlice};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
const IOV_CAP: usize = 64;
const PENDING_CAPACITY: usize = 16;
const FAST_PATH_MAX_PAYLOAD: usize = 65536;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
pub(crate) struct EngineWriteHalf<W>(W);
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
impl<W: AsyncVectoredWrite> EngineWriteHalf<W> {
pub(crate) fn new(inner: W) -> Self {
Self(inner)
}
#[inline]
fn try_write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.0.try_write_vectored(bufs)
}
async fn writable(&self) -> io::Result<()> {
self.0.writable().await
}
}
#[derive(Debug)]
struct HeaderBuf {
bytes: [u8; 9],
len: u8,
}
#[derive(Debug)]
struct PendingFrame {
header: HeaderBuf,
header_pos: u8,
payload: Bytes,
}
impl PendingFrame {
#[inline]
fn is_drained(&self) -> bool {
self.header_pos == self.header.len && self.payload.is_empty()
}
}
#[derive(Debug)]
pub(crate) struct PendingMsg {
frames: SmallVec<[PendingFrame; 2]>,
}
impl PendingMsg {
fn from_zmq_message(msg: &crate::ZmqMessage) -> Self {
let mut frames: SmallVec<[PendingFrame; 2]> = SmallVec::with_capacity(msg.len());
let last = msg.len() - 1;
for (idx, frame) in msg.iter().enumerate() {
let (header_bytes, header_len) = encode_frame_header(frame.len(), idx != last);
frames.push(PendingFrame {
header: HeaderBuf {
bytes: header_bytes,
len: header_len,
},
header_pos: 0,
payload: frame.clone(),
});
}
PendingMsg { frames }
}
fn from_single_frame_partial(
header: HeaderBuf,
payload: Bytes,
bytes_already_written: usize,
) -> Self {
let mut frame = PendingFrame {
header,
header_pos: 0,
payload,
};
let mut n = bytes_already_written;
let hdr_left = (frame.header.len - frame.header_pos) as usize;
let take = n.min(hdr_left);
frame.header_pos += take as u8;
n -= take;
if n > 0 {
frame.payload.advance(n.min(frame.payload.len()));
}
let mut frames: SmallVec<[PendingFrame; 2]> = SmallVec::new();
frames.push(frame);
PendingMsg { frames }
}
fn from_raw_bytes(payload: Bytes) -> Self {
let mut frames: SmallVec<[PendingFrame; 2]> = SmallVec::new();
frames.push(PendingFrame {
header: HeaderBuf {
bytes: [0; 9],
len: 0,
},
header_pos: 0,
payload,
});
PendingMsg { frames }
}
}
#[derive(Debug)]
pub(crate) enum FastPath {
Sent,
Enqueued,
NotTaken(Message),
}
#[cfg(all(
feature = "tokio",
any(feature = "tcp", all(feature = "ipc", target_family = "unix"))
))]
pub enum ZmqEngineWriteHalf {
#[cfg(feature = "tcp")]
Tcp(EngineWriteHalf<tokio::net::tcp::OwnedWriteHalf>),
#[cfg(all(feature = "ipc", feature = "tokio", target_family = "unix"))]
Ipc(EngineWriteHalf<tokio::net::unix::OwnedWriteHalf>),
}
#[cfg(all(
feature = "tokio",
any(feature = "tcp", all(feature = "ipc", target_family = "unix"))
))]
impl AsyncVectoredWrite for ZmqEngineWriteHalf {
#[inline]
fn try_write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
match self {
#[cfg(feature = "tcp")]
ZmqEngineWriteHalf::Tcp(h) => h.try_write_vectored(bufs),
#[cfg(all(feature = "ipc", feature = "tokio", target_family = "unix"))]
ZmqEngineWriteHalf::Ipc(h) => h.try_write_vectored(bufs),
}
}
async fn writable(&self) -> io::Result<()> {
match self {
#[cfg(feature = "tcp")]
ZmqEngineWriteHalf::Tcp(h) => h.writable().await,
#[cfg(all(feature = "ipc", feature = "tokio", target_family = "unix"))]
ZmqEngineWriteHalf::Ipc(h) => h.writable().await,
}
}
}
#[cfg(all(feature = "smol", not(feature = "tokio"), feature = "tcp"))]
pub struct SmolEngineWriteHalf(
pub(crate) EngineWriteHalf<std::sync::Arc<async_io::Async<std::net::TcpStream>>>,
);
#[cfg(all(feature = "smol", not(feature = "tokio"), feature = "tcp"))]
impl AsyncVectoredWrite for SmolEngineWriteHalf {
#[inline]
fn try_write_vectored(&self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.0.try_write_vectored(bufs)
}
async fn writable(&self) -> io::Result<()> {
self.0.writable().await
}
}
pub(crate) struct SharedHalf<W> {
half: W,
write_lock: Mutex<()>,
pending_len: AtomicUsize,
peer_loop_busy: std::sync::atomic::AtomicBool,
pending_overflow: Mutex<VecDeque<PendingMsg>>,
pending_overflow_len: AtomicUsize,
pub(crate) overflow_notify: crate::async_rt::notify::RuntimeNotify,
}
impl<W> SharedHalf<W> {
fn new(half: W) -> Self {
Self {
half,
write_lock: Mutex::new(()),
pending_len: AtomicUsize::new(0),
peer_loop_busy: std::sync::atomic::AtomicBool::new(false),
pending_overflow: Mutex::new(VecDeque::new()),
pending_overflow_len: AtomicUsize::new(0),
overflow_notify: crate::async_rt::notify::RuntimeNotify::new(),
}
}
#[inline]
pub(crate) fn has_overflow(&self) -> bool {
self.pending_overflow_len.load(Ordering::Acquire) > 0
}
pub(crate) fn drain_overflow_into(&self, dst: &mut VecDeque<PendingMsg>) {
let mut g = self.pending_overflow.lock();
while let Some(msg) = g.pop_back() {
dst.push_front(msg);
}
self.pending_overflow_len.store(0, Ordering::Release);
}
#[inline]
pub(crate) fn mark_peer_loop_busy(&self) {
self.peer_loop_busy.store(true, Ordering::Release);
}
#[inline]
pub(crate) fn clear_peer_loop_busy(&self) {
self.peer_loop_busy.store(false, Ordering::Release);
}
}
const FAST_PATH_MAX_FRAMES: usize = 4;
pub(crate) trait InlineWriteTarget: Send + Sync {
fn try_inline_single_frame(&self, payload: &[u8], cap: Option<usize>)
-> Option<io::Result<()>>;
fn try_inline_multi_frame(
&self,
frames: &[&[u8]],
cap: Option<usize>,
) -> Option<io::Result<()>>;
}
#[inline]
fn cap_exceeded(payload_len: usize, cap: Option<usize>) -> bool {
matches!(cap, Some(c) if payload_len >= c)
}
impl<W: AsyncVectoredWrite + Send + Sync + 'static> InlineWriteTarget for SharedHalf<W> {
fn try_inline_single_frame(
&self,
payload: &[u8],
cap: Option<usize>,
) -> Option<io::Result<()>> {
if cap_exceeded(payload.len(), cap) {
return None;
}
if self.pending_len.load(Ordering::Acquire) > 0 {
return None;
}
if self.peer_loop_busy.load(Ordering::Acquire) {
return None;
}
if self.has_overflow() {
return None;
}
let (hdr_bytes, hdr_len) = encode_frame_header(payload.len(), false);
let hdr_len_usize = hdr_len as usize;
let total = hdr_len_usize + payload.len();
let iovs: [IoSlice<'_>; 2] = [
IoSlice::new(&hdr_bytes[..hdr_len_usize]),
IoSlice::new(payload),
];
let _guard = self.write_lock.try_lock()?;
if self.pending_len.load(Ordering::Acquire) > 0
|| self.peer_loop_busy.load(Ordering::Acquire)
|| self.has_overflow()
{
return None;
}
match self.half.try_write_vectored(&iovs) {
Ok(n) if n == total => Some(Ok(())),
Ok(0) => Some(Err(io::Error::from(CodecError::WriteZero))),
Ok(n) => {
let header = HeaderBuf {
bytes: hdr_bytes,
len: hdr_len,
};
let payload_owned = Bytes::copy_from_slice(payload);
let pending = PendingMsg::from_single_frame_partial(header, payload_owned, n);
{
let mut g = self.pending_overflow.lock();
g.push_back(pending);
self.pending_overflow_len.store(g.len(), Ordering::Release);
}
use crate::async_rt::notify::AsyncNotify;
self.overflow_notify.notify_one();
Some(Ok(()))
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
Err(e) => Some(Err(e)),
}
}
fn try_inline_multi_frame(
&self,
frames: &[&[u8]],
cap: Option<usize>,
) -> Option<io::Result<()>> {
let n_frames = frames.len();
if !(1..=FAST_PATH_MAX_FRAMES).contains(&n_frames) {
return None;
}
let payload_total: usize = frames.iter().map(|f| f.len()).sum();
if cap_exceeded(payload_total, cap) {
return None;
}
if self.pending_len.load(Ordering::Acquire) > 0 {
return None;
}
if self.peer_loop_busy.load(Ordering::Acquire) {
return None;
}
let mut hdr_bufs: [[u8; 9]; FAST_PATH_MAX_FRAMES] = [[0u8; 9]; FAST_PATH_MAX_FRAMES];
let mut hdr_lens: [u8; FAST_PATH_MAX_FRAMES] = [0u8; FAST_PATH_MAX_FRAMES];
let mut total: usize = 0;
for (i, frame) in frames.iter().enumerate() {
let more = i + 1 < n_frames;
let (buf, len) = encode_frame_header(frame.len(), more);
hdr_bufs[i] = buf;
hdr_lens[i] = len;
total += len as usize + frame.len();
}
let mut iov_storage: [IoSlice<'_>; FAST_PATH_MAX_FRAMES * 2] = [
IoSlice::new(&[]),
IoSlice::new(&[]),
IoSlice::new(&[]),
IoSlice::new(&[]),
IoSlice::new(&[]),
IoSlice::new(&[]),
IoSlice::new(&[]),
IoSlice::new(&[]),
];
for (i, frame) in frames.iter().enumerate() {
let hdr_len_usize = hdr_lens[i] as usize;
iov_storage[i * 2] = IoSlice::new(&hdr_bufs[i][..hdr_len_usize]);
iov_storage[i * 2 + 1] = IoSlice::new(frame);
}
let iovs = &iov_storage[..n_frames * 2];
let _guard = self.write_lock.try_lock()?;
if self.pending_len.load(Ordering::Acquire) > 0
|| self.peer_loop_busy.load(Ordering::Acquire)
{
return None;
}
match self.half.try_write_vectored(iovs) {
Ok(n) if n == total => Some(Ok(())),
Ok(0) => Some(Err(io::Error::from(CodecError::WriteZero))),
Ok(_partial) => {
Some(Err(io::Error::from(CodecError::WriteZero)))
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
Err(e) => Some(Err(e)),
}
}
}
pub(crate) struct VectoredWriter<W> {
shared: Arc<SharedHalf<W>>,
pending: VecDeque<PendingMsg>,
}
impl<W: AsyncVectoredWrite + Send + Sync + 'static> VectoredWriter<W> {
pub(crate) fn inline_write_target(&self) -> Arc<dyn InlineWriteTarget> {
self.shared.clone()
}
}
impl<W: AsyncVectoredWrite> VectoredWriter<W> {
pub(crate) fn new(half: W) -> Self {
VectoredWriter {
shared: Arc::new(SharedHalf::new(half)),
pending: VecDeque::with_capacity(PENDING_CAPACITY),
}
}
pub(crate) fn pull_inline_overflow(&mut self) {
if self.shared.has_overflow() {
self.shared.drain_overflow_into(&mut self.pending);
self.shared
.pending_len
.store(self.pending.len(), Ordering::Release);
}
}
pub(crate) fn overflow_notified(&self) -> impl std::future::Future<Output = ()> + Send + 'static
where
W: Send + Sync + 'static,
{
use crate::async_rt::notify::AsyncNotify;
let shared = self.shared.clone();
async move { shared.overflow_notify.notified().await }
}
#[inline]
pub(crate) fn mark_peer_loop_busy(&self) {
self.shared.mark_peer_loop_busy();
}
#[inline]
pub(crate) fn clear_peer_loop_busy(&self) {
self.shared.clear_peer_loop_busy();
}
pub(crate) fn enqueue(&mut self, msg: Message) {
match msg {
Message::Message(m) => self.pending.push_back(PendingMsg::from_zmq_message(&m)),
Message::Shared(arc) => self
.pending
.push_back(PendingMsg::from_zmq_message(arc.as_ref())),
Message::Greeting(g) => {
let mut buf = bytes::BytesMut::new();
buf.unsplit(g.into());
self.pending
.push_back(PendingMsg::from_raw_bytes(buf.freeze()));
}
Message::Command(c) => {
let mut buf = bytes::BytesMut::new();
buf.unsplit(c.into());
self.pending
.push_back(PendingMsg::from_raw_bytes(buf.freeze()));
}
Message::Heartbeat(hb) => {
let encoded: bytes::BytesMut = hb.into();
self.pending
.push_back(PendingMsg::from_raw_bytes(encoded.freeze()));
}
Message::SecurityRaw(raw) => {
self.pending.push_back(PendingMsg::from_raw_bytes(raw));
}
}
self.shared
.pending_len
.store(self.pending.len(), Ordering::Release);
}
pub(crate) fn is_empty(&self) -> bool {
self.pending.is_empty()
}
#[inline]
pub(crate) fn try_fast_path_single_frame(&mut self, msg: Message) -> io::Result<FastPath> {
debug_assert!(
self.pending.is_empty(),
"fast path requires empty pending queue",
);
let payload_slice: &[u8] = match &msg {
Message::Message(m) if m.len() == 1 => {
let f = m.get(0).expect("len==1");
if f.len() >= FAST_PATH_MAX_PAYLOAD {
return Ok(FastPath::NotTaken(msg));
}
f.as_ref()
}
Message::Shared(arc) if arc.len() == 1 => {
let f = arc.get(0).expect("len==1");
if f.len() >= FAST_PATH_MAX_PAYLOAD {
return Ok(FastPath::NotTaken(msg));
}
f.as_ref()
}
_ => return Ok(FastPath::NotTaken(msg)),
};
let payload_len = payload_slice.len();
let (hdr_bytes, hdr_len) = encode_frame_header(payload_len, false);
let hdr_len_usize = hdr_len as usize;
let total = hdr_len_usize + payload_len;
let iovs: [IoSlice<'_>; 2] = [
IoSlice::new(&hdr_bytes[..hdr_len_usize]),
IoSlice::new(payload_slice),
];
let write_result = {
let _g = self.shared.write_lock.lock();
self.shared.half.try_write_vectored(&iovs)
};
match write_result {
Ok(0) => Err(io::Error::from(CodecError::WriteZero)),
Ok(n) if n == total => Ok(FastPath::Sent),
Ok(n) => {
let zmsg = match msg {
Message::Message(m) => m,
Message::Shared(arc) => (*arc).clone(),
_ => unreachable!(),
};
self.pending.push_back(PendingMsg::from_zmq_message(&zmsg));
self.shared
.pending_len
.store(self.pending.len(), Ordering::Release);
let _ = self.advance(n);
Ok(FastPath::Enqueued)
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Ok(FastPath::NotTaken(msg)),
Err(e) => Err(e),
}
}
pub(crate) fn drain_batch(
&mut self,
outbound: &flume::Receiver<crate::engine::Outbound>,
max_bytes: Option<usize>,
max_msgs: Option<usize>,
) -> usize {
use flume::TryRecvError;
let mut count = 0;
let mut accumulated: usize = 0;
loop {
if count > 0 {
let byte_limit_hit = max_bytes.is_some_and(|m| accumulated >= m);
let msg_limit_hit = max_msgs.is_some_and(|m| count >= m);
if byte_limit_hit || msg_limit_hit {
break;
}
}
match outbound.try_recv() {
Ok(o) => {
accumulated += msg_payload_size(&o.msg);
self.enqueue(o.msg);
count += 1;
}
Err(TryRecvError::Empty | TryRecvError::Disconnected) => break,
}
}
count
}
pub(crate) fn flush_one_pass(&mut self) -> std::io::Result<usize> {
if self.pending.is_empty() {
return Ok(0);
}
let iovs: smallvec::SmallVec<[std::io::IoSlice<'_>; IOV_CAP]> = build_iovs(&self.pending);
let write_result = {
let _g = self.shared.write_lock.lock();
self.shared.half.try_write_vectored(&iovs)
};
match write_result {
Ok(0) => Err(io::Error::from(CodecError::WriteZero)),
Ok(n) => {
drop(iovs);
Ok(self.advance(n) as usize)
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
drop(iovs);
Ok(0)
}
Err(e) => Err(e),
}
}
pub(crate) fn writable_owned(
&self,
) -> impl std::future::Future<Output = std::io::Result<()>> + Send + 'static
where
W: Send + Sync + 'static,
{
let shared = self.shared.clone();
async move { shared.half.writable().await }
}
#[cfg(all(test, feature = "tokio"))]
pub(crate) async fn flush_all(&mut self) -> io::Result<u64> {
let mut whole: u64 = 0;
while !self.pending.is_empty() {
let iovs: SmallVec<[IoSlice<'_>; IOV_CAP]> = build_iovs(&self.pending);
debug_assert!(
!iovs.is_empty(),
"non-empty pending must yield at least one slice"
);
let write_result = {
let _g = self.shared.write_lock.lock();
self.shared.half.try_write_vectored(&iovs)
};
match write_result {
Ok(0) => {
return Err(io::Error::from(CodecError::WriteZero));
}
Ok(n) => {
drop(iovs);
whole += self.advance(n);
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
drop(iovs);
self.shared.half.writable().await?;
}
Err(e) => return Err(e),
}
}
Ok(whole)
}
fn advance(&mut self, mut n: usize) -> u64 {
let mut whole: u64 = 0;
while n > 0 {
let front = self
.pending
.front_mut()
.expect("advance called with pending empty");
while n > 0 && !front.frames.is_empty() {
let frame = &mut front.frames[0];
let hdr_left = (frame.header.len - frame.header_pos) as usize;
if hdr_left > 0 {
let take = n.min(hdr_left);
frame.header_pos += take as u8;
n -= take;
if n == 0 {
break;
}
}
let pay_left = frame.payload.len();
if pay_left > 0 {
let take = n.min(pay_left);
frame.payload.advance(take);
n -= take;
}
if frame.is_drained() {
front.frames.remove(0);
} else {
debug_assert_eq!(n, 0, "partial frame must leave n == 0");
break;
}
}
if front.frames.is_empty() {
self.pending.pop_front();
whole += 1;
}
}
self.shared
.pending_len
.store(self.pending.len(), Ordering::Release);
whole
}
}
fn msg_payload_size(msg: &Message) -> usize {
match msg {
Message::Message(m) => m.iter().map(|f| f.len()).sum(),
Message::Shared(m) => m.iter().map(|f| f.len()).sum(),
Message::Greeting(_) | Message::Command(_) | Message::Heartbeat(_) => 0,
Message::SecurityRaw(b) => b.len(),
}
}
fn build_iovs(pending: &VecDeque<PendingMsg>) -> SmallVec<[IoSlice<'_>; IOV_CAP]> {
let mut iovs: SmallVec<[IoSlice<'_>; IOV_CAP]> = SmallVec::new();
'outer: for msg in pending.iter() {
for frame in &msg.frames {
let hdr_start = frame.header_pos as usize;
let hdr_end = frame.header.len as usize;
if hdr_start < hdr_end {
if iovs.len() == IOV_CAP {
break 'outer;
}
iovs.push(IoSlice::new(&frame.header.bytes[hdr_start..hdr_end]));
}
if !frame.payload.is_empty() {
if iovs.len() == IOV_CAP {
break 'outer;
}
iovs.push(IoSlice::new(frame.payload.as_ref()));
}
}
}
iovs
}
#[cfg(all(test, feature = "tokio"))]
mod tests {
use super::*;
use crate::codec::Message;
use crate::message::ZmqMessage;
use bytes::Bytes;
use futures::StreamExt;
use tokio::net::{TcpListener, TcpStream};
fn engine_half_from_tcp(tcp: TcpStream) -> VectoredWriter<tokio::net::tcp::OwnedWriteHalf> {
let (_r, w) = tcp.into_split();
VectoredWriter::new(w)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn vectored_writer_roundtrip() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = TcpStream::connect(addr);
let (accept_res, connect_res) = futures::join!(listener.accept(), connect_fut);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
let mut writer = engine_half_from_tcp(server);
use crate::codec::ZmqCodec;
let (read_half, _write_half) = client.into_split();
let codec = ZmqCodec::post_greeting();
let mut reader = tokio_util::codec::FramedRead::new(read_half, codec);
for i in 0..1000u32 {
let msg = ZmqMessage::from(Bytes::from(i.to_be_bytes().to_vec()));
writer.enqueue(Message::Message(msg));
}
let flushed = writer.flush_all().await.unwrap();
assert_eq!(flushed, 1000);
assert!(writer.is_empty());
for i in 0..1000u32 {
match reader.next().await.expect("stream closed").unwrap() {
Message::Message(m) => {
let frame = m.get(0).expect("frame").clone();
assert_eq!(&frame[..], &i.to_be_bytes()[..]);
}
other => panic!("unexpected variant: {:?}", other),
}
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn vectored_writer_multiframe_atomicity() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = TcpStream::connect(addr);
let (accept_res, connect_res) = futures::join!(listener.accept(), connect_fut);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
let mut writer = engine_half_from_tcp(server);
use crate::codec::ZmqCodec;
let (read_half, _w) = client.into_split();
let codec = ZmqCodec::post_greeting();
let mut reader = tokio_util::codec::FramedRead::new(read_half, codec);
let sizes = [1usize, 256, 4096, 65_536];
for i in 0..10u32 {
let mut m = ZmqMessage::from(Bytes::from(format!("start-{}", i).into_bytes()));
for (idx, &sz) in sizes.iter().enumerate() {
if idx == 0 {
continue;
}
m.push_back(Bytes::from(vec![idx as u8 ^ i as u8; sz]));
}
assert_eq!(m.len(), 4);
writer.enqueue(Message::Message(m));
}
let reader_task = tokio::spawn(async move {
let mut received = Vec::with_capacity(10);
for _ in 0..10u32 {
match reader.next().await.expect("closed").unwrap() {
Message::Message(m) => received.push(m),
other => panic!("unexpected variant: {:?}", other),
}
}
received
});
let flushed = writer.flush_all().await.unwrap();
assert_eq!(flushed, 10);
let received = reader_task.await.unwrap();
for (i, m) in received.into_iter().enumerate() {
assert_eq!(m.len(), 4, "message {} frame count", i);
assert_eq!(
m.get(0).unwrap().as_ref(),
format!("start-{}", i).as_bytes()
);
assert_eq!(m.get(1).unwrap().len(), 256);
assert_eq!(m.get(2).unwrap().len(), 4096);
assert_eq!(m.get(3).unwrap().len(), 65_536);
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn vectored_writer_partial_write() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = TcpStream::connect(addr);
let (accept_res, connect_res) = futures::join!(listener.accept(), connect_fut);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
server.set_nodelay(true).ok();
client.set_nodelay(true).ok();
let mut writer = engine_half_from_tcp(server);
use crate::codec::ZmqCodec;
let (read_half, _w) = client.into_split();
let codec = ZmqCodec::post_greeting();
let mut reader = tokio_util::codec::FramedRead::new(read_half, codec);
let payload = Bytes::from(vec![0xa5u8; 128 * 1024]);
let msg = ZmqMessage::from(payload.clone());
writer.enqueue(Message::Message(msg));
let (write_res, read_res) = tokio::join!(writer.flush_all(), reader.next());
assert_eq!(write_res.unwrap(), 1);
match read_res.expect("closed").unwrap() {
Message::Message(m) => {
assert_eq!(m.len(), 1);
assert_eq!(m.get(0).unwrap().as_ref(), payload.as_ref());
}
other => panic!("unexpected variant: {:?}", other),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn vectored_writer_interleaved_partial() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = TcpStream::connect(addr);
let (accept_res, connect_res) = futures::join!(listener.accept(), connect_fut);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
let mut writer = engine_half_from_tcp(server);
use crate::codec::ZmqCodec;
let (read_half, _w) = client.into_split();
let codec = ZmqCodec::post_greeting();
let mut reader = tokio_util::codec::FramedRead::new(read_half, codec);
for i in 0u32..3 {
let payload = Bytes::from(vec![i as u8 + 1; 64 * 1024]);
writer.enqueue(Message::Message(ZmqMessage::from(payload)));
}
let (write_res, r0) = tokio::join!(writer.flush_all(), reader.next());
assert_eq!(write_res.unwrap(), 3);
let mut seen = Vec::new();
match r0.unwrap().unwrap() {
Message::Message(m) => seen.push(m.get(0).unwrap().clone()),
_ => unreachable!(),
}
for _ in 0..2 {
match reader.next().await.unwrap().unwrap() {
Message::Message(m) => seen.push(m.get(0).unwrap().clone()),
_ => unreachable!(),
}
}
for (i, frame) in seen.iter().enumerate() {
assert_eq!(frame.len(), 64 * 1024);
assert!(frame.iter().all(|&b| b == (i as u8 + 1)));
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn vectored_writer_peer_close_mid_flush() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = TcpStream::connect(addr);
let (accept_res, connect_res) = futures::join!(listener.accept(), connect_fut);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
let mut writer = engine_half_from_tcp(server);
drop(client);
for _ in 0..128 {
writer.enqueue(Message::Message(ZmqMessage::from(Bytes::from(vec![
0xau8;
4096
]))));
}
let res = writer.flush_all().await;
assert!(res.is_err(), "flush_all should surface peer-close error");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fast_path_single_small_frame_sent() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = TcpStream::connect(addr);
let (accept_res, connect_res) = futures::join!(listener.accept(), connect_fut);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
let mut writer = engine_half_from_tcp(server);
use crate::codec::ZmqCodec;
let (read_half, _w) = client.into_split();
let codec = ZmqCodec::post_greeting();
let mut reader = tokio_util::codec::FramedRead::new(read_half, codec);
writer.writable_owned().await.unwrap();
let payload = Bytes::from(vec![0x77u8; 16]);
let msg = Message::Message(ZmqMessage::from(payload.clone()));
match writer.try_fast_path_single_frame(msg).unwrap() {
FastPath::Sent => {}
other => panic!("expected Sent, got {:?}", other),
}
assert!(writer.is_empty(), "fast path must not queue");
match reader.next().await.expect("closed").unwrap() {
Message::Message(m) => {
assert_eq!(m.len(), 1);
assert_eq!(m.get(0).unwrap().as_ref(), payload.as_ref());
}
other => panic!("unexpected variant: {:?}", other),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fast_path_threshold_boundary() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (accept_res, connect_res) =
futures::join!(listener.accept(), TcpStream::connect(addr),);
let (server, _) = accept_res.unwrap();
let _client = connect_res.unwrap();
let mut writer = engine_half_from_tcp(server);
writer.writable_owned().await.unwrap();
let ok_msg = Message::Message(ZmqMessage::from(Bytes::from(vec![
0x11u8;
FAST_PATH_MAX_PAYLOAD - 1
])));
match writer.try_fast_path_single_frame(ok_msg).unwrap() {
FastPath::Sent | FastPath::Enqueued => {}
other @ FastPath::NotTaken(_) => {
panic!("expected Sent/Enqueued at cap-1 B, got {:?}", other)
}
}
let _ = writer.flush_all().await;
let big_msg = Message::Message(ZmqMessage::from(Bytes::from(vec![
0x22u8;
FAST_PATH_MAX_PAYLOAD
])));
match writer.try_fast_path_single_frame(big_msg).unwrap() {
FastPath::NotTaken(_) => {}
other => panic!("expected NotTaken at cap, got {:?}", other),
}
assert!(writer.is_empty(), "NotTaken must not touch queue");
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fast_path_skipped_for_multiframe() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (accept_res, connect_res) =
futures::join!(listener.accept(), TcpStream::connect(addr),);
let (server, _) = accept_res.unwrap();
let _client = connect_res.unwrap();
let mut writer = engine_half_from_tcp(server);
let mut m = ZmqMessage::from(Bytes::from(vec![0xaau8; 16]));
m.push_back(Bytes::from(vec![0xbbu8; 16]));
assert_eq!(m.len(), 2);
let msg = Message::Message(m);
match writer.try_fast_path_single_frame(msg).unwrap() {
FastPath::NotTaken(_) => {}
_ => panic!("2-frame must not fast-path"),
}
assert!(writer.is_empty());
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fast_path_shared_variant_sent() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (accept_res, connect_res) =
futures::join!(listener.accept(), TcpStream::connect(addr),);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
let mut writer = engine_half_from_tcp(server);
use crate::codec::ZmqCodec;
let (read_half, _w) = client.into_split();
let codec = ZmqCodec::post_greeting();
let mut reader = tokio_util::codec::FramedRead::new(read_half, codec);
writer.writable_owned().await.unwrap();
let payload = Bytes::from(vec![0x3cu8; 32]);
let arc = std::sync::Arc::new(ZmqMessage::from(payload.clone()));
let msg = Message::Shared(arc);
assert!(matches!(
writer.try_fast_path_single_frame(msg).unwrap(),
FastPath::Sent
));
match reader.next().await.expect("closed").unwrap() {
Message::Message(m) => {
assert_eq!(m.get(0).unwrap().as_ref(), payload.as_ref());
}
other => panic!("unexpected: {:?}", other),
}
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn fast_path_peer_close_errors() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (accept_res, connect_res) =
futures::join!(listener.accept(), TcpStream::connect(addr),);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
drop(client);
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let mut writer = engine_half_from_tcp(server);
let _ = writer.try_fast_path_single_frame(Message::Message(ZmqMessage::from(Bytes::from(
vec![0u8; 8],
))));
let mut saw_err = false;
for _ in 0..64 {
if writer
.try_fast_path_single_frame(Message::Message(ZmqMessage::from(Bytes::from(
vec![0u8; 8],
))))
.is_err()
{
saw_err = true;
break;
}
tokio::time::sleep(std::time::Duration::from_millis(5)).await;
}
assert!(saw_err, "expected I/O error once peer-close is detected");
}
}