use rustls::{
self, CertificateError, DigitallySignedStruct, DistinguishedName, Error, SignatureScheme,
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime, pem::PemObject},
server::danger::{ClientCertVerified, ClientCertVerifier},
};
use x509_cert::{
Certificate,
der::{Decode, DecodePem, Encode},
};
#[cfg(feature = "std")]
use std::sync::Arc;
use core::time::Duration;
#[cfg(not(feature = "std"))]
use alloc::{
string::{String, ToString},
sync::Arc,
vec,
vec::Vec,
};
use crate::{
device::{Device, PairState},
io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl},
packet::{NetworkPacket, NetworkPacketBody},
};
const VERIFY_SCHEMES: &[SignatureScheme] = &[
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
];
pub(crate) fn is_expired(cert: &Certificate, current_time: UnixTime) -> bool {
let validity = cert.tbs_certificate.validity;
let not_before = validity.not_before.to_unix_duration().as_secs();
let not_after = validity.not_after.to_unix_duration().as_secs();
current_time.as_secs() < not_before || current_time.as_secs() > not_after
}
pub(crate) fn extract_device_id_from_cert(cert: &Certificate) -> Option<String> {
cert.tbs_certificate
.subject
.0
.iter()
.find_map(|n| n.to_string().strip_prefix("CN=").map(ToString::to_string))
}
pub(crate) async fn upgrade_tcp_connection<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
packet: crate::packet::identity::IdentityPacket,
socket: TcpStream,
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
is_tls_client: bool,
) {
let Ok(mut trusted_device_cert) = device
.trust_handler
.lock()
.await
.get_certificate(&packet.device_id)
.await
.map(|f| Certificate::from_pem(f))
.transpose()
else {
log::warn!(
"Failed to parse the stored certificate of the previously trusted device {}",
packet.device_id
);
return;
};
let mut is_trusted = trusted_device_cert.is_some();
let current_time = device.io_impl.get_current_timestamp().await;
if let Some(cert) = &trusted_device_cert
&& is_expired(
cert,
UnixTime::since_unix_epoch(Duration::from_secs(current_time)),
)
{
log::warn!(
"SSL certificate of the trusted device `{}` is expired, make the device untrusted",
packet.device_name.as_ref().unwrap_or(&packet.device_id)
);
device
.trust_handler
.lock()
.await
.untrust_device(&packet.device_id)
.await;
trusted_device_cert = None;
is_trusted = false;
}
let Ok(my_cert) = CertificateDer::from_pem_slice(&device.config.cert) else {
log::warn!(
"Failed to parse the host certificate stored in the configuration, you need to regenerate it"
);
return;
};
let Ok(my_private_key) = PrivateKeyDer::from_pem_slice(&device.config.private_key) else {
log::warn!(
"Failed to parse the host private key stored in the configuration, you need to regenerate it"
);
return;
};
let socket_with_tls = if is_tls_client {
#[cfg(feature = "embedded")]
let config = rustls::ClientConfig::builder_with_details(
Arc::new(rustls_rustcrypto::provider()),
Arc::new(CustomTimeProvider { current_time }),
)
.with_safe_default_protocol_versions()
.unwrap();
#[cfg(not(feature = "embedded"))]
let config = rustls::ClientConfig::builder();
let config = config
.dangerous()
.with_custom_certificate_verifier(Arc::new(CertVerifier {
trusted_device_cert,
trusted_device_id: packet.device_id.clone(),
}))
.with_client_auth_cert(vec![my_cert], my_private_key)
.unwrap();
log::debug!("Upgrade to TLS connection as client");
let server_name = ServerName::try_from(packet.device_id.clone()).unwrap();
device
.io_impl
.connect_client_tls(config, server_name, socket)
.await
} else {
#[cfg(feature = "embedded")]
let config = rustls::ServerConfig::builder_with_details(
Arc::new(rustls_rustcrypto::provider()),
Arc::new(CustomTimeProvider { current_time }),
)
.with_safe_default_protocol_versions()
.unwrap();
#[cfg(not(feature = "embedded"))]
let config = rustls::ServerConfig::builder();
let config = config
.with_client_cert_verifier(Arc::new(CertVerifier {
trusted_device_cert,
trusted_device_id: packet.device_id.clone(),
}))
.with_single_cert(vec![my_cert], my_private_key)
.unwrap();
log::debug!("Upgrade to TLS connection as server");
device.io_impl.accept_server_tls(config, socket).await
};
let Ok(socket_with_tls) = socket_with_tls else {
log::warn!("Failed to upgrade the TCP connection to TLS");
return;
};
log::debug!("TLS connection established successfully");
per_tls_stream(packet, socket_with_tls, device, is_trusted, is_tls_client).await;
}
pub(crate) async fn per_tls_stream<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
packet: crate::packet::identity::IdentityPacket,
mut socket: TlsStream,
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
is_trusted: bool,
send_after: bool,
) {
if !send_after {
log::debug!("Sending my encrypted identity packet");
device
.get_identity_packet()
.write_to_socket(&mut socket)
.await;
}
let mut i = 0;
let mut buf = [0u8; crate::config::TLS_BUFFER_SIZE];
loop {
let Ok(bytes_read) = socket.read(&mut buf[i..]).await else {
i = 0;
continue;
};
i += bytes_read;
if let Ok(NetworkPacket {
body: NetworkPacketBody::Identity(encrypted_packet),
..
}) = NetworkPacket::try_read_from(&buf[..i])
{
log::debug!("Received encrypted identity packet: {packet:?}");
if encrypted_packet.device_id != packet.device_id {
log::warn!("Device ID changed half-way through the handshake");
return;
}
if encrypted_packet.protocol_version != packet.protocol_version {
log::warn!("Protocol version changed half-way through the handshake");
return;
}
break;
}
if bytes_read == 0 {
log::warn!("tls: Failed to parse the received JSON");
return;
}
}
if send_after {
log::debug!("Sending my encrypted identity packet");
device
.get_identity_packet()
.write_to_socket(&mut socket)
.await;
}
if let Some(e) = device.links.lock().await.get_mut(&packet.device_id) {
e.info = packet;
} else {
let device_id = packet.device_id.clone();
let (tx, rx) = async_channel::bounded(32);
let new_link = device.new_link(
packet,
if is_trusted {
PairState::Paired
} else {
PairState::Unpaired
},
tx,
);
device
.links
.lock()
.await
.insert(device_id.clone(), new_link);
device.on_conn_established(device_id, socket, rx).await;
}
}
#[derive(Debug)]
pub(crate) struct CertVerifier {
pub(crate) trusted_device_cert: Option<Certificate>,
pub(crate) trusted_device_id: String,
}
impl CertVerifier {
fn verify(&self, cert: &CertificateDer, current_time: UnixTime) -> Result<(), Error> {
if let Some(ref trusted_cert) = self.trusted_device_cert {
let cert = Certificate::from_der(cert)
.map_err(|_e| Error::InvalidCertificate(CertificateError::BadEncoding))?;
if is_expired(&cert, current_time) {
return Err(Error::InvalidCertificate(CertificateError::Expired));
}
if let Some(common_name) = extract_device_id_from_cert(&cert) {
if common_name != self.trusted_device_id {
return Err(Error::InvalidCertificate(
CertificateError::NotValidForNameContext {
expected: ServerName::try_from(self.trusted_device_id.clone()).unwrap(),
presented: vec![common_name],
},
));
}
} else {
return Err(Error::InvalidCertificate(CertificateError::UnknownIssuer));
}
if trusted_cert.to_der().unwrap() != cert.to_der().unwrap() {
return Err(Error::InvalidCertificate(
CertificateError::ApplicationVerificationFailure,
));
}
} else {
}
Ok(())
}
}
impl ClientCertVerifier for CertVerifier {
fn root_hint_subjects(&self) -> &[DistinguishedName] {
&[]
}
fn verify_client_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
current_time: UnixTime,
) -> Result<ClientCertVerified, Error> {
self.verify(end_entity, current_time)
.map(|_e| ClientCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
VERIFY_SCHEMES.to_vec()
}
}
impl ServerCertVerifier for CertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer,
_intermediates: &[CertificateDer],
_server_name: &ServerName,
_ocsp_response: &[u8],
current_time: UnixTime,
) -> Result<ServerCertVerified, Error> {
self.verify(end_entity, current_time)
.map(|_e| ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
VERIFY_SCHEMES.to_vec()
}
}
#[cfg(feature = "embedded")]
#[derive(Debug)]
pub(crate) struct CustomTimeProvider {
pub(crate) current_time: u64,
}
#[cfg(feature = "embedded")]
impl rustls::time_provider::TimeProvider for CustomTimeProvider {
fn current_time(&self) -> Option<UnixTime> {
Some(UnixTime::since_unix_epoch(Duration::from_secs(
self.current_time,
)))
}
}