mezzenger_tcp/
lib.rs

1//! Transport for communication over [tokio](https://tokio.rs/)
2//! TCP implementation.
3//!
4//! See [repository](https://github.com/zduny/mezzenger) for more info.
5//!
6//! ## Example
7//!
8//! ```ignore
9//! let tcp_stream = TcpStream::connect("127.0.0.1:8080").await?;
10//!
11//! use kodec::binary::Codec;
12//! let mut transport: Transport<_, Codec, i32, String> =
13//!     Transport::new(tcp_stream, Codec::default());
14//!
15//! use mezzenger::Receive;
16//! let integer = transport.receive().await?;
17//!
18//! transport.send("Hello World!".to_string()).await?;
19//! ```
20
21use std::{
22    fmt::{Debug, Display},
23    io::ErrorKind,
24    marker::PhantomData,
25    pin::Pin,
26    task::{Context, Poll},
27};
28
29use bytes::{Buf, BufMut, BytesMut};
30use futures::{ready, stream::FusedStream, Sink, Stream};
31use kodec::{Decode, Encode};
32use pin_project::pin_project;
33use serde::Serialize;
34use tokio::io::{AsyncRead, AsyncWrite};
35use tokio_util::io::{poll_read_buf, poll_write_buf};
36
37pub const DEFAULT_MAX_MESSAGE_SIZE: u32 = 65536;
38
39#[derive(Debug)]
40pub enum Error<SerializationError, DeserializationError> {
41    MessageTooLarge,
42    SerializationError(SerializationError),
43    DeserializationError(DeserializationError),
44    IoError(std::io::Error),
45}
46
47impl<SerializationError, DeserializationError> Display
48    for Error<SerializationError, DeserializationError>
49where
50    SerializationError: Display,
51    DeserializationError: Display,
52{
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            Error::MessageTooLarge => write!(f, "message was too large"),
56            Error::SerializationError(error) => write!(f, "failed to serialize message: {error}"),
57            Error::DeserializationError(error) => {
58                write!(f, "failed to deserialize message: {error}")
59            }
60            Error::IoError(error) => write!(f, "IO error occurred: {error}"),
61        }
62    }
63}
64
65impl<SerializationError, DeserializationError> std::error::Error
66    for Error<SerializationError, DeserializationError>
67where
68    SerializationError: Debug + Display,
69    DeserializationError: Debug + Display,
70{
71}
72
73struct ReceiveState {
74    pub buffer: BytesMut,
75    pub message_size: u32,
76    pub receiving_size: bool,
77    pub bytes_to_receive: i64,
78    pub bytes_to_skip: u32,
79}
80
81impl ReceiveState {
82    fn new() -> Self {
83        ReceiveState {
84            buffer: BytesMut::new(),
85            message_size: 0,
86            receiving_size: true,
87            bytes_to_receive: 4,
88            bytes_to_skip: 0,
89        }
90    }
91}
92
93/// Transport for communication over [tokio](https://tokio.rs/)'s TCP implementation.
94///
95/// Wraps over struct implementing [tokio::io::AsyncWrite] and [tokio::io::AsyncRead].
96#[pin_project]
97pub struct Transport<T, Codec, Incoming, Outgoing>
98where
99    T: AsyncWrite + AsyncRead,
100    Codec: kodec::Codec,
101    for<'de> Incoming: serde::de::Deserialize<'de>,
102    Outgoing: Serialize,
103{
104    #[pin]
105    inner: T,
106    send_buffer: BytesMut,
107    receive_state: ReceiveState,
108    codec: Codec,
109    terminated: bool,
110    max_message_size: u32,
111    _incoming: PhantomData<Incoming>,
112    _outgoing: PhantomData<Outgoing>,
113}
114
115impl<T, Codec, Incoming, Outgoing> Transport<T, Codec, Incoming, Outgoing>
116where
117    T: AsyncWrite + AsyncRead,
118    Codec: kodec::Codec,
119    for<'de> Incoming: serde::de::Deserialize<'de>,
120    Outgoing: Serialize,
121{
122    /// Create new transport wrapping a provided struct implementing
123    /// [tokio::io::AsyncWrite] and [tokio::io::AsyncRead].
124    ///
125    /// **NOTE**: By default serialized message size is limited to [DEFAULT_MAX_MESSAGE_SIZE].<br>
126    /// Sending or receiving messages of larger size will result in [Error::MessageTooLarge].
127    pub fn new(transport: T, codec: Codec) -> Self {
128        Transport::new_with_max_message_size(transport, codec, DEFAULT_MAX_MESSAGE_SIZE)
129    }
130
131    /// Create new transport wrapping a provided struct implementing
132    /// [tokio::io::AsyncWrite] and [tokio::io::AsyncRead].
133    ///
134    /// Serialized message size will be limited to `max_message_size`.<br>
135    /// Sending or receiving messages of larger size will result in [Error::MessageTooLarge].
136    pub fn new_with_max_message_size(transport: T, codec: Codec, max_message_size: u32) -> Self {
137        Transport {
138            inner: transport,
139            codec,
140            send_buffer: BytesMut::new(),
141            receive_state: ReceiveState::new(),
142            terminated: false,
143            max_message_size,
144            _incoming: PhantomData,
145            _outgoing: PhantomData,
146        }
147    }
148}
149
150impl<T, Codec, Incoming, Outgoing> Sink<Outgoing> for Transport<T, Codec, Incoming, Outgoing>
151where
152    T: AsyncWrite + AsyncRead,
153    Codec: kodec::Codec,
154    for<'de> Incoming: serde::de::Deserialize<'de>,
155    Outgoing: Serialize,
156{
157    type Error = mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
158
159    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160        Poll::Ready(Ok(()))
161    }
162
163    fn start_send(self: Pin<&mut Self>, item: Outgoing) -> Result<(), Self::Error> {
164        if self.terminated {
165            Err(mezzenger::Error::Closed)
166        } else {
167            let me = self.project();
168            let size_position = me.send_buffer.len();
169            me.send_buffer.put_u32(0);
170            let current_length = me.send_buffer.len();
171            me.codec
172                .encode(me.send_buffer.writer(), &item)
173                .map_err(Error::SerializationError)
174                .map_err(mezzenger::Error::Other)?;
175            let message_size = me.send_buffer.len() - current_length;
176            if message_size > *me.max_message_size as usize {
177                me.send_buffer.truncate(size_position);
178                Err(mezzenger::Error::Other(Error::MessageTooLarge))
179            } else {
180                let size_slice = &mut me.send_buffer[size_position..(size_position + 4)];
181                let message_size = message_size as u32;
182                size_slice.swap_with_slice(&mut message_size.to_be_bytes());
183                Ok(())
184            }
185        }
186    }
187
188    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189        let mut me = self.project();
190
191        let result = if me.send_buffer.is_empty() {
192            ready!(me.inner.poll_flush(cx))
193        } else {
194            ready!(poll_write_buf(me.inner.as_mut(), cx, me.send_buffer)).map(|_| ())
195        }
196        .map_err(|error| match error.kind() {
197            ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => mezzenger::Error::Closed,
198            _ => mezzenger::Error::Other(Error::IoError(error)),
199        });
200
201        Poll::Ready(result)
202    }
203
204    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
205        let me = self.project();
206        let result = ready!(me.inner.poll_shutdown(cx)).map_err(|error| match error.kind() {
207            ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => mezzenger::Error::Closed,
208            _ => mezzenger::Error::Other(Error::IoError(error)),
209        });
210        Poll::Ready(result)
211    }
212}
213
214impl<T, Codec, Incoming, Outgoing> Stream for Transport<T, Codec, Incoming, Outgoing>
215where
216    T: AsyncWrite + AsyncRead,
217    Codec: kodec::Codec,
218    for<'de> Incoming: serde::de::Deserialize<'de>,
219    Outgoing: Serialize,
220{
221    type Item = Result<Incoming, Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
222
223    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224        if self.terminated {
225            return Poll::Ready(None);
226        }
227
228        let mut me = self.project();
229        loop {
230            if me.receive_state.bytes_to_receive <= 0 {
231                if me.receive_state.receiving_size {
232                    let message_size = me.receive_state.buffer.get_u32();
233                    me.receive_state.message_size = message_size;
234                    me.receive_state.bytes_to_receive += message_size as i64;
235                    if message_size > *me.max_message_size {
236                        me.receive_state.bytes_to_receive += 4;
237                        if me.receive_state.bytes_to_receive > 0 {
238                            me.receive_state.bytes_to_skip = message_size;
239                        } else {
240                            me.receive_state.buffer.advance(message_size as usize);
241                        }
242                        return Poll::Ready(Some(Err(Error::MessageTooLarge)));
243                    } else {
244                        me.receive_state.receiving_size = false;
245                    }
246                } else {
247                    let message_size = me.receive_state.message_size as usize;
248                    let message = &me.receive_state.buffer[0..message_size];
249                    let result: Result<Incoming, _> = me.codec.decode(message);
250                    me.receive_state.buffer.advance(message_size);
251                    me.receive_state.receiving_size = true;
252                    me.receive_state.bytes_to_receive += 4;
253                    return {
254                        match result {
255                            Ok(message) => Poll::Ready(Some(Ok(message))),
256                            Err(error) => {
257                                Poll::Ready(Some(Err(Error::DeserializationError(error))))
258                            }
259                        }
260                    };
261                }
262            } else {
263                let result = ready!(poll_read_buf(
264                    me.inner.as_mut(),
265                    cx,
266                    &mut me.receive_state.buffer
267                ));
268                match result {
269                    Ok(bytes_read) => {
270                        if bytes_read == 0 {
271                            *me.terminated = true;
272                            return Poll::Ready(None);
273                        }
274                        me.receive_state.bytes_to_receive = me
275                            .receive_state
276                            .bytes_to_receive
277                            .saturating_sub_unsigned(bytes_read as u64);
278                        if me.receive_state.bytes_to_skip > 0 {
279                            let buffer_len = me.receive_state.buffer.len();
280                            let skipped = buffer_len.min(me.receive_state.bytes_to_skip as usize);
281                            if skipped == buffer_len {
282                                me.receive_state.buffer.clear();
283                            } else {
284                                me.receive_state.buffer.advance(skipped);
285                            }
286                            me.receive_state.bytes_to_skip = me
287                                .receive_state
288                                .bytes_to_skip
289                                .saturating_sub(skipped as u32);
290                        }
291                    }
292                    Err(error) => match error.kind() {
293                        ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => {
294                            *me.terminated = true;
295                            return Poll::Ready(None);
296                        }
297                        _ => return Poll::Ready(Some(Err(Error::IoError(error)))),
298                    },
299                }
300            }
301        }
302    }
303}
304
305impl<T, Codec, Incoming, Outgoing> FusedStream for Transport<T, Codec, Incoming, Outgoing>
306where
307    T: AsyncWrite + AsyncRead,
308    Codec: kodec::Codec,
309    for<'de> Incoming: serde::de::Deserialize<'de>,
310    Outgoing: Serialize,
311{
312    fn is_terminated(&self) -> bool {
313        self.terminated
314    }
315}
316
317impl<T, Codec, Incoming, Outgoing> mezzenger::Reliable for Transport<T, Codec, Incoming, Outgoing>
318where
319    T: AsyncWrite + AsyncRead,
320    Codec: kodec::Codec,
321    for<'de> Incoming: serde::de::Deserialize<'de>,
322    Outgoing: Serialize,
323{
324}
325
326impl<T, Codec, Incoming, Outgoing> mezzenger::Order for Transport<T, Codec, Incoming, Outgoing>
327where
328    T: AsyncWrite + AsyncRead,
329    Codec: kodec::Codec,
330    for<'de> Incoming: serde::de::Deserialize<'de>,
331    Outgoing: Serialize,
332{
333}
334
335#[cfg(test)]
336mod tests {
337    use futures::{stream, SinkExt, StreamExt};
338    use kodec::binary::Codec;
339    use mezzenger::{Messages, Receive};
340    use tokio::net::{TcpListener, TcpStream};
341
342    use crate::{Error, Transport};
343
344    #[tokio::test]
345    async fn test_transport() {
346        let left = TcpListener::bind("127.0.0.1:8080").await.unwrap();
347        let right = TcpStream::connect("127.0.0.1:8080").await.unwrap();
348
349        let (left, _) = left.accept().await.unwrap();
350
351        let mut left: Transport<TcpStream, Codec, u32, String> =
352            Transport::new(left, Codec::default());
353        let mut right: Transport<TcpStream, Codec, String, u32> =
354            Transport::new(right, Codec::default());
355
356        left.send("Hello World!".to_string()).await.unwrap();
357        left.send("Hello World again!".to_string()).await.unwrap();
358        right.send(128).await.unwrap();
359        right.send(1).await.unwrap();
360
361        assert_eq!(right.receive().await.unwrap(), "Hello World!");
362        assert_eq!(right.receive().await.unwrap(), "Hello World again!");
363        assert_eq!(left.receive().await.unwrap(), 128);
364        assert_eq!(left.receive().await.unwrap(), 1);
365    }
366
367    #[tokio::test]
368    async fn test_unit_message() {
369        let left = TcpListener::bind("127.0.0.1:8081").await.unwrap();
370        let right = TcpStream::connect("127.0.0.1:8081").await.unwrap();
371
372        let (left, _) = left.accept().await.unwrap();
373
374        let mut left: Transport<TcpStream, Codec, (), ()> = Transport::new(left, Codec::default());
375        let mut right: Transport<TcpStream, Codec, (), ()> =
376            Transport::new(right, Codec::default());
377
378        left.send(()).await.unwrap();
379        left.send(()).await.unwrap();
380        right.send(()).await.unwrap();
381        right.send(()).await.unwrap();
382
383        assert_eq!(right.receive().await.unwrap(), ());
384        assert_eq!(right.receive().await.unwrap(), ());
385        assert_eq!(left.receive().await.unwrap(), ());
386        assert_eq!(left.receive().await.unwrap(), ());
387    }
388
389    #[tokio::test]
390    async fn test_stream() {
391        let left = TcpListener::bind("127.0.0.1:8082").await.unwrap();
392        let right = TcpStream::connect("127.0.0.1:8082").await.unwrap();
393
394        let (left, _) = left.accept().await.unwrap();
395
396        let mut left: Transport<TcpStream, Codec, (), u32> = Transport::new(left, Codec::default());
397        let right: Transport<TcpStream, Codec, u32, ()> = Transport::new(right, Codec::default());
398
399        left.send_all(&mut stream::iter(vec![1, 2, 3].into_iter().map(Ok)))
400            .await
401            .unwrap();
402        drop(left);
403
404        assert_eq!(right.messages().collect::<Vec<u32>>().await, vec![1, 2, 3]);
405    }
406
407    #[tokio::test]
408    async fn test_size_limit() {
409        let left = TcpListener::bind("127.0.0.1:8084").await.unwrap();
410        let right = TcpStream::connect("127.0.0.1:8084").await.unwrap();
411
412        let (left, _) = left.accept().await.unwrap();
413
414        let mut left: Transport<TcpStream, Codec, String, String> =
415            Transport::new_with_max_message_size(left, Codec::default(), 15);
416        let mut right: Transport<TcpStream, Codec, String, String> =
417            Transport::new(right, Codec::default());
418
419        left.send("Hey".to_string()).await.unwrap();
420        assert!(matches!(
421            left.send("Hello, hello, hello".to_string()).await,
422            Err(mezzenger::Error::Other(Error::MessageTooLarge))
423        ));
424        left.send("Hi".to_string()).await.unwrap();
425
426        assert_eq!(right.receive().await.unwrap(), "Hey");
427        assert_eq!(right.receive().await.unwrap(), "Hi");
428
429        right.send("Hey".to_string()).await.unwrap();
430        for _i in 0..139 {
431            right.send("Hello, hello, hello".to_string()).await.unwrap();
432        }
433        right.send("Hi".to_string()).await.unwrap();
434
435        assert_eq!(left.receive().await.unwrap(), "Hey");
436        for _i in 0..139 {
437            assert!(matches!(
438                left.receive().await,
439                Err(mezzenger::Error::Other(Error::MessageTooLarge))
440            ));
441        }
442        assert_eq!(left.receive().await.unwrap(), "Hi");
443
444        right.send("Hey".to_string()).await.unwrap();
445        for _i in 0..17 {
446            right.send("Hello, hello, hello".to_string()).await.unwrap();
447            right
448                .send("Hello, hello, hello, hi".to_string())
449                .await
450                .unwrap();
451        }
452        right.send("Hi".to_string()).await.unwrap();
453
454        assert_eq!(left.receive().await.unwrap(), "Hey");
455        for _i in 0..17 {
456            assert!(matches!(
457                left.receive().await,
458                Err(mezzenger::Error::Other(Error::MessageTooLarge))
459            ));
460            assert!(matches!(
461                left.receive().await,
462                Err(mezzenger::Error::Other(Error::MessageTooLarge))
463            ));
464        }
465        assert_eq!(left.receive().await.unwrap(), "Hi");
466    }
467}