heliosdb_proxy/
client_tls.rs1use std::pin::Pin;
12use std::sync::Arc;
13use std::task::{Context, Poll};
14
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16use tokio::net::TcpStream;
17use tokio_rustls::server::TlsStream;
18use tokio_rustls::TlsAcceptor;
19
20use crate::config::TlsConfig;
21
22pub enum ClientStream {
27 Plain(TcpStream),
28 Tls(Box<TlsStream<TcpStream>>),
29}
30
31impl ClientStream {
32 pub fn peer_cert_present(&self) -> bool {
35 match self {
36 ClientStream::Plain(_) => false,
37 ClientStream::Tls(s) => s
38 .get_ref()
39 .1
40 .peer_certificates()
41 .map(|c| !c.is_empty())
42 .unwrap_or(false),
43 }
44 }
45}
46
47impl AsyncRead for ClientStream {
48 fn poll_read(
49 self: Pin<&mut Self>,
50 cx: &mut Context<'_>,
51 buf: &mut ReadBuf<'_>,
52 ) -> Poll<std::io::Result<()>> {
53 match self.get_mut() {
54 ClientStream::Plain(s) => Pin::new(s).poll_read(cx, buf),
55 ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_read(cx, buf),
56 }
57 }
58}
59
60impl AsyncWrite for ClientStream {
61 fn poll_write(
62 self: Pin<&mut Self>,
63 cx: &mut Context<'_>,
64 buf: &[u8],
65 ) -> Poll<std::io::Result<usize>> {
66 match self.get_mut() {
67 ClientStream::Plain(s) => Pin::new(s).poll_write(cx, buf),
68 ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_write(cx, buf),
69 }
70 }
71
72 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
73 match self.get_mut() {
74 ClientStream::Plain(s) => Pin::new(s).poll_flush(cx),
75 ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_flush(cx),
76 }
77 }
78
79 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
80 match self.get_mut() {
81 ClientStream::Plain(s) => Pin::new(s).poll_shutdown(cx),
82 ClientStream::Tls(s) => Pin::new(s.as_mut()).poll_shutdown(cx),
83 }
84 }
85}
86
87pub fn build_tls_acceptor(tls: &TlsConfig) -> Result<TlsAcceptor, String> {
91 use rustls::pki_types::{CertificateDer, PrivateKeyDer};
92
93 let cert_chain: Vec<CertificateDer<'static>> = {
94 let data = std::fs::read(&tls.cert_path)
95 .map_err(|e| format!("reading cert {}: {}", tls.cert_path, e))?;
96 rustls_pemfile::certs(&mut &data[..])
97 .collect::<Result<Vec<_>, _>>()
98 .map_err(|e| format!("parsing cert {}: {}", tls.cert_path, e))?
99 };
100 if cert_chain.is_empty() {
101 return Err(format!("no certificates found in {}", tls.cert_path));
102 }
103
104 let key: PrivateKeyDer<'static> = {
105 let data = std::fs::read(&tls.key_path)
106 .map_err(|e| format!("reading key {}: {}", tls.key_path, e))?;
107 rustls_pemfile::private_key(&mut &data[..])
108 .map_err(|e| format!("parsing key {}: {}", tls.key_path, e))?
109 .ok_or_else(|| format!("no private key found in {}", tls.key_path))?
110 };
111
112 let builder = rustls::ServerConfig::builder();
113
114 let config = if tls.require_client_cert {
115 let ca_path = tls
116 .ca_path
117 .as_ref()
118 .ok_or_else(|| "require_client_cert is set but ca_path is missing".to_string())?;
119 let ca_data =
120 std::fs::read(ca_path).map_err(|e| format!("reading ca {}: {}", ca_path, e))?;
121 let mut roots = rustls::RootCertStore::empty();
122 for ca in rustls_pemfile::certs(&mut &ca_data[..]) {
123 let ca = ca.map_err(|e| format!("parsing ca {}: {}", ca_path, e))?;
124 roots
125 .add(ca)
126 .map_err(|e| format!("adding ca cert: {}", e))?;
127 }
128 let verifier =
129 rustls::server::WebPkiClientVerifier::builder(Arc::new(roots))
130 .build()
131 .map_err(|e| format!("building client verifier: {}", e))?;
132 builder
133 .with_client_cert_verifier(verifier)
134 .with_single_cert(cert_chain, key)
135 .map_err(|e| format!("server config (mTLS): {}", e))?
136 } else {
137 builder
138 .with_no_client_auth()
139 .with_single_cert(cert_chain, key)
140 .map_err(|e| format!("server config: {}", e))?
141 };
142
143 Ok(TlsAcceptor::from(Arc::new(config)))
144}