use std::io::BufReader;
use std::sync::Arc;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::crypto::CryptoProvider;
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, Error as TlsError, SignatureScheme};
use tokio_rustls::TlsConnector;
use dicom_toolkit_core::error::{DcmError, DcmResult};
#[derive(Debug, Clone, Default)]
pub struct TlsConfig {
pub accept_invalid_certs: bool,
pub ca_cert_pem: Option<Vec<u8>>,
pub client_cert_pem: Option<Vec<u8>>,
pub client_key_pem: Option<Vec<u8>>,
}
pub async fn connect_tls(
stream: tokio::net::TcpStream,
server_name: &str,
config: &TlsConfig,
) -> DcmResult<tokio_rustls::client::TlsStream<tokio::net::TcpStream>> {
let client_config = build_client_config(config)?;
let connector = TlsConnector::from(Arc::new(client_config));
let sni = ServerName::try_from(server_name.to_string()).map_err(|e| DcmError::TlsError {
reason: format!("invalid server name '{server_name}': {e}"),
})?;
connector
.connect(sni, stream)
.await
.map_err(|e| DcmError::TlsError {
reason: e.to_string(),
})
}
pub fn make_acceptor(cert_pem: &[u8], key_pem: &[u8]) -> DcmResult<tokio_rustls::TlsAcceptor> {
let certs = parse_certs(cert_pem)?;
let key = parse_private_key(key_pem)?;
let config = rustls::ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(certs, key)
.map_err(|e| DcmError::TlsError {
reason: e.to_string(),
})?;
Ok(tokio_rustls::TlsAcceptor::from(Arc::new(config)))
}
fn build_client_config(cfg: &TlsConfig) -> DcmResult<rustls::ClientConfig> {
if cfg.accept_invalid_certs {
let provider = default_crypto_provider();
let verifier = Arc::new(NoCertificateVerification(provider));
return Ok(rustls::ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth());
}
let mut root_store = rustls::RootCertStore::empty();
if let Some(ref ca_pem) = cfg.ca_cert_pem {
for cert in parse_certs(ca_pem)? {
root_store.add(cert).map_err(|e| DcmError::TlsError {
reason: e.to_string(),
})?;
}
}
match (&cfg.client_cert_pem, &cfg.client_key_pem) {
(Some(cert_pem), Some(key_pem)) => {
let certs = parse_certs(cert_pem)?;
let key = parse_private_key(key_pem)?;
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_client_auth_cert(certs, key)
.map_err(|e| DcmError::TlsError {
reason: e.to_string(),
})
}
_ => Ok(rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth()),
}
}
fn parse_certs(pem: &[u8]) -> DcmResult<Vec<CertificateDer<'static>>> {
let mut reader = BufReader::new(pem);
rustls_pemfile::certs(&mut reader)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| DcmError::TlsError {
reason: format!("certificate parse error: {e}"),
})
}
fn parse_private_key(pem: &[u8]) -> DcmResult<PrivateKeyDer<'static>> {
let mut reader = BufReader::new(pem);
rustls_pemfile::private_key(&mut reader)
.map_err(|e| DcmError::TlsError {
reason: format!("private key parse error: {e}"),
})?
.ok_or(DcmError::TlsError {
reason: "no private key found in PEM data".into(),
})
}
fn default_crypto_provider() -> Arc<CryptoProvider> {
CryptoProvider::get_default()
.cloned()
.unwrap_or_else(|| Arc::new(rustls::crypto::aws_lc_rs::default_provider()))
}
#[derive(Debug)]
struct NoCertificateVerification(Arc<CryptoProvider>);
impl ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, TlsError> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TlsError> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, TlsError> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tls_config_default_is_strict() {
let cfg = TlsConfig::default();
assert!(!cfg.accept_invalid_certs);
assert!(cfg.ca_cert_pem.is_none());
assert!(cfg.client_cert_pem.is_none());
assert!(cfg.client_key_pem.is_none());
}
#[test]
fn build_client_config_no_verify_succeeds() {
let cfg = TlsConfig {
accept_invalid_certs: true,
..Default::default()
};
let result = build_client_config(&cfg);
assert!(result.is_ok(), "expected Ok, got {:?}", result.err());
}
#[test]
fn build_client_config_empty_root_store_succeeds() {
let cfg = TlsConfig::default();
let result = build_client_config(&cfg);
assert!(result.is_ok(), "expected Ok, got {:?}", result.err());
}
#[test]
fn parse_certs_invalid_pem_returns_error() {
let result = parse_certs(b"not a pem file");
assert!(result.is_ok());
}
#[test]
fn parse_private_key_missing_returns_error() {
let result = parse_private_key(b"not a pem file");
assert!(
matches!(result, Err(DcmError::TlsError { .. })),
"expected TlsError, got {:?}",
result.ok()
);
}
}