heliosdb_proxy/backend/
stream.rs1use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::net::TcpStream;
12use tokio_rustls::client::TlsStream;
13
14#[allow(clippy::large_enum_variant)]
16pub enum Stream {
17 Plain(TcpStream),
18 Tls(TlsStream<TcpStream>),
19}
20
21impl Stream {
22 pub fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
24 match self {
25 Stream::Plain(s) => s.peer_addr(),
26 Stream::Tls(s) => s.get_ref().0.peer_addr(),
27 }
28 }
29
30 pub fn is_tls(&self) -> bool {
32 matches!(self, Stream::Tls(_))
33 }
34}
35
36impl AsyncRead for Stream {
37 fn poll_read(
38 self: Pin<&mut Self>,
39 cx: &mut Context<'_>,
40 buf: &mut ReadBuf<'_>,
41 ) -> Poll<io::Result<()>> {
42 match self.get_mut() {
43 Stream::Plain(s) => Pin::new(s).poll_read(cx, buf),
44 Stream::Tls(s) => Pin::new(s).poll_read(cx, buf),
45 }
46 }
47}
48
49impl AsyncWrite for Stream {
50 fn poll_write(
51 self: Pin<&mut Self>,
52 cx: &mut Context<'_>,
53 buf: &[u8],
54 ) -> Poll<io::Result<usize>> {
55 match self.get_mut() {
56 Stream::Plain(s) => Pin::new(s).poll_write(cx, buf),
57 Stream::Tls(s) => Pin::new(s).poll_write(cx, buf),
58 }
59 }
60
61 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
62 match self.get_mut() {
63 Stream::Plain(s) => Pin::new(s).poll_flush(cx),
64 Stream::Tls(s) => Pin::new(s).poll_flush(cx),
65 }
66 }
67
68 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
69 match self.get_mut() {
70 Stream::Plain(s) => Pin::new(s).poll_shutdown(cx),
71 Stream::Tls(s) => Pin::new(s).poll_shutdown(cx),
72 }
73 }
74
75 fn poll_write_vectored(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 bufs: &[io::IoSlice<'_>],
79 ) -> Poll<io::Result<usize>> {
80 match self.get_mut() {
81 Stream::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
82 Stream::Tls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
83 }
84 }
85
86 fn is_write_vectored(&self) -> bool {
87 match self {
88 Stream::Plain(s) => s.is_write_vectored(),
89 Stream::Tls(s) => s.is_write_vectored(),
90 }
91 }
92}