Skip to main content

heliosdb_proxy/backend/
stream.rs

1//! A pluggable I/O stream — plain TCP or TLS-wrapped TCP.
2//!
3//! Used by the backend client so the rest of the module code (auth,
4//! query, etc.) stays ignorant of whether TLS is on. Implements
5//! `AsyncRead`/`AsyncWrite` by delegating to the inner variant.
6
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::net::TcpStream;
12use tokio_rustls::client::TlsStream;
13
14/// A backend connection stream.
15#[allow(clippy::large_enum_variant)]
16pub enum Stream {
17    Plain(TcpStream),
18    Tls(TlsStream<TcpStream>),
19}
20
21impl Stream {
22    /// Expose the peer address if available (best-effort — TLS hides it).
23    pub fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
24        match self {
25            Stream::Plain(s) => s.peer_addr(),
26            Stream::Tls(s) => s.get_ref().0.peer_addr(),
27        }
28    }
29
30    /// Return `true` if this connection is encrypted.
31    pub fn is_tls(&self) -> bool {
32        matches!(self, Stream::Tls(_))
33    }
34}
35
36impl AsyncRead for Stream {
37    fn poll_read(
38        self: Pin<&mut Self>,
39        cx: &mut Context<'_>,
40        buf: &mut ReadBuf<'_>,
41    ) -> Poll<io::Result<()>> {
42        match self.get_mut() {
43            Stream::Plain(s) => Pin::new(s).poll_read(cx, buf),
44            Stream::Tls(s) => Pin::new(s).poll_read(cx, buf),
45        }
46    }
47}
48
49impl AsyncWrite for Stream {
50    fn poll_write(
51        self: Pin<&mut Self>,
52        cx: &mut Context<'_>,
53        buf: &[u8],
54    ) -> Poll<io::Result<usize>> {
55        match self.get_mut() {
56            Stream::Plain(s) => Pin::new(s).poll_write(cx, buf),
57            Stream::Tls(s) => Pin::new(s).poll_write(cx, buf),
58        }
59    }
60
61    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
62        match self.get_mut() {
63            Stream::Plain(s) => Pin::new(s).poll_flush(cx),
64            Stream::Tls(s) => Pin::new(s).poll_flush(cx),
65        }
66    }
67
68    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
69        match self.get_mut() {
70            Stream::Plain(s) => Pin::new(s).poll_shutdown(cx),
71            Stream::Tls(s) => Pin::new(s).poll_shutdown(cx),
72        }
73    }
74
75    fn poll_write_vectored(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78        bufs: &[io::IoSlice<'_>],
79    ) -> Poll<io::Result<usize>> {
80        match self.get_mut() {
81            Stream::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
82            Stream::Tls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
83        }
84    }
85
86    fn is_write_vectored(&self) -> bool {
87        match self {
88            Stream::Plain(s) => s.is_write_vectored(),
89            Stream::Tls(s) => s.is_write_vectored(),
90        }
91    }
92}