kdeconnect-proto 0.1.0

A pure Rust modular implementation of the KDE Connect protocol
Documentation
//! TCP connection response implementation.
use core::{
    net::{IpAddr, Ipv6Addr, SocketAddr},
    time::Duration,
};

#[cfg(feature = "std")]
use std::sync::Arc;

#[cfg(not(feature = "std"))]
use alloc::{sync::Arc, vec};

use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime, pem::PemObject};
use x509_cert::{Certificate, der::DecodePem};

use crate::{
    device::Device, io::{IoImpl, KnownFunctionName, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl}, packet::{NetworkPacket, NetworkPacketBody}
};

/// Start a TCP listener on a port in the range 1716-1764.
///
/// As a library user, you should ignore this function as it's only useful to develop other IO
/// backends.
///
/// If a TCP connection is made with this device, it's upgraded to a TLS connection and used to
/// send application packets.
pub async fn setup_tcp<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
    #[allow(unused_mut)]
    let (mut tcp_listener, tcp_port) = {
        let mut port = crate::config::MIN_TCP_PORT;

        loop {
            if port > crate::config::MAX_TCP_PORT {
                log::error!(
                    "No port available in {}..{}",
                    crate::config::MIN_TCP_PORT,
                    crate::config::MAX_TCP_PORT
                );
                return;
            }

            if let Ok(tcp_listener) = device
                .io_impl
                .listen_tcp(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port))
                .await
            {
                break (tcp_listener, port);
            }

            port += 1;
        }
    };

    let _ = device.my_tcp_port.set(tcp_port).await;

    // The TCP server is initialized, launch discovery mechanisms
    super::start_discovering(Arc::clone(&device));

    loop {
        let Ok(socket) = tcp_listener.accept().await else {
            #[cfg(feature = "embedded")]
            device.io_impl.sleep(Duration::from_millis(500)).await;
            continue;
        };

        // A task is spawned per TCP-stream which will be further converted to a TLS stream
        Arc::clone(&device)
            .io_impl
            .spawn(KnownFunctionName::PerTcpStream(socket), Arc::clone(&device));

        #[cfg(feature = "embedded")]
        {
            tcp_listener = device
                .io_impl
                .listen_tcp(SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), tcp_port))
                .await
                .unwrap();
        }
    }
}

// Function which should be called when a new TCP stream is made.
///
/// As a library user, you should ignore this function as it's only useful to develop other IO
/// backends.
///
/// It upgrades the connection to a TLS connection and use it to
/// send application packets.
pub async fn per_tcp_stream<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    mut socket: TcpStream,
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
    let mut i = 0;
    let mut buf = [0; crate::config::TCP_BUFFER_SIZE];

    loop {
        // TODO: avoid panicking (especially when the buffer is too large)
        let bytes_read = socket.read(&mut buf[i..]).await.unwrap();
        i += bytes_read;

        if let Ok(NetworkPacket {
            body: NetworkPacketBody::Identity(packet),
            ..
        }) = NetworkPacket::try_read_from(&buf[..i])
        {
            log::debug!("TCP Identity packet received, upgrading connection");
            upgrade_tcp_connection(packet, socket, device).await;
            break;
        }

        if bytes_read == 0 {
            log::warn!("tcp: Failed to parse the received JSON");
            break;
        }
    }
}

async fn upgrade_tcp_connection<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    packet: crate::packet::identity::IdentityPacket,
    socket: TcpStream,
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
    if device.links.lock().await.contains_key(&packet.device_id) {
        log::debug!(
            "Device {} has already established connection, ignore the MDNS request",
            packet.device_id
        );
        return;
    }

    // I'm the SSL client
    let mut trusted_device_cert = device
        .trust_handler
        .lock()
        .await
        .get_certificate(&packet.device_id)
        .await
        .map(|f| Certificate::from_pem(f).unwrap());
    let mut is_trusted = trusted_device_cert.is_some();
    let current_time = device.io_impl.get_current_timestamp().await;

    // Check expiration of the trusted certificate
    if let Some(cert) = &trusted_device_cert
        && super::tls::is_expired(
            cert,
            UnixTime::since_unix_epoch(Duration::from_secs(current_time)),
        )
    {
        log::warn!(
            "SSL certificate of the trusted device `{}` is expired, make the device untrusted",
            packet.device_name.as_ref().unwrap_or(&packet.device_id)
        );
        device
            .trust_handler
            .lock()
            .await
            .untrust_device(&packet.device_id)
            .await;
        trusted_device_cert = None;
        is_trusted = false;
    }

    let certs = vec![CertificateDer::from_pem_slice(&device.config.cert).unwrap()];
    let key = PrivateKeyDer::from_pem_slice(&device.config.private_key).unwrap();

    #[cfg(feature = "embedded")]
    let config = rustls::ClientConfig::builder_with_details(
        Arc::new(rustls_rustcrypto::provider()),
        Arc::new(super::tls::CustomTimeProvider { current_time }),
    )
    .with_safe_default_protocol_versions()
    .unwrap();

    #[cfg(not(feature = "embedded"))]
    let config = rustls::ClientConfig::builder();

    let config = config
        .dangerous()
        .with_custom_certificate_verifier(Arc::new(super::tls::CertVerifier {
            trusted_device_cert,
            trusted_device_id: packet.device_id.clone(),
        }))
        .with_client_auth_cert(certs, key)
        .unwrap();

    log::debug!("Upgrade to SSL connection as client");
    let server_name = ServerName::try_from(packet.device_id.clone()).unwrap();
    let Ok(socket_with_tls) = device
        .io_impl
        .connect_client_tls(config, server_name, socket)
        .await
    else {
        return;
    };
    log::debug!("SSL connection established successfully");

    super::tls::per_tls_stream(packet, socket_with_tls, device, is_trusted, true).await;
}