aldrin_core/
tokio.rs

1use crate::message::{Message, MessageOps, Packetizer};
2use crate::message_deserializer::MessageDeserializeError;
3use crate::message_serializer::MessageSerializeError;
4use crate::transport::AsyncTransport;
5use bytes::{Buf, BytesMut};
6use pin_project_lite::pin_project;
7use std::io::{Error as IoError, ErrorKind as IoErrorKind};
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use thiserror::Error;
11use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
12
13const INITIAL_CAPACITY: usize = 8 * 1024;
14const BACKPRESSURE_BOUNDARY: usize = INITIAL_CAPACITY;
15
16pin_project! {
17    #[derive(Debug)]
18    pub struct TokioTransport<T> {
19        #[pin]
20        io: T,
21        packetizer: Packetizer,
22        write_buf: BytesMut,
23    }
24}
25
26impl<T> TokioTransport<T> {
27    pub fn new(io: T) -> Self {
28        Self {
29            io,
30            packetizer: Packetizer::new(),
31            write_buf: BytesMut::with_capacity(INITIAL_CAPACITY),
32        }
33    }
34}
35
36impl<T> AsyncTransport for TokioTransport<T>
37where
38    T: AsyncRead + AsyncWrite,
39{
40    type Error = TokioTransportError;
41
42    fn receive_poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<Message, Self::Error>> {
43        let mut this = self.project();
44
45        loop {
46            if let Some(buf) = this.packetizer.next_message() {
47                return Poll::Ready(
48                    Message::deserialize_message(buf).map_err(TokioTransportError::Deserialize),
49                );
50            }
51
52            let mut read_buf = ReadBuf::uninit(this.packetizer.spare_capacity_mut());
53            match this.io.as_mut().poll_read(cx, &mut read_buf) {
54                Poll::Ready(Ok(())) if read_buf.filled().is_empty() => {
55                    return Poll::Ready(Err(TokioTransportError::Io(
56                        IoErrorKind::UnexpectedEof.into(),
57                    )))
58                }
59
60                Poll::Ready(Ok(())) => {
61                    // SAFETY: The first len bytes have been initialized.
62                    let len = read_buf.filled().len();
63                    unsafe {
64                        this.packetizer.bytes_written(len);
65                    }
66                }
67
68                Poll::Ready(Err(e)) => return Poll::Ready(Err(TokioTransportError::Io(e))),
69                Poll::Pending => return Poll::Pending,
70            }
71        }
72    }
73
74    fn send_poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
75        if self.write_buf.len() >= BACKPRESSURE_BOUNDARY {
76            self.send_poll_flush(cx)
77        } else {
78            Poll::Ready(Ok(()))
79        }
80    }
81
82    fn send_start(self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> {
83        let this = self.project();
84
85        let msg = msg
86            .serialize_message()
87            .map_err(TokioTransportError::Serialize)?;
88
89        if this.write_buf.is_empty() {
90            *this.write_buf = msg;
91        } else {
92            this.write_buf.extend_from_slice(&msg);
93        }
94
95        Ok(())
96    }
97
98    fn send_poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
99        let mut this = self.project();
100
101        while !this.write_buf.is_empty() {
102            match this.io.as_mut().poll_write(cx, this.write_buf) {
103                Poll::Ready(Ok(0)) => {
104                    return Poll::Ready(Err(TokioTransportError::Io(
105                        IoErrorKind::WriteZero.into(),
106                    )));
107                }
108                Poll::Ready(Ok(n)) => {
109                    this.write_buf.advance(n);
110                }
111                Poll::Ready(Err(e)) => return Poll::Ready(Err(TokioTransportError::Io(e))),
112                Poll::Pending => return Poll::Pending,
113            }
114        }
115
116        this.io.poll_flush(cx).map_err(TokioTransportError::Io)
117    }
118}
119
120#[derive(Error, Debug)]
121pub enum TokioTransportError {
122    #[error(transparent)]
123    Io(#[from] IoError),
124
125    #[error(transparent)]
126    Serialize(#[from] MessageSerializeError),
127
128    #[error(transparent)]
129    Deserialize(#[from] MessageDeserializeError),
130}