use crate::codec::{CodecError, FramedIo, IntoEngineWriter, Message};
use crate::endpoint::Endpoint;
use crate::engine::peer_loop::{ConflateSlot, ConflateSlotInner};
#[cfg(feature = "inproc")]
use crate::engine::registry::AnyEngine;
use crate::engine::registry::{make_framed_engine, PeerKey, PeerRegistry};
use crate::engine::PeerEngine;
#[cfg(feature = "inproc")]
use crate::engine::{connect_inproc_engine, inproc_placeholder_engine, InprocInboundTx};
use crate::error::SendError;
use crate::PeerIdentity;
use crate::{
MultiPeerBackend, SocketBackend, SocketEvent, SocketOptions, SocketType, ZmqError, ZmqMessage,
ZmqResult,
};
use crate::async_rt::notify::AsyncNotify;
use flume::{Receiver, Sender};
use futures::channel::mpsc;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::sync::Arc;
pub(crate) type DisconnectNotifier = mpsc::Sender<PeerIdentity>;
#[doc(hidden)]
pub struct GenericSocketBackend {
registry: PeerRegistry,
inbound_tx: Sender<(
crate::engine::registry::PeerKey,
Result<Message, CodecError>,
)>,
inbound_rx: Receiver<(
crate::engine::registry::PeerKey,
Result<Message, CodecError>,
)>,
#[cfg(feature = "inproc")]
inproc_inbound_tx: InprocInboundTx,
#[cfg(feature = "inproc")]
pub(crate) inproc_inbound_rx: crate::engine::InprocInboundRx,
#[cfg(feature = "inproc")]
pub(crate) inproc_notify: Arc<crate::async_rt::notify::RuntimeNotify>,
socket_type: SocketType,
socket_options: SocketOptions,
pub(crate) socket_monitor: Mutex<Option<mpsc::Sender<SocketEvent>>>,
disconnect_notifiers: Mutex<HashMap<PeerIdentity, DisconnectNotifier>>,
conflate_slots: Mutex<HashMap<PeerKey, ConflateSlot>>,
pub(crate) conflate_notify: Arc<crate::async_rt::notify::RuntimeNotify>,
}
impl GenericSocketBackend {
pub(crate) fn with_options(socket_type: SocketType, options: SocketOptions) -> Self {
let (inbound_tx, inbound_rx) = flume::bounded(options.receive_hwm);
#[cfg(feature = "inproc")]
let (inproc_inbound_tx, inproc_inbound_rx) =
crossbeam_channel::bounded(options.receive_hwm);
#[cfg(feature = "inproc")]
let inproc_notify = Arc::new(crate::async_rt::notify::RuntimeNotify::new());
Self {
registry: PeerRegistry::new(),
inbound_tx,
inbound_rx,
#[cfg(feature = "inproc")]
inproc_inbound_tx,
#[cfg(feature = "inproc")]
inproc_inbound_rx,
#[cfg(feature = "inproc")]
inproc_notify,
socket_type,
socket_options: options,
socket_monitor: Mutex::new(None),
disconnect_notifiers: Mutex::new(HashMap::new()),
conflate_slots: Mutex::new(HashMap::new()),
conflate_notify: Arc::new(crate::async_rt::notify::RuntimeNotify::new()),
}
}
fn build_heartbeat_cfg(opts: &SocketOptions) -> Option<crate::engine::HeartbeatConfig> {
Some(crate::engine::HeartbeatConfig {
interval: opts.heartbeat_interval?,
timeout: opts.heartbeat_timeout?,
ttl: opts.heartbeat_ttl?,
})
}
pub(crate) fn register_disconnect_notifier(
&self,
peer_id: PeerIdentity,
notifier: DisconnectNotifier,
) {
self.disconnect_notifiers.lock().insert(peer_id, notifier);
}
pub(crate) fn inbound(
&self,
) -> Receiver<(
crate::engine::registry::PeerKey,
Result<Message, CodecError>,
)> {
self.inbound_rx.clone()
}
pub(crate) fn registry(&self) -> &PeerRegistry {
&self.registry
}
#[cfg(feature = "inproc")]
pub(crate) async fn recv_next_timed<B: MultiPeerBackend + HasRegistry + HasInproc + ?Sized>(
inbound: &Receiver<TaggedMsg>,
backend: &B,
receive_timeout: Option<std::time::Duration>,
) -> ZmqResult<(crate::engine::registry::PeerKey, crate::message::ZmqMessage)> {
match receive_timeout {
None => recv_next(inbound, backend).await,
Some(d) => crate::async_rt::task::timeout(d, recv_next(inbound, backend))
.await
.map_err(|_e| ZmqError::NoMessage)?,
}
}
#[cfg(not(feature = "inproc"))]
pub(crate) async fn recv_next_timed<B: MultiPeerBackend + HasRegistry + ?Sized>(
inbound: &Receiver<TaggedMsg>,
backend: &B,
receive_timeout: Option<std::time::Duration>,
) -> ZmqResult<(crate::engine::registry::PeerKey, crate::message::ZmqMessage)> {
match receive_timeout {
None => recv_next(inbound, backend).await,
Some(d) => crate::async_rt::task::timeout(d, recv_next(inbound, backend))
.await
.map_err(|_e| ZmqError::NoMessage)?,
}
}
pub(crate) async fn recv_next_conflate_timed(
&self,
receive_timeout: Option<std::time::Duration>,
) -> ZmqResult<(crate::engine::registry::PeerKey, crate::message::ZmqMessage)> {
let fut = async {
loop {
let slots: Vec<ConflateSlot> =
self.conflate_slots.lock().values().cloned().collect();
for slot in &slots {
if let Some((key, msg)) = slot.slot.lock().take() {
return Ok((key, msg));
}
}
if slots.is_empty() && self.registry.is_empty() {
return Err(ZmqError::NoMessage);
}
self.conflate_notify.notified().await;
}
};
match receive_timeout {
None => fut.await,
Some(d) => crate::async_rt::task::timeout(d, fut)
.await
.map_err(|_e| ZmqError::NoMessage)?,
}
}
pub(crate) async fn recv_auto(
&self,
inbound: &Receiver<TaggedMsg>,
receive_timeout: Option<std::time::Duration>,
) -> ZmqResult<(crate::engine::registry::PeerKey, crate::message::ZmqMessage)> {
if self.socket_options.conflate {
return self.recv_next_conflate_timed(receive_timeout).await;
}
GenericSocketBackend::recv_next_timed(inbound, self, receive_timeout).await
}
pub(crate) async fn send_round_robin_timed(
&self,
message: crate::message::ZmqMessage,
send_timeout: Option<std::time::Duration>,
) -> ZmqResult<PeerIdentity> {
match send_timeout {
None => self.send_round_robin(message).await,
Some(d) => crate::async_rt::task::timeout(d, self.send_round_robin(message))
.await
.map_err(|_e| ZmqError::NoMessage)?,
}
}
pub(crate) async fn send_to_timed(
&self,
peer_id: &PeerIdentity,
message: crate::message::ZmqMessage,
send_timeout: Option<std::time::Duration>,
) -> ZmqResult<()> {
match send_timeout {
None => self.send_to(peer_id, message).await,
Some(d) => crate::async_rt::task::timeout(d, self.send_to(peer_id, message))
.await
.map_err(|_e| ZmqError::NoMessage)?,
}
}
pub(crate) async fn send_round_robin(
&self,
message: crate::message::ZmqMessage,
) -> ZmqResult<PeerIdentity> {
if self.socket_options.immediate && self.registry.is_empty() {
return Err(ZmqError::ReturnToSender {
reason: "Not connected to peers. Unable to send messages".into(),
message,
});
}
let (key, engine) = match self.registry.next_round_robin() {
Some(pair) => pair,
None => {
return Err(ZmqError::ReturnToSender {
reason: "Not connected to peers. Unable to send messages".into(),
message,
})
}
};
match engine.send_msg(message).await {
Ok(()) => self
.registry
.id_for(key)
.ok_or(ZmqError::Other("Peer disappeared mid-send".into())),
Err(SendError::Enqueue(Message::Message(m))) => {
self.registry.remove_by_key(key);
Err(ZmqError::ReturnToSender {
reason: "Not connected to peers. Unable to send messages".into(),
message: m,
})
}
Err(SendError::Enqueue(_) | SendError::Flush) => {
self.registry.remove_by_key(key);
Err(ZmqError::Other("Peer disconnected during send".into()))
}
}
}
pub(crate) async fn send_to(
&self,
peer_id: &PeerIdentity,
message: crate::message::ZmqMessage,
) -> ZmqResult<()> {
if self.socket_options.immediate && self.registry.is_empty() {
return Err(ZmqError::ReturnToSender {
reason: "Not connected to peers. Unable to send messages".into(),
message,
});
}
let (key, engine) = self.registry.get_by_id(peer_id).ok_or(ZmqError::Other(
"Destination client not found by identity".into(),
))?;
match engine.send_msg(message).await {
Ok(()) => Ok(()),
Err(SendError::Enqueue(_) | SendError::Flush) => {
self.registry.remove_by_key(key);
Err(ZmqError::Other(
"Destination client not found by identity".into(),
))
}
}
}
}
impl SocketBackend for GenericSocketBackend {
fn socket_type(&self) -> SocketType {
self.socket_type
}
fn socket_options(&self) -> &SocketOptions {
&self.socket_options
}
fn shutdown(&self) {
self.registry.clear();
self.conflate_slots.lock().clear();
}
fn monitor(&self) -> &Mutex<Option<mpsc::Sender<SocketEvent>>> {
&self.socket_monitor
}
}
pub(crate) trait HasRegistry {
fn registry(&self) -> &PeerRegistry;
}
impl HasRegistry for GenericSocketBackend {
fn registry(&self) -> &PeerRegistry {
&self.registry
}
}
#[cfg(feature = "inproc")]
pub(crate) trait HasInproc {
fn inproc_inbound_rx(&self) -> &crossbeam_channel::Receiver<TaggedMsg>;
fn inproc_notify(&self) -> &Arc<crate::async_rt::notify::RuntimeNotify>;
}
#[cfg(feature = "inproc")]
impl HasInproc for GenericSocketBackend {
#[inline]
fn inproc_inbound_rx(&self) -> &crossbeam_channel::Receiver<TaggedMsg> {
&self.inproc_inbound_rx
}
#[inline]
fn inproc_notify(&self) -> &Arc<crate::async_rt::notify::RuntimeNotify> {
&self.inproc_notify
}
}
impl MultiPeerBackend for GenericSocketBackend {
async fn peer_connected<R, W>(
self: Arc<Self>,
peer_id: &PeerIdentity,
io: FramedIo<R, W>,
endpoint: Option<Endpoint>,
) where
R: futures::Stream<Item = Result<Message, CodecError>> + Unpin + Send + 'static,
W: futures::Sink<Message, Error = CodecError> + Unpin + Send + IntoEngineWriter + 'static,
W::Writer: Send + 'static,
{
#[cfg(feature = "curve")]
let (read_half, write_half, curve) = io.into_parts();
#[cfg(not(feature = "curve"))]
let (read_half, write_half) = io.into_parts();
let inbound_tx = self.inbound_tx.clone();
let peer_id_owned = peer_id.clone();
#[cfg(feature = "curve")]
let mut curve = curve;
let writer = write_half.into_engine_writer();
let send_hwm = self.socket_options.send_hwm;
let conflate = self.socket_options.conflate;
let conflate_slot: Option<ConflateSlot> = if conflate {
Some(Arc::new(ConflateSlotInner {
slot: parking_lot::Mutex::new(None),
notify: self.conflate_notify.clone(),
}))
} else {
None
};
let conflate_slot_for_engine = conflate_slot.clone();
let inline_write_max: Option<Option<usize>> = match self.socket_options.inline_write_max {
Some(Some(0)) => Some(None),
Some(Some(n)) => Some(Some(n)),
Some(None) | None => None,
};
#[cfg(feature = "curve")]
let inline_write_max: Option<Option<usize>> = if curve.is_some() {
None
} else {
inline_write_max
};
let config = crate::engine::peer_loop::PeerConfig {
heartbeat: Self::build_heartbeat_cfg(&self.socket_options),
max_msg_size: self.socket_options.max_msg_size,
#[cfg(feature = "curve")]
curve: curve.take(),
conflate_slot: conflate_slot_for_engine,
out_batch_size: self.socket_options.out_batch_size,
out_batch_msgs: self.socket_options.out_batch_msgs.unwrap_or(None),
in_batch_msgs: self.socket_options.in_batch_msgs,
inline_write_max,
};
let (key, _prev) = self.registry.insert_with(peer_id.clone(), |key| {
make_framed_engine(Arc::new(PeerEngine::spawn(
key,
peer_id_owned,
read_half,
writer,
send_hwm,
inbound_tx,
config,
)))
});
if let Some(slot) = conflate_slot {
self.conflate_slots.lock().insert(key, slot);
}
if let Some(hello) = &self.socket_options.hello_msg {
if let Some((_, engine)) = self.registry.get_by_id(peer_id) {
let _ = engine.try_send_oneshot(hello.clone());
}
}
if let Some(ep) = endpoint {
if let Some(tx) = self.socket_monitor.lock().as_mut() {
let _ = tx.try_send(SocketEvent::Accepted(ep, peer_id.clone()));
}
}
}
#[cfg(feature = "inproc")]
#[allow(private_interfaces)]
async fn peer_connected_inproc(
self: Arc<Self>,
peer_id: &PeerIdentity,
peer: crate::transport::inproc::InprocPeer,
endpoint: Option<Endpoint>,
) -> crate::ZmqResult<()> {
let inproc_tx = self.inproc_inbound_tx.clone();
let inproc_notify = self.inproc_notify.clone();
let (local_key, _prev) = self.registry.insert_with(peer_id.clone(), |_| {
AnyEngine::Inproc(Arc::new(inproc_placeholder_engine()))
});
let local_socket_type = self.socket_type();
let local_routing_id = self.socket_options.peer_id.clone();
let (engine, remote_routing_id) = match connect_inproc_engine(
local_key,
local_socket_type,
local_routing_id,
inproc_tx,
inproc_notify,
peer,
)
.await
{
Ok(pair) => pair,
Err(e) => {
self.peer_disconnected(peer_id);
return Err(e);
}
};
let effective_peer_id = if local_socket_type.wants_remote_routing_id() {
if let Some(remote_id) = remote_routing_id {
if self.registry.rename_peer_id(local_key, remote_id.clone()) {
remote_id
} else {
peer_id.clone()
}
} else {
peer_id.clone()
}
} else {
peer_id.clone()
};
self.registry
.replace_engine(local_key, AnyEngine::Inproc(Arc::new(engine)));
if let Some(hello) = &self.socket_options.hello_msg {
if let Some((_, e)) = self.registry.get_by_id(&effective_peer_id) {
let _ = e.try_send_oneshot(hello.clone());
}
}
if let Some(ep) = endpoint {
if let Some(tx) = self.socket_monitor.lock().as_mut() {
let _ = tx.try_send(SocketEvent::Accepted(ep, effective_peer_id));
}
}
Ok(())
}
fn peer_disconnected(&self, peer_id: &PeerIdentity) {
if let Some(disc) = &self.socket_options.disconnect_msg {
if let Some((_, engine)) = self.registry.get_by_id(peer_id) {
let _ = engine.try_send_oneshot(disc.clone());
}
}
if let Some((key, _)) = self.registry.remove_by_id(peer_id) {
self.conflate_slots.lock().remove(&key);
}
if let Some(tx) = self.socket_monitor.lock().as_mut() {
let _ = tx.try_send(SocketEvent::Disconnected(peer_id.clone()));
}
if let Some(mut notifier) = self.disconnect_notifiers.lock().remove(peer_id) {
let _ = notifier.try_send(peer_id.clone());
}
}
}
pub(crate) type TaggedMsg = (
crate::engine::registry::PeerKey,
Result<Message, CodecError>,
);
fn handle_tagged_msg<B: MultiPeerBackend + HasRegistry + ?Sized>(
item: TaggedMsg,
backend: &B,
) -> Option<ZmqResult<(crate::engine::registry::PeerKey, ZmqMessage)>> {
match item {
(peer_key, Ok(Message::Message(m))) => Some(Ok((peer_key, m))),
(_, Ok(_)) => None,
(peer_key, Err(CodecError::PeerDisconnected)) => {
if let Some(id) = backend.registry().id_for(peer_key) {
backend.peer_disconnected(&id);
}
None
}
(peer_key, Err(e)) => {
if let Some(id) = backend.registry().id_for(peer_key) {
backend.peer_disconnected(&id);
}
Some(Err(e.into()))
}
}
}
#[cfg(feature = "inproc")]
pub(crate) async fn recv_next<B>(
async_inbound: &Receiver<TaggedMsg>,
backend: &B,
) -> ZmqResult<(crate::engine::registry::PeerKey, ZmqMessage)>
where
B: MultiPeerBackend + HasRegistry + HasInproc + ?Sized,
{
use crate::async_rt::notify::AsyncNotify;
use futures::FutureExt;
let inproc_inbound = backend.inproc_inbound_rx();
let inproc_notify = backend.inproc_notify();
loop {
while let Ok(item) = inproc_inbound.try_recv() {
if let Some(result) = handle_tagged_msg(item, backend) {
crate::wake_counter::bump(&crate::wake_counter::RECV_NEXT_WAKES);
return result;
}
}
loop {
match async_inbound.try_recv() {
Ok(item) => {
if let Some(result) = handle_tagged_msg(item, backend) {
crate::wake_counter::bump(&crate::wake_counter::RECV_NEXT_WAKES);
return result;
}
}
Err(flume::TryRecvError::Empty) => break,
Err(flume::TryRecvError::Disconnected) => return Err(ZmqError::NoMessage),
}
}
futures::select! {
item = async_inbound.recv_async().fuse() => {
match item {
Ok(item) => {
if let Some(result) = handle_tagged_msg(item, backend) {
crate::wake_counter::bump(&crate::wake_counter::RECV_NEXT_WAKES);
return result;
}
}
Err(_closed) => return Err(ZmqError::NoMessage),
}
}
_ = inproc_notify.notified().fuse() => {
}
}
}
}
#[cfg(not(feature = "inproc"))]
pub(crate) async fn recv_next<B: MultiPeerBackend + HasRegistry + ?Sized>(
async_inbound: &Receiver<TaggedMsg>,
backend: &B,
) -> ZmqResult<(crate::engine::registry::PeerKey, ZmqMessage)> {
loop {
match async_inbound.recv_async().await {
Ok(item) => {
if let Some(result) = handle_tagged_msg(item, backend) {
crate::wake_counter::bump(&crate::wake_counter::RECV_NEXT_WAKES);
return result;
}
}
Err(_closed) => return Err(ZmqError::NoMessage),
}
}
}