1use crate::codec::{CodecResult, FramedIo};
2use crate::*;
3
4use asynchronous_codec::FramedRead;
5use bytes::Bytes;
6use futures::{SinkExt, StreamExt};
7use rand::Rng;
8
9use std::convert::{TryFrom, TryInto};
10use std::ops::Deref;
11use std::str::FromStr;
12use std::sync::Arc;
13use uuid::Uuid;
14
15#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Hash, Clone)]
16pub struct PeerIdentity(Bytes);
17
18impl PeerIdentity {
19 pub const MAX_LENGTH: usize = 255;
20
21 pub fn new() -> Self {
22 let id = Uuid::new_v4();
23 Self(Bytes::copy_from_slice(id.as_bytes()))
24 }
25}
26
27impl AsRef<[u8]> for PeerIdentity {
28 fn as_ref(&self) -> &[u8] {
29 self.0.as_ref()
30 }
31}
32
33impl Deref for PeerIdentity {
34 type Target = [u8];
35
36 fn deref(&self) -> &[u8] {
37 self.0.as_ref()
38 }
39}
40
41impl Default for PeerIdentity {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl FromStr for PeerIdentity {
48 type Err = ZmqError;
49
50 fn from_str(s: &str) -> Result<Self, Self::Err> {
51 Self::try_from(s.as_bytes())
52 }
53}
54
55impl TryFrom<Bytes> for PeerIdentity {
56 type Error = ZmqError;
57
58 fn try_from(data: Bytes) -> Result<Self, ZmqError> {
59 if data.is_empty() {
60 Ok(Self::new())
61 } else if data.len() > Self::MAX_LENGTH {
62 Err(ZmqError::PeerIdentity)
63 } else {
64 Ok(Self(data))
65 }
66 }
67}
68
69impl TryFrom<&[u8]> for PeerIdentity {
70 type Error = ZmqError;
71
72 fn try_from(data: &[u8]) -> Result<Self, ZmqError> {
73 Self::try_from(Bytes::copy_from_slice(data))
74 }
75}
76
77impl TryFrom<Vec<u8>> for PeerIdentity {
78 type Error = ZmqError;
79
80 fn try_from(data: Vec<u8>) -> Result<Self, ZmqError> {
81 Self::try_from(Bytes::from(data))
82 }
83}
84
85impl From<PeerIdentity> for Bytes {
86 fn from(p_id: PeerIdentity) -> Self {
87 p_id.0
88 }
89}
90
91impl From<PeerIdentity> for Vec<u8> {
92 fn from(p_id: PeerIdentity) -> Self {
93 p_id.0.to_vec()
94 }
95}
96
97pub(crate) struct Peer {
98 pub(crate) _identity: PeerIdentity,
99 pub(crate) send_queue: FramedWrite<Box<dyn FrameableWrite>, ZmqCodec>,
100 pub(crate) recv_queue: FramedRead<Box<dyn FrameableRead>, ZmqCodec>,
101}
102
103fn negotiate_version(greeting: Message) -> ZmqResult<ZmtpVersion> {
107 let my_version = ZmqGreeting::default().version;
108
109 match greeting {
110 Message::Greeting(peer) => {
111 if peer.version >= my_version {
112 Ok(my_version)
120 } else {
121 Err(ZmqError::UnsupportedVersion(peer.version))
127 }
128 }
129 _ => Err(ZmqError::Other("Failed Greeting exchange")),
130 }
131}
132
133pub(crate) async fn greet_exchange(raw_socket: &mut FramedIo) -> ZmqResult<ZmtpVersion> {
134 raw_socket
135 .write_half
136 .send(Message::Greeting(ZmqGreeting::default()))
137 .await?;
138
139 let greeting = match raw_socket.read_half.next().await {
140 Some(message) => message?,
141 None => return Err(ZmqError::Other("Failed Greeting exchange")),
142 };
143 negotiate_version(greeting)
144}
145
146pub(crate) async fn ready_exchange(
147 raw_socket: &mut FramedIo,
148 socket_type: SocketType,
149 props: Option<HashMap<String, Bytes>>,
150) -> ZmqResult<PeerIdentity> {
151 let mut ready = ZmqCommand::ready(socket_type);
152 if let Some(props) = props {
153 ready.add_properties(props);
154 }
155 raw_socket.write_half.send(Message::Command(ready)).await?;
156
157 let ready_repl: Option<CodecResult<Message>> = raw_socket.read_half.next().await;
158 match ready_repl {
159 Some(Ok(Message::Command(command))) => match command.name {
160 ZmqCommandName::READY => {
161 let other_sock_type = match command.properties.get("Socket-Type") {
162 Some(s) => SocketType::try_from(&s[..])?,
163 None => Err(ZmqError::Other("Failed to parse other socket type"))?,
164 };
165
166 let peer_id = command
167 .properties
168 .get("Identity")
169 .map(|x| x.clone().try_into())
170 .transpose()?
171 .unwrap_or_default();
172
173 if socket_type.compatible(other_sock_type) {
174 Ok(peer_id)
175 } else {
176 Err(ZmqError::Other(
177 "Provided sockets combination is not compatible",
178 ))
179 }
180 }
181 },
182 Some(Ok(_)) => Err(ZmqError::Other("Failed to confirm ready state")),
183 Some(Err(e)) => Err(e.into()),
184 None => Err(ZmqError::Other("No reply from server")),
185 }
186}
187
188pub(crate) async fn peer_connected(
189 mut raw_socket: FramedIo,
190 backend: Arc<dyn MultiPeerBackend>,
191) -> ZmqResult<PeerIdentity> {
192 greet_exchange(&mut raw_socket).await?;
193 let mut props = None;
194 if let Some(identity) = &backend.socket_options().peer_id {
195 let mut connect_ops = HashMap::new();
196 connect_ops.insert("Identity".to_string(), identity.clone().into());
197 props = Some(connect_ops);
198 }
199 let peer_id = ready_exchange(&mut raw_socket, backend.socket_type(), props).await?;
200 backend.peer_connected(&peer_id, raw_socket).await;
201 Ok(peer_id)
202}
203
204pub(crate) async fn connect_forever(endpoint: Endpoint) -> ZmqResult<(FramedIo, Endpoint)> {
205 let mut try_num: u64 = 0;
206 loop {
207 match transport::connect(&endpoint).await {
208 Ok(res) => return Ok(res),
209 Err(ZmqError::Network(e)) if e.kind() == std::io::ErrorKind::ConnectionRefused => {
210 if try_num < 5 {
211 try_num += 1;
212 }
213 let delay = {
214 let mut rng = rand::rng();
215 std::f64::consts::E.powf(try_num as f64 / 3.0)
216 + rng.random_range(0.0f64..0.1f64)
217 };
218 async_rt::task::sleep(std::time::Duration::from_secs_f64(delay)).await;
219 }
220 Err(e) => return Err(e),
221 }
222 }
223}
224
225#[cfg(test)]
226pub(crate) mod tests {
227 use super::*;
228 use crate::codec::mechanism::ZmqMechanism;
229
230 pub async fn test_bind_to_unspecified_interface_helper(
231 any: std::net::IpAddr,
232 mut sock: impl Socket,
233 start_port: u16,
234 ) -> ZmqResult<()> {
235 assert!(sock.binds().is_empty());
236 assert!(any.is_unspecified());
237
238 for i in 0..4 {
239 sock.bind(
240 Endpoint::Tcp(any.into(), start_port + i)
241 .to_string()
242 .as_str(),
243 )
244 .await?;
245 }
246
247 let bound_to = sock.binds();
248 assert_eq!(bound_to.len(), 4);
249
250 let mut port_set = std::collections::HashSet::new();
251 for b in bound_to.keys() {
252 if let Endpoint::Tcp(host, port) = b {
253 assert_eq!(host, &any.into());
254 port_set.insert(*port);
255 } else {
256 unreachable!()
257 }
258 }
259
260 (start_port..start_port + 4).for_each(|p| assert!(port_set.contains(&p)));
261
262 Ok(())
263 }
264
265 pub async fn test_bind_to_any_port_helper(mut sock: impl Socket) -> ZmqResult<()> {
266 assert!(sock.binds().is_empty());
267 for _ in 0..4 {
268 sock.bind("tcp://localhost:0").await?;
269 }
270
271 let bound_to = sock.binds();
272 assert_eq!(bound_to.len(), 4);
273 let mut port_set = std::collections::HashSet::new();
274 for b in bound_to.keys() {
275 if let Endpoint::Tcp(host, port) = b {
276 assert_eq!(host, &Host::Domain("localhost".to_string()));
277 assert_ne!(*port, 0);
278 assert!(port_set.insert(*port));
280 } else {
281 unreachable!()
282 }
283 }
284
285 Ok(())
286 }
287
288 fn new_greeting(version: ZmtpVersion) -> Message {
289 Message::Greeting(ZmqGreeting {
290 version,
291 mechanism: ZmqMechanism::PLAIN,
292 as_server: false,
293 })
294 }
295
296 #[test]
297 fn negotiate_version_peer_is_using_the_same_version() {
298 let peer_version = ZmqGreeting::default().version;
300 let expected = ZmqGreeting::default().version;
301 let actual = negotiate_version(new_greeting(peer_version)).unwrap();
302 assert_eq!(actual, expected);
303 }
304
305 #[test]
306 fn negotiate_version_peer_is_using_a_newer_version() {
307 let peer_version = (3, 1);
309 let expected = ZmqGreeting::default().version;
310 let actual = negotiate_version(new_greeting(peer_version)).unwrap();
311 assert_eq!(actual, expected);
312 }
313
314 #[test]
315 fn negotiate_version_peer_is_using_an_older_version() {
316 let peer_version = (2, 1);
320 let actual = negotiate_version(new_greeting(peer_version));
321 match actual {
322 Err(ZmqError::UnsupportedVersion(version)) => assert_eq!(version, peer_version),
323 _ => panic!("Unexpected result"),
324 }
325 }
326
327 #[test]
328 fn negotiate_version_invalid_greeting() {
329 let message = Message::Message(ZmqMessage::from(""));
331 let actual = negotiate_version(message);
332 match actual {
333 Err(ZmqError::Other(_)) => {}
334 _ => panic!("Unexpected result"),
335 }
336 }
337}