kdeconnect-proto 0.1.0

A pure Rust modular implementation of the KDE Connect protocol
Documentation
use rustls::{
    self, CertificateError, DigitallySignedStruct, DistinguishedName, Error, SignatureScheme,
    client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier},
    pki_types::{CertificateDer, ServerName, UnixTime},
    server::danger::{ClientCertVerified, ClientCertVerifier},
};
use x509_cert::{
    Certificate,
    der::{Decode, Encode},
};

#[cfg(feature = "std")]
use std::sync::Arc;

#[cfg(not(feature = "std"))]
use alloc::{
    string::{String, ToString},
    sync::Arc,
    vec,
    vec::Vec,
};

use crate::{
    device::{Device, PairState},
    io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl},
    packet::{NetworkPacket, NetworkPacketBody},
};

const VERIFY_SCHEMES: &[SignatureScheme] = &[
    SignatureScheme::RSA_PKCS1_SHA1,
    SignatureScheme::ECDSA_SHA1_Legacy,
    SignatureScheme::RSA_PKCS1_SHA256,
    SignatureScheme::ECDSA_NISTP256_SHA256,
    SignatureScheme::RSA_PKCS1_SHA384,
    SignatureScheme::ECDSA_NISTP384_SHA384,
    SignatureScheme::RSA_PKCS1_SHA512,
    SignatureScheme::ECDSA_NISTP521_SHA512,
    SignatureScheme::RSA_PSS_SHA256,
    SignatureScheme::RSA_PSS_SHA384,
    SignatureScheme::RSA_PSS_SHA512,
    SignatureScheme::ED25519,
    SignatureScheme::ED448,
];

pub(crate) fn is_expired(cert: &Certificate, current_time: UnixTime) -> bool {
    let validity = cert.tbs_certificate.validity;
    let not_before = validity.not_before.to_unix_duration().as_secs();
    let not_after = validity.not_after.to_unix_duration().as_secs();

    current_time.as_secs() < not_before || current_time.as_secs() > not_after
}

pub(crate) fn extract_device_id_from_cert(cert: &Certificate) -> Option<String> {
    cert.tbs_certificate
        .subject
        .0
        .iter()
        .find_map(|n| n.to_string().strip_prefix("CN=").map(ToString::to_string))
}

pub(crate) async fn per_tls_stream<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    packet: crate::packet::identity::IdentityPacket,
    mut socket: TlsStream,
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
    is_trusted: bool,
    send_after: bool,
) {
    if !send_after {
        log::debug!("Sending my encrypted identity packet");
        device
            .get_identity_packet()
            .write_to_socket(&mut socket)
            .await;
    }

    let mut i = 0;
    let mut buf = [0u8; crate::config::TLS_BUFFER_SIZE];

    // Need to read before writing because it's used to send the handshake and receive the response
    // on embedded devices
    loop {
        // TODO: make sure the buffer does not overflow
        let bytes_read = socket.read(&mut buf[i..]).await.unwrap();
        i += bytes_read;

        if let Ok(NetworkPacket {
            body: NetworkPacketBody::Identity(encrypted_packet),
            ..
        }) = NetworkPacket::try_read_from(&buf[..i])
        {
            log::debug!("Received encrypted identity packet: {packet:?}");

            if encrypted_packet.device_id != packet.device_id {
                log::warn!("Device ID changed half-way through the handshake");
                return;
            }

            if encrypted_packet.protocol_version != packet.protocol_version {
                log::warn!("Protocol version changed half-way through the handshake");
                return;
            }

            break;
        }

        if bytes_read == 0 {
            log::warn!("tcp: Failed to parse the received JSON");
            return;
        }
    }

    if send_after {
        log::debug!("Sending my encrypted identity packet");
        device
            .get_identity_packet()
            .write_to_socket(&mut socket)
            .await;
    }

    if let Some(e) = device.links.lock().await.get_mut(&packet.device_id) {
        e.info = packet;
    } else {
        let device_id = packet.device_id.clone();
        let (tx, rx) = async_channel::bounded(32);
        let new_link = device.new_link(
            packet,
            if is_trusted {
                PairState::Paired
            } else {
                PairState::Unpaired
            },
            tx,
        );
        device
            .links
            .lock()
            .await
            .insert(device_id.clone(), new_link);

        device.on_conn_established(device_id, socket, rx).await;
    }
}

#[derive(Debug)]
pub(crate) struct CertVerifier {
    pub(crate) trusted_device_cert: Option<Certificate>,
    pub(crate) trusted_device_id: String,
}

impl CertVerifier {
    fn verify(&self, cert: &CertificateDer, current_time: UnixTime) -> Result<(), Error> {
        if let Some(ref trusted_cert) = self.trusted_device_cert {
            // Check expiration
            let cert = Certificate::from_der(cert)
                .map_err(|_e| Error::InvalidCertificate(CertificateError::BadEncoding))?;

            if is_expired(&cert, current_time) {
                return Err(Error::InvalidCertificate(CertificateError::Expired));
            }

            // Check common name
            if let Some(common_name) = extract_device_id_from_cert(&cert) {
                if common_name != self.trusted_device_id {
                    return Err(Error::InvalidCertificate(
                        CertificateError::NotValidForNameContext {
                            expected: ServerName::try_from(self.trusted_device_id.clone()).unwrap(),
                            presented: vec![common_name],
                        },
                    ));
                }
            } else {
                return Err(Error::InvalidCertificate(CertificateError::UnknownIssuer));
            }

            // Check certificate contents
            if trusted_cert.to_der().unwrap() != cert.to_der().unwrap() {
                return Err(Error::InvalidCertificate(
                    CertificateError::ApplicationVerificationFailure,
                ));
            }
        } else {
            // No reference certificate to compare against, blindly trust the certificate
        }

        Ok(())
    }
}

impl ClientCertVerifier for CertVerifier {
    fn root_hint_subjects(&self) -> &[DistinguishedName] {
        &[]
    }

    fn verify_client_cert(
        &self,
        end_entity: &CertificateDer<'_>,
        _intermediates: &[CertificateDer<'_>],
        current_time: UnixTime,
    ) -> Result<ClientCertVerified, Error> {
        self.verify(end_entity, current_time)
            .map(|_e| ClientCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer<'_>,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer<'_>,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
        VERIFY_SCHEMES.to_vec()
    }
}

impl ServerCertVerifier for CertVerifier {
    fn verify_server_cert(
        &self,
        end_entity: &CertificateDer,
        _intermediates: &[CertificateDer],
        _server_name: &ServerName,
        _ocsp_response: &[u8],
        current_time: UnixTime,
    ) -> Result<ServerCertVerified, Error> {
        self.verify(end_entity, current_time)
            .map(|_e| ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer,
        _dss: &DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
        VERIFY_SCHEMES.to_vec()
    }
}

#[cfg(feature = "embedded")]
#[derive(Debug)]
pub(crate) struct CustomTimeProvider {
    pub(crate) current_time: u64,
}

#[cfg(feature = "embedded")]
impl rustls::time_provider::TimeProvider for CustomTimeProvider {
    fn current_time(&self) -> Option<UnixTime> {
        Some(UnixTime::since_unix_epoch(core::time::Duration::from_secs(
            self.current_time,
        )))
    }
}