heliosdb_proxy/backend/
stream.rs1use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::net::TcpStream;
12use tokio_rustls::client::TlsStream;
13
14pub enum Stream {
16 Plain(TcpStream),
17 Tls(TlsStream<TcpStream>),
18}
19
20impl Stream {
21 pub fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
23 match self {
24 Stream::Plain(s) => s.peer_addr(),
25 Stream::Tls(s) => s.get_ref().0.peer_addr(),
26 }
27 }
28
29 pub fn is_tls(&self) -> bool {
31 matches!(self, Stream::Tls(_))
32 }
33}
34
35impl AsyncRead for Stream {
36 fn poll_read(
37 self: Pin<&mut Self>,
38 cx: &mut Context<'_>,
39 buf: &mut ReadBuf<'_>,
40 ) -> Poll<io::Result<()>> {
41 match self.get_mut() {
42 Stream::Plain(s) => Pin::new(s).poll_read(cx, buf),
43 Stream::Tls(s) => Pin::new(s).poll_read(cx, buf),
44 }
45 }
46}
47
48impl AsyncWrite for Stream {
49 fn poll_write(
50 self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 buf: &[u8],
53 ) -> Poll<io::Result<usize>> {
54 match self.get_mut() {
55 Stream::Plain(s) => Pin::new(s).poll_write(cx, buf),
56 Stream::Tls(s) => Pin::new(s).poll_write(cx, buf),
57 }
58 }
59
60 fn poll_flush(
61 self: Pin<&mut Self>,
62 cx: &mut Context<'_>,
63 ) -> Poll<io::Result<()>> {
64 match self.get_mut() {
65 Stream::Plain(s) => Pin::new(s).poll_flush(cx),
66 Stream::Tls(s) => Pin::new(s).poll_flush(cx),
67 }
68 }
69
70 fn poll_shutdown(
71 self: Pin<&mut Self>,
72 cx: &mut Context<'_>,
73 ) -> Poll<io::Result<()>> {
74 match self.get_mut() {
75 Stream::Plain(s) => Pin::new(s).poll_shutdown(cx),
76 Stream::Tls(s) => Pin::new(s).poll_shutdown(cx),
77 }
78 }
79
80 fn poll_write_vectored(
81 self: Pin<&mut Self>,
82 cx: &mut Context<'_>,
83 bufs: &[io::IoSlice<'_>],
84 ) -> Poll<io::Result<usize>> {
85 match self.get_mut() {
86 Stream::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
87 Stream::Tls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
88 }
89 }
90
91 fn is_write_vectored(&self) -> bool {
92 match self {
93 Stream::Plain(s) => s.is_write_vectored(),
94 Stream::Tls(s) => s.is_write_vectored(),
95 }
96 }
97}