#[cfg(all(feature = "tokio", any(feature = "tcp", test)))]
use crate::codec::ZmtpVersion;
#[cfg(any(
all(feature = "tokio", any(feature = "tcp", test)),
feature = "tcp",
all(feature = "ipc", target_family = "unix")
))]
use crate::codec::{CodecError, FramedIo, Message, ZmqGreeting};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::codec::{CodecResult, ZmqCommand, ZmqCommandName};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::peer_identity::PeerIdentity;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::SocketOptions;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use crate::SocketType;
#[cfg(any(
all(feature = "tokio", any(feature = "tcp", test)),
feature = "tcp",
all(feature = "ipc", target_family = "unix")
))]
use crate::{ZmqError, ZmqResult};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use std::collections::HashMap;
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use std::convert::{TryFrom, TryInto};
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
use bytes::Bytes;
#[cfg(any(
all(feature = "tokio", any(feature = "tcp", test)),
feature = "tcp",
all(feature = "ipc", target_family = "unix")
))]
use futures::{Sink, SinkExt, Stream, StreamExt};
#[cfg(all(feature = "tokio", any(feature = "tcp", test)))]
pub(crate) fn negotiate_version(greeting: Message) -> ZmqResult<ZmtpVersion> {
let my_version = ZmqGreeting::default().version;
match greeting {
Message::Greeting(peer) => {
if peer.version >= my_version {
Ok(my_version)
} else {
Err(ZmqError::UnsupportedVersion(peer.version))
}
}
_ => Err(ZmqError::Other("Failed Greeting exchange".into())),
}
}
#[cfg(all(feature = "tokio", any(feature = "tcp", test)))]
pub(crate) async fn greet_exchange<R, W>(raw_socket: &mut FramedIo<R, W>) -> ZmqResult<ZmtpVersion>
where
R: Stream<Item = Result<Message, CodecError>> + Unpin,
W: Sink<Message, Error = CodecError> + Unpin,
{
raw_socket
.write_half
.send(Message::Greeting(ZmqGreeting::default()))
.await?;
let greeting = match raw_socket.read_half.next().await {
Some(message) => message?,
None => return Err(ZmqError::Other("Failed Greeting exchange".into())),
};
negotiate_version(greeting)
}
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
pub(crate) async fn greet_exchange_full<R, W>(
raw_socket: &mut FramedIo<R, W>,
opts: &SocketOptions,
) -> ZmqResult<ZmqGreeting>
where
R: Stream<Item = Result<Message, CodecError>> + Unpin,
W: Sink<Message, Error = CodecError> + Unpin,
{
raw_socket
.write_half
.send(Message::Greeting(ZmqGreeting::from_options(opts)))
.await?;
match raw_socket.read_half.next().await {
Some(Ok(Message::Greeting(peer))) => {
let my_version = ZmqGreeting::default().version;
if peer.version < my_version {
return Err(ZmqError::UnsupportedVersion(peer.version));
}
Ok(peer)
}
Some(Ok(_)) | None => Err(ZmqError::Other("Failed Greeting exchange".into())),
Some(Err(e)) => Err(e.into()),
}
}
#[cfg(any(feature = "tcp", all(feature = "ipc", target_family = "unix")))]
pub(crate) async fn ready_exchange<R, W>(
raw_socket: &mut FramedIo<R, W>,
socket_type: SocketType,
props: Option<HashMap<String, Bytes>>,
) -> ZmqResult<PeerIdentity>
where
R: Stream<Item = Result<Message, CodecError>> + Unpin,
W: Sink<Message, Error = CodecError> + Unpin,
{
let mut ready = ZmqCommand::ready(socket_type);
if let Some(props) = props {
ready.add_properties(props);
}
raw_socket.write_half.send(Message::Command(ready)).await?;
let ready_repl: Option<CodecResult<Message>> = raw_socket.read_half.next().await;
match ready_repl {
Some(Ok(Message::Command(command))) => match command.name {
ZmqCommandName::READY => {
let other_sock_type = match command.properties.get("Socket-Type") {
Some(s) => SocketType::try_from(&s[..])?,
None => Err(ZmqError::Other("Failed to parse other socket type".into()))?,
};
let peer_id = command
.properties
.get("Identity")
.map(|x| x.clone().try_into())
.transpose()?
.unwrap_or_default();
if socket_type.compatible(other_sock_type) {
Ok(peer_id)
} else {
Err(ZmqError::IncompatiblePeer)
}
}
},
Some(Ok(Message::SecurityRaw(raw))) => {
let command = ZmqCommand::try_from(raw).map_err(ZmqError::from)?;
match command.name {
ZmqCommandName::READY => {
let other_sock_type = match command.properties.get("Socket-Type") {
Some(s) => SocketType::try_from(&s[..])?,
None => {
return Err(ZmqError::Other("Failed to parse other socket type".into()))
}
};
let peer_id = command
.properties
.get("Identity")
.map(|x| x.clone().try_into())
.transpose()?
.unwrap_or_default();
if socket_type.compatible(other_sock_type) {
Ok(peer_id)
} else {
Err(ZmqError::Other(
"Provided sockets combination is not compatible".into(),
))
}
}
}
}
Some(Ok(_)) => Err(ZmqError::Other("Failed to confirm ready state".into())),
Some(Err(e)) => Err(e.into()),
None => Err(ZmqError::Other("No reply from server".into())),
}
}
#[cfg(all(test, feature = "tokio"))]
mod tests {
use super::*;
use crate::codec::mechanism::ZmqMechanism;
use crate::message::ZmqMessage;
fn new_greeting(version: ZmtpVersion) -> Message {
Message::Greeting(ZmqGreeting {
version,
mechanism: ZmqMechanism::PLAIN,
as_server: false,
})
}
#[test]
fn negotiate_version_peer_is_using_the_same_version() {
let peer_version = ZmqGreeting::default().version;
let expected = ZmqGreeting::default().version;
let actual = negotiate_version(new_greeting(peer_version)).unwrap();
assert_eq!(actual, expected);
}
#[test]
fn negotiate_version_peer_is_using_a_newer_version() {
let peer_version = (3, 1);
let expected = ZmqGreeting::default().version;
let actual = negotiate_version(new_greeting(peer_version)).unwrap();
assert_eq!(actual, expected);
}
#[test]
fn negotiate_version_peer_is_using_an_older_version() {
let peer_version = (2, 1);
let actual = negotiate_version(new_greeting(peer_version));
match actual {
Err(ZmqError::UnsupportedVersion(version)) => assert_eq!(version, peer_version),
_ => panic!("Unexpected result"),
}
}
#[test]
fn negotiate_version_invalid_greeting() {
let message = Message::Message(ZmqMessage::from(""));
let actual = negotiate_version(message);
match actual {
Err(ZmqError::Other(_)) => {}
_ => panic!("Unexpected result"),
}
}
}