socket_flow/
stream.rs

1use std::pin::Pin;
2use std::task::{Context, Poll};
3use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
4use tokio::net::TcpStream;
5use tokio_rustls::TlsStream as RustTlsStream;
6
7// We need to implement AsyncRead and AsyncWrite for SocketFlowStream,
8// because when we split a TlsStream, it returns a ReadHalf<T>, WriteHalf<T>
9// where T: AsyncRead + AsyncWrite
10// This is a good solution, when you don't want to use a generic for your own functions
11// If we use a generic in accept_async(handshake.rs), we will need to add trait signatures to all the
12// functions that are called inside accept_async recursively.
13pub enum SocketFlowStream {
14    Plain(TcpStream),
15    Secure(RustTlsStream<TcpStream>),
16}
17
18impl AsyncRead for SocketFlowStream {
19    fn poll_read(
20        self: Pin<&mut Self>,
21        cx: &mut Context<'_>,
22        buf: &mut ReadBuf<'_>,
23    ) -> Poll<std::io::Result<()>> {
24        match self.get_mut() {
25            SocketFlowStream::Plain(ref mut s) => Pin::new(s).poll_read(cx, buf),
26            SocketFlowStream::Secure(s) => Pin::new(s).poll_read(cx, buf),
27        }
28    }
29}
30
31impl AsyncWrite for SocketFlowStream {
32    fn poll_write(
33        self: Pin<&mut Self>,
34        cx: &mut Context<'_>,
35        buf: &[u8],
36    ) -> Poll<Result<usize, std::io::Error>> {
37        match self.get_mut() {
38            SocketFlowStream::Plain(ref mut s) => Pin::new(s).poll_write(cx, buf),
39            SocketFlowStream::Secure(s) => Pin::new(s).poll_write(cx, buf),
40        }
41    }
42
43    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
44        match self.get_mut() {
45            SocketFlowStream::Plain(ref mut s) => Pin::new(s).poll_flush(cx),
46            SocketFlowStream::Secure(s) => Pin::new(s).poll_flush(cx),
47        }
48    }
49
50    fn poll_shutdown(
51        self: Pin<&mut Self>,
52        cx: &mut Context<'_>,
53    ) -> Poll<Result<(), std::io::Error>> {
54        match self.get_mut() {
55            SocketFlowStream::Plain(ref mut s) => Pin::new(s).poll_shutdown(cx),
56            SocketFlowStream::Secure(s) => Pin::new(s).poll_shutdown(cx),
57        }
58    }
59}