#![allow(dead_code)]
#[cfg(feature = "inproc")]
use crate::engine::InprocEngine;
use crate::engine::PeerEngine;
use crate::PeerIdentity;
use parking_lot::RwLock;
use slab::Slab;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[cfg_attr(not(feature = "inproc"), repr(transparent))]
#[derive(Clone)]
pub(crate) enum AnyEngine {
Framed(Arc<PeerEngine>),
#[cfg(feature = "inproc")]
Inproc(Arc<InprocEngine>),
}
#[inline]
pub(crate) fn make_framed_engine(e: Arc<PeerEngine>) -> AnyEngine {
AnyEngine::Framed(e)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum TrySendOutcome {
Sent,
Full,
Closed,
}
impl AnyEngine {
#[inline]
pub(crate) fn writer_alive(&self) -> bool {
match self {
AnyEngine::Framed(e) => e.writer_alive(),
#[cfg(feature = "inproc")]
AnyEngine::Inproc(e) => e.writer_alive(),
}
}
#[inline]
pub(crate) async fn send_msg(
&self,
m: crate::message::ZmqMessage,
) -> Result<(), crate::error::SendError> {
use crate::codec::Message;
match self {
AnyEngine::Framed(e) => e.send(Message::Message(m)).await,
#[cfg(feature = "inproc")]
AnyEngine::Inproc(e) => e.send_direct(m),
}
}
#[inline]
pub(crate) async fn send_msg_flushed(
&self,
m: crate::message::ZmqMessage,
) -> Result<(), crate::error::SendError> {
use crate::codec::Message;
match self {
AnyEngine::Framed(e) => e.send_flushed(Message::Message(m)).await,
#[cfg(feature = "inproc")]
AnyEngine::Inproc(e) => e.send_direct(m),
}
}
#[inline]
pub(crate) fn try_send_fanout(
&self,
shared: std::sync::Arc<crate::message::ZmqMessage>,
) -> TrySendOutcome {
use crate::codec::Message;
match self {
AnyEngine::Framed(e) => {
if let Some(result) = e.try_inline_fanout(&shared) {
return match result {
Ok(()) => TrySendOutcome::Sent,
Err(_) => TrySendOutcome::Closed,
};
}
match e.try_send_fire_and_forget(Message::Shared(shared)) {
Ok(()) => TrySendOutcome::Sent,
Err(flume::TrySendError::Full(_)) => TrySendOutcome::Full,
Err(flume::TrySendError::Disconnected(_)) => TrySendOutcome::Closed,
}
}
#[cfg(feature = "inproc")]
AnyEngine::Inproc(e) => match e.try_send_direct((*shared).clone()) {
Ok(()) => TrySendOutcome::Sent,
Err(true) => TrySendOutcome::Full,
Err(false) => TrySendOutcome::Closed,
},
}
}
#[inline]
pub(crate) fn try_send_oneshot(&self, m: crate::message::ZmqMessage) -> bool {
use crate::codec::Message;
match self {
AnyEngine::Framed(e) => matches!(
e.try_send_tracked(Message::Message(m)),
Ok(_) | Err(flume::TrySendError::Full(_))
),
#[cfg(feature = "inproc")]
AnyEngine::Inproc(e) => matches!(e.try_send_direct(m), Ok(()) | Err(true)),
}
}
#[inline]
pub(crate) async fn drain_outbound(&self, timeout: Option<std::time::Duration>) {
match self {
AnyEngine::Framed(e) => e.drain_outbound(timeout).await,
#[cfg(feature = "inproc")]
AnyEngine::Inproc(_) => {}
}
}
}
pub(crate) type PeerKey = u32;
struct Entry {
id: PeerIdentity,
engine: AnyEngine,
}
struct Inner {
peers: Slab<Entry>,
id_to_key: HashMap<PeerIdentity, PeerKey>,
active: Vec<PeerKey>,
}
impl Inner {
fn new() -> Self {
Self {
peers: Slab::new(),
id_to_key: HashMap::new(),
active: Vec::new(),
}
}
#[cfg(debug_assertions)]
fn assert_invariants(&self) {
debug_assert_eq!(self.peers.len(), self.id_to_key.len());
debug_assert_eq!(self.peers.len(), self.active.len());
}
#[cfg(not(debug_assertions))]
fn assert_invariants(&self) {}
fn swap_remove_active(&mut self, key: PeerKey) {
if let Some(pos) = self.active.iter().position(|&k| k == key) {
self.active.swap_remove(pos);
}
}
}
pub(crate) struct PeerRegistry {
inner: RwLock<Inner>,
cursor: AtomicUsize,
}
impl PeerRegistry {
pub(crate) fn new() -> Self {
Self {
inner: RwLock::new(Inner::new()),
cursor: AtomicUsize::new(0),
}
}
pub(crate) fn insert_with<F>(
&self,
peer_id: PeerIdentity,
build: F,
) -> (PeerKey, Option<AnyEngine>)
where
F: FnOnce(PeerKey) -> AnyEngine,
{
let mut inner = self.inner.write();
if let Some(&existing_key) = inner.id_to_key.get(&peer_id) {
let new_engine = build(existing_key);
let prev = inner
.peers
.get_mut(existing_key as usize)
.map(|e| std::mem::replace(&mut e.engine, new_engine));
inner.assert_invariants();
return (existing_key, prev);
}
let entry = inner.peers.vacant_entry();
let key: PeerKey = entry
.key()
.try_into()
.expect("peer-registry slab key exceeds u32::MAX");
let engine = build(key);
entry.insert(Entry {
id: peer_id.clone(),
engine,
});
inner.id_to_key.insert(peer_id, key);
inner.active.push(key);
inner.assert_invariants();
(key, None)
}
pub(crate) fn remove_by_key(&self, key: PeerKey) -> Option<AnyEngine> {
let mut inner = self.inner.write();
let entry = inner.peers.try_remove(key as usize)?;
inner.id_to_key.remove(&entry.id);
inner.swap_remove_active(key);
inner.assert_invariants();
Some(entry.engine)
}
pub(crate) fn remove_by_id(&self, peer_id: &PeerIdentity) -> Option<(PeerKey, AnyEngine)> {
let mut inner = self.inner.write();
let key = inner.id_to_key.remove(peer_id)?;
let entry = inner
.peers
.try_remove(key as usize)
.expect("id_to_key points at a live slab entry");
inner.swap_remove_active(key);
inner.assert_invariants();
Some((key, entry.engine))
}
pub(crate) fn get_by_key(&self, key: PeerKey) -> Option<AnyEngine> {
self.inner
.read()
.peers
.get(key as usize)
.map(|e| e.engine.clone())
}
pub(crate) fn get_by_id(&self, peer_id: &PeerIdentity) -> Option<(PeerKey, AnyEngine)> {
let inner = self.inner.read();
let key = *inner.id_to_key.get(peer_id)?;
inner
.peers
.get(key as usize)
.map(|e| (key, e.engine.clone()))
}
pub(crate) fn id_for(&self, key: PeerKey) -> Option<PeerIdentity> {
self.inner
.read()
.peers
.get(key as usize)
.map(|e| e.id.clone())
}
pub(crate) fn key_for(&self, peer_id: &PeerIdentity) -> Option<PeerKey> {
self.inner.read().id_to_key.get(peer_id).copied()
}
pub(crate) fn is_empty(&self) -> bool {
self.inner.read().active.is_empty()
}
pub(crate) fn any_key(&self) -> Option<PeerKey> {
self.inner.read().active.first().copied()
}
pub(crate) fn len(&self) -> usize {
self.inner.read().active.len()
}
pub(crate) fn next_round_robin(&self) -> Option<(PeerKey, AnyEngine)> {
let inner = self.inner.read();
let n = inner.active.len();
if n == 0 {
return None;
}
let idx = self.cursor.fetch_add(1, Ordering::Relaxed);
let key = inner.active[idx % n];
inner
.peers
.get(key as usize)
.map(|e| (key, e.engine.clone()))
}
pub(crate) fn snapshot_into(&self, buf: &mut Vec<(PeerKey, AnyEngine)>) {
buf.clear();
let inner = self.inner.read();
buf.reserve(inner.active.len());
for &key in &inner.active {
if let Some(e) = inner.peers.get(key as usize) {
buf.push((key, e.engine.clone()));
}
}
}
#[cfg(feature = "inproc")]
pub(crate) fn replace_engine(&self, key: PeerKey, engine: AnyEngine) -> Option<AnyEngine> {
let mut inner = self.inner.write();
inner
.peers
.get_mut(key as usize)
.map(|e| std::mem::replace(&mut e.engine, engine))
}
#[cfg(feature = "inproc")]
pub(crate) fn rename_peer_id(&self, key: PeerKey, new_id: PeerIdentity) -> bool {
let mut inner = self.inner.write();
let Some(current_id) = inner.peers.get(key as usize).map(|e| e.id.clone()) else {
return false;
};
if current_id == new_id {
return true;
}
if let Some(&other_key) = inner.id_to_key.get(&new_id) {
if other_key != key {
return false;
}
}
if let Some(entry) = inner.peers.get_mut(key as usize) {
entry.id = new_id.clone();
}
inner.id_to_key.remove(¤t_id);
inner.id_to_key.insert(new_id, key);
inner.assert_invariants();
true
}
pub(crate) fn clear(&self) {
let mut inner = self.inner.write();
inner.peers.clear();
inner.id_to_key.clear();
inner.active.clear();
inner.assert_invariants();
}
}
impl Default for PeerRegistry {
fn default() -> Self {
Self::new()
}
}
pub(crate) async fn drain_registry(registry: &PeerRegistry, opts: &crate::SocketOptions) {
let linger = opts.linger;
let mut peers = Vec::new();
registry.snapshot_into(&mut peers);
if let Some(disc) = &opts.disconnect_msg {
for (_, engine) in &peers {
let _ = engine.try_send_oneshot(disc.clone());
}
}
if linger == Some(std::time::Duration::ZERO) {
return;
}
for (_, engine) in &peers {
engine.drain_outbound(linger).await;
}
}
#[cfg(all(test, feature = "tokio"))]
mod tests {
use super::*;
use crate::async_rt;
use crate::codec::{DefaultFramedIo as FramedIo, Message};
use crate::message::ZmqMessage;
use bytes::Bytes;
use tokio::net::{TcpListener, TcpStream};
async fn connected_pair() -> (FramedIo, FramedIo) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let connect_fut = TcpStream::connect(addr);
let (accept_res, connect_res) = futures::join!(listener.accept(), connect_fut);
let (server, _) = accept_res.unwrap();
let client = connect_res.unwrap();
let mut io_a = FramedIo::from_tcp(server);
let mut io_b = FramedIo::from_tcp(client);
let (greet_a, greet_b) = futures::join!(
crate::codec::handshake::greet_exchange(&mut io_a),
crate::codec::handshake::greet_exchange(&mut io_b),
);
greet_a.unwrap();
greet_b.unwrap();
(io_a, io_b)
}
fn spawn_engine(
key: PeerKey,
peer_id: PeerIdentity,
io: FramedIo,
inbound: crate::engine::TaggedInboundTx,
) -> AnyEngine {
#[cfg(feature = "curve")]
let (read, write, _) = io.into_parts();
#[cfg(not(feature = "curve"))]
let (read, write) = io.into_parts();
make_framed_engine(Arc::new(PeerEngine::spawn(
key,
peer_id,
read,
write.into_engine_writer(),
64,
inbound,
crate::engine::peer_loop::PeerConfig::default(),
)))
}
#[async_rt::test]
async fn round_robin_alternates() {
let registry = PeerRegistry::new();
let (io_a, _io_a_peer) = connected_pair().await;
let (io_b, _io_b_peer) = connected_pair().await;
let id_a = PeerIdentity::new();
let id_b = PeerIdentity::new();
let (dummy_tx, _dummy_rx) = flume::bounded(8);
let (key_a, _) = registry.insert_with(id_a.clone(), |k| {
spawn_engine(k, id_a.clone(), io_a, dummy_tx.clone())
});
let (key_b, _) = registry.insert_with(id_b.clone(), |k| {
spawn_engine(k, id_b.clone(), io_b, dummy_tx.clone())
});
let mut picks = Vec::new();
for _ in 0..4 {
picks.push(registry.next_round_robin().unwrap().0);
}
let a_count = picks.iter().filter(|&&k| k == key_a).count();
let b_count = picks.iter().filter(|&&k| k == key_b).count();
assert_eq!(a_count, 2);
assert_eq!(b_count, 2);
}
#[async_rt::test]
async fn snapshot_fanout() {
let registry = PeerRegistry::new();
let (io_a, mut far_a) = connected_pair().await;
let (io_b, mut far_b) = connected_pair().await;
let id_a = PeerIdentity::new();
let id_b = PeerIdentity::new();
let (dummy_tx, _dummy_rx) = flume::bounded(64);
registry.insert_with(id_a.clone(), |k| {
spawn_engine(k, id_a.clone(), io_a, dummy_tx.clone())
});
registry.insert_with(id_b.clone(), |k| {
spawn_engine(k, id_b.clone(), io_b, dummy_tx.clone())
});
let mut buf = Vec::new();
registry.snapshot_into(&mut buf);
assert_eq!(buf.len(), 2);
let msg = ZmqMessage::from(Bytes::from_static(b"fanout"));
let shared = Arc::new(msg);
for (_key, engine) in &buf {
match engine {
AnyEngine::Framed(e) => e.try_send(Message::Shared(shared.clone())).unwrap(),
#[cfg(feature = "inproc")]
AnyEngine::Inproc(_) => {}
}
}
drop(buf);
for far in [&mut far_a, &mut far_b] {
use futures::StreamExt;
let got = far.read_half.next().await.expect("closed").unwrap();
match got {
Message::Message(m) => {
assert_eq!(m.get(0).expect("frame").as_ref(), b"fanout");
}
other => panic!("unexpected variant: {:?}", other),
}
}
registry.clear();
}
#[async_rt::test]
async fn insert_returns_stable_key_and_id_map() {
let registry = PeerRegistry::new();
let (io, _far) = connected_pair().await;
let id = PeerIdentity::new();
let (dummy_tx, _dummy_rx) = flume::bounded(8);
let (key, prev) = registry.insert_with(id.clone(), |k| {
spawn_engine(k, id.clone(), io, dummy_tx.clone())
});
assert!(prev.is_none());
assert_eq!(registry.key_for(&id), Some(key));
assert_eq!(registry.id_for(key), Some(id));
}
#[async_rt::test]
async fn remove_by_key_clears_both_sides() {
let registry = PeerRegistry::new();
let (io, _far) = connected_pair().await;
let id = PeerIdentity::new();
let (dummy_tx, _dummy_rx) = flume::bounded(8);
let (key, _) = registry.insert_with(id.clone(), |k| {
spawn_engine(k, id.clone(), io, dummy_tx.clone())
});
assert!(registry.remove_by_key(key).is_some());
assert_eq!(registry.key_for(&id), None);
assert_eq!(registry.id_for(key), None);
assert!(registry.is_empty());
}
#[async_rt::test]
async fn remove_by_id_clears_both_sides() {
let registry = PeerRegistry::new();
let (io, _far) = connected_pair().await;
let id = PeerIdentity::new();
let (dummy_tx, _dummy_rx) = flume::bounded(8);
let (key, _) = registry.insert_with(id.clone(), |k| {
spawn_engine(k, id.clone(), io, dummy_tx.clone())
});
let (removed_key, _engine) = registry.remove_by_id(&id).expect("was inserted");
assert_eq!(removed_key, key);
assert_eq!(registry.key_for(&id), None);
assert_eq!(registry.id_for(key), None);
}
#[async_rt::test]
async fn duplicate_identity_reuses_key_replaces_engine() {
let registry = PeerRegistry::new();
let (io1, _far1) = connected_pair().await;
let (io2, _far2) = connected_pair().await;
let id = PeerIdentity::new();
let (dummy_tx, _dummy_rx) = flume::bounded(8);
let (key1, prev1) = registry.insert_with(id.clone(), |k| {
spawn_engine(k, id.clone(), io1, dummy_tx.clone())
});
assert!(prev1.is_none());
let (key2, prev2) = registry.insert_with(id.clone(), |k| {
spawn_engine(k, id.clone(), io2, dummy_tx.clone())
});
assert_eq!(key1, key2);
assert!(prev2.is_some(), "duplicate insert should return old engine");
assert_eq!(registry.key_for(&id), Some(key1));
assert_eq!(registry.len(), 1);
}
#[async_rt::test]
async fn slab_key_recycled_after_disconnect() {
let registry = PeerRegistry::new();
let (io_a, _far_a) = connected_pair().await;
let (io_b, _far_b) = connected_pair().await;
let id_a = PeerIdentity::new();
let id_b = PeerIdentity::new();
let (dummy_tx, _dummy_rx) = flume::bounded(8);
let (key_a, _) = registry.insert_with(id_a.clone(), |k| {
spawn_engine(k, id_a.clone(), io_a, dummy_tx.clone())
});
registry.remove_by_key(key_a);
let (key_b, _) = registry.insert_with(id_b.clone(), |k| {
spawn_engine(k, id_b.clone(), io_b, dummy_tx.clone())
});
assert_eq!(key_a, key_b);
assert_eq!(registry.id_for(key_b), Some(id_b));
}
#[async_rt::test]
async fn insert_with_closure_sees_returned_key() {
let registry = PeerRegistry::new();
let (io, _far) = connected_pair().await;
let id = PeerIdentity::new();
let (dummy_tx, _dummy_rx) = flume::bounded(8);
let captured: std::sync::Mutex<Option<PeerKey>> = std::sync::Mutex::new(None);
let (key, _) = registry.insert_with(id.clone(), |k| {
*captured.lock().unwrap() = Some(k);
spawn_engine(k, id.clone(), io, dummy_tx.clone())
});
assert_eq!(*captured.lock().unwrap(), Some(key));
}
#[async_rt::test]
async fn registry_invariant_after_mixed_mutations() {
let registry = PeerRegistry::new();
let (dummy_tx, _dummy_rx) = flume::bounded(8);
let mut ids_keys = Vec::new();
for _ in 0..6 {
let (io, _far) = connected_pair().await;
let id = PeerIdentity::new();
let (key, _) = registry.insert_with(id.clone(), |k| {
spawn_engine(k, id.clone(), io, dummy_tx.clone())
});
ids_keys.push((id, key));
}
for &(_, key) in ids_keys.iter().step_by(2) {
registry.remove_by_key(key);
}
assert_eq!(registry.len(), 3);
for (id, key) in ids_keys.iter().skip(1).step_by(2) {
assert_eq!(registry.key_for(id), Some(*key));
assert_eq!(registry.id_for(*key), Some(id.clone()));
}
registry.clear();
assert!(registry.is_empty());
}
}