rustzmq2 0.1.0

A native async Rust implementation of ZeroMQ
Documentation
//! ZMTP greeting + READY command exchange over a `FramedIo`.
//!
//! Implements the wire-level handshake from [ZMTP RFC 23](https://rfc.zeromq.org/spec/23/):
//! swap `ZmqGreeting`, negotiate version, then swap `READY` commands carrying
//! socket-type + metadata properties. Security-mechanism handshakes (PLAIN,
//! CURVE) run *between* greeting and READY and live in `crate::mechanism`.

#[cfg(all(feature = "tokio", any(feature = "tcp", test)))]
use crate::codec::ZmtpVersion;
#[cfg(any(
    all(feature = "tokio", any(feature = "tcp", test)),
    feature = "tcp",
    all(feature = "ipc", target_family = "unix")
))]
use crate::codec::{CodecError, FramedIo, Message, ZmqGreeting};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::codec::{CodecResult, ZmqCommand, ZmqCommandName};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::peer_identity::PeerIdentity;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::SocketOptions;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::SocketType;
#[cfg(any(
    all(feature = "tokio", any(feature = "tcp", test)),
    feature = "tcp",
    all(feature = "ipc", target_family = "unix")
))]
use crate::{ZmqError, ZmqResult};

#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use std::collections::HashMap;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use std::convert::{TryFrom, TryInto};

#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use bytes::Bytes;
#[cfg(any(
    all(feature = "tokio", any(feature = "tcp", test)),
    feature = "tcp",
    all(feature = "ipc", target_family = "unix")
))]
use futures::{Sink, SinkExt, Stream, StreamExt};

/// Given the result of the greetings exchange, determines the version of the
/// ZMTP protocol that should be used for communication with the peer according
/// to [ZeroMQ RFC 23](https://rfc.zeromq.org/spec/23/#version-negotiation).
#[cfg(all(feature = "tokio", any(feature = "tcp", test)))]
pub(crate) fn negotiate_version(greeting: Message) -> ZmqResult<ZmtpVersion> {
    let my_version = ZmqGreeting::default().version;

    match greeting {
        Message::Greeting(peer) => {
            if peer.version >= my_version {
                // A peer MUST accept higher protocol versions as valid. That is,
                // a ZMTP peer MUST accept protocol versions greater or equal to 3.0.
                // This allows future implementations to safely interoperate with
                // current implementations.
                //
                // A peer SHALL always use its own protocol (including framing)
                // when talking to an equal or higher protocol peer.
                Ok(my_version)
            } else {
                // Per RFC 23: a peer MAY downgrade to talk to an older peer, but if it
                // cannot, it MUST close the connection. We do not support downgrading.
                Err(ZmqError::UnsupportedVersion(peer.version))
            }
        }
        _ => Err(ZmqError::Other("Failed Greeting exchange".into())),
    }
}

#[cfg(all(feature = "tokio", any(feature = "tcp", test)))]
pub(crate) async fn greet_exchange<R, W>(raw_socket: &mut FramedIo<R, W>) -> ZmqResult<ZmtpVersion>
where
    R: Stream<Item = Result<Message, CodecError>> + Unpin,
    W: Sink<Message, Error = CodecError> + Unpin,
{
    raw_socket
        .write_half
        .send(Message::Greeting(ZmqGreeting::default()))
        .await?;

    let greeting = match raw_socket.read_half.next().await {
        Some(message) => message?,
        None => return Err(ZmqError::Other("Failed Greeting exchange".into())),
    };
    negotiate_version(greeting)
}

/// Like `greet_exchange` but returns the full peer `ZmqGreeting` so that the
/// caller can inspect the peer's mechanism and server flag.
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
pub(crate) async fn greet_exchange_full<R, W>(
    raw_socket: &mut FramedIo<R, W>,
    opts: &SocketOptions,
) -> ZmqResult<ZmqGreeting>
where
    R: Stream<Item = Result<Message, CodecError>> + Unpin,
    W: Sink<Message, Error = CodecError> + Unpin,
{
    raw_socket
        .write_half
        .send(Message::Greeting(ZmqGreeting::from_options(opts)))
        .await?;

    match raw_socket.read_half.next().await {
        Some(Ok(Message::Greeting(peer))) => {
            let my_version = ZmqGreeting::default().version;
            if peer.version < my_version {
                return Err(ZmqError::UnsupportedVersion(peer.version));
            }
            Ok(peer)
        }
        Some(Ok(_)) | None => Err(ZmqError::Other("Failed Greeting exchange".into())),
        Some(Err(e)) => Err(e.into()),
    }
}

#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
pub(crate) async fn ready_exchange<R, W>(
    raw_socket: &mut FramedIo<R, W>,
    socket_type: SocketType,
    props: Option<HashMap<String, Bytes>>,
) -> ZmqResult<PeerIdentity>
where
    R: Stream<Item = Result<Message, CodecError>> + Unpin,
    W: Sink<Message, Error = CodecError> + Unpin,
{
    let mut ready = ZmqCommand::ready(socket_type);
    if let Some(props) = props {
        ready.add_properties(props);
    }
    raw_socket.write_half.send(Message::Command(ready)).await?;

    let ready_repl: Option<CodecResult<Message>> = raw_socket.read_half.next().await;
    match ready_repl {
        Some(Ok(Message::Command(command))) => match command.name {
            ZmqCommandName::READY => {
                let other_sock_type = match command.properties.get("Socket-Type") {
                    Some(s) => SocketType::try_from(&s[..])?,
                    None => Err(ZmqError::Other("Failed to parse other socket type".into()))?,
                };

                let peer_id = command
                    .properties
                    .get("Identity")
                    .map(|x| x.clone().try_into())
                    .transpose()?
                    .unwrap_or_default();

                if socket_type.compatible(other_sock_type) {
                    Ok(peer_id)
                } else {
                    Err(ZmqError::IncompatiblePeer)
                }
            }
        },
        // After a CURVE/PLAIN handshake the codec routes READY as SecurityRaw.
        // Re-parse it as a ZmqCommand so the rest of this function stays the same.
        Some(Ok(Message::SecurityRaw(raw))) => {
            let command = ZmqCommand::try_from(raw).map_err(ZmqError::from)?;
            match command.name {
                ZmqCommandName::READY => {
                    let other_sock_type = match command.properties.get("Socket-Type") {
                        Some(s) => SocketType::try_from(&s[..])?,
                        None => {
                            return Err(ZmqError::Other("Failed to parse other socket type".into()))
                        }
                    };
                    let peer_id = command
                        .properties
                        .get("Identity")
                        .map(|x| x.clone().try_into())
                        .transpose()?
                        .unwrap_or_default();
                    if socket_type.compatible(other_sock_type) {
                        Ok(peer_id)
                    } else {
                        Err(ZmqError::Other(
                            "Provided sockets combination is not compatible".into(),
                        ))
                    }
                }
            }
        }
        Some(Ok(_)) => Err(ZmqError::Other("Failed to confirm ready state".into())),
        Some(Err(e)) => Err(e.into()),
        None => Err(ZmqError::Other("No reply from server".into())),
    }
}

#[cfg(all(test, feature = "tokio"))]
mod tests {
    use super::*;
    use crate::codec::mechanism::ZmqMechanism;
    use crate::message::ZmqMessage;

    fn new_greeting(version: ZmtpVersion) -> Message {
        Message::Greeting(ZmqGreeting {
            version,
            mechanism: ZmqMechanism::PLAIN,
            as_server: false,
        })
    }

    #[test]
    fn negotiate_version_peer_is_using_the_same_version() {
        let peer_version = ZmqGreeting::default().version;
        let expected = ZmqGreeting::default().version;
        let actual = negotiate_version(new_greeting(peer_version)).unwrap();
        assert_eq!(actual, expected);
    }

    #[test]
    fn negotiate_version_peer_is_using_a_newer_version() {
        let peer_version = (3, 1);
        let expected = ZmqGreeting::default().version;
        let actual = negotiate_version(new_greeting(peer_version)).unwrap();
        assert_eq!(actual, expected);
    }

    #[test]
    fn negotiate_version_peer_is_using_an_older_version() {
        let peer_version = (2, 1);
        let actual = negotiate_version(new_greeting(peer_version));
        match actual {
            Err(ZmqError::UnsupportedVersion(version)) => assert_eq!(version, peer_version),
            _ => panic!("Unexpected result"),
        }
    }

    #[test]
    fn negotiate_version_invalid_greeting() {
        let message = Message::Message(ZmqMessage::from(""));
        let actual = negotiate_version(message);
        match actual {
            Err(ZmqError::Other(_)) => {}
            _ => panic!("Unexpected result"),
        }
    }
}