1use crate::message::{Message, MessageOps, Packetizer};
2use crate::message_deserializer::MessageDeserializeError;
3use crate::message_serializer::MessageSerializeError;
4use crate::transport::AsyncTransport;
5use bytes::{Buf, BytesMut};
6use pin_project_lite::pin_project;
7use std::io::{Error as IoError, ErrorKind as IoErrorKind};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12
13const INITIAL_CAPACITY: usize = 8 * 1024;
14const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY;
15
16pin_project! {
17 #[derive(Debug)]
18 pub struct TokioTransport<T> {
19 #[pin]
20 io: T,
21 packetizer: Packetizer,
22 write_buf: BytesMut,
23 }
24}
25
26impl<T> TokioTransport<T> {
27 pub fn new(io: T) -> Self {
28 Self {
29 io,
30 packetizer: Packetizer::new(),
31 write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
32 }
33 }
34}
35
36impl<T> AsyncTransport for TokioTransport<T>
37where
38 T: AsyncRead + AsyncWrite,
39{
40 type Error = TokioTransportError;
41
42 fn receive_poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<Message, Self::Error>> {
43 let mut this = self.project();
44
45 loop {
46 if let Some(buf) = this.packetizer.next_message() {
47 return Poll::Ready(
48 Message::deserialize_message(buf).map_err(TokioTransportError::Deserialize),
49 );
50 }
51
52 let mut read_buf = ReadBuf::uninit(this.packetizer.spare_capacity_mut());
53 match this.io.as_mut().poll_read(cx, &mut read_buf) {
54 Poll::Ready(Ok(())) if read_buf.filled().is_empty() => {
55 return Poll::Ready(Err(TokioTransportError::Io(
56 IoErrorKind::UnexpectedEof.into(),
57 )))
58 }
59
60 Poll::Ready(Ok(())) => {
61 let len = read_buf.filled().len();
63 unsafe {
64 this.packetizer.bytes_written(len);
65 }
66 }
67
68 Poll::Ready(Err(e)) => return Poll::Ready(Err(TokioTransportError::Io(e))),
69 Poll::Pending => return Poll::Pending,
70 }
71 }
72 }
73
74 fn send_poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
75 if self.write_buf.len() >= BACKPRESSURE_BOUNDARY {
76 self.send_poll_flush(cx)
77 } else {
78 Poll::Ready(Ok(()))
79 }
80 }
81
82 fn send_start(self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> {
83 let this = self.project();
84
85 let msg = msg
86 .serialize_message()
87 .map_err(TokioTransportError::Serialize)?;
88
89 if this.write_buf.is_empty() {
90 *this.write_buf = msg;
91 } else {
92 this.write_buf.extend_from_slice(&msg);
93 }
94
95 Ok(())
96 }
97
98 fn send_poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
99 let mut this = self.project();
100
101 while !this.write_buf.is_empty() {
102 match this.io.as_mut().poll_write(cx, this.write_buf) {
103 Poll::Ready(Ok(0)) => {
104 return Poll::Ready(Err(TokioTransportError::Io(
105 IoErrorKind::WriteZero.into(),
106 )));
107 }
108 Poll::Ready(Ok(n)) => {
109 this.write_buf.advance(n);
110 }
111 Poll::Ready(Err(e)) => return Poll::Ready(Err(TokioTransportError::Io(e))),
112 Poll::Pending => return Poll::Pending,
113 }
114 }
115
116 this.io.poll_flush(cx).map_err(TokioTransportError::Io)
117 }
118}
119
120#[derive(Error, Debug)]
121pub enum TokioTransportError {
122 #[error(transparent)]
123 Io(#[from] IoError),
124
125 #[error(transparent)]
126 Serialize(#[from] MessageSerializeError),
127
128 #[error(transparent)]
129 Deserialize(#[from] MessageDeserializeError),
130}