use crate::async_rt;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::codec::handshake::{greet_exchange_full, ready_exchange};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::codec::{CodecError, FramedIo, IntoEngineWriter, Message};
use crate::endpoint::Endpoint;
use crate::peer_identity::PeerIdentity;
use crate::transport;
use crate::{MultiPeerBackend, ZmqError, ZmqResult};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use std::collections::HashMap;
use std::sync::Arc;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use futures::{Sink, Stream};
use rand::RngExt;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
pub(crate) async fn peer_connected<R, W, B>(
mut raw_socket: FramedIo<R, W>,
backend: Arc<B>,
endpoint: Option<crate::endpoint::Endpoint>,
peer_addr: Option<String>,
) -> ZmqResult<PeerIdentity>
where
R: Stream<Item = Result<Message, CodecError>> + Unpin + Send + 'static,
W: Sink<Message, Error = CodecError> + Unpin + Send + IntoEngineWriter + 'static,
W::Writer: Send + 'static,
B: MultiPeerBackend + 'static,
{
let opts = backend.socket_options();
let handshake_interval = opts.handshake_interval;
let handshake = async {
{
let peer_greeting = greet_exchange_full(&mut raw_socket, opts).await?;
let we_are_server = opts.plain_server || {
#[cfg(feature = "curve")]
{
opts.curve_server
}
#[cfg(not(feature = "curve"))]
{
false
}
};
let peer_uses_auth = !matches!(
peer_greeting.mechanism,
crate::codec::mechanism::ZmqMechanism::NULL
);
if peer_uses_auth && we_are_server && peer_greeting.as_server {
return Err(ZmqError::ServerRoleConflict);
}
#[cfg_attr(not(feature = "curve"), allow(unused_variables))]
let state = crate::mechanism::mech_handshake(
&mut raw_socket,
opts,
peer_greeting.mechanism,
&peer_greeting,
peer_addr.as_deref().unwrap_or(""),
backend.socket_type(),
)
.await?;
#[cfg(feature = "curve")]
{
raw_socket.curve = state.curve;
}
};
let skip_ready =
{
#[cfg(feature = "curve")]
{
raw_socket.curve.is_some()
}
#[cfg(not(feature = "curve"))]
{
false
}
} || matches!(opts.mechanism, crate::codec::mechanism::ZmqMechanism::PLAIN);
let peer_id = if skip_ready {
PeerIdentity::default()
} else {
let mut connect_ops: HashMap<String, bytes::Bytes> = HashMap::new();
if let Some(identity) = &opts.peer_id {
connect_ops.insert("Identity".to_string(), identity.clone().into());
}
for (k, v) in &opts.metadata {
connect_ops.insert(k.clone(), v.clone());
}
let props = if connect_ops.is_empty() {
None
} else {
Some(connect_ops)
};
ready_exchange(&mut raw_socket, backend.socket_type(), props).await?
};
Ok::<_, ZmqError>((peer_id, raw_socket))
};
let (peer_id, raw_socket) = match handshake_interval {
Some(d) => crate::async_rt::task::timeout(d, handshake)
.await
.map_err(|_e| ZmqError::HandshakeTimeout)??,
None => handshake.await?,
};
backend.peer_connected(&peer_id, raw_socket, endpoint).await;
Ok(peer_id)
}
pub(crate) async fn connect_peer_forever<B>(
endpoint: Endpoint,
backend: Arc<B>,
connect_timeout: Option<std::time::Duration>,
) -> ZmqResult<(Endpoint, PeerIdentity)>
where
B: MultiPeerBackend + 'static,
{
use crate::transport::TransportIo;
#[cfg(feature = "tcp")]
let tcp_cfg = crate::transport::TcpConfig::from_options(backend.socket_options());
#[cfg(not(feature = "tcp"))]
let tcp_cfg: () = ();
let mut try_num: u64 = 0;
loop {
match transport::connect(&endpoint, connect_timeout, &tcp_cfg).await {
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
Ok(TransportIo::Framed(io, resolved)) => {
let peer_addr = Some(resolved.to_string());
let peer_id =
peer_connected(*io, backend, Some(resolved.clone()), peer_addr).await?;
return Ok((resolved, peer_id));
}
#[cfg(feature = "inproc")]
Ok(TransportIo::Inproc(peer)) => {
let resolved = peer.endpoint.clone();
let peer_id = PeerIdentity::new();
backend
.peer_connected_inproc(&peer_id, peer, Some(resolved.clone()))
.await?;
return Ok((resolved, peer_id));
}
Err(ZmqError::Network(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
if try_num < 5 {
try_num += 1;
}
let delay = {
let mut rng = rand::rng();
std::f64::consts::E.powf(try_num as f64 / 3.0)
+ rng.random_range(0.0f64..0.1f64)
};
async_rt::task::sleep(std::time::Duration::from_secs_f64(delay)).await;
}
Err(e) => return Err(e),
}
}
}
#[cfg(all(test, feature = "tokio"))]
pub(crate) mod tests {
use crate::endpoint::Endpoint;
use crate::Socket;
use crate::ZmqResult;
pub async fn test_bind_to_unspecified_interface_helper(
any: std::net::IpAddr,
mut sock: impl Socket,
start_port: u16,
) -> ZmqResult<()> {
assert!(sock.binds().is_empty());
assert!(any.is_unspecified());
for i in 0..4 {
sock.bind(
Endpoint::Tcp(any.into(), start_port + i)
.to_string()
.as_str(),
)
.await?;
}
let bound_to = sock.binds();
assert_eq!(bound_to.len(), 4);
let mut port_set = std::collections::HashSet::new();
for b in bound_to.keys() {
if let Endpoint::Tcp(host, port) = b {
assert_eq!(host, &any.into());
port_set.insert(*port);
} else {
unreachable!()
}
}
(start_port..start_port + 4).for_each(|p| assert!(port_set.contains(&p)));
Ok(())
}
pub async fn test_bind_to_any_port_helper(mut sock: impl Socket) -> ZmqResult<()> {
use crate::endpoint::Host;
assert!(sock.binds().is_empty());
for _ in 0..4 {
sock.bind("tcp://localhost:0").await?;
}
let bound_to = sock.binds();
assert_eq!(bound_to.len(), 4);
let mut port_set = std::collections::HashSet::new();
for b in bound_to.keys() {
if let Endpoint::Tcp(host, port) = b {
assert_eq!(host, &Host::Domain("localhost".to_string()));
assert_ne!(*port, 0);
assert!(port_set.insert(*port));
} else {
unreachable!()
}
}
Ok(())
}
}