Skip to main content

heliosdb_proxy/backend/
stream.rs

1//! A pluggable I/O stream — plain TCP or TLS-wrapped TCP.
2//!
3//! Used by the backend client so the rest of the module code (auth,
4//! query, etc.) stays ignorant of whether TLS is on. Implements
5//! `AsyncRead`/`AsyncWrite` by delegating to the inner variant.
6
7use std::io;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
11use tokio::net::TcpStream;
12use tokio_rustls::client::TlsStream;
13
14/// A backend connection stream.
15pub enum Stream {
16    Plain(TcpStream),
17    Tls(TlsStream<TcpStream>),
18}
19
20impl Stream {
21    /// Expose the peer address if available (best-effort — TLS hides it).
22    pub fn peer_addr(&self) -> io::Result<std::net::SocketAddr> {
23        match self {
24            Stream::Plain(s) => s.peer_addr(),
25            Stream::Tls(s) => s.get_ref().0.peer_addr(),
26        }
27    }
28
29    /// Return `true` if this connection is encrypted.
30    pub fn is_tls(&self) -> bool {
31        matches!(self, Stream::Tls(_))
32    }
33}
34
35impl AsyncRead for Stream {
36    fn poll_read(
37        self: Pin<&mut Self>,
38        cx: &mut Context<'_>,
39        buf: &mut ReadBuf<'_>,
40    ) -> Poll<io::Result<()>> {
41        match self.get_mut() {
42            Stream::Plain(s) => Pin::new(s).poll_read(cx, buf),
43            Stream::Tls(s) => Pin::new(s).poll_read(cx, buf),
44        }
45    }
46}
47
48impl AsyncWrite for Stream {
49    fn poll_write(
50        self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52        buf: &[u8],
53    ) -> Poll<io::Result<usize>> {
54        match self.get_mut() {
55            Stream::Plain(s) => Pin::new(s).poll_write(cx, buf),
56            Stream::Tls(s) => Pin::new(s).poll_write(cx, buf),
57        }
58    }
59
60    fn poll_flush(
61        self: Pin<&mut Self>,
62        cx: &mut Context<'_>,
63    ) -> Poll<io::Result<()>> {
64        match self.get_mut() {
65            Stream::Plain(s) => Pin::new(s).poll_flush(cx),
66            Stream::Tls(s) => Pin::new(s).poll_flush(cx),
67        }
68    }
69
70    fn poll_shutdown(
71        self: Pin<&mut Self>,
72        cx: &mut Context<'_>,
73    ) -> Poll<io::Result<()>> {
74        match self.get_mut() {
75            Stream::Plain(s) => Pin::new(s).poll_shutdown(cx),
76            Stream::Tls(s) => Pin::new(s).poll_shutdown(cx),
77        }
78    }
79
80    fn poll_write_vectored(
81        self: Pin<&mut Self>,
82        cx: &mut Context<'_>,
83        bufs: &[io::IoSlice<'_>],
84    ) -> Poll<io::Result<usize>> {
85        match self.get_mut() {
86            Stream::Plain(s) => Pin::new(s).poll_write_vectored(cx, bufs),
87            Stream::Tls(s) => Pin::new(s).poll_write_vectored(cx, bufs),
88        }
89    }
90
91    fn is_write_vectored(&self) -> bool {
92        match self {
93            Stream::Plain(s) => s.is_write_vectored(),
94            Stream::Tls(s) => s.is_write_vectored(),
95        }
96    }
97}