use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct TlsConfig {
pub verify_server: bool,
pub ca_cert_path: Option<PathBuf>,
pub client_cert_path: Option<PathBuf>,
pub client_key_path: Option<PathBuf>,
pub server_name: Option<String>,
}
impl Default for TlsConfig {
fn default() -> Self {
TlsConfig {
verify_server: true,
ca_cert_path: None,
client_cert_path: None,
client_key_path: None,
server_name: None,
}
}
}
impl TlsConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn danger_accept_invalid_certs(mut self) -> Self {
tracing::warn!(
"TLS certificate verification disabled - this should only be used for testing. \
Man-in-the-middle attacks are possible."
);
self.verify_server = false;
self
}
#[must_use]
pub fn ca_cert(mut self, path: impl Into<PathBuf>) -> Self {
self.ca_cert_path = Some(path.into());
self
}
#[must_use]
pub fn client_cert(
mut self,
cert_path: impl Into<PathBuf>,
key_path: impl Into<PathBuf>,
) -> Self {
self.client_cert_path = Some(cert_path.into());
self.client_key_path = Some(key_path.into());
self
}
#[must_use]
pub fn server_name(mut self, name: impl Into<String>) -> Self {
self.server_name = Some(name.into());
self
}
#[must_use]
pub fn has_client_cert(&self) -> bool {
self.client_cert_path.is_some() && self.client_key_path.is_some()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TlsMode {
#[default]
Disable,
Prefer,
Require,
VerifyCA,
VerifyFull,
}
impl TlsMode {
#[must_use]
pub fn is_enabled(&self) -> bool {
!matches!(self, TlsMode::Disable)
}
#[must_use]
pub fn is_required(&self) -> bool {
matches!(
self,
TlsMode::Require | TlsMode::VerifyCA | TlsMode::VerifyFull
)
}
#[must_use]
pub fn verify_server(&self) -> bool {
matches!(self, TlsMode::VerifyCA | TlsMode::VerifyFull)
}
#[must_use]
pub fn verify_hostname(&self) -> bool {
matches!(self, TlsMode::VerifyFull)
}
}
pub mod rustls_impl {
use super::TlsConfig;
use std::io::BufReader;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::TlsConnector;
use crate::client::error::{Error, ErrorKind, Result};
pub fn create_connector(config: &TlsConfig, _host: &str) -> Result<TlsConnector> {
let mut root_store = RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
if let Some(ref ca_path) = config.ca_cert_path {
let ca_file = std::fs::File::open(ca_path).map_err(|e| {
Error::new(ErrorKind::Config, format!("failed to open CA cert: {e}"))
})?;
let mut ca_reader = BufReader::new(ca_file);
let certs = rustls_pemfile::certs(&mut ca_reader)
.map(|r| {
r.map_err(|e| Error::new(ErrorKind::Config, format!("invalid CA cert: {e}")))
})
.collect::<Result<Vec<_>>>()?;
for cert in certs {
root_store.add(cert).map_err(|e| {
Error::new(ErrorKind::Config, format!("failed to add CA cert: {e}"))
})?;
}
}
let provider = Arc::new(rustls::crypto::ring::default_provider());
let builder = ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| Error::new(ErrorKind::Config, format!("TLS protocol config error: {e}")))?
.with_root_certificates(root_store);
let client_config = if config.has_client_cert() {
let cert_path = config.client_cert_path.as_ref().unwrap();
let key_path = config.client_key_path.as_ref().unwrap();
let cert_file = std::fs::File::open(cert_path).map_err(|e| {
Error::new(
ErrorKind::Config,
format!("failed to open client cert: {e}"),
)
})?;
let mut cert_reader = BufReader::new(cert_file);
let certs = rustls_pemfile::certs(&mut cert_reader)
.map(|r| {
r.map_err(|e| {
Error::new(ErrorKind::Config, format!("invalid client cert: {e}"))
})
})
.collect::<Result<Vec<_>>>()?;
let key_file = std::fs::File::open(key_path).map_err(|e| {
Error::new(ErrorKind::Config, format!("failed to open client key: {e}"))
})?;
let mut key_reader = BufReader::new(key_file);
let key = rustls_pemfile::private_key(&mut key_reader)
.map_err(|e| Error::new(ErrorKind::Config, format!("invalid client key: {e}")))?
.ok_or_else(|| Error::new(ErrorKind::Config, "no private key found"))?;
builder
.with_client_auth_cert(certs, key)
.map_err(|e| Error::new(ErrorKind::Config, format!("invalid client auth: {e}")))?
} else {
builder.with_no_client_auth()
};
Ok(TlsConnector::from(Arc::new(client_config)))
}
pub type TlsStream = tokio_rustls::client::TlsStream<TcpStream>;
pub async fn wrap_stream(
stream: TcpStream,
connector: &TlsConnector,
server_name: &str,
) -> Result<TlsStream> {
let domain = rustls::pki_types::ServerName::try_from(server_name.to_string())
.map_err(|_| Error::new(ErrorKind::Config, "invalid server name"))?;
connector
.connect(domain, stream)
.await
.map_err(|e| Error::new(ErrorKind::Connection, format!("TLS handshake failed: {e}")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tls_config_default() {
let config = TlsConfig::default();
assert!(config.verify_server);
assert!(config.ca_cert_path.is_none());
assert!(!config.has_client_cert());
}
#[test]
fn test_tls_config_builder() {
let config = TlsConfig::new()
.ca_cert("/path/to/ca.pem")
.client_cert("/path/to/cert.pem", "/path/to/key.pem")
.server_name("example.com");
assert!(config.has_client_cert());
assert_eq!(config.server_name, Some("example.com".to_string()));
}
#[test]
fn test_tls_mode() {
assert!(!TlsMode::Disable.is_enabled());
assert!(TlsMode::Require.is_required());
assert!(TlsMode::VerifyFull.verify_hostname());
}
}