use std::fs::File;
use std::path::Path;
use std::sync::Arc;
use rustls::client::WantsClientCert;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::crypto::CryptoProvider;
use rustls::pki_types::UnixTime;
use rustls::pki_types::pem::PemObject;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName};
use rustls::{ClientConfig, ConfigBuilder, RootCertStore};
use rustls::{DigitallySignedStruct, Error as RustlsError, SignatureScheme};
#[cfg(feature = "native-tls-roots")]
use rustls_native_certs::load_native_certs;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_rustls::TlsConnector;
use tokio_rustls::client::TlsStream;
use tracing::warn;
use crate::auth::TlsConfig;
use crate::error::{KrafkaError, Result};
#[non_exhaustive]
pub enum MaybeSecureStream {
Plain(TcpStream),
Tls(Box<TlsStream<TcpStream>>),
}
impl AsyncRead for MaybeSecureStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeSecureStream::Plain(stream) => std::pin::Pin::new(stream).poll_read(cx, buf),
MaybeSecureStream::Tls(stream) => {
std::pin::Pin::new(stream.as_mut()).poll_read(cx, buf)
}
}
}
}
impl AsyncWrite for MaybeSecureStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
match self.get_mut() {
MaybeSecureStream::Plain(stream) => std::pin::Pin::new(stream).poll_write(cx, buf),
MaybeSecureStream::Tls(stream) => {
std::pin::Pin::new(stream.as_mut()).poll_write(cx, buf)
}
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeSecureStream::Plain(stream) => std::pin::Pin::new(stream).poll_flush(cx),
MaybeSecureStream::Tls(stream) => std::pin::Pin::new(stream.as_mut()).poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
MaybeSecureStream::Plain(stream) => std::pin::Pin::new(stream).poll_shutdown(cx),
MaybeSecureStream::Tls(stream) => std::pin::Pin::new(stream.as_mut()).poll_shutdown(cx),
}
}
}
pub async fn build_tls_config(config: &TlsConfig) -> Result<ClientConfig> {
let config = config.clone();
tokio::task::spawn_blocking(move || build_tls_config_sync(&config))
.await
.map_err(|e| KrafkaError::config(format!("Failed to spawn blocking task: {e}")))?
}
pub async fn build_tls_connector(config: &TlsConfig) -> Result<TlsConnector> {
let client_config = build_tls_config(config).await?;
Ok(TlsConnector::from(Arc::new(client_config)))
}
pub async fn connect_tls(
stream: TcpStream,
hostname: &str,
sni_hostname: Option<&str>,
connector: &TlsConnector,
) -> Result<TlsStream<TcpStream>> {
let sni_hostname = sni_hostname.unwrap_or(hostname);
let host = crate::util::extract_sni_hostname(sni_hostname)?.to_string();
let server_name = ServerName::try_from(host)
.map_err(|e| KrafkaError::config(format!("Invalid server name: {e}")))?;
connector
.connect(server_name, stream)
.await
.map_err(|e| KrafkaError::auth(format!("TLS handshake failed: {e}")))
}
pub fn extract_tls_server_end_point(stream: &TlsStream<TcpStream>) -> Option<Vec<u8>> {
use sha2::{Digest, Sha256};
let (_, conn) = stream.get_ref();
let certs = conn.peer_certificates()?;
let end_entity = certs.first()?;
Some(Sha256::digest(end_entity.as_ref()).to_vec())
}
fn build_tls_config_sync(config: &TlsConfig) -> Result<ClientConfig> {
if !config.verify_server_cert {
use std::sync::Once;
static WARN_ONCE: Once = Once::new();
WARN_ONCE.call_once(|| {
warn!(
"TLS certificate verification is disabled (verify_server_cert=false). \
This is insecure and must only be used for local development or testing \
with self-signed certificates. Never use in production."
);
});
return build_insecure_tls_config(config);
}
let root_store = load_root_store(config)?;
let builder = ClientConfig::builder().with_root_certificates(root_store);
let client_auth = load_client_auth(config)?;
let mut tls_config = finish_with_client_auth(builder, client_auth)?;
if !config.alpn_protocols.is_empty() {
tls_config.alpn_protocols.clone_from(&config.alpn_protocols);
}
Ok(tls_config)
}
fn finish_with_client_auth(
builder: ConfigBuilder<ClientConfig, WantsClientCert>,
client_auth: Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>,
) -> Result<ClientConfig> {
if let Some((certs, key)) = client_auth {
builder
.with_client_auth_cert(certs, key)
.map_err(|e| KrafkaError::config(format!("Failed to set client auth: {e}")))
} else {
Ok(builder.with_no_client_auth())
}
}
fn load_certs(path: &str) -> Result<Vec<CertificateDer<'static>>> {
CertificateDer::pem_file_iter(Path::new(path))
.map_err(|e| KrafkaError::config(format!("Failed to open cert file {path}: {e}")))?
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| KrafkaError::config(format!("Failed to parse cert file {path}: {e}")))
}
fn load_private_key(path: &str) -> Result<PrivateKeyDer<'static>> {
let file = File::open(Path::new(path))
.map_err(|e| KrafkaError::config(format!("Failed to open key file {path}: {e}")))?;
#[cfg(unix)]
{
use std::os::unix::fs::MetadataExt;
if let Ok(meta) = file.metadata() {
let mode = meta.mode();
if mode & 0o077 != 0 {
tracing::warn!(
"Private key file {path} has world/group-readable permissions \
(mode {mode:#o}). Consider restricting to owner-only (chmod 600)."
);
}
}
}
PrivateKeyDer::from_pem_file(Path::new(path))
.map_err(|e| KrafkaError::config(format!("Failed to read private key file {path}: {e}")))
}
fn load_root_store(config: &TlsConfig) -> Result<RootCertStore> {
let mut root_store = RootCertStore::empty();
if let Some(ca_path) = &config.ca_cert_path {
if config.use_native_roots {
load_default_roots(&mut root_store, config)?;
}
for cert in load_certs(ca_path)? {
root_store
.add(cert)
.map_err(|e| KrafkaError::config(format!("Failed to add CA cert: {e}")))?;
}
} else {
load_default_roots(&mut root_store, config)?;
}
Ok(root_store)
}
fn load_default_roots(root_store: &mut RootCertStore, config: &TlsConfig) -> Result<()> {
if config.use_native_roots {
#[cfg(feature = "native-tls-roots")]
{
let result = load_native_certs();
if !result.errors.is_empty() {
warn!(
error_count = result.errors.len(),
"Some native TLS root certificates could not be loaded"
);
}
if result.certs.is_empty() {
return Err(KrafkaError::config(
"No native TLS root certificates could be loaded",
));
}
for cert in result.certs {
root_store.add(cert).map_err(|e| {
KrafkaError::config(format!("Failed to add native TLS root certificate: {e}"))
})?;
}
return Ok(());
}
#[cfg(not(feature = "native-tls-roots"))]
{
return Err(KrafkaError::config(
"TlsConfig::with_native_roots() requires the 'native-tls-roots' crate feature",
));
}
}
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
Ok(())
}
fn load_client_auth(
config: &TlsConfig,
) -> Result<Option<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)>> {
if let (Some(cert_path), Some(key_path)) = (&config.client_cert_path, &config.client_key_path) {
let certs = load_certs(cert_path)?;
let key = load_private_key(key_path)?;
Ok(Some((certs, key)))
} else {
Ok(None)
}
}
fn resolve_crypto_provider() -> Arc<CryptoProvider> {
CryptoProvider::get_default().cloned().unwrap_or_else(|| {
#[cfg(feature = "rustls-aws-lc-rs")]
{
Arc::new(rustls::crypto::aws_lc_rs::default_provider())
}
#[cfg(not(feature = "rustls-aws-lc-rs"))]
{
Arc::new(rustls::crypto::ring::default_provider())
}
})
}
fn insecure_builder(
provider: Arc<CryptoProvider>,
) -> Result<ConfigBuilder<ClientConfig, WantsClientCert>> {
let verifier = Arc::new(NoServerCertVerifier::new(Arc::clone(&provider)));
Ok(ClientConfig::builder_with_provider(provider)
.with_safe_default_protocol_versions()
.map_err(|e| KrafkaError::config(format!("Failed to set protocol versions: {e}")))?
.dangerous()
.with_custom_certificate_verifier(verifier))
}
fn build_insecure_tls_config(config: &TlsConfig) -> Result<ClientConfig> {
let builder = insecure_builder(resolve_crypto_provider())?;
let client_auth = load_client_auth(config)?;
let mut tls_config = finish_with_client_auth(builder, client_auth)?;
if !config.alpn_protocols.is_empty() {
tls_config.alpn_protocols.clone_from(&config.alpn_protocols);
}
Ok(tls_config)
}
#[derive(Debug)]
struct NoServerCertVerifier {
provider: Arc<CryptoProvider>,
}
impl NoServerCertVerifier {
fn new(provider: Arc<CryptoProvider>) -> Self {
Self { provider }
}
}
impl ServerCertVerifier for NoServerCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, RustlsError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, RustlsError> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, RustlsError> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.provider
.signature_verification_algorithms
.supported_schemes()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used, clippy::panic)]
mod tests {
use super::*;
fn setup_crypto_provider() {
#[cfg(feature = "rustls-aws-lc-rs")]
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
#[cfg(not(feature = "rustls-aws-lc-rs"))]
let _ = rustls::crypto::ring::default_provider().install_default();
}
#[test]
fn test_build_tls_config_defaults() {
setup_crypto_provider();
let config = TlsConfig::new();
let result = build_tls_config_sync(&config);
assert!(result.is_ok());
}
#[test]
#[cfg(not(feature = "native-tls-roots"))]
fn test_build_tls_config_native_roots_requires_feature() {
setup_crypto_provider();
let mut config = TlsConfig::new();
config.use_native_roots = true;
let err = build_tls_config_sync(&config).unwrap_err();
assert!(
err.to_string().contains("native-tls-roots"),
"expected native root feature error, got: {err}"
);
}
#[test]
fn test_build_tls_config_insecure_succeeds() {
setup_crypto_provider();
let config = TlsConfig::insecure();
let result = build_tls_config_sync(&config);
assert!(
result.is_ok(),
"insecure TLS config should succeed: {result:?}"
);
}
#[test]
fn test_load_certs_nonexistent() {
let result = load_certs("/nonexistent/path/cert.pem");
assert!(result.is_err());
}
#[test]
fn test_load_private_key_nonexistent() {
let result = load_private_key("/nonexistent/path/key.pem");
assert!(result.is_err());
}
#[test]
fn test_build_tls_connector() {
setup_crypto_provider();
let config = TlsConfig::new();
let result = build_tls_config_sync(&config).map(|c| TlsConnector::from(Arc::new(c)));
assert!(result.is_ok());
}
#[test]
fn test_alpn_protocols_set() {
setup_crypto_provider();
let config = TlsConfig::new().with_kafka_alpn();
let tls_config = build_tls_config_sync(&config).unwrap();
assert_eq!(tls_config.alpn_protocols, vec![b"kafka".to_vec()]);
}
#[test]
fn test_alpn_protocols_empty_by_default() {
setup_crypto_provider();
let config = TlsConfig::new();
let tls_config = build_tls_config_sync(&config).unwrap();
assert!(tls_config.alpn_protocols.is_empty());
}
#[test]
fn test_alpn_custom_protocols() {
setup_crypto_provider();
let config = TlsConfig::new().with_alpn_protocols(vec![b"kafka".to_vec(), b"h2".to_vec()]);
let tls_config = build_tls_config_sync(&config).unwrap();
assert_eq!(
tls_config.alpn_protocols,
vec![b"kafka".to_vec(), b"h2".to_vec()]
);
}
#[test]
fn test_server_name_accepts_dns_and_ip_literals() {
let ipv4 = crate::util::extract_sni_hostname("127.0.0.1:9092")
.unwrap()
.to_string();
let ipv6 = crate::util::extract_sni_hostname("[::1]:9092")
.unwrap()
.to_string();
let dns = crate::util::extract_sni_hostname("broker.example.com:9092")
.unwrap()
.to_string();
assert!(ServerName::try_from(ipv4).is_ok());
assert!(ServerName::try_from(ipv6).is_ok());
assert!(ServerName::try_from(dns).is_ok());
}
}