Skip to main content

borer_core/stream/
acceptor.rs

1use std::{
2    path::PathBuf,
3    pin::Pin,
4    task::{Context, Poll},
5};
6
7use anyhow::Context as _;
8use tokio::{
9    io::{AsyncRead, AsyncWrite, ReadBuf},
10    net::TcpStream,
11};
12use tokio_rustls::{TlsAcceptor, server};
13
14use crate::tls::{load_certs, load_private_key, make_tls_acceptor};
15
16/// Accepts plain TCP or upgrades accepted sockets to TLS.
17#[derive(Clone)]
18pub struct Acceptor {
19    inner: Option<TlsAcceptor>,
20}
21
22#[non_exhaustive]
23#[derive(Debug)]
24/// Stream returned by [`Acceptor`], either plain TCP or TLS-wrapped TCP.
25pub enum MaybeTlsStream<S> {
26    Plain(S),
27    Tls(Box<server::TlsStream<S>>),
28}
29
30impl Acceptor {
31    /// Create a plain or TLS acceptor from optional certificate and key paths.
32    pub fn new(cert: Option<String>, key: Option<String>) -> anyhow::Result<Self> {
33        match (cert, key) {
34            (Some(cert), Some(key)) => {
35                let certs = load_certs(PathBuf::from(cert)).context("load_certs failed")?;
36                let key =
37                    load_private_key(PathBuf::from(key)).context("load_private_key failed")?;
38                let tls_acceptor = make_tls_acceptor(certs, key)?;
39                Ok(Self {
40                    inner: Some(tls_acceptor),
41                })
42            }
43            _ => Ok(Self { inner: None }),
44        }
45    }
46
47    pub async fn accept(&self, ts: TcpStream) -> anyhow::Result<MaybeTlsStream<TcpStream>> {
48        match &self.inner {
49            Some(acceptor) => {
50                let tls_ts = acceptor.accept(ts).await?;
51                Ok(MaybeTlsStream::Tls(Box::new(tls_ts)))
52            }
53            _ => Ok(MaybeTlsStream::Plain(ts)),
54        }
55    }
56}
57
58impl<S> AsyncRead for MaybeTlsStream<S>
59where
60    S: AsyncRead + AsyncWrite + Unpin,
61{
62    fn poll_read(
63        self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65        buf: &mut ReadBuf<'_>,
66    ) -> Poll<std::io::Result<()>> {
67        match self.get_mut() {
68            MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
69            MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
70        }
71    }
72}
73
74impl<S> AsyncWrite for MaybeTlsStream<S>
75where
76    S: AsyncRead + AsyncWrite + Unpin,
77{
78    fn poll_write(
79        self: Pin<&mut Self>,
80        cx: &mut Context<'_>,
81        buf: &[u8],
82    ) -> Poll<Result<usize, std::io::Error>> {
83        match self.get_mut() {
84            MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
85            MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
86        }
87    }
88
89    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
90        match self.get_mut() {
91            MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
92            MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
93        }
94    }
95
96    fn poll_shutdown(
97        self: Pin<&mut Self>,
98        cx: &mut Context<'_>,
99    ) -> Poll<Result<(), std::io::Error>> {
100        match self.get_mut() {
101            MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
102            MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
103        }
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use tokio::{
110        io::{AsyncReadExt, AsyncWriteExt},
111        net::{TcpListener, TcpStream},
112    };
113
114    use super::{Acceptor, MaybeTlsStream};
115
116    async fn tcp_pair() -> (TcpStream, TcpStream) {
117        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
118        let addr = listener.local_addr().unwrap();
119        let client = TcpStream::connect(addr).await.unwrap();
120        let (server, _) = listener.accept().await.unwrap();
121
122        (server, client)
123    }
124
125    #[tokio::test]
126    async fn accept_without_tls_returns_plain_stream() {
127        let acceptor = Acceptor::new(None, None).unwrap();
128        let (server, _client) = tcp_pair().await;
129
130        let stream = acceptor.accept(server).await.unwrap();
131
132        assert!(matches!(stream, MaybeTlsStream::Plain(_)));
133    }
134
135    #[tokio::test]
136    async fn maybe_tls_plain_stream_reads_and_writes() {
137        let (server, mut client) = tcp_pair().await;
138        let mut stream = MaybeTlsStream::Plain(server);
139
140        stream.write_all(b"ping").await.unwrap();
141        let mut received = [0u8; 4];
142        client.read_exact(&mut received).await.unwrap();
143        assert_eq!(&received, b"ping");
144
145        client.write_all(b"pong").await.unwrap();
146        let mut buf = [0u8; 4];
147        stream.read_exact(&mut buf).await.unwrap();
148        assert_eq!(&buf, b"pong");
149    }
150}