use rustls::{
self, CertificateError, DigitallySignedStruct, DistinguishedName, Error, SignatureScheme,
client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
pki_types::{CertificateDer, ServerName, UnixTime},
server::danger::{ClientCertVerified, ClientCertVerifier},
};
use x509_cert::{
Certificate,
der::{Decode, Encode},
};
#[cfg(feature = "std")]
use std::sync::Arc;
#[cfg(not(feature = "std"))]
use alloc::{
string::{String, ToString},
sync::Arc,
vec,
vec::Vec,
};
use crate::{
device::{Device, PairState},
io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl},
packet::{NetworkPacket, NetworkPacketBody},
};
const VERIFY_SCHEMES: &[SignatureScheme] = &[
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
];
pub(crate) fn is_expired(cert: &Certificate, current_time: UnixTime) -> bool {
let validity = cert.tbs_certificate.validity;
let not_before = validity.not_before.to_unix_duration().as_secs();
let not_after = validity.not_after.to_unix_duration().as_secs();
current_time.as_secs() < not_before || current_time.as_secs() > not_after
}
pub(crate) fn extract_device_id_from_cert(cert: &Certificate) -> Option<String> {
cert.tbs_certificate
.subject
.0
.iter()
.find_map(|n| n.to_string().strip_prefix("CN=").map(ToString::to_string))
}
pub(crate) async fn per_tls_stream<
Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
UdpSocket: UdpSocketImpl + Unpin + 'static,
TcpStream: TcpStreamImpl + Unpin + 'static,
TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
TlsStream: TlsStreamImpl + Unpin + 'static,
>(
packet: crate::packet::identity::IdentityPacket,
mut socket: TlsStream,
device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
is_trusted: bool,
send_after: bool,
) {
if !send_after {
log::debug!("Sending my encrypted identity packet");
device
.get_identity_packet()
.write_to_socket(&mut socket)
.await;
}
let mut i = 0;
let mut buf = [0u8; crate::config::TLS_BUFFER_SIZE];
loop {
let bytes_read = socket.read(&mut buf[i..]).await.unwrap();
i += bytes_read;
if let Ok(NetworkPacket {
body: NetworkPacketBody::Identity(encrypted_packet),
..
}) = NetworkPacket::try_read_from(&buf[..i])
{
log::debug!("Received encrypted identity packet: {packet:?}");
if encrypted_packet.device_id != packet.device_id {
log::warn!("Device ID changed half-way through the handshake");
return;
}
if encrypted_packet.protocol_version != packet.protocol_version {
log::warn!("Protocol version changed half-way through the handshake");
return;
}
break;
}
if bytes_read == 0 {
log::warn!("tls: Failed to parse the received JSON");
return;
}
}
if send_after {
log::debug!("Sending my encrypted identity packet");
device
.get_identity_packet()
.write_to_socket(&mut socket)
.await;
}
if let Some(e) = device.links.lock().await.get_mut(&packet.device_id) {
e.info = packet;
} else {
let device_id = packet.device_id.clone();
let (tx, rx) = async_channel::bounded(32);
let new_link = device.new_link(
packet,
if is_trusted {
PairState::Paired
} else {
PairState::Unpaired
},
tx,
);
device
.links
.lock()
.await
.insert(device_id.clone(), new_link);
device.on_conn_established(device_id, socket, rx).await;
}
}
#[derive(Debug)]
pub(crate) struct CertVerifier {
pub(crate) trusted_device_cert: Option<Certificate>,
pub(crate) trusted_device_id: String,
}
impl CertVerifier {
fn verify(&self, cert: &CertificateDer, current_time: UnixTime) -> Result<(), Error> {
if let Some(ref trusted_cert) = self.trusted_device_cert {
let cert = Certificate::from_der(cert)
.map_err(|_e| Error::InvalidCertificate(CertificateError::BadEncoding))?;
if is_expired(&cert, current_time) {
return Err(Error::InvalidCertificate(CertificateError::Expired));
}
if let Some(common_name) = extract_device_id_from_cert(&cert) {
if common_name != self.trusted_device_id {
return Err(Error::InvalidCertificate(
CertificateError::NotValidForNameContext {
expected: ServerName::try_from(self.trusted_device_id.clone()).unwrap(),
presented: vec![common_name],
},
));
}
} else {
return Err(Error::InvalidCertificate(CertificateError::UnknownIssuer));
}
if trusted_cert.to_der().unwrap() != cert.to_der().unwrap() {
return Err(Error::InvalidCertificate(
CertificateError::ApplicationVerificationFailure,
));
}
} else {
}
Ok(())
}
}
impl ClientCertVerifier for CertVerifier {
fn root_hint_subjects(&self) -> &[DistinguishedName] {
&[]
}
fn verify_client_cert(
&self,
end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
current_time: UnixTime,
) -> Result<ClientCertVerified, Error> {
self.verify(end_entity, current_time)
.map(|_e| ClientCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
VERIFY_SCHEMES.to_vec()
}
}
impl ServerCertVerifier for CertVerifier {
fn verify_server_cert(
&self,
end_entity: &CertificateDer,
_intermediates: &[CertificateDer],
_server_name: &ServerName,
_ocsp_response: &[u8],
current_time: UnixTime,
) -> Result<ServerCertVerified, Error> {
self.verify(end_entity, current_time)
.map(|_e| ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
VERIFY_SCHEMES.to_vec()
}
}
#[cfg(feature = "embedded")]
#[derive(Debug)]
pub(crate) struct CustomTimeProvider {
pub(crate) current_time: u64,
}
#[cfg(feature = "embedded")]
impl rustls::time_provider::TimeProvider for CustomTimeProvider {
fn current_time(&self) -> Option<UnixTime> {
Some(UnixTime::since_unix_epoch(core::time::Duration::from_secs(
self.current_time,
)))
}
}