use std::fs::File;
use std::io::{self, BufReader, Seek, Write};
use std::net::SocketAddr;
use std::path::Path;
use std::sync::Arc;
use rustls::client::danger::{ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, ServerConfig};
use rustls_pemfile::{certs, pkcs8_private_keys};
use tokio::net::{TcpListener, TcpStream};
use tokio_rustls::client::TlsStream as ClientTlsStream;
use tokio_rustls::server::TlsStream as ServerTlsStream;
use tokio_rustls::{TlsAcceptor, TlsConnector};
use tokio_util::codec::Framed;
use tracing::{debug, error, info, instrument, warn};
use crate::core::codec::PacketCodec;
use crate::core::packet::Packet;
use crate::error::{ProtocolError, Result};
use futures::{SinkExt, StreamExt};
#[derive(Debug)]
struct CertificateFingerprint {
fingerprint: Vec<u8>,
}
impl ServerCertVerifier for CertificateFingerprint {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(end_entity);
let hash = hasher.finalize();
if hash.as_slice() == self.fingerprint.as_slice() {
Ok(ServerCertVerified::assertion())
} else {
Err(rustls::Error::General(
"Pinned certificate hash mismatch".into(),
))
}
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ED25519,
]
}
}
#[derive(Debug)]
struct AcceptAnyServerCert;
impl ServerCertVerifier for AcceptAnyServerCert {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName,
_ocsp_response: &[u8],
_now: UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> std::result::Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
vec![
rustls::SignatureScheme::RSA_PKCS1_SHA256,
rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
rustls::SignatureScheme::ED25519,
]
}
}
fn load_private_key(reader: &mut BufReader<File>) -> Result<PrivateKeyDer<'static>> {
reader
.seek(std::io::SeekFrom::Start(0))
.map_err(ProtocolError::Io)?;
let keys: std::result::Result<Vec<_>, _> = pkcs8_private_keys(reader).collect();
let keys =
keys.map_err(|_| ProtocolError::TlsError("Failed to parse PKCS8 private key".into()))?;
if !keys.is_empty() {
return Ok(PrivateKeyDer::Pkcs8(keys[0].clone_key()));
}
Err(ProtocolError::TlsError(
"No supported private key format found".into(),
))
}
pub enum TlsVersion {
TLS12,
TLS13,
All,
}
pub struct TlsServerConfig {
cert_path: String,
key_path: String,
client_ca_path: Option<String>,
require_client_auth: bool,
tls_versions: Option<Vec<TlsVersion>>,
cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
alpn_protocols: Option<Vec<Vec<u8>>>,
}
impl TlsServerConfig {
pub fn new<P: AsRef<std::path::Path>>(cert_path: P, key_path: P) -> Self {
Self {
cert_path: cert_path.as_ref().to_string_lossy().to_string(),
key_path: key_path.as_ref().to_string_lossy().to_string(),
client_ca_path: None,
require_client_auth: false,
tls_versions: None,
cipher_suites: None,
alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
}
}
pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
self.tls_versions = Some(versions);
self
}
pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
self.cipher_suites = Some(cipher_suites);
self
}
pub fn with_client_auth<S: Into<String>>(mut self, client_ca_path: S) -> Self {
self.client_ca_path = Some(client_ca_path.into());
self.require_client_auth = true;
self
}
pub fn require_client_auth(mut self, required: bool) -> Self {
self.require_client_auth = required;
self
}
pub fn with_alpn_protocols(mut self, protocols: Vec<Vec<u8>>) -> Self {
self.alpn_protocols = Some(protocols);
self
}
pub fn generate_self_signed<P: AsRef<Path>>(cert_path: P, key_path: P) -> io::Result<Self> {
let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()])
.map_err(|e| io::Error::other(format!("Certificate generation error: {e}")))?;
let mut cert_file = File::create(&cert_path)?;
let pem = cert.cert.pem();
cert_file.write_all(pem.as_bytes())?;
let mut key_file = File::create(&key_path)?;
key_file.write_all(cert.signing_key.serialize_pem().as_bytes())?;
Ok(Self {
cert_path: cert_path.as_ref().to_string_lossy().to_string(),
key_path: key_path.as_ref().to_string_lossy().to_string(),
client_ca_path: None,
require_client_auth: false,
tls_versions: None,
cipher_suites: None,
alpn_protocols: Some(vec![b"h2".to_vec(), b"http/1.1".to_vec()]),
})
}
pub fn load_server_config(&self) -> Result<ServerConfig> {
let cert_file = File::open(&self.cert_path)
.map_err(|e| ProtocolError::TlsError(format!("Failed to open cert file: {e}")))?;
let mut cert_reader = BufReader::new(cert_file);
let cert_chain: std::result::Result<Vec<_>, _> = certs(&mut cert_reader).collect();
let cert_chain: Vec<CertificateDer<'static>> = cert_chain
.map_err(|_| ProtocolError::TlsError("Failed to parse certificate".into()))?;
if cert_chain.is_empty() {
return Err(ProtocolError::TlsError("No certificates found".into()));
}
let key_file = File::open(&self.key_path)
.map_err(|e| ProtocolError::TlsError(format!("Failed to open key file: {e}")))?;
let mut key_reader = BufReader::new(key_file);
let private_key = load_private_key(&mut key_reader)?;
if let Some(versions) = &self.tls_versions {
let mut has_tls13 = false;
let mut has_tls12 = false;
for v in versions {
match v {
TlsVersion::TLS12 => has_tls12 = true,
TlsVersion::TLS13 => has_tls13 = true,
TlsVersion::All => {
has_tls13 = true;
has_tls12 = true;
}
}
}
debug!(
"TLS versions requested: TLS1.2={}, TLS1.3={}",
has_tls12, has_tls13
);
}
let config_builder = ServerConfig::builder_with_provider(std::sync::Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?;
let cert_builder = config_builder.with_no_client_auth();
let mut config = cert_builder
.with_single_cert(cert_chain.clone(), private_key.clone_key())
.map_err(|e| ProtocolError::TlsError(format!("TLS error: {e}")))?;
if let Some(client_ca_path) = &self.client_ca_path {
let client_ca_file = File::open(client_ca_path).map_err(|e| {
ProtocolError::TlsError(format!("Failed to open client CA file: {e}"))
})?;
let mut client_ca_reader = BufReader::new(client_ca_file);
let client_ca_certs: std::result::Result<Vec<_>, _> =
certs(&mut client_ca_reader).collect();
let client_ca_certs: Vec<CertificateDer<'static>> = client_ca_certs.map_err(|_| {
ProtocolError::TlsError("Failed to parse client CA certificate".into())
})?;
if client_ca_certs.is_empty() {
return Err(ProtocolError::TlsError(
"No client CA certificates found".into(),
));
}
let mut client_root_store = RootCertStore::empty();
for cert in client_ca_certs {
client_root_store.add(cert).map_err(|e| {
ProtocolError::TlsError(format!("Failed to add client CA cert: {e}"))
})?;
}
let client_auth = rustls::server::WebPkiClientVerifier::builder(std::sync::Arc::new(
client_root_store,
))
.build()
.map_err(|e| {
ProtocolError::TlsError(format!("Failed to build client verifier: {e}"))
})?;
let new_builder = ServerConfig::builder_with_provider(std::sync::Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.map_err(|_| {
ProtocolError::TlsError("Failed to configure TLS protocol versions".into())
})?;
let new_cert_builder = new_builder.with_client_cert_verifier(client_auth);
config = new_cert_builder
.with_single_cert(cert_chain, private_key.clone_key())
.map_err(|e| ProtocolError::TlsError(format!("TLS error with client auth: {e}")))?;
debug!("mTLS enabled with client certificate verification required");
}
if let Some(protocols) = &self.alpn_protocols {
config.alpn_protocols = protocols.clone();
debug!(
protocol_count = protocols.len(),
"ALPN protocols configured"
);
}
Ok(config)
}
pub fn calculate_cert_hash(cert: &CertificateDer<'_>) -> Vec<u8> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(cert.as_ref());
hasher.finalize().to_vec()
}
}
pub struct TlsClientConfig {
server_name: String,
insecure: bool,
pinned_cert_hash: Option<Vec<u8>>,
client_cert_path: Option<String>,
client_key_path: Option<String>,
tls_versions: Option<Vec<TlsVersion>>,
cipher_suites: Option<Vec<rustls::SupportedCipherSuite>>,
}
impl TlsClientConfig {
pub fn new<S: Into<String>>(server_name: S) -> Self {
Self {
server_name: server_name.into(),
insecure: false,
pinned_cert_hash: None,
client_cert_path: None,
client_key_path: None,
tls_versions: None,
cipher_suites: None,
}
}
pub fn with_tls_versions(mut self, versions: Vec<TlsVersion>) -> Self {
self.tls_versions = Some(versions);
self
}
pub fn with_cipher_suites(mut self, cipher_suites: Vec<rustls::SupportedCipherSuite>) -> Self {
self.cipher_suites = Some(cipher_suites);
self
}
pub fn with_client_certificate<S: Into<String>>(mut self, cert_path: S, key_path: S) -> Self {
self.client_cert_path = Some(cert_path.into());
self.client_key_path = Some(key_path.into());
self
}
pub fn insecure(mut self) -> Self {
warn!("INSECURE MODE ENABLED: Certificate verification is disabled. This should only be used for development/testing.");
self.insecure = true;
self
}
pub fn with_pinned_cert_hash(mut self, hash: Vec<u8>) -> Self {
if hash.len() != 32 {
warn!(
"Certificate hash has unexpected length: {} (expected 32 bytes for SHA-256)",
hash.len()
);
}
self.pinned_cert_hash = Some(hash);
self
}
pub fn load_client_config(&self) -> Result<ClientConfig> {
self.log_tls_version_info();
if self.insecure {
self.build_insecure_client_config()
} else {
self.build_secure_client_config()
}
}
fn log_tls_version_info(&self) {
if let Some(versions) = &self.tls_versions {
let mut has_tls13 = false;
let mut has_tls12 = false;
for v in versions {
match v {
TlsVersion::TLS12 => has_tls12 = true,
TlsVersion::TLS13 => has_tls13 = true,
TlsVersion::All => {
has_tls13 = true;
has_tls12 = true;
}
}
}
debug!(
"TLS client versions requested: TLS1.2={}, TLS1.3={}",
has_tls12, has_tls13
);
}
}
fn build_secure_client_config(&self) -> Result<ClientConfig> {
let root_store = self.load_system_root_certificates()?;
let builder = ClientConfig::builder_with_provider(std::sync::Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?
.with_root_certificates(root_store);
if let (Some(client_cert_path), Some(client_key_path)) =
(&self.client_cert_path, &self.client_key_path)
{
let (cert_chain, key) =
self.load_client_credentials(client_cert_path, client_key_path)?;
builder.with_client_auth_cert(cert_chain, key).map_err(|e| {
ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
})
} else {
Ok(builder.with_no_client_auth())
}
}
fn build_insecure_client_config(&self) -> Result<ClientConfig> {
let builder = ClientConfig::builder_with_provider(std::sync::Arc::new(
rustls::crypto::ring::default_provider(),
))
.with_safe_default_protocol_versions()
.map_err(|_| ProtocolError::TlsError("Failed to configure TLS protocol versions".into()))?;
let verifier = self.create_custom_verifier();
let custom_builder = builder
.dangerous()
.with_custom_certificate_verifier(verifier);
if let (Some(client_cert_path), Some(client_key_path)) =
(&self.client_cert_path, &self.client_key_path)
{
let (cert_chain, key) =
self.load_client_credentials(client_cert_path, client_key_path)?;
custom_builder
.with_client_auth_cert(cert_chain, key)
.map_err(|e| {
ProtocolError::TlsError(format!("Failed to set client certificate: {e}"))
})
} else {
Ok(custom_builder.with_no_client_auth())
}
}
fn load_system_root_certificates(&self) -> Result<RootCertStore> {
let mut root_store = RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs()
.map_err(|e| ProtocolError::TlsError(format!("Failed to load native certs: {e}")))?;
for cert in native_certs {
root_store.add(cert).map_err(|e| {
ProtocolError::TlsError(format!("Failed to add cert to root store: {e}"))
})?;
}
Ok(root_store)
}
fn create_custom_verifier(&self) -> Arc<dyn ServerCertVerifier> {
if let Some(hash) = &self.pinned_cert_hash {
Arc::new(CertificateFingerprint {
fingerprint: hash.clone(),
})
} else {
Arc::new(AcceptAnyServerCert)
}
}
fn load_client_credentials(
&self,
cert_path: &str,
key_path: &str,
) -> Result<(Vec<CertificateDer<'static>>, PrivateKeyDer<'static>)> {
let cert_file = File::open(cert_path).map_err(ProtocolError::Io)?;
let mut cert_reader = BufReader::new(cert_file);
let certs_result: std::result::Result<Vec<_>, _> =
rustls_pemfile::certs(&mut cert_reader).collect();
let certs: Vec<CertificateDer<'static>> = certs_result
.map_err(|_| ProtocolError::TlsError("Failed to parse client certificate".into()))?;
if certs.is_empty() {
return Err(ProtocolError::TlsError(
"No client certificates found".into(),
));
}
let key_file = File::open(key_path).map_err(ProtocolError::Io)?;
let mut key_reader = BufReader::new(key_file);
let key = load_private_key(&mut key_reader)?;
Ok((certs, key))
}
pub fn server_name(&self) -> Result<ServerName<'_>> {
ServerName::try_from(self.server_name.as_str())
.map_err(|_| ProtocolError::TlsError("Invalid server name".into()))
}
pub fn server_name_string(&self) -> String {
self.server_name.clone()
}
}
#[instrument(skip(config))]
pub async fn start_server(addr: &str, config: TlsServerConfig) -> Result<()> {
let tls_config = config.load_server_config()?;
let acceptor = TlsAcceptor::from(Arc::new(tls_config));
let listener = TcpListener::bind(addr).await?;
info!(address=%addr, "TLS server listening");
loop {
let (stream, peer) = listener.accept().await?;
let acceptor = acceptor.clone();
tokio::spawn(async move {
match acceptor.accept(stream).await {
Ok(tls_stream) => {
if let Err(e) = handle_tls_connection(tls_stream, peer).await {
error!(%peer, error=%e, "Connection error");
}
}
Err(e) => {
error!(%peer, error=%e, "TLS handshake failed");
}
}
});
}
}
#[instrument(skip(tls_stream), fields(peer=%peer))]
async fn handle_tls_connection(
tls_stream: ServerTlsStream<TcpStream>,
peer: SocketAddr,
) -> Result<()> {
let mut framed = Framed::new(tls_stream, PacketCodec);
info!("TLS connection established");
while let Some(packet) = framed.next().await {
match packet {
Ok(pkt) => {
debug!(bytes = pkt.payload.len(), "Received data");
on_packet(pkt, &mut framed).await?;
}
Err(e) => {
error!(error=%e, "Protocol error");
break;
}
}
}
info!("TLS connection closed");
Ok(())
}
#[instrument(skip(framed), fields(packet_version=pkt.version, payload_size=pkt.payload.len()))]
async fn on_packet<T>(pkt: Packet, framed: &mut Framed<T, PacketCodec>) -> Result<()>
where
T: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let response = Packet {
version: pkt.version,
payload: pkt.payload,
};
framed.send(response).await?;
Ok(())
}
pub async fn connect(
addr: &str,
config: TlsClientConfig,
) -> Result<Framed<ClientTlsStream<TcpStream>, PacketCodec>> {
let tls_config = Arc::new(config.load_client_config()?);
let connector = TlsConnector::from(tls_config);
let stream = TcpStream::connect(addr).await?;
let server_name_str = config.server_name_string();
let domain_static: &'static str = Box::leak(server_name_str.into_boxed_str());
let domain = ServerName::try_from(domain_static)
.map_err(|_| ProtocolError::TlsError("Invalid server name".into()))?;
let tls_stream = connector
.connect(domain, stream)
.await
.map_err(|e| ProtocolError::TlsError(format!("TLS connection failed: {e}")))?;
let framed = Framed::new(tls_stream, PacketCodec);
Ok(framed)
}