use crate::async_rt::notify::AsyncNotify;
use crate::async_rt::time::{DefaultClock, RuntimeClock};
use crate::codec::{CodecError, Message};
use crate::engine::registry::PeerKey;
use crate::engine::writer::VectoredWriter;
use crate::engine::{FlushState, Outbound, TaggedInboundTx};
use crate::io_compat::AsyncVectoredWrite;
#[cfg(feature = "curve")]
use crate::mechanism::{build_nonce, CurveSession};
#[cfg(feature = "curve")]
use crate::message::ZmqMessage;
#[cfg(feature = "curve")]
use bytes::Bytes;
#[cfg(feature = "curve")]
use crypto_box::aead::Aead;
use futures::channel::oneshot;
use futures::future::Shared;
use futures::{FutureExt, Stream, StreamExt};
use rand::RngExt;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use crate::engine::HeartbeatConfig;
pub(crate) type PeerWriterKind<W> = VectoredWriter<W>;
pub(crate) struct ConflateSlotInner {
pub slot: parking_lot::Mutex<Option<(PeerKey, crate::message::ZmqMessage)>>,
pub notify: Arc<crate::async_rt::notify::RuntimeNotify>,
}
pub(crate) type ConflateSlot = Arc<ConflateSlotInner>;
pub(crate) struct PeerConfig {
pub heartbeat: Option<HeartbeatConfig>,
#[cfg(feature = "curve")]
pub curve: Option<CurveSession>,
pub max_msg_size: Option<usize>,
pub conflate_slot: Option<ConflateSlot>,
pub out_batch_size: Option<usize>,
pub out_batch_msgs: Option<usize>,
pub in_batch_msgs: Option<usize>,
#[allow(clippy::option_option)]
pub inline_write_max: Option<Option<usize>>,
}
impl Default for PeerConfig {
fn default() -> Self {
Self {
heartbeat: None,
#[cfg(feature = "curve")]
curve: None,
max_msg_size: None,
conflate_slot: None,
out_batch_size: Some(8192),
out_batch_msgs: Some(32),
in_batch_msgs: None,
inline_write_max: None,
}
}
}
pub(crate) struct PeerChannels {
pub outbound_rx: flume::Receiver<Outbound>,
pub shared_inbound: TaggedInboundTx,
}
pub(crate) async fn peer_loop<R, W>(
read_half: R,
writer: PeerWriterKind<W>,
channels: PeerChannels,
peer_key: PeerKey,
flush_state: Arc<FlushState>,
shutdown: Shared<oneshot::Receiver<()>>,
config: PeerConfig,
) where
R: Stream<Item = Result<Message, CodecError>> + Unpin + Send + 'static,
W: AsyncVectoredWrite + Send + 'static,
{
let result = peer_loop_inner(
read_half,
writer,
channels,
peer_key,
&flush_state,
shutdown,
config,
)
.await;
let _ = result;
flush_state.writer_alive.store(false, Ordering::Release);
flush_state.notify_flush_waiters();
}
async fn peer_loop_inner<R, W>(
mut read_half: R,
mut writer: PeerWriterKind<W>,
channels: PeerChannels,
peer_key: PeerKey,
flush_state: &Arc<FlushState>,
shutdown: Shared<oneshot::Receiver<()>>,
config: PeerConfig,
) -> std::io::Result<()>
where
R: Stream<Item = Result<Message, CodecError>> + Unpin,
W: AsyncVectoredWrite + Send + Sync + 'static,
{
let PeerChannels {
outbound_rx,
shared_inbound,
} = channels;
let PeerConfig {
heartbeat: heartbeat_cfg,
#[cfg(feature = "curve")]
curve,
max_msg_size,
conflate_slot,
out_batch_size,
out_batch_msgs,
in_batch_msgs,
inline_write_max,
} = config;
let inline_enabled = inline_write_max.is_some();
let mut hb: Option<HeartbeatState<DefaultClock>> = heartbeat_cfg.map(HeartbeatState::new);
#[cfg(feature = "curve")]
let mut curve = curve;
let mut shutdown = shutdown.fuse();
#[cfg(feature = "curve")]
let mut curve_recv_buf: Vec<Bytes> = Vec::new();
loop {
crate::wake_counter::bump(&crate::wake_counter::PEER_LOOP_ITERS);
if inline_enabled {
writer.pull_inline_overflow();
}
if !writer.is_empty() {
let flushed = writer.flush_one_pass()?;
if flushed > 0 {
flush_state
.flushed
.fetch_add(flushed as u64, Ordering::Release);
flush_state.notify_flush_waiters();
}
}
let pending_drain = !writer.is_empty();
use futures::future::Either;
let writable_fut = if pending_drain {
Either::Left(writer.writable_owned())
} else {
Either::Right(std::future::pending::<std::io::Result<()>>())
};
let outbound_fut = if !pending_drain {
Either::Left(outbound_rx.recv_async())
} else {
Either::Right(std::future::pending::<
Result<crate::engine::Outbound, flume::RecvError>,
>())
};
let hb_sleep_fut = match hb {
Some(ref mut h) => Either::Left(h.next_ping_sleep()),
None => Either::Right(std::future::pending::<()>()),
};
let overflow_fut = if inline_enabled {
Either::Left(writer.overflow_notified())
} else {
Either::Right(std::future::pending::<()>())
};
futures::pin_mut!(writable_fut, outbound_fut, hb_sleep_fut, overflow_fut);
futures::select! {
ready = writable_fut.fuse() => {
crate::wake_counter::bump(&crate::wake_counter::PEER_LOOP_WRITABLE_WAKES);
ready?;
}
msg = read_half.next().fuse() => {
crate::wake_counter::bump(&crate::wake_counter::PEER_LOOP_READ_WAKES);
let budget = in_batch_msgs.unwrap_or(1).max(1);
let mut msg = msg;
let mut drained: usize = 0;
loop {
let should_continue = process_one_read_message(
msg,
&mut hb,
#[cfg(feature = "curve")]
&mut curve,
#[cfg(feature = "curve")]
&mut curve_recv_buf,
max_msg_size,
&shared_inbound,
conflate_slot.as_deref(),
peer_key,
&mut writer,
).await?;
if !should_continue {
return Ok(());
}
drained += 1;
if drained >= budget {
break;
}
match read_half.next().now_or_never() {
Some(next_msg) => msg = next_msg,
None => break,
}
}
}
item = outbound_fut.fuse() => {
crate::wake_counter::bump(&crate::wake_counter::PEER_LOOP_OUTBOUND_WAKES);
match item {
Ok(o) => {
if inline_enabled {
writer.mark_peer_loop_busy();
}
let msg = o.msg;
#[cfg(feature = "curve")]
{
if let Some(sess) = curve.as_mut() {
if let Message::Message(zm) = msg {
let n = zm.len();
for (i, frame) in zm.iter().enumerate() {
let more = i < n - 1;
let wire = curve_encrypt_frame(sess, frame, more)?;
writer.enqueue(Message::SecurityRaw(wire));
}
} else {
writer.enqueue(msg);
}
} else {
use crate::engine::writer::FastPath;
match writer.try_fast_path_single_frame(msg)? {
FastPath::Sent => {
flush_state.flushed.fetch_add(1, Ordering::Release);
flush_state.notify_flush_waiters();
}
FastPath::Enqueued => {}
FastPath::NotTaken(msg) => {
writer.enqueue(msg);
writer.drain_batch(&outbound_rx, out_batch_size, out_batch_msgs);
}
}
}
}
#[cfg(not(feature = "curve"))]
{
use crate::engine::writer::FastPath;
match writer.try_fast_path_single_frame(msg)? {
FastPath::Sent => {
flush_state.flushed.fetch_add(1, Ordering::Release);
flush_state.notify_flush_waiters();
}
FastPath::Enqueued => {}
FastPath::NotTaken(msg) => {
writer.enqueue(msg);
writer.drain_batch(&outbound_rx, out_batch_size, out_batch_msgs);
}
}
}
if inline_enabled {
writer.clear_peer_loop_busy();
}
}
Err(_) => return Ok(()),
}
}
_ = hb_sleep_fut.fuse() => {
if let Some(ref mut h) = hb {
match h.on_ping_tick(&mut writer) {
HeartbeatAction::Evict => return Ok(()),
HeartbeatAction::Continue => {}
}
}
}
_ = overflow_fut.fuse() => {
}
_ = shutdown => return Ok(()),
}
}
}
use crate::codec::HeartbeatFrame;
enum HeartbeatAction {
Continue,
Evict,
}
struct HeartbeatState<C: RuntimeClock> {
cfg: HeartbeatConfig,
next_ping_at: C::Instant,
pong_deadline: Option<C::Instant>,
ttl_tenths: u16,
}
impl<C: RuntimeClock> HeartbeatState<C> {
fn new(cfg: HeartbeatConfig) -> Self {
let ttl_tenths = (cfg.ttl.as_millis() / 100) as u16;
let next_ping_at = C::now() + cfg.interval;
Self {
cfg,
next_ping_at,
pong_deadline: None,
ttl_tenths,
}
}
fn next_ping_sleep(
&mut self,
) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>> {
let now = C::now();
let deadline = if let Some(pong_dl) = self.pong_deadline {
pong_dl.min(self.next_ping_at)
} else {
self.next_ping_at
};
C::sleep_until(deadline.max(now))
}
fn on_ping_tick<W: AsyncVectoredWrite>(
&mut self,
writer: &mut VectoredWriter<W>,
) -> HeartbeatAction {
let now = C::now();
if let Some(pong_dl) = self.pong_deadline {
if now >= pong_dl {
return HeartbeatAction::Evict;
}
}
if now >= self.next_ping_at {
let ctx_bytes: [u8; 8] = rand::rng().random();
let context = bytes::Bytes::copy_from_slice(&ctx_bytes);
writer.enqueue(Message::Heartbeat(HeartbeatFrame::Ping {
ttl_tenths: self.ttl_tenths,
context,
}));
self.pong_deadline = Some(now + self.cfg.timeout);
self.next_ping_at = now + self.cfg.interval;
}
HeartbeatAction::Continue
}
fn on_pong(&mut self) {
self.pong_deadline = None;
}
}
#[cfg(feature = "curve")]
fn curve_encrypt_frame(
sess: &mut CurveSession,
payload: &Bytes,
more: bool,
) -> std::io::Result<Bytes> {
use bytes::BufMut;
use crypto_box::aead::Aead;
let prefix: &[u8; 16] = if sess.is_server {
b"CurveZMQMESSAGES"
} else {
b"CurveZMQMESSAGEC"
};
let nonce_ctr = sess.tx_nonce;
sess.tx_nonce += 1;
let nonce = build_nonce(prefix, nonce_ctr);
let mut plain = bytes::BytesMut::with_capacity(1 + payload.len());
plain.put_u8(if more { 0x01u8 } else { 0x00u8 });
plain.extend_from_slice(payload);
let ciphertext = sess
.session_box
.encrypt(&nonce, plain.as_ref())
.map_err(|_e| std::io::Error::from(CodecError::CurveEncryptFailed))?;
let body = 8 + 8 + ciphertext.len(); let header_len = if body > 255 { 10 } else { 2 }; let mut wire = bytes::BytesMut::with_capacity(header_len + body);
if body > 255 {
wire.put_u8(if more { 0x03u8 } else { 0x02u8 }); wire.put_u64(body as u64);
} else {
wire.put_u8(if more { 0x01u8 } else { 0x00u8 }); wire.put_u8(body as u8);
}
wire.put_u8(7u8); wire.extend_from_slice(b"MESSAGE");
wire.put_u64(nonce_ctr);
wire.extend_from_slice(&ciphertext);
Ok(wire.freeze())
}
#[cfg(feature = "curve")]
fn curve_decrypt_message_frame(
sess: &mut CurveSession,
frame: Bytes,
) -> std::io::Result<(Bytes, bool)> {
const MIN_LEN: usize = 33;
if frame.len() < MIN_LEN {
return Err(std::io::Error::from(CodecError::MessageFrameTooShort));
}
if &frame[..8] != b"\x07MESSAGE" {
return Err(std::io::Error::from(CodecError::Decode(
"not a MESSAGE frame",
)));
}
let nonce_ctr = u64::from_be_bytes(frame[8..16].try_into().unwrap());
if nonce_ctr <= sess.rx_nonce {
return Err(std::io::Error::from(CodecError::CurveNonceOutOfOrder));
}
let prefix: &[u8; 16] = if sess.is_server {
b"CurveZMQMESSAGEC"
} else {
b"CurveZMQMESSAGES"
};
let nonce = build_nonce(prefix, nonce_ctr);
sess.rx_nonce = nonce_ctr;
let plain = sess
.session_box
.decrypt(&nonce, &frame[16..])
.map_err(|_e| std::io::Error::from(CodecError::CurveDecryptFailed))?;
if plain.is_empty() {
return Err(std::io::Error::from(CodecError::CurveEmptyPlaintext));
}
let more = plain[0] & 0x01 != 0;
let mut plain = Bytes::from(plain);
let _ = plain.split_to(1); Ok((plain, more))
}
#[allow(clippy::too_many_arguments)]
async fn process_one_read_message<W: AsyncVectoredWrite>(
msg: Option<Result<Message, CodecError>>,
hb: &mut Option<HeartbeatState<DefaultClock>>,
#[cfg(feature = "curve")] curve: &mut Option<CurveSession>,
#[cfg(feature = "curve")] curve_recv_buf: &mut Vec<bytes::Bytes>,
max_msg_size: Option<usize>,
shared_inbound: &TaggedInboundTx,
conflate_slot: Option<&ConflateSlotInner>,
peer_key: PeerKey,
writer: &mut PeerWriterKind<W>,
) -> std::io::Result<bool> {
{
let pong_received = matches!(
msg,
Some(Ok(Message::Heartbeat(
crate::codec::HeartbeatFrame::Pong { .. }
)))
);
if pong_received {
if let Some(ref mut h) = hb {
h.on_pong();
}
}
}
#[cfg(feature = "curve")]
let msg = match (curve, msg) {
(Some(sess), Some(Ok(Message::Message(ref zm))))
if zm
.get(0)
.is_some_and(|f| f.len() >= 8 && &f[..8] == b"\x07MESSAGE") =>
{
let mut emit = None;
'frames: for frame in zm.iter() {
if let Ok((payload, more)) = curve_decrypt_message_frame(sess, frame.clone()) {
curve_recv_buf.push(payload);
if !more {
let frames = std::mem::take(curve_recv_buf);
let mut out = ZmqMessage::from(frames[0].clone());
for f in &frames[1..] {
out.push_back(f.clone());
}
emit = Some(Ok(Message::Message(out)));
break 'frames;
}
} else {
curve_recv_buf.clear();
emit = Some(Err(CodecError::Decode("CURVE decrypt error")));
break 'frames;
}
}
match emit {
Some(result) => Some(result),
None => return Ok(true),
}
}
(_, other) => other,
};
let msg = if let (Some(max), Some(Ok(Message::Message(ref zm)))) = (max_msg_size, &msg) {
let total: usize = zm.iter().map(|f| f.len()).sum();
if total > max {
Some(Err(CodecError::Decode("message exceeds MAXMSGSIZE")))
} else {
msg
}
} else {
msg
};
handle_read(msg, shared_inbound, conflate_slot, peer_key, writer).await
}
async fn handle_read<W: AsyncVectoredWrite>(
msg: Option<Result<Message, CodecError>>,
shared_inbound: &TaggedInboundTx,
conflate_slot: Option<&ConflateSlotInner>,
peer_key: PeerKey,
writer: &mut PeerWriterKind<W>,
) -> std::io::Result<bool> {
match msg {
Some(Ok(frame)) => {
{
use crate::codec::HeartbeatFrame;
if let Message::Heartbeat(ref hb) = frame {
if let HeartbeatFrame::Ping { context, .. } = hb {
writer.enqueue(Message::Heartbeat(HeartbeatFrame::Pong {
context: context.clone(),
}));
}
return Ok(true);
}
}
if let Some(slot) = conflate_slot {
if let Message::Message(zm) = frame {
*slot.slot.lock() = Some((peer_key, zm));
slot.notify.notify_one();
}
} else if shared_inbound
.send_async((peer_key, Ok(frame)))
.await
.is_err()
{
return Ok(false);
}
}
Some(Err(e)) => {
if conflate_slot.is_some() {
return Ok(false);
}
let _ = shared_inbound.send_async((peer_key, Err(e))).await;
return Ok(false);
}
None => {
if conflate_slot.is_some() {
return Ok(false);
}
let _ = shared_inbound
.send_async((peer_key, Err(CodecError::PeerDisconnected)))
.await;
return Ok(false);
}
}
Ok(true)
}