use std::sync::Arc;
use std::thread::available_parallelism;
use std::time::Duration;
use std::{fmt::Display, str::FromStr};
use log::{debug, error, info, trace};
use nethsm_sdk_rs::ureq::{Agent, AgentBuilder};
use rustls::client::{
ClientConfig,
danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
};
use rustls::crypto::{CryptoProvider, ring as tls_provider};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{DigitallySignedStruct, SignatureScheme};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use crate::Error;
#[cfg(doc)]
use crate::NetHsm;
pub const DEFAULT_MAX_IDLE_CONNECTIONS: usize = 100;
pub const DEFAULT_TIMEOUT_SECONDS: u64 = 10;
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct CertFingerprint(
#[serde(
deserialize_with = "hex::serde::deserialize",
serialize_with = "hex::serde::serialize"
)]
Vec<u8>,
);
impl Display for CertFingerprint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
for byte in self.0.iter() {
write!(f, "{byte:02x?}")?
}
Ok(())
}
}
impl FromStr for CertFingerprint {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.as_bytes().to_vec()))
}
}
impl From<Vec<u8>> for CertFingerprint {
fn from(value: Vec<u8>) -> Self {
Self(value)
}
}
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub struct HostCertificateFingerprints {
sha256: Option<Vec<CertFingerprint>>,
}
impl Display for HostCertificateFingerprints {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{}",
if let Some(fingerprints) = self.sha256.as_ref() {
if fingerprints.is_empty() {
"n/a".to_string()
} else {
fingerprints
.iter()
.map(|fingerprint| format!("sha256:{fingerprint}"))
.collect::<Vec<String>>()
.join("\n")
}
} else {
"n/a".to_string()
}
)
}
}
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
pub enum ConnectionSecurity {
Unsafe,
Native,
Fingerprints(HostCertificateFingerprints),
}
impl Display for ConnectionSecurity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Unsafe => write!(f, "unsafe"),
Self::Native => write!(f, "native"),
Self::Fingerprints(fingerprints) => write!(f, "{fingerprints}"),
}
}
}
impl FromStr for ConnectionSecurity {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"unsafe" | "Unsafe" => Ok(Self::Unsafe),
"native" | "Native" => Ok(Self::Native),
_ => {
let sha256_fingerprints: Vec<Vec<u8>> = s
.split(',')
.filter_map(|checksum| {
checksum
.strip_prefix("sha256:")
.filter(|x| x.len() == 64 && x.chars().all(|x| x.is_ascii_hexdigit()))
.map(|checksum| checksum.as_bytes().to_vec())
})
.collect();
if sha256_fingerprints.is_empty() {
Err(Error::Default(
"No valid TLS certificate fingerprints detected.".to_string(),
))
} else {
Ok(Self::Fingerprints(HostCertificateFingerprints {
sha256: Some(
sha256_fingerprints
.iter()
.map(|checksum| checksum.clone().into())
.collect(),
),
}))
}
}
}
}
}
#[derive(Debug)]
pub struct DangerIgnoreVerifier(pub CryptoProvider);
impl ServerCertVerifier for DangerIgnoreVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.0.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.0.signature_verification_algorithms.supported_schemes()
}
}
#[derive(Debug)]
pub struct FingerprintVerifier {
pub fingerprints: HostCertificateFingerprints,
pub provider: CryptoProvider,
}
impl ServerCertVerifier for FingerprintVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
if let Some(sha256_fingerprints) = self.fingerprints.sha256.as_ref() {
let mut hasher = Sha256::new();
hasher.update(end_entity.as_ref());
let result = hasher.finalize();
for fingerprint in sha256_fingerprints.iter() {
if fingerprint.0 == result[..] {
trace!("Certificate fingerprint matches");
return Ok(ServerCertVerified::assertion());
}
}
} else {
return Err(rustls::Error::General(
"Could not verify certificate fingerprint as no fingerprints were provided to match against".to_string(),
));
}
Err(rustls::Error::General(
"Could not verify certificate fingerprint".to_string(),
))
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls12_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
rustls::crypto::verify_tls13_signature(
message,
cert,
dss,
&self.provider.signature_verification_algorithms,
)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.provider
.signature_verification_algorithms
.supported_schemes()
}
}
pub(crate) fn create_agent(
tls_security: ConnectionSecurity,
max_idle_connections: Option<usize>,
timeout_seconds: Option<u64>,
) -> Result<Agent, Error> {
let tls_conf = {
let tls_conf = ClientConfig::builder_with_provider(Arc::new(CryptoProvider {
cipher_suites: tls_provider::ALL_CIPHER_SUITES.into(),
..tls_provider::default_provider()
}))
.with_protocol_versions(rustls::DEFAULT_VERSIONS)?;
match tls_security {
ConnectionSecurity::Unsafe => {
let dangerous = tls_conf.dangerous();
dangerous
.with_custom_certificate_verifier(Arc::new(DangerIgnoreVerifier(
tls_provider::default_provider(),
)))
.with_no_client_auth()
}
ConnectionSecurity::Native => {
let native_certs = rustls_native_certs::load_native_certs();
if !native_certs.errors.is_empty() {
return Err(Error::CertLoading(native_certs.errors));
}
let native_certs = native_certs.certs;
let roots = {
let mut roots = rustls::RootCertStore::empty();
let (added, failed) = roots.add_parsable_certificates(native_certs);
debug!("Added {added} certificates and failed to parse {failed} certificates");
if added == 0 {
error!("Added no native certificates");
return Err(Error::NoSystemCertsAdded { failed });
}
roots
};
tls_conf.with_root_certificates(roots).with_no_client_auth()
}
ConnectionSecurity::Fingerprints(fingerprints) => {
let dangerous = tls_conf.dangerous();
dangerous
.with_custom_certificate_verifier(Arc::new(FingerprintVerifier {
fingerprints,
provider: tls_provider::default_provider(),
}))
.with_no_client_auth()
}
}
};
let max_idle_connections = max_idle_connections
.or_else(|| available_parallelism().ok().map(Into::into))
.unwrap_or(DEFAULT_MAX_IDLE_CONNECTIONS);
let timeout_seconds = timeout_seconds.unwrap_or(DEFAULT_TIMEOUT_SECONDS);
info!(
"NetHSM connection configured with \"max_idle_connection\" {max_idle_connections} and \"timeout_seconds\" {timeout_seconds}."
);
Ok(AgentBuilder::new()
.tls_config(Arc::new(tls_conf))
.max_idle_connections(max_idle_connections)
.max_idle_connections_per_host(max_idle_connections)
.timeout_connect(Duration::from_secs(timeout_seconds))
.build())
}
#[cfg(test)]
mod tests {
use rstest::rstest;
use testresult::TestResult;
use super::*;
#[test]
fn certfingerprint_display() -> TestResult {
let digest = vec![
181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
];
let cert_fingerprint = CertFingerprint::from(digest);
assert_eq!(
cert_fingerprint.to_string(),
"b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c"
);
Ok(())
}
#[rstest]
#[case(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from(vec![
181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
])]) }, "sha256:b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c")]
#[case(HostCertificateFingerprints { sha256: Some(Vec::new()) }, "n/a")]
#[case(HostCertificateFingerprints { sha256: None }, "n/a")]
fn hostcertfingerprints_display(
#[case] fingerprints: HostCertificateFingerprints,
#[case] expected: &str,
) -> TestResult {
assert_eq!(fingerprints.to_string(), expected);
Ok(())
}
#[rstest]
#[case(ConnectionSecurity::Native, "native")]
#[case(ConnectionSecurity::Unsafe, "unsafe")]
#[case(ConnectionSecurity::Fingerprints(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from(vec![
181, 187, 157, 128, 20, 160, 249, 177, 214, 30, 33, 231, 150, 215, 141, 204, 223, 19,
82, 242, 60, 211, 40, 18, 244, 133, 11, 135, 138, 228, 148, 76,
])]) }), "sha256:b5bb9d8014a0f9b1d61e21e796d78dccdf1352f23cd32812f4850b878ae4944c")]
fn connectionsecurity_display(
#[case] connection_security: ConnectionSecurity,
#[case] expected: &str,
) -> TestResult {
assert_eq!(connection_security.to_string(), expected);
Ok(())
}
#[rstest]
#[case("native", Some(ConnectionSecurity::Native))]
#[case("unsafe", Some(ConnectionSecurity::Unsafe))]
#[case("sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7", Some(ConnectionSecurity::Fingerprints(HostCertificateFingerprints { sha256: Some(vec![CertFingerprint::from_str("324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7")?]) })))]
#[case(
"324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b7",
None
)]
#[case(
"sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e",
None
)]
#[case(
"sha256:324f7bd1530c55cf6812ca6865445de21dfc74cf7a3bb5fae7585e849e3553b73553b7",
None
)]
fn connection_security_fromstr(
#[case] input: &str,
#[case] expected: Option<ConnectionSecurity>,
) -> TestResult {
if let Some(expected) = expected {
assert_eq!(ConnectionSecurity::from_str(input)?, expected);
} else {
assert!(ConnectionSecurity::from_str(input).is_err());
}
Ok(())
}
}