use crate::codec::{CodecError, FramedIo, IntoEngineWriter, Message};
use crate::engine::registry::{make_framed_engine, PeerRegistry};
use crate::engine::PeerEngine;
use crate::PeerIdentity;
use crate::{async_rt, CaptureSocket};
use crate::{
MultiPeerBackend, Socket, SocketBackend, SocketEvent, SocketOptions, SocketSend, SocketType,
ZmqMessage, ZmqResult,
};
use flume::{Receiver, Sender};
use futures::channel::mpsc;
use parking_lot::Mutex;
use std::sync::Arc;
type TaggedInbound = (
crate::engine::registry::PeerKey,
Result<Message, CodecError>,
);
type InboundRx = Mutex<Option<Receiver<TaggedInbound>>>;
#[cfg(feature = "inproc")]
type InprocInboundRxCell = Mutex<Option<crate::engine::InprocInboundRx>>;
#[doc(hidden)]
pub struct PubSocketBackend {
registry: PeerRegistry,
router: crate::socket::topic_router::TopicRouter,
inbound_tx: Sender<TaggedInbound>,
inbound_rx: InboundRx,
#[cfg(feature = "inproc")]
inproc_inbound_tx: crate::engine::InprocInboundTx,
#[cfg(feature = "inproc")]
inproc_inbound_rx: InprocInboundRxCell,
#[cfg(feature = "inproc")]
pub(crate) inproc_notify: Arc<crate::async_rt::notify::RuntimeNotify>,
socket_options: SocketOptions,
pub(crate) socket_monitor: Mutex<Option<mpsc::Sender<SocketEvent>>>,
}
impl PubSocketBackend {
fn with_options(options: SocketOptions) -> Self {
let (inbound_tx, inbound_rx) = flume::bounded(options.receive_hwm);
#[cfg(feature = "inproc")]
let (inproc_inbound_tx, inproc_inbound_rx) =
crossbeam_channel::bounded(options.receive_hwm);
#[cfg(feature = "inproc")]
let inproc_notify = Arc::new(crate::async_rt::notify::RuntimeNotify::new());
Self {
registry: PeerRegistry::new(),
router: crate::socket::topic_router::TopicRouter::new(),
inbound_tx,
inbound_rx: Mutex::new(Some(inbound_rx)),
#[cfg(feature = "inproc")]
inproc_inbound_tx,
#[cfg(feature = "inproc")]
inproc_inbound_rx: Mutex::new(Some(inproc_inbound_rx)),
#[cfg(feature = "inproc")]
inproc_notify,
socket_options: options,
socket_monitor: Mutex::new(None),
}
}
fn apply_sub_message(&self, peer_key: crate::engine::registry::PeerKey, message: Message) {
let data = match message {
Message::Message(m) => {
if m.len() != 1 {
log::warn!("PUB sub message unexpected length: {}", m.len());
return;
}
m.into_vec().pop().unwrap_or_default()
}
_ => return,
};
let _ = self.router.apply_sub_message(peer_key, &data);
}
}
impl SocketBackend for PubSocketBackend {
fn socket_type(&self) -> SocketType {
SocketType::PUB
}
fn socket_options(&self) -> &SocketOptions {
&self.socket_options
}
fn shutdown(&self) {
self.registry.clear();
self.router.clear();
}
fn monitor(&self) -> &Mutex<Option<mpsc::Sender<SocketEvent>>> {
&self.socket_monitor
}
}
impl MultiPeerBackend for PubSocketBackend {
async fn peer_connected<R, W>(
self: Arc<Self>,
peer_id: &PeerIdentity,
io: FramedIo<R, W>,
_endpoint: Option<crate::endpoint::Endpoint>,
) where
R: futures::Stream<Item = Result<Message, CodecError>> + Unpin + Send + 'static,
W: futures::Sink<Message, Error = CodecError> + Unpin + Send + IntoEngineWriter + 'static,
W::Writer: Send + 'static,
{
#[cfg(feature = "curve")]
let (read_half, write_half, _curve) = io.into_parts();
#[cfg(not(feature = "curve"))]
let (read_half, write_half) = io.into_parts();
let inbound_tx = self.inbound_tx.clone();
let peer_id_owned = peer_id.clone();
let writer = write_half.into_engine_writer();
let send_hwm = self.socket_options.send_hwm;
let (key, _prev) = self.registry.insert_with(peer_id.clone(), |key| {
make_framed_engine(Arc::new(PeerEngine::spawn(
key,
peer_id_owned,
read_half,
writer,
send_hwm,
inbound_tx,
crate::engine::peer_loop::PeerConfig::default(),
)))
});
self.router.register_peer(key);
if let Some(hello) = &self.socket_options.hello_msg {
if let Some((_, engine)) = self.registry.get_by_id(peer_id) {
let _ = engine.try_send_oneshot(hello.clone());
}
}
}
#[cfg(feature = "inproc")]
#[allow(private_interfaces)]
async fn peer_connected_inproc(
self: Arc<Self>,
peer_id: &PeerIdentity,
peer: crate::transport::inproc::InprocPeer,
_endpoint: Option<crate::endpoint::Endpoint>,
) -> crate::ZmqResult<()> {
let inproc_tx = self.inproc_inbound_tx.clone();
let inproc_notify = self.inproc_notify.clone();
let (local_key, _) = self.registry.insert_with(peer_id.clone(), |_| {
crate::engine::registry::AnyEngine::Inproc(Arc::new(
crate::engine::inproc_placeholder_engine(),
))
});
self.router.register_peer(local_key);
let local_socket_type = self.socket_type();
let local_routing_id = self.socket_options.peer_id.clone();
let (engine, _remote_routing_id) = match crate::engine::connect_inproc_engine(
local_key,
local_socket_type,
local_routing_id,
inproc_tx,
inproc_notify,
peer,
)
.await
{
Ok(pair) => pair,
Err(e) => {
self.peer_disconnected(peer_id);
return Err(e);
}
};
let engine = Arc::new(engine);
self.registry.replace_engine(
local_key,
crate::engine::registry::AnyEngine::Inproc(engine.clone()),
);
if let Some(hello) = &self.socket_options.hello_msg {
let _ = engine.try_send_direct(hello.clone());
}
Ok(())
}
fn peer_disconnected(&self, peer_id: &PeerIdentity) {
if let Some(disc) = &self.socket_options.disconnect_msg {
if let Some((_, engine)) = self.registry.get_by_id(peer_id) {
let _ = engine.try_send_oneshot(disc.clone());
}
}
if let Some(monitor) = self.monitor().lock().as_mut() {
let _ = monitor.try_send(SocketEvent::Disconnected(peer_id.clone()));
}
if let Some((key, _)) = self.registry.remove_by_id(peer_id) {
self.router.forget_peer(key);
}
}
}
pub struct PubSocket {
pub(crate) common: crate::socket::common::SocketCommon<PubSocketBackend>,
peer_buf: Vec<(
crate::engine::registry::PeerKey,
crate::engine::registry::AnyEngine,
)>,
dead_buf: Vec<crate::engine::registry::PeerKey>,
}
impl crate::socket::family::sealed::Sealed for PubSocket {}
impl crate::socket::family::Publisher for PubSocket {}
impl Drop for PubSocket {
fn drop(&mut self) {
self.common.backend.shutdown();
}
}
impl crate::socket::common::HasCommon for PubSocket {
type Backend = PubSocketBackend;
fn common(&self) -> &crate::socket::common::SocketCommon<Self::Backend> {
&self.common
}
fn common_mut(&mut self) -> &mut crate::socket::common::SocketCommon<Self::Backend> {
&mut self.common
}
}
impl Socket for PubSocket {
type Backend = PubSocketBackend;
fn with_options(options: SocketOptions) -> Self {
let backend = Arc::new(PubSocketBackend::with_options(options));
let inbound_rx = backend
.inbound_rx
.lock()
.take()
.expect("inbound_rx taken twice");
let backend_weak = Arc::downgrade(&backend);
async_rt::task::spawn(async move {
while let Ok((peer_key, res)) = inbound_rx.recv_async().await {
let backend = match backend_weak.upgrade() {
Some(b) => b,
None => return,
};
match res {
Ok(msg) => backend.apply_sub_message(peer_key, msg),
Err(_) => {
if let Some(id) = backend.registry.id_for(peer_key) {
backend.peer_disconnected(&id);
}
}
}
}
});
#[cfg(feature = "inproc")]
{
use crate::async_rt::notify::AsyncNotify;
let inproc_rx = backend
.inproc_inbound_rx
.lock()
.take()
.expect("inproc_inbound_rx taken twice");
let inproc_notify = backend.inproc_notify.clone();
let backend_weak = Arc::downgrade(&backend);
async_rt::task::spawn(async move {
loop {
while let Ok((peer_key, res)) = inproc_rx.try_recv() {
let backend = match backend_weak.upgrade() {
Some(b) => b,
None => return,
};
match res {
Ok(msg) => backend.apply_sub_message(peer_key, msg),
Err(_) => {
if let Some(id) = backend.registry.id_for(peer_key) {
backend.peer_disconnected(&id);
}
}
}
}
if backend_weak.strong_count() == 0 {
return;
}
inproc_notify.notified().await;
}
});
}
Self {
common: crate::socket::common::SocketCommon::new(backend),
peer_buf: Vec::new(),
dead_buf: Vec::new(),
}
}
async fn linger_drain(&mut self) {
let opts = self.common.backend.socket_options();
crate::engine::registry::drain_registry(&self.common.backend.registry, opts).await;
}
}
impl SocketSend for PubSocket {
async fn send(&mut self, message: impl Into<ZmqMessage> + Send) -> ZmqResult<()> {
let message = message.into();
let first_frame = match message.get(0) {
Some(frame) => frame.clone(),
None => return Ok(()),
};
let shared = Arc::new(message);
self.common
.backend
.registry
.snapshot_into(&mut self.peer_buf);
self.dead_buf.clear();
let invert = self.common.backend.socket_options.invert_matching;
let dead = &mut self.dead_buf;
self.common.backend.router.with_match_guard(|m| {
for (key, engine) in self.peer_buf.iter() {
if !m.matches(*key, &first_frame, invert) {
continue;
}
use crate::engine::registry::TrySendOutcome;
match engine.try_send_fanout(shared.clone()) {
TrySendOutcome::Sent | TrySendOutcome::Full => {}
TrySendOutcome::Closed => dead.push(*key),
}
}
});
if !self.dead_buf.is_empty() {
for key in self.dead_buf.drain(..) {
if let Some(id) = self.common.backend.registry.id_for(key) {
self.common.backend.peer_disconnected(&id);
}
}
}
crate::async_rt::task::yield_now().await;
Ok(())
}
}
impl CaptureSocket for PubSocket {}
#[cfg(all(test, feature = "tokio", feature = "tcp"))]
mod tests {
use super::*;
use crate::socket::handshake::tests::{
test_bind_to_any_port_helper, test_bind_to_unspecified_interface_helper,
};
use std::net::IpAddr;
#[async_rt::test]
async fn test_bind_to_any_port() -> ZmqResult<()> {
let s = PubSocket::new();
test_bind_to_any_port_helper(s).await
}
#[async_rt::test]
async fn test_bind_to_any_ipv4_interface() -> ZmqResult<()> {
let any_ipv4: IpAddr = "0.0.0.0".parse().unwrap();
let s = PubSocket::new();
test_bind_to_unspecified_interface_helper(any_ipv4, s, 4000).await
}
#[async_rt::test]
async fn test_bind_to_any_ipv6_interface() -> ZmqResult<()> {
let any_ipv6: IpAddr = "::".parse().unwrap();
let s = PubSocket::new();
test_bind_to_unspecified_interface_helper(any_ipv6, s, 4010).await
}
}