1use std::{
22 fmt::{Debug, Display},
23 io::ErrorKind,
24 marker::PhantomData,
25 pin::Pin,
26 task::{Context, Poll},
27};
28
29use bytes::{Buf, BufMut, BytesMut};
30use futures::{ready, stream::FusedStream, Sink, Stream};
31use kodec::{Decode, Encode};
32use pin_project::pin_project;
33use serde::Serialize;
34use tokio::io::{AsyncRead, AsyncWrite};
35use tokio_util::io::{poll_read_buf, poll_write_buf};
36
37pub const DEFAULT_MAX_MESSAGE_SIZE: u32 = 65536;
38
39#[derive(Debug)]
40pub enum Error<SerializationError, DeserializationError> {
41 MessageTooLarge,
42 SerializationError(SerializationError),
43 DeserializationError(DeserializationError),
44 IoError(std::io::Error),
45}
46
47impl<SerializationError, DeserializationError> Display
48 for Error<SerializationError, DeserializationError>
49where
50 SerializationError: Display,
51 DeserializationError: Display,
52{
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 Error::MessageTooLarge => write!(f, "message was too large"),
56 Error::SerializationError(error) => write!(f, "failed to serialize message: {error}"),
57 Error::DeserializationError(error) => {
58 write!(f, "failed to deserialize message: {error}")
59 }
60 Error::IoError(error) => write!(f, "IO error occurred: {error}"),
61 }
62 }
63}
64
65impl<SerializationError, DeserializationError> std::error::Error
66 for Error<SerializationError, DeserializationError>
67where
68 SerializationError: Debug + Display,
69 DeserializationError: Debug + Display,
70{
71}
72
73struct ReceiveState {
74 pub buffer: BytesMut,
75 pub message_size: u32,
76 pub receiving_size: bool,
77 pub bytes_to_receive: i64,
78 pub bytes_to_skip: u32,
79}
80
81impl ReceiveState {
82 fn new() -> Self {
83 ReceiveState {
84 buffer: BytesMut::new(),
85 message_size: 0,
86 receiving_size: true,
87 bytes_to_receive: 4,
88 bytes_to_skip: 0,
89 }
90 }
91}
92
93#[pin_project]
97pub struct Transport<T, Codec, Incoming, Outgoing>
98where
99 T: AsyncWrite + AsyncRead,
100 Codec: kodec::Codec,
101 for<'de> Incoming: serde::de::Deserialize<'de>,
102 Outgoing: Serialize,
103{
104 #[pin]
105 inner: T,
106 send_buffer: BytesMut,
107 receive_state: ReceiveState,
108 codec: Codec,
109 terminated: bool,
110 max_message_size: u32,
111 _incoming: PhantomData<Incoming>,
112 _outgoing: PhantomData<Outgoing>,
113}
114
115impl<T, Codec, Incoming, Outgoing> Transport<T, Codec, Incoming, Outgoing>
116where
117 T: AsyncWrite + AsyncRead,
118 Codec: kodec::Codec,
119 for<'de> Incoming: serde::de::Deserialize<'de>,
120 Outgoing: Serialize,
121{
122 pub fn new(transport: T, codec: Codec) -> Self {
128 Transport::new_with_max_message_size(transport, codec, DEFAULT_MAX_MESSAGE_SIZE)
129 }
130
131 pub fn new_with_max_message_size(transport: T, codec: Codec, max_message_size: u32) -> Self {
137 Transport {
138 inner: transport,
139 codec,
140 send_buffer: BytesMut::new(),
141 receive_state: ReceiveState::new(),
142 terminated: false,
143 max_message_size,
144 _incoming: PhantomData,
145 _outgoing: PhantomData,
146 }
147 }
148}
149
150impl<T, Codec, Incoming, Outgoing> Sink<Outgoing> for Transport<T, Codec, Incoming, Outgoing>
151where
152 T: AsyncWrite + AsyncRead,
153 Codec: kodec::Codec,
154 for<'de> Incoming: serde::de::Deserialize<'de>,
155 Outgoing: Serialize,
156{
157 type Error = mezzenger::Error<Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
158
159 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
160 Poll::Ready(Ok(()))
161 }
162
163 fn start_send(self: Pin<&mut Self>, item: Outgoing) -> Result<(), Self::Error> {
164 if self.terminated {
165 Err(mezzenger::Error::Closed)
166 } else {
167 let me = self.project();
168 let size_position = me.send_buffer.len();
169 me.send_buffer.put_u32(0);
170 let current_length = me.send_buffer.len();
171 me.codec
172 .encode(me.send_buffer.writer(), &item)
173 .map_err(Error::SerializationError)
174 .map_err(mezzenger::Error::Other)?;
175 let message_size = me.send_buffer.len() - current_length;
176 if message_size > *me.max_message_size as usize {
177 me.send_buffer.truncate(size_position);
178 Err(mezzenger::Error::Other(Error::MessageTooLarge))
179 } else {
180 let size_slice = &mut me.send_buffer[size_position..(size_position + 4)];
181 let message_size = message_size as u32;
182 size_slice.swap_with_slice(&mut message_size.to_be_bytes());
183 Ok(())
184 }
185 }
186 }
187
188 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 let mut me = self.project();
190
191 let result = if me.send_buffer.is_empty() {
192 ready!(me.inner.poll_flush(cx))
193 } else {
194 ready!(poll_write_buf(me.inner.as_mut(), cx, me.send_buffer)).map(|_| ())
195 }
196 .map_err(|error| match error.kind() {
197 ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => mezzenger::Error::Closed,
198 _ => mezzenger::Error::Other(Error::IoError(error)),
199 });
200
201 Poll::Ready(result)
202 }
203
204 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
205 let me = self.project();
206 let result = ready!(me.inner.poll_shutdown(cx)).map_err(|error| match error.kind() {
207 ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => mezzenger::Error::Closed,
208 _ => mezzenger::Error::Other(Error::IoError(error)),
209 });
210 Poll::Ready(result)
211 }
212}
213
214impl<T, Codec, Incoming, Outgoing> Stream for Transport<T, Codec, Incoming, Outgoing>
215where
216 T: AsyncWrite + AsyncRead,
217 Codec: kodec::Codec,
218 for<'de> Incoming: serde::de::Deserialize<'de>,
219 Outgoing: Serialize,
220{
221 type Item = Result<Incoming, Error<<Codec as Encode>::Error, <Codec as Decode>::Error>>;
222
223 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
224 if self.terminated {
225 return Poll::Ready(None);
226 }
227
228 let mut me = self.project();
229 loop {
230 if me.receive_state.bytes_to_receive <= 0 {
231 if me.receive_state.receiving_size {
232 let message_size = me.receive_state.buffer.get_u32();
233 me.receive_state.message_size = message_size;
234 me.receive_state.bytes_to_receive += message_size as i64;
235 if message_size > *me.max_message_size {
236 me.receive_state.bytes_to_receive += 4;
237 if me.receive_state.bytes_to_receive > 0 {
238 me.receive_state.bytes_to_skip = message_size;
239 } else {
240 me.receive_state.buffer.advance(message_size as usize);
241 }
242 return Poll::Ready(Some(Err(Error::MessageTooLarge)));
243 } else {
244 me.receive_state.receiving_size = false;
245 }
246 } else {
247 let message_size = me.receive_state.message_size as usize;
248 let message = &me.receive_state.buffer[0..message_size];
249 let result: Result<Incoming, _> = me.codec.decode(message);
250 me.receive_state.buffer.advance(message_size);
251 me.receive_state.receiving_size = true;
252 me.receive_state.bytes_to_receive += 4;
253 return {
254 match result {
255 Ok(message) => Poll::Ready(Some(Ok(message))),
256 Err(error) => {
257 Poll::Ready(Some(Err(Error::DeserializationError(error))))
258 }
259 }
260 };
261 }
262 } else {
263 let result = ready!(poll_read_buf(
264 me.inner.as_mut(),
265 cx,
266 &mut me.receive_state.buffer
267 ));
268 match result {
269 Ok(bytes_read) => {
270 if bytes_read == 0 {
271 *me.terminated = true;
272 return Poll::Ready(None);
273 }
274 me.receive_state.bytes_to_receive = me
275 .receive_state
276 .bytes_to_receive
277 .saturating_sub_unsigned(bytes_read as u64);
278 if me.receive_state.bytes_to_skip > 0 {
279 let buffer_len = me.receive_state.buffer.len();
280 let skipped = buffer_len.min(me.receive_state.bytes_to_skip as usize);
281 if skipped == buffer_len {
282 me.receive_state.buffer.clear();
283 } else {
284 me.receive_state.buffer.advance(skipped);
285 }
286 me.receive_state.bytes_to_skip = me
287 .receive_state
288 .bytes_to_skip
289 .saturating_sub(skipped as u32);
290 }
291 }
292 Err(error) => match error.kind() {
293 ErrorKind::ConnectionReset | ErrorKind::ConnectionAborted => {
294 *me.terminated = true;
295 return Poll::Ready(None);
296 }
297 _ => return Poll::Ready(Some(Err(Error::IoError(error)))),
298 },
299 }
300 }
301 }
302 }
303}
304
305impl<T, Codec, Incoming, Outgoing> FusedStream for Transport<T, Codec, Incoming, Outgoing>
306where
307 T: AsyncWrite + AsyncRead,
308 Codec: kodec::Codec,
309 for<'de> Incoming: serde::de::Deserialize<'de>,
310 Outgoing: Serialize,
311{
312 fn is_terminated(&self) -> bool {
313 self.terminated
314 }
315}
316
317impl<T, Codec, Incoming, Outgoing> mezzenger::Reliable for Transport<T, Codec, Incoming, Outgoing>
318where
319 T: AsyncWrite + AsyncRead,
320 Codec: kodec::Codec,
321 for<'de> Incoming: serde::de::Deserialize<'de>,
322 Outgoing: Serialize,
323{
324}
325
326impl<T, Codec, Incoming, Outgoing> mezzenger::Order for Transport<T, Codec, Incoming, Outgoing>
327where
328 T: AsyncWrite + AsyncRead,
329 Codec: kodec::Codec,
330 for<'de> Incoming: serde::de::Deserialize<'de>,
331 Outgoing: Serialize,
332{
333}
334
335#[cfg(test)]
336mod tests {
337 use futures::{stream, SinkExt, StreamExt};
338 use kodec::binary::Codec;
339 use mezzenger::{Messages, Receive};
340 use tokio::net::{TcpListener, TcpStream};
341
342 use crate::{Error, Transport};
343
344 #[tokio::test]
345 async fn test_transport() {
346 let left = TcpListener::bind("127.0.0.1:8080").await.unwrap();
347 let right = TcpStream::connect("127.0.0.1:8080").await.unwrap();
348
349 let (left, _) = left.accept().await.unwrap();
350
351 let mut left: Transport<TcpStream, Codec, u32, String> =
352 Transport::new(left, Codec::default());
353 let mut right: Transport<TcpStream, Codec, String, u32> =
354 Transport::new(right, Codec::default());
355
356 left.send("Hello World!".to_string()).await.unwrap();
357 left.send("Hello World again!".to_string()).await.unwrap();
358 right.send(128).await.unwrap();
359 right.send(1).await.unwrap();
360
361 assert_eq!(right.receive().await.unwrap(), "Hello World!");
362 assert_eq!(right.receive().await.unwrap(), "Hello World again!");
363 assert_eq!(left.receive().await.unwrap(), 128);
364 assert_eq!(left.receive().await.unwrap(), 1);
365 }
366
367 #[tokio::test]
368 async fn test_unit_message() {
369 let left = TcpListener::bind("127.0.0.1:8081").await.unwrap();
370 let right = TcpStream::connect("127.0.0.1:8081").await.unwrap();
371
372 let (left, _) = left.accept().await.unwrap();
373
374 let mut left: Transport<TcpStream, Codec, (), ()> = Transport::new(left, Codec::default());
375 let mut right: Transport<TcpStream, Codec, (), ()> =
376 Transport::new(right, Codec::default());
377
378 left.send(()).await.unwrap();
379 left.send(()).await.unwrap();
380 right.send(()).await.unwrap();
381 right.send(()).await.unwrap();
382
383 assert_eq!(right.receive().await.unwrap(), ());
384 assert_eq!(right.receive().await.unwrap(), ());
385 assert_eq!(left.receive().await.unwrap(), ());
386 assert_eq!(left.receive().await.unwrap(), ());
387 }
388
389 #[tokio::test]
390 async fn test_stream() {
391 let left = TcpListener::bind("127.0.0.1:8082").await.unwrap();
392 let right = TcpStream::connect("127.0.0.1:8082").await.unwrap();
393
394 let (left, _) = left.accept().await.unwrap();
395
396 let mut left: Transport<TcpStream, Codec, (), u32> = Transport::new(left, Codec::default());
397 let right: Transport<TcpStream, Codec, u32, ()> = Transport::new(right, Codec::default());
398
399 left.send_all(&mut stream::iter(vec![1, 2, 3].into_iter().map(Ok)))
400 .await
401 .unwrap();
402 drop(left);
403
404 assert_eq!(right.messages().collect::<Vec<u32>>().await, vec![1, 2, 3]);
405 }
406
407 #[tokio::test]
408 async fn test_size_limit() {
409 let left = TcpListener::bind("127.0.0.1:8084").await.unwrap();
410 let right = TcpStream::connect("127.0.0.1:8084").await.unwrap();
411
412 let (left, _) = left.accept().await.unwrap();
413
414 let mut left: Transport<TcpStream, Codec, String, String> =
415 Transport::new_with_max_message_size(left, Codec::default(), 15);
416 let mut right: Transport<TcpStream, Codec, String, String> =
417 Transport::new(right, Codec::default());
418
419 left.send("Hey".to_string()).await.unwrap();
420 assert!(matches!(
421 left.send("Hello, hello, hello".to_string()).await,
422 Err(mezzenger::Error::Other(Error::MessageTooLarge))
423 ));
424 left.send("Hi".to_string()).await.unwrap();
425
426 assert_eq!(right.receive().await.unwrap(), "Hey");
427 assert_eq!(right.receive().await.unwrap(), "Hi");
428
429 right.send("Hey".to_string()).await.unwrap();
430 for _i in 0..139 {
431 right.send("Hello, hello, hello".to_string()).await.unwrap();
432 }
433 right.send("Hi".to_string()).await.unwrap();
434
435 assert_eq!(left.receive().await.unwrap(), "Hey");
436 for _i in 0..139 {
437 assert!(matches!(
438 left.receive().await,
439 Err(mezzenger::Error::Other(Error::MessageTooLarge))
440 ));
441 }
442 assert_eq!(left.receive().await.unwrap(), "Hi");
443
444 right.send("Hey".to_string()).await.unwrap();
445 for _i in 0..17 {
446 right.send("Hello, hello, hello".to_string()).await.unwrap();
447 right
448 .send("Hello, hello, hello, hi".to_string())
449 .await
450 .unwrap();
451 }
452 right.send("Hi".to_string()).await.unwrap();
453
454 assert_eq!(left.receive().await.unwrap(), "Hey");
455 for _i in 0..17 {
456 assert!(matches!(
457 left.receive().await,
458 Err(mezzenger::Error::Other(Error::MessageTooLarge))
459 ));
460 assert!(matches!(
461 left.receive().await,
462 Err(mezzenger::Error::Other(Error::MessageTooLarge))
463 ));
464 }
465 assert_eq!(left.receive().await.unwrap(), "Hi");
466 }
467}