mezzenger_udp/
lib.rs

1//! Transport for communication over [tokio](https://tokio.rs/)
2//! UDP implementation.
3//!
4//! **NOTE**: This transport inherits UDP properties:
5//! - it is **unreliable** - messages are not guaranteed to reach destination,
6//! - it is **unordered** - messages may arrive at destination out of order, also they
7//! may be duplicated (the same message may arrive at destination twice or more times).
8//! - message size is limited to datagram size - sending may result in error if encoded
9//! message is too large.
10//!
11//! See [repository](https://github.com/zduny/mezzenger) for more info.
12//!
13//! ## Example
14//!
15//! ```ignore
16//! let udp_socket = UdpSocket::bind("127.0.0.1:8080").await?;
17//! udp_socket.connect(remote_address).await?;
18//!
19//! use kodec::binary::Codec;
20//! let mut transport: Transport<_, Codec, i32, String> =
21//!     Transport::new(udp_socket, Codec::default());
22//!
23//! use mezzenger::Receive;
24//! let integer = transport.receive().await?;
25//!
26//! transport.send("Hello World!".to_string()).await?;
27//! ```
28
29use futures::{future::poll_fn, stream::FusedStream, Sink, SinkExt, Stream};
30use kodec::{Decode, Encode};
31use pin_project::pin_project;
32use serde::Serialize;
33use std::{
34    borrow::Borrow,
35    collections::VecDeque,
36    fmt::{Debug, Display},
37    io::ErrorKind,
38    marker::PhantomData,
39    net::SocketAddr,
40    pin::Pin,
41    task::{Context, Poll},
42};
43use tokio::{
44    io::ReadBuf,
45    net::{ToSocketAddrs, UdpSocket},
46};
47
48#[derive(Debug)]
49pub enum Error<SerializationError, DeserializationError> {
50    SendingError,
51    SerializationError(SerializationError),
52    DeserializationError(DeserializationError),
53    IoError(tokio::io::Error),
54}
55
56impl<SerializationError, DeserializationError> Display
57    for Error<SerializationError, DeserializationError>
58where
59    SerializationError: Display,
60    DeserializationError: Display,
61{
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        match self {
64            Error::SendingError => write!(f, "not all bytes were sent"),
65            Error::SerializationError(error) => write!(f, "failed to serialize message: {error}"),
66            Error::DeserializationError(error) => {
67                write!(f, "failed to deserialize message: {error}")
68            }
69            Error::IoError(error) => write!(f, "IO error occurred: {error}"),
70        }
71    }
72}
73
74impl<SerializationError, DeserializationError> std::error::Error
75    for Error<SerializationError, DeserializationError>
76where
77    SerializationError: Debug + Display,
78    DeserializationError: Debug + Display,
79{
80}
81
82/// Transport over [tokio](https://tokio.rs/)'s UDP implementation.
83///
84/// Wraps over [tokio::net::UdpSocket].
85///
86/// **NOTE**: This transport inherits UDP properties:
87/// - it is **unreliable** - messages are NOT guaranteed to reach destination,
88/// - it is **unordered** - messages may arrive at destination out of order, also they
89/// may be duplicated (the same message may arrive at destination twice or more times).
90/// - message size is limited to datagram size - sending may result in error if encoded
91/// message is too large.
92#[pin_project]
93pub struct Transport<U, Codec, Incoming, Outgoing>
94where
95    U: Borrow<UdpSocket>,
96    Codec: kodec::Codec,
97    for<'de> Incoming: serde::de::Deserialize<'de>,
98    Outgoing: Serialize,
99{
100    udp_socket: Option<U>,
101    codec: Codec,
102    send_queue: VecDeque<Outgoing>,
103    send_buffer: Vec<u8>,
104    message_pending: bool,
105    receive_buffer: Vec<u8>,
106    _incoming: PhantomData<Incoming>,
107}
108
109impl<U, Codec, Incoming, Outgoing> Transport<U, Codec, Incoming, Outgoing>
110where
111    U: Borrow<UdpSocket>,
112    Codec: kodec::Codec,
113    for<'de> Incoming: serde::de::Deserialize<'de>,
114    Outgoing: Serialize,
115{
116    /// Create new transport wrapping a provided `[tokio::net::UdpSocket]`.
117    pub fn new(udp_socket: U, codec: Codec) -> Self {
118        Transport {
119            udp_socket: Some(udp_socket),
120            codec,
121            send_queue: VecDeque::new(),
122            send_buffer: vec![],
123            message_pending: false,
124            receive_buffer: vec![0; 65536],
125            _incoming: PhantomData,
126        }
127    }
128
129    /// Send message to address.
130    pub async fn send_to<A: ToSocketAddrs>(
131        &mut self,
132        message: Outgoing,
133        target: A,
134    ) -> Result<(), mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>>
135    {
136        self.flush().await?;
137        if let Some(udp_socket) = &self.udp_socket {
138            self.codec
139                .encode(&mut self.send_buffer, &message)
140                .map_err(
141                    Error::<<Codec as Encode>::Error, <Codec as Decode>::Error>::SerializationError,
142                )
143                .map_err(mezzenger::Error::Other)?;
144            udp_socket
145                .borrow()
146                .send_to(&self.send_buffer, target)
147                .await
148                .map_err(Error::<<Codec as Encode>::Error, <Codec as Decode>::Error>::IoError)
149                .map_err(mezzenger::Error::Other)?;
150            self.send_buffer.clear();
151            Ok(())
152        } else {
153            Err(mezzenger::Error::Closed)
154        }
155    }
156
157    /// Receive single message.
158    ///
159    /// Returns a pair of incoming message and its origin address.
160    pub async fn receive_from(
161        &mut self,
162    ) -> Result<
163        (Incoming, SocketAddr),
164        mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>,
165    > {
166        if self.udp_socket.is_some() {
167            let result = poll_fn(|cx| self.poll_recv_from(cx)).await;
168            if let Some(result) = result {
169                result.map_err(mezzenger::Error::Other)
170            } else {
171                Err(mezzenger::Error::Closed)
172            }
173        } else {
174            Err(mezzenger::Error::Closed)
175        }
176    }
177
178    #[allow(clippy::type_complexity)]
179    fn poll_recv_from(
180        &mut self,
181        cx: &mut Context<'_>,
182    ) -> Poll<
183        Option<
184            Result<
185                (Incoming, SocketAddr),
186                Error<<Codec as Encode>::Error, <Codec as Decode>::Error>,
187            >,
188        >,
189    > {
190        if let Some(udp_socket) = &self.udp_socket {
191            let mut buf = ReadBuf::new(&mut self.receive_buffer);
192            match udp_socket.borrow().poll_recv_from(cx, &mut buf) {
193                Poll::Ready(result) => match result {
194                    Ok(address) => {
195                        let result: Result<Incoming, _> = self.codec.decode(buf.filled());
196                        match result {
197                            Ok(message) => Poll::Ready(Some(Ok((message, address)))),
198                            Err(error) => {
199                                Poll::Ready(Some(Err(Error::DeserializationError(error))))
200                            }
201                        }
202                    }
203                    Err(error) => match error.kind() {
204                        ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => {
205                            self.udp_socket = None;
206                            Poll::Ready(None)
207                        }
208                        _ => Poll::Ready(Some(Err(Error::IoError(error)))),
209                    },
210                },
211                Poll::Pending => Poll::Pending,
212            }
213        } else {
214            Poll::Ready(None)
215        }
216    }
217}
218
219impl<U, Codec, Incoming, Outgoing> Sink<Outgoing> for Transport<U, Codec, Incoming, Outgoing>
220where
221    U: Borrow<UdpSocket>,
222    Codec: kodec::Codec,
223    for<'de> Incoming: serde::de::Deserialize<'de>,
224    Outgoing: Serialize,
225{
226    type Error = mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
227
228    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
229        Poll::Ready(Ok(()))
230    }
231
232    fn start_send(mut self: Pin<&mut Self>, item: Outgoing) -> Result<(), Self::Error> {
233        if self.udp_socket.is_some() {
234            self.send_queue.push_back(item);
235            Ok(())
236        } else {
237            Err(mezzenger::Error::Closed)
238        }
239    }
240
241    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
242        let me = self.project();
243        if me.send_queue.is_empty() && !*me.message_pending {
244            return Poll::Ready(Ok(()));
245        }
246        if let Some(udp_socket) = &me.udp_socket {
247            loop {
248                if *me.message_pending {
249                    let bytes_to_send = me.send_buffer.len();
250                    let result = udp_socket.borrow().poll_send(cx, me.send_buffer);
251                    match result {
252                        Poll::Ready(result) => {
253                            *me.message_pending = false;
254                            me.send_buffer.clear();
255                            match result {
256                                Ok(bytes_written) => {
257                                    if bytes_written != bytes_to_send {
258                                        return Poll::Ready(Err(mezzenger::Error::Other(
259                                            Error::SendingError,
260                                        )));
261                                    }
262                                }
263                                Err(error) => match error.kind() {
264                                    ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => {
265                                        *me.udp_socket = None;
266                                        return Poll::Ready(Err(mezzenger::Error::Closed));
267                                    }
268                                    _ => {
269                                        return Poll::Ready(Err(mezzenger::Error::Other(
270                                            Error::IoError(error),
271                                        )))
272                                    }
273                                },
274                            }
275                        }
276                        Poll::Pending => return Poll::Pending,
277                    }
278                } else if let Some(message) = me.send_queue.pop_front() {
279                    let result = me.codec.encode(&mut *me.send_buffer, &message);
280                    if let Err(error) = result {
281                        me.send_buffer.clear();
282                        return Poll::Ready(Err(mezzenger::Error::Other(
283                            Error::SerializationError(error),
284                        )));
285                    } else {
286                        *me.message_pending = true;
287                    }
288                } else {
289                    return Poll::Ready(Ok(()));
290                }
291            }
292        } else {
293            Poll::Ready(Err(mezzenger::Error::Closed))
294        }
295    }
296
297    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
298        match self.poll_flush_unpin(cx) {
299            Poll::Ready(_) => {
300                self.udp_socket = None;
301                Poll::Ready(Ok(()))
302            }
303            Poll::Pending => Poll::Pending,
304        }
305    }
306}
307
308impl<U, Codec, Incoming, Outgoing> Stream for Transport<U, Codec, Incoming, Outgoing>
309where
310    U: Borrow<UdpSocket>,
311    Codec: kodec::Codec,
312    for<'de> Incoming: serde::de::Deserialize<'de>,
313    Outgoing: Serialize,
314{
315    type Item = Result<Incoming, Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
316
317    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
318        match self.poll_recv_from(cx) {
319            Poll::Ready(result) => {
320                let result = result.map(|result| result.map(|(incoming, _)| incoming));
321                Poll::Ready(result)
322            }
323            Poll::Pending => Poll::Pending,
324        }
325    }
326}
327
328impl<U, Codec, Incoming, Outgoing> FusedStream for Transport<U, Codec, Incoming, Outgoing>
329where
330    U: Borrow<UdpSocket>,
331    Codec: kodec::Codec,
332    for<'de> Incoming: serde::de::Deserialize<'de>,
333    Outgoing: Serialize,
334{
335    fn is_terminated(&self) -> bool {
336        self.udp_socket.is_none()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use futures::SinkExt;
343    use kodec::binary::Codec;
344    use mezzenger::Receive;
345    use tokio::net::UdpSocket;
346
347    use crate::Transport;
348
349    #[tokio::test]
350    async fn test_transport() {
351        let left = UdpSocket::bind("127.0.0.1:8080").await.unwrap();
352        let right = UdpSocket::bind("127.0.0.1:8081").await.unwrap();
353
354        left.connect(right.local_addr().unwrap()).await.unwrap();
355        right.connect(left.local_addr().unwrap()).await.unwrap();
356
357        let mut left: Transport<UdpSocket, Codec, u32, String> =
358            Transport::new(left, Codec::default());
359        let mut right: Transport<UdpSocket, Codec, String, u32> =
360            Transport::new(right, Codec::default());
361
362        left.send("Hello World!".to_string()).await.unwrap();
363        left.send("Hello World again!".to_string()).await.unwrap();
364        right.send(128).await.unwrap();
365        right.send(1).await.unwrap();
366
367        assert_eq!(right.receive().await.unwrap(), "Hello World!");
368        assert_eq!(right.receive().await.unwrap(), "Hello World again!");
369        assert_eq!(left.receive().await.unwrap(), 128);
370        assert_eq!(left.receive().await.unwrap(), 1);
371    }
372
373    #[tokio::test]
374    async fn test_unit_message() {
375        let left = UdpSocket::bind("127.0.0.1:8082").await.unwrap();
376        let right = UdpSocket::bind("127.0.0.1:8083").await.unwrap();
377
378        left.connect(right.local_addr().unwrap()).await.unwrap();
379        right.connect(left.local_addr().unwrap()).await.unwrap();
380
381        let mut left: Transport<UdpSocket, Codec, (), ()> = Transport::new(left, Codec::default());
382        let mut right: Transport<UdpSocket, Codec, (), ()> =
383            Transport::new(right, Codec::default());
384
385        left.send(()).await.unwrap();
386        left.send(()).await.unwrap();
387        left.send(()).await.unwrap();
388        left.send(()).await.unwrap();
389        right.send(()).await.unwrap();
390        right.send(()).await.unwrap();
391        right.send(()).await.unwrap();
392
393        assert_eq!(right.receive().await.unwrap(), ());
394        assert_eq!(right.receive().await.unwrap(), ());
395        assert_eq!(right.receive().await.unwrap(), ());
396        assert_eq!(right.receive().await.unwrap(), ());
397        assert_eq!(left.receive().await.unwrap(), ());
398        assert_eq!(left.receive().await.unwrap(), ());
399        assert_eq!(left.receive().await.unwrap(), ());
400    }
401}