rustzmq2 0.1.0

A native async Rust implementation of ZeroMQ
Documentation
//! Top-level peer-handshake orchestration: greeting + mechanism + READY.
//!
//! `peer_connected` runs the full ZMTP handshake on an already-open
//! `FramedIo`; `connect_peer_forever` wraps the transport dial loop around
//! it, retrying on `ConnectionRefused` (TCP only). Wire-level codec
//! primitives live in `crate::codec::handshake`.

use crate::async_rt;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::codec::handshake::{greet_exchange_full, ready_exchange};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::codec::{CodecError, FramedIo, IntoEngineWriter, Message};
use crate::endpoint::Endpoint;
use crate::peer_identity::PeerIdentity;
use crate::transport;
use crate::{MultiPeerBackend, ZmqError, ZmqResult};

#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use std::collections::HashMap;
use std::sync::Arc;

#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use futures::{Sink, Stream};
use rand::RngExt;

#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
pub(crate) async fn peer_connected<R, W, B>(
    mut raw_socket: FramedIo<R, W>,
    backend: Arc<B>,
    endpoint: Option<crate::endpoint::Endpoint>,
    peer_addr: Option<String>,
) -> ZmqResult<PeerIdentity>
where
    R: Stream<Item = Result<Message, CodecError>> + Unpin + Send + 'static,
    W: Sink<Message, Error = CodecError> + Unpin + Send + IntoEngineWriter + 'static,
    W::Writer: Send + 'static,
    B: MultiPeerBackend + 'static,
{
    let opts = backend.socket_options();
    let handshake_interval = opts.handshake_interval;

    let handshake = async {
        // ── Greeting + mechanism handshake ─────────────────────────────────
        {
            let peer_greeting = greet_exchange_full(&mut raw_socket, opts).await?;

            // Validate server/client role for authenticated mechanisms only.
            // NULL mechanism doesn't use the as_server flag meaningfully.
            let we_are_server = opts.plain_server || {
                #[cfg(feature = "curve")]
                {
                    opts.curve_server
                }
                #[cfg(not(feature = "curve"))]
                {
                    false
                }
            };
            let peer_uses_auth = !matches!(
                peer_greeting.mechanism,
                crate::codec::mechanism::ZmqMechanism::NULL
            );
            if peer_uses_auth && we_are_server && peer_greeting.as_server {
                return Err(ZmqError::ServerRoleConflict);
            }

            #[cfg_attr(not(feature = "curve"), allow(unused_variables))]
            let state = crate::mechanism::mech_handshake(
                &mut raw_socket,
                opts,
                peer_greeting.mechanism,
                &peer_greeting,
                peer_addr.as_deref().unwrap_or(""),
                backend.socket_type(),
            )
            .await?;

            #[cfg(feature = "curve")]
            {
                raw_socket.curve = state.curve;
            }
        };

        let skip_ready =
            {
                #[cfg(feature = "curve")]
                {
                    raw_socket.curve.is_some()
                }
                #[cfg(not(feature = "curve"))]
                {
                    false
                }
            } || matches!(opts.mechanism, crate::codec::mechanism::ZmqMechanism::PLAIN);
        let peer_id = if skip_ready {
            PeerIdentity::default()
        } else {
            let mut connect_ops: HashMap<String, bytes::Bytes> = HashMap::new();
            if let Some(identity) = &opts.peer_id {
                connect_ops.insert("Identity".to_string(), identity.clone().into());
            }
            for (k, v) in &opts.metadata {
                connect_ops.insert(k.clone(), v.clone());
            }
            let props = if connect_ops.is_empty() {
                None
            } else {
                Some(connect_ops)
            };
            ready_exchange(&mut raw_socket, backend.socket_type(), props).await?
        };
        Ok::<_, ZmqError>((peer_id, raw_socket))
    };

    let (peer_id, raw_socket) = match handshake_interval {
        Some(d) => crate::async_rt::task::timeout(d, handshake)
            .await
            .map_err(|_e| ZmqError::HandshakeTimeout)??,
        None => handshake.await?,
    };
    backend.peer_connected(&peer_id, raw_socket, endpoint).await;
    Ok(peer_id)
}

/// Connect to `endpoint` and register the new peer with `backend`, retrying
/// on `ConnectionRefused` (TCP only). Returns `(resolved_endpoint, peer_id)`.
///
/// This is the single call site that understands `TransportIo`: wire transports
/// run the full ZMTP handshake; inproc bypasses it entirely.
pub(crate) async fn connect_peer_forever<B>(
    endpoint: Endpoint,
    backend: Arc<B>,
    connect_timeout: Option<std::time::Duration>,
) -> ZmqResult<(Endpoint, PeerIdentity)>
where
    B: MultiPeerBackend + 'static,
{
    use crate::transport::TransportIo;
    #[cfg(feature = "tcp")]
    let tcp_cfg = crate::transport::TcpConfig::from_options(backend.socket_options());
    #[cfg(not(feature = "tcp"))]
    let tcp_cfg: () = ();
    let mut try_num: u64 = 0;
    loop {
        match transport::connect(&endpoint, connect_timeout, &tcp_cfg).await {
            #[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
            Ok(TransportIo::Framed(io, resolved)) => {
                let peer_addr = Some(resolved.to_string());
                let peer_id =
                    peer_connected(*io, backend, Some(resolved.clone()), peer_addr).await?;
                return Ok((resolved, peer_id));
            }
            #[cfg(feature = "inproc")]
            Ok(TransportIo::Inproc(peer)) => {
                let resolved = peer.endpoint.clone();
                // Placeholder id: ROUTER replaces this with the remote's
                // advertised routing_id inside peer_connected_inproc.
                let peer_id = PeerIdentity::new();
                backend
                    .peer_connected_inproc(&peer_id, peer, Some(resolved.clone()))
                    .await?;
                return Ok((resolved, peer_id));
            }
            Err(ZmqError::Network(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
                if try_num < 5 {
                    try_num += 1;
                }
                let delay = {
                    let mut rng = rand::rng();
                    std::f64::consts::E.powf(try_num as f64 / 3.0)
                        + rng.random_range(0.0f64..0.1f64)
                };
                async_rt::task::sleep(std::time::Duration::from_secs_f64(delay)).await;
            }
            Err(e) => return Err(e),
        }
    }
}

#[cfg(all(test, feature = "tokio"))]
pub(crate) mod tests {
    use crate::endpoint::Endpoint;
    use crate::Socket;
    use crate::ZmqResult;

    pub async fn test_bind_to_unspecified_interface_helper(
        any: std::net::IpAddr,
        mut sock: impl Socket,
        start_port: u16,
    ) -> ZmqResult<()> {
        assert!(sock.binds().is_empty());
        assert!(any.is_unspecified());

        for i in 0..4 {
            sock.bind(
                Endpoint::Tcp(any.into(), start_port + i)
                    .to_string()
                    .as_str(),
            )
            .await?;
        }

        let bound_to = sock.binds();
        assert_eq!(bound_to.len(), 4);

        let mut port_set = std::collections::HashSet::new();
        for b in bound_to.keys() {
            if let Endpoint::Tcp(host, port) = b {
                assert_eq!(host, &any.into());
                port_set.insert(*port);
            } else {
                unreachable!()
            }
        }

        (start_port..start_port + 4).for_each(|p| assert!(port_set.contains(&p)));

        Ok(())
    }

    pub async fn test_bind_to_any_port_helper(mut sock: impl Socket) -> ZmqResult<()> {
        use crate::endpoint::Host;

        assert!(sock.binds().is_empty());
        for _ in 0..4 {
            sock.bind("tcp://localhost:0").await?;
        }

        let bound_to = sock.binds();
        assert_eq!(bound_to.len(), 4);
        let mut port_set = std::collections::HashSet::new();
        for b in bound_to.keys() {
            if let Endpoint::Tcp(host, port) = b {
                assert_eq!(host, &Host::Domain("localhost".to_string()));
                assert_ne!(*port, 0);
                assert!(port_set.insert(*port));
            } else {
                unreachable!()
            }
        }

        Ok(())
    }
}