Skip to main content

zeromq/
util.rs

1use crate::codec::{CodecResult, FramedIo};
2use crate::*;
3
4use asynchronous_codec::FramedRead;
5use bytes::Bytes;
6use futures::{SinkExt, StreamExt};
7use rand::Rng;
8
9use std::convert::{TryFrom, TryInto};
10use std::ops::Deref;
11use std::str::FromStr;
12use std::sync::Arc;
13use uuid::Uuid;
14
15#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Clone)]
16pub struct PeerIdentity(Bytes);
17
18impl PeerIdentity {
19    pub const MAX_LENGTH: usize = 255;
20
21    pub fn new() -> Self {
22        let id = Uuid::new_v4();
23        Self(Bytes::copy_from_slice(id.as_bytes()))
24    }
25}
26
27impl AsRef<[u8]> for PeerIdentity {
28    fn as_ref(&self) -> &[u8] {
29        self.0.as_ref()
30    }
31}
32
33impl Deref for PeerIdentity {
34    type Target = [u8];
35
36    fn deref(&self) -> &[u8] {
37        self.0.as_ref()
38    }
39}
40
41impl Default for PeerIdentity {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl FromStr for PeerIdentity {
48    type Err = ZmqError;
49
50    fn from_str(s: &str) -> Result<Self, Self::Err> {
51        Self::try_from(s.as_bytes())
52    }
53}
54
55impl TryFrom<Bytes> for PeerIdentity {
56    type Error = ZmqError;
57
58    fn try_from(data: Bytes) -> Result<Self, ZmqError> {
59        if data.is_empty() {
60            Ok(Self::new())
61        } else if data.len() > Self::MAX_LENGTH {
62            Err(ZmqError::PeerIdentity)
63        } else {
64            Ok(Self(data))
65        }
66    }
67}
68
69impl TryFrom<&[u8]> for PeerIdentity {
70    type Error = ZmqError;
71
72    fn try_from(data: &[u8]) -> Result<Self, ZmqError> {
73        Self::try_from(Bytes::copy_from_slice(data))
74    }
75}
76
77impl TryFrom<Vec<u8>> for PeerIdentity {
78    type Error = ZmqError;
79
80    fn try_from(data: Vec<u8>) -> Result<Self, ZmqError> {
81        Self::try_from(Bytes::from(data))
82    }
83}
84
85impl From<PeerIdentity> for Bytes {
86    fn from(p_id: PeerIdentity) -> Self {
87        p_id.0
88    }
89}
90
91impl From<PeerIdentity> for Vec<u8> {
92    fn from(p_id: PeerIdentity) -> Self {
93        p_id.0.to_vec()
94    }
95}
96
97pub(crate) struct Peer {
98    pub(crate) _identity: PeerIdentity,
99    pub(crate) send_queue: FramedWrite<Box<dyn FrameableWrite>, ZmqCodec>,
100    pub(crate) recv_queue: FramedRead<Box<dyn FrameableRead>, ZmqCodec>,
101}
102
103/// Given the result of the greetings exchange, determines the version of the
104/// ZMTP protocol that should be used for communication with the peer according
105/// to [ZeroMQ RFC 23](https://rfc.zeromq.org/spec/23/#version-negotiation).
106fn negotiate_version(greeting: Message) -> ZmqResult<ZmtpVersion> {
107    let my_version = ZmqGreeting::default().version;
108
109    match greeting {
110        Message::Greeting(peer) => {
111            if peer.version >= my_version {
112                // A peer MUST accept higher protocol versions as valid. That is,
113                // a ZMTP peer MUST accept protocol versions greater or equal to 3.0.
114                // This allows future implementations to safely interoperate with
115                // current implementations.
116                //
117                // A peer SHALL always use its own protocol (including framing)
118                // when talking to an equal or higher protocol peer.
119                Ok(my_version)
120            } else {
121                // A peer MAY downgrade its protocol to talk to a lower protocol peer.
122                //
123                // If a peer cannot downgrade its protocol to match its peer, it MUST
124                // close the connection.
125                // TODO: implement interoperability with older protocol versions
126                Err(ZmqError::UnsupportedVersion(peer.version))
127            }
128        }
129        _ => Err(ZmqError::Other("Failed Greeting exchange")),
130    }
131}
132
133pub(crate) async fn greet_exchange(raw_socket: &mut FramedIo) -> ZmqResult<ZmtpVersion> {
134    raw_socket
135        .write_half
136        .send(Message::Greeting(ZmqGreeting::default()))
137        .await?;
138
139    let greeting = match raw_socket.read_half.next().await {
140        Some(message) => message?,
141        None => return Err(ZmqError::Other("Failed Greeting exchange")),
142    };
143    negotiate_version(greeting)
144}
145
146pub(crate) async fn ready_exchange(
147    raw_socket: &mut FramedIo,
148    socket_type: SocketType,
149    props: Option<HashMap<String, Bytes>>,
150) -> ZmqResult<PeerIdentity> {
151    let mut ready = ZmqCommand::ready(socket_type);
152    if let Some(props) = props {
153        ready.add_properties(props);
154    }
155    raw_socket.write_half.send(Message::Command(ready)).await?;
156
157    let ready_repl: Option<CodecResult<Message>> = raw_socket.read_half.next().await;
158    match ready_repl {
159        Some(Ok(Message::Command(command))) => match command.name {
160            ZmqCommandName::READY => {
161                let other_sock_type = match command.properties.get("Socket-Type") {
162                    Some(s) => SocketType::try_from(&s[..])?,
163                    None => Err(ZmqError::Other("Failed to parse other socket type"))?,
164                };
165
166                let peer_id = command
167                    .properties
168                    .get("Identity")
169                    .map(|x| x.clone().try_into())
170                    .transpose()?
171                    .unwrap_or_default();
172
173                if socket_type.compatible(other_sock_type) {
174                    Ok(peer_id)
175                } else {
176                    Err(ZmqError::Other(
177                        "Provided sockets combination is not compatible",
178                    ))
179                }
180            }
181        },
182        Some(Ok(_)) => Err(ZmqError::Other("Failed to confirm ready state")),
183        Some(Err(e)) => Err(e.into()),
184        None => Err(ZmqError::Other("No reply from server")),
185    }
186}
187
188pub(crate) async fn peer_connected(
189    mut raw_socket: FramedIo,
190    backend: Arc<dyn MultiPeerBackend>,
191) -> ZmqResult<PeerIdentity> {
192    greet_exchange(&mut raw_socket).await?;
193    let mut props = None;
194    if let Some(identity) = &backend.socket_options().peer_id {
195        let mut connect_ops = HashMap::new();
196        connect_ops.insert("Identity".to_string(), identity.clone().into());
197        props = Some(connect_ops);
198    }
199    let peer_id = ready_exchange(&mut raw_socket, backend.socket_type(), props).await?;
200    backend.peer_connected(&peer_id, raw_socket).await;
201    Ok(peer_id)
202}
203
204pub(crate) async fn connect_forever(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> {
205    let mut try_num: u64 = 0;
206    loop {
207        match transport::connect(&endpoint).await {
208            Ok(res) => return Ok(res),
209            Err(ZmqError::Network(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
210                if try_num < 5 {
211                    try_num += 1;
212                }
213                let delay = {
214                    let mut rng = rand::rng();
215                    std::f64::consts::E.powf(try_num as f64 / 3.0)
216                        + rng.random_range(0.0f64..0.1f64)
217                };
218                async_rt::task::sleep(std::time::Duration::from_secs_f64(delay)).await;
219            }
220            Err(e) => return Err(e),
221        }
222    }
223}
224
225#[cfg(test)]
226pub(crate) mod tests {
227    use super::*;
228    use crate::codec::mechanism::ZmqMechanism;
229
230    pub async fn test_bind_to_unspecified_interface_helper(
231        any: std::net::IpAddr,
232        mut sock: impl Socket,
233        start_port: u16,
234    ) -> ZmqResult<()> {
235        assert!(sock.binds().is_empty());
236        assert!(any.is_unspecified());
237
238        for i in 0..4 {
239            sock.bind(
240                Endpoint::Tcp(any.into(), start_port + i)
241                    .to_string()
242                    .as_str(),
243            )
244            .await?;
245        }
246
247        let bound_to = sock.binds();
248        assert_eq!(bound_to.len(), 4);
249
250        let mut port_set = std::collections::HashSet::new();
251        for b in bound_to.keys() {
252            if let Endpoint::Tcp(host, port) = b {
253                assert_eq!(host, &any.into());
254                port_set.insert(*port);
255            } else {
256                unreachable!()
257            }
258        }
259
260        (start_port..start_port + 4).for_each(|p| assert!(port_set.contains(&p)));
261
262        Ok(())
263    }
264
265    pub async fn test_bind_to_any_port_helper(mut sock: impl Socket) -> ZmqResult<()> {
266        assert!(sock.binds().is_empty());
267        for _ in 0..4 {
268            sock.bind("tcp://localhost:0").await?;
269        }
270
271        let bound_to = sock.binds();
272        assert_eq!(bound_to.len(), 4);
273        let mut port_set = std::collections::HashSet::new();
274        for b in bound_to.keys() {
275            if let Endpoint::Tcp(host, port) = b {
276                assert_eq!(host, &Host::Domain("localhost".to_string()));
277                assert_ne!(*port, 0);
278                // Insert and check that it wasn't already present
279                assert!(port_set.insert(*port));
280            } else {
281                unreachable!()
282            }
283        }
284
285        Ok(())
286    }
287
288    fn new_greeting(version: ZmtpVersion) -> Message {
289        Message::Greeting(ZmqGreeting {
290            version,
291            mechanism: ZmqMechanism::PLAIN,
292            as_server: false,
293        })
294    }
295
296    #[test]
297    fn negotiate_version_peer_is_using_the_same_version() {
298        // if both peers are using the same protocol version, negotiation is trivial
299        let peer_version = ZmqGreeting::default().version;
300        let expected = ZmqGreeting::default().version;
301        let actual = negotiate_version(new_greeting(peer_version)).unwrap();
302        assert_eq!(actual, expected);
303    }
304
305    #[test]
306    fn negotiate_version_peer_is_using_a_newer_version() {
307        // if the other end is using a newer protocol version, they should adjust to us
308        let peer_version = (3, 1);
309        let expected = ZmqGreeting::default().version;
310        let actual = negotiate_version(new_greeting(peer_version)).unwrap();
311        assert_eq!(actual, expected);
312    }
313
314    #[test]
315    fn negotiate_version_peer_is_using_an_older_version() {
316        // if the other end is using an older protocol version, we should adjust to
317        // them, but interoperability with older peers is not implemented at the
318        // moment, so we just give up immediately, which is allowed by the spec
319        let peer_version = (2, 1);
320        let actual = negotiate_version(new_greeting(peer_version));
321        match actual {
322            Err(ZmqError::UnsupportedVersion(version)) => assert_eq!(version, peer_version),
323            _ => panic!("Unexpected result"),
324        }
325    }
326
327    #[test]
328    fn negotiate_version_invalid_greeting() {
329        // unexpected message during greetings exchange
330        let message = Message::Message(ZmqMessage::from(""));
331        let actual = negotiate_version(message);
332        match actual {
333            Err(ZmqError::Other(_)) => {}
334            _ => panic!("Unexpected result"),
335        }
336    }
337}