borer_core/stream/
acceptor.rs1use std::{
2 path::PathBuf,
3 pin::Pin,
4 task::{Context, Poll},
5};
6
7use anyhow::Context as _;
8use tokio::{
9 io::{AsyncRead, AsyncWrite, ReadBuf},
10 net::TcpStream,
11};
12use tokio_rustls::{TlsAcceptor, server};
13
14use crate::tls::{load_certs, load_private_key, make_tls_acceptor};
15
16#[derive(Clone)]
18pub struct Acceptor {
19 inner: Option<TlsAcceptor>,
20}
21
22#[non_exhaustive]
23#[derive(Debug)]
24pub enum MaybeTlsStream<S> {
26 Plain(S),
27 Tls(Box<server::TlsStream<S>>),
28}
29
30impl Acceptor {
31 pub fn new(cert: Option<String>, key: Option<String>) -> anyhow::Result<Self> {
33 match (cert, key) {
34 (Some(cert), Some(key)) => {
35 let certs = load_certs(PathBuf::from(cert)).context("load_certs failed")?;
36 let key =
37 load_private_key(PathBuf::from(key)).context("load_private_key failed")?;
38 let tls_acceptor = make_tls_acceptor(certs, key)?;
39 Ok(Self {
40 inner: Some(tls_acceptor),
41 })
42 }
43 _ => Ok(Self { inner: None }),
44 }
45 }
46
47 pub async fn accept(&self, ts: TcpStream) -> anyhow::Result<MaybeTlsStream<TcpStream>> {
48 match &self.inner {
49 Some(acceptor) => {
50 let tls_ts = acceptor.accept(ts).await?;
51 Ok(MaybeTlsStream::Tls(Box::new(tls_ts)))
52 }
53 _ => Ok(MaybeTlsStream::Plain(ts)),
54 }
55 }
56}
57
58impl<S> AsyncRead for MaybeTlsStream<S>
59where
60 S: AsyncRead + AsyncWrite + Unpin,
61{
62 fn poll_read(
63 self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 buf: &mut ReadBuf<'_>,
66 ) -> Poll<std::io::Result<()>> {
67 match self.get_mut() {
68 MaybeTlsStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
69 MaybeTlsStream::Tls(s) => Pin::new(s).poll_read(cx, buf),
70 }
71 }
72}
73
74impl<S> AsyncWrite for MaybeTlsStream<S>
75where
76 S: AsyncRead + AsyncWrite + Unpin,
77{
78 fn poll_write(
79 self: Pin<&mut Self>,
80 cx: &mut Context<'_>,
81 buf: &[u8],
82 ) -> Poll<Result<usize, std::io::Error>> {
83 match self.get_mut() {
84 MaybeTlsStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
85 MaybeTlsStream::Tls(s) => Pin::new(s).poll_write(cx, buf),
86 }
87 }
88
89 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
90 match self.get_mut() {
91 MaybeTlsStream::Plain(s) => Pin::new(s).poll_flush(cx),
92 MaybeTlsStream::Tls(s) => Pin::new(s).poll_flush(cx),
93 }
94 }
95
96 fn poll_shutdown(
97 self: Pin<&mut Self>,
98 cx: &mut Context<'_>,
99 ) -> Poll<Result<(), std::io::Error>> {
100 match self.get_mut() {
101 MaybeTlsStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
102 MaybeTlsStream::Tls(s) => Pin::new(s).poll_shutdown(cx),
103 }
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use tokio::{
110 io::{AsyncReadExt, AsyncWriteExt},
111 net::{TcpListener, TcpStream},
112 };
113
114 use super::{Acceptor, MaybeTlsStream};
115
116 async fn tcp_pair() -> (TcpStream, TcpStream) {
117 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
118 let addr = listener.local_addr().unwrap();
119 let client = TcpStream::connect(addr).await.unwrap();
120 let (server, _) = listener.accept().await.unwrap();
121
122 (server, client)
123 }
124
125 #[tokio::test]
126 async fn accept_without_tls_returns_plain_stream() {
127 let acceptor = Acceptor::new(None, None).unwrap();
128 let (server, _client) = tcp_pair().await;
129
130 let stream = acceptor.accept(server).await.unwrap();
131
132 assert!(matches!(stream, MaybeTlsStream::Plain(_)));
133 }
134
135 #[tokio::test]
136 async fn maybe_tls_plain_stream_reads_and_writes() {
137 let (server, mut client) = tcp_pair().await;
138 let mut stream = MaybeTlsStream::Plain(server);
139
140 stream.write_all(b"ping").await.unwrap();
141 let mut received = [0u8; 4];
142 client.read_exact(&mut received).await.unwrap();
143 assert_eq!(&received, b"ping");
144
145 client.write_all(b"pong").await.unwrap();
146 let mut buf = [0u8; 4];
147 stream.read_exact(&mut buf).await.unwrap();
148 assert_eq!(&buf, b"pong");
149 }
150}