kdeconnect-proto 0.1.0

A pure Rust modular implementation of the KDE Connect protocol
Documentation
//! UDP discovery mechanism.
use core::{
    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
    time::Duration,
};

#[cfg(feature = "std")]
use std::sync::Arc;

#[cfg(not(feature = "std"))]
use alloc::{sync::Arc, vec};

use rustls::pki_types::{CertificateDer, PrivateKeyDer, UnixTime, pem::PemObject};
use x509_cert::{Certificate, der::DecodePem};

use crate::{
    device::Device, io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl}, packet::{NetworkPacket, NetworkPacketBody}
};

async fn broadcast_udp_identity<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
    let Ok(mut socket) = device
        .io_impl
        .bind_udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0))
        .await
    else {
        return;
    };

    let my_identity_packet = device.get_identity_packet();
    let serialized_my_identity_packet = serde_json::to_string(&my_identity_packet).unwrap();
    socket.set_broadcast(true).unwrap();
    socket
        .send_to(
            serialized_my_identity_packet.as_bytes(),
            SocketAddr::new(IpAddr::V4(Ipv4Addr::BROADCAST), crate::config::UDP_PORT),
        )
        .await
        .unwrap();
}

/// Start an UDP listener on the reserved udp port (1716).
///
/// As a library user, you should ignore this function as it's only useful to develop other IO
/// backends.
///
/// This function sets up a listener incoming UDP connections and regularly
/// broadcasts an identity packet over UDP to discover other devices.
///
/// If a valid identity packet is received, a TCP connection is established with the other device
/// which is then upgraded to a TLS connection used to send application packets.
pub async fn setup_udp<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
    broadcast_udp_identity(Arc::clone(&device)).await;

    let Ok(udp_listener) = device
        .io_impl
        .bind_udp_reuse_v6(SocketAddr::new(
            IpAddr::V6(Ipv6Addr::UNSPECIFIED),
            crate::config::UDP_PORT,
        ))
        .await
    else {
        return;
    };

    let mut i = 0;
    let mut buf = [0u8; crate::config::UDP_BUFFER_SIZE];

    loop {
        let (bytes_read, addr) = udp_listener.recv_from(&mut buf[i..]).await.unwrap();
        i += bytes_read;

        let device = Arc::clone(&device);

        // A UDP packet received must be an Identity packet
        if let Ok(NetworkPacket {
            body: NetworkPacketBody::Identity(packet),
            ..
        }) = NetworkPacket::try_read_from(&buf[..i])
        {
            i = 0;
            if packet.device_id == device.host_device_id {
                continue;
            }

            log::debug!("UDP Identity packet received, upgrading connection");
            upgrade_udp_connection(addr, packet, device).await;
        }

        if bytes_read == 0 {
            log::warn!("udp: Failed to parse the received JSON");
            i = 0;
        }
    }
}

async fn upgrade_udp_connection<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    mut addr: SocketAddr,
    packet: crate::packet::identity::IdentityPacket,
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
    if device.links.lock().await.contains_key(&packet.device_id) {
        log::debug!(
            "Device {} has already established connection, ignore the MDNS request",
            packet.device_id
        );
        return;
    }

    addr.set_port(packet.get_tcp_port().unwrap());

    // Make a TCP connection
    let Ok(mut socket) = device.io_impl.connect_tcp(addr).await else {
        return;
    };
    socket.writable().await.unwrap();

    device
        .get_identity_packet()
        .write_to_socket_unencrypted(&mut socket)
        .await;

    // I'm the SSL server
    let mut trusted_device_cert = device
        .trust_handler
        .lock()
        .await
        .get_certificate(&packet.device_id)
        .await
        .map(|f| Certificate::from_pem(f).unwrap());
    let mut is_trusted = trusted_device_cert.is_some();
    let current_time = device.io_impl.get_current_timestamp().await;

    // Check expiration of the trusted certificate
    if let Some(cert) = &trusted_device_cert
        && super::tls::is_expired(
            cert,
            UnixTime::since_unix_epoch(Duration::from_secs(current_time)),
        )
    {
        log::warn!(
            "SSL certificate of the trusted device `{}` is expired, make the device untrusted",
            packet.device_name.as_ref().unwrap_or(&packet.device_id)
        );
        device
            .trust_handler
            .lock()
            .await
            .untrust_device(&packet.device_id)
            .await;
        trusted_device_cert = None;
        is_trusted = false;
    }

    let certs = vec![CertificateDer::from_pem_slice(&device.config.cert).unwrap()];
    let key = PrivateKeyDer::from_pem_slice(&device.config.private_key).unwrap();

    #[cfg(feature = "embedded")]
    let config = rustls::ServerConfig::builder_with_details(
        Arc::new(rustls_rustcrypto::provider()),
        Arc::new(super::tls::CustomTimeProvider { current_time }),
    )
    .with_safe_default_protocol_versions()
    .unwrap();

    #[cfg(not(feature = "embedded"))]
    let config = rustls::ServerConfig::builder();

    let config = config
        .with_client_cert_verifier(Arc::new(super::tls::CertVerifier {
            trusted_device_cert,
            trusted_device_id: packet.device_id.clone(),
        }))
        .with_single_cert(certs, key)
        .unwrap();

    log::debug!("Upgrade to SSL connection as server");
    let Ok(socket_with_tls) = device.io_impl.accept_server_tls(config, socket).await else {
        return;
    };
    log::debug!("SSL connection established successfully");

    super::tls::per_tls_stream(packet, socket_with_tls, device, is_trusted, false).await;
}