heliosdb_proxy/backend/
tls.rs1use super::error::{BackendError, BackendResult};
12use super::stream::Stream;
13use rustls::{ClientConfig, RootCertStore};
14use rustls_pki_types::ServerName;
15use std::sync::Arc;
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::TcpStream;
18use tokio_rustls::TlsConnector;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum TlsMode {
23 Disable,
25 Prefer,
28 Require,
30}
31
32pub fn default_client_config() -> Arc<ClientConfig> {
36 let mut roots = RootCertStore::empty();
37 roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
38
39 let config = ClientConfig::builder()
40 .with_root_certificates(roots)
41 .with_no_client_auth();
42
43 Arc::new(config)
44}
45
46pub async fn negotiate(
53 mut tcp: TcpStream,
54 mode: TlsMode,
55 config: Arc<ClientConfig>,
56 sni: &str,
57) -> BackendResult<Stream> {
58 if mode == TlsMode::Disable {
59 return Ok(Stream::Plain(tcp));
60 }
61
62 let ssl_request: [u8; 8] = [
64 0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f, ];
67 tcp.write_all(&ssl_request).await?;
68
69 let mut reply = [0u8; 1];
70 tcp.read_exact(&mut reply).await?;
71
72 match reply[0] {
73 b'S' => {
74 let dns = ServerName::try_from(sni.to_string()).map_err(|_| {
75 BackendError::Tls(format!("invalid SNI hostname: {:?}", sni))
76 })?;
77 let connector = TlsConnector::from(config);
78 let tls = connector
79 .connect(dns, tcp)
80 .await
81 .map_err(|e| BackendError::Tls(e.to_string()))?;
82 Ok(Stream::Tls(tls))
83 }
84 b'N' => {
85 if mode == TlsMode::Require {
86 Err(BackendError::Tls(
87 "server refused TLS and tls_mode=require".to_string(),
88 ))
89 } else {
90 Ok(Stream::Plain(tcp))
91 }
92 }
93 other => Err(BackendError::Tls(format!(
94 "unexpected reply to SSLRequest: 0x{:02x}",
95 other
96 ))),
97 }
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103
104 #[test]
105 fn test_default_client_config_builds() {
106 let _ = default_client_config();
107 }
108
109 #[test]
110 fn test_tls_mode_variants() {
111 assert_ne!(TlsMode::Disable, TlsMode::Prefer);
112 assert_ne!(TlsMode::Prefer, TlsMode::Require);
113 }
114}