use super::error::{BackendError, BackendResult};
use super::stream::Stream;
use rustls::{ClientConfig, RootCertStore};
use rustls_pki_types::ServerName;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TlsMode {
Disable,
Prefer,
Require,
}
pub fn default_client_config() -> Arc<ClientConfig> {
let mut roots = RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let config = ClientConfig::builder()
.with_root_certificates(roots)
.with_no_client_auth();
Arc::new(config)
}
pub async fn negotiate(
mut tcp: TcpStream,
mode: TlsMode,
config: Arc<ClientConfig>,
sni: &str,
) -> BackendResult<Stream> {
if mode == TlsMode::Disable {
return Ok(Stream::Plain(tcp));
}
let ssl_request: [u8; 8] = [
0x00, 0x00, 0x00, 0x08, 0x04, 0xd2, 0x16, 0x2f, ];
tcp.write_all(&ssl_request).await?;
let mut reply = [0u8; 1];
tcp.read_exact(&mut reply).await?;
match reply[0] {
b'S' => {
let dns = ServerName::try_from(sni.to_string()).map_err(|_| {
BackendError::Tls(format!("invalid SNI hostname: {:?}", sni))
})?;
let connector = TlsConnector::from(config);
let tls = connector
.connect(dns, tcp)
.await
.map_err(|e| BackendError::Tls(e.to_string()))?;
Ok(Stream::Tls(tls))
}
b'N' => {
if mode == TlsMode::Require {
Err(BackendError::Tls(
"server refused TLS and tls_mode=require".to_string(),
))
} else {
Ok(Stream::Plain(tcp))
}
}
other => Err(BackendError::Tls(format!(
"unexpected reply to SSLRequest: 0x{:02x}",
other
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_client_config_builds() {
let _ = default_client_config();
}
#[test]
fn test_tls_mode_variants() {
assert_ne!(TlsMode::Disable, TlsMode::Prefer);
assert_ne!(TlsMode::Prefer, TlsMode::Require);
}
}