kdeconnect-proto 0.2.0

A pure Rust modular implementation of the KDE Connect protocol
Documentation
use core::{
    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
    task::{Context, Poll},
};
use net2::unix::UnixUdpBuilderExt;
use rustls::{ClientConfig, CommonState, ServerConfig, pki_types::ServerName};
use tokio_rustls::{TlsAcceptor, TlsConnector};

#[cfg(feature = "std")]
use std::sync::Arc;

#[cfg(not(feature = "std"))]
use alloc::sync::Arc;

use crate::{
    device::Device,
    io::{
        IoImpl, KnownFunctionName, Result, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl,
        UdpSocketImpl,
    },
};

/// Default implementation of the [`IoImpl`] trait.
///
/// It uses [`tokio`] to do IO tasks asynchronously.
#[derive(Debug)]
pub struct TokioIoImpl;

impl
    IoImpl<
        tokio::net::UdpSocket,
        tokio::net::TcpStream,
        tokio::net::TcpListener,
        tokio_rustls::TlsStream<tokio::net::TcpStream>,
    > for TokioIoImpl
{
    async fn bind_udp(&self, addr: SocketAddr) -> Result<tokio::net::UdpSocket> {
        tokio::net::UdpSocket::bind(addr).await.map_err(Into::into)
    }

    async fn bind_udp_reuse_v4(&self, addr: SocketAddr) -> Result<tokio::net::UdpSocket> {
        let raw_socket = net2::UdpBuilder::new_v4()?
            .reuse_address(true)?
            .reuse_port(true)?
            .bind(addr)?;
        raw_socket.set_nonblocking(true)?;

        tokio::net::UdpSocket::from_std(raw_socket).map_err(Into::into)
    }

    async fn bind_udp_reuse_multicast_v4(
        &self,
        addr: SocketAddr,
        multicast_addr: (Ipv4Addr, Ipv4Addr),
    ) -> Result<tokio::net::UdpSocket> {
        let raw_socket = net2::UdpBuilder::new_v4()?
            .reuse_address(true)?
            .reuse_port(true)?
            .bind(addr)?;
        raw_socket.join_multicast_v4(&multicast_addr.0, &multicast_addr.1)?;
        raw_socket.set_nonblocking(true)?;

        tokio::net::UdpSocket::from_std(raw_socket).map_err(Into::into)
    }

    async fn listen_tcp(&self, addr: SocketAddr) -> Result<tokio::net::TcpListener> {
        tokio::net::TcpListener::bind(addr)
            .await
            .map_err(Into::into)
    }

    async fn connect_tcp(&self, addr: SocketAddr) -> Result<tokio::net::TcpStream> {
        tokio::net::TcpStream::connect(addr)
            .await
            .map_err(Into::into)
    }

    async fn accept_server_tls(
        &self,
        config: ServerConfig,
        stream: tokio::net::TcpStream,
    ) -> Result<tokio_rustls::TlsStream<tokio::net::TcpStream>> {
        match TlsAcceptor::from(Arc::new(config)).accept(stream).await {
            Ok(r) => Ok(tokio_rustls::TlsStream::Server(r)),
            Err(e) => Err(e.into()),
        }
    }

    async fn connect_client_tls(
        &self,
        config: ClientConfig,
        server_name: ServerName<'static>,
        stream: tokio::net::TcpStream,
    ) -> Result<tokio_rustls::TlsStream<tokio::net::TcpStream>> {
        match TlsConnector::from(Arc::new(config))
            .connect(server_name, stream)
            .await
        {
            Ok(r) => Ok(tokio_rustls::TlsStream::Client(r)),
            Err(e) => Err(e.into()),
        }
    }

    async fn get_host_addresses(&self) -> (Option<Ipv4Addr>, Option<Ipv6Addr>) {
        let addrs = if_addrs::get_if_addrs().unwrap();
        let mut ipv4_addr = None;
        let mut ipv6_addr = None;

        for addr in addrs {
            if !addr.is_loopback() && addr.is_oper_up() {
                match addr.ip() {
                    IpAddr::V4(addr) => {
                        if ipv6_addr.is_some() {
                            return (Some(addr), ipv6_addr);
                        }
                        ipv4_addr = Some(addr);
                    }
                    IpAddr::V6(addr) => {
                        if ipv4_addr.is_some() {
                            return (ipv4_addr, Some(addr));
                        }
                        ipv6_addr = Some(addr);
                    }
                }
            }
        }

        (ipv4_addr, ipv6_addr)
    }

    async fn sleep(&self, duration: std::time::Duration) {
        tokio::time::sleep(duration).await;
    }

    fn spawn(
        &self,
        name: KnownFunctionName<tokio::net::TcpStream>,
        device: Arc<
            Device<
                Self,
                tokio::net::UdpSocket,
                tokio::net::TcpStream,
                tokio::net::TcpListener,
                tokio_rustls::TlsStream<tokio::net::TcpStream>,
            >,
        >,
    ) {
        match name {
            KnownFunctionName::SetupUdp => {
                tokio::task::spawn_local(crate::transport::udp::setup_udp(device))
            }
            KnownFunctionName::SetupMdns => {
                tokio::task::spawn_local(crate::transport::mdns::setup_mdns(device))
            }
            KnownFunctionName::PerTcpStream(stream) => {
                tokio::task::spawn_local(crate::transport::tcp::per_tcp_stream(stream, device))
            }
        };
    }

    fn start(
        &self,
        device: Arc<
            Device<
                Self,
                tokio::net::UdpSocket,
                tokio::net::TcpStream,
                tokio::net::TcpListener,
                tokio_rustls::TlsStream<tokio::net::TcpStream>,
            >,
        >,
    ) {
        let rt = tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()
            .unwrap();
        let set = tokio::task::LocalSet::new();
        set.enter();

        set.spawn_local(async { crate::transport::tcp::setup_tcp(device).await });

        rt.block_on(set);
    }

    async fn get_current_timestamp(&self) -> u64 {
        std::time::SystemTime::now()
            .duration_since(std::time::UNIX_EPOCH)
            .unwrap()
            .as_secs()
    }
}

impl UdpSocketImpl for tokio::net::UdpSocket {
    fn set_broadcast(&self, on: bool) -> Result<()> {
        self.set_broadcast(on).map_err(Into::into)
    }

    fn poll_recv(&self, cx: &mut Context, buf: &mut [u8]) -> Poll<Result<()>> {
        self.poll_recv(cx, &mut tokio::io::ReadBuf::new(buf))
            .map_err(Into::into)
    }

    async fn recv_from(&self, buf: &mut [u8]) -> Result<(usize, SocketAddr)> {
        self.recv_from(buf).await.map_err(Into::into)
    }

    async fn send_to(&mut self, buf: &[u8], addr: SocketAddr) -> Result<usize> {
        tokio::net::UdpSocket::send_to(self, buf, addr)
            .await
            .map_err(Into::into)
    }
}

impl TcpStreamImpl for tokio::net::TcpStream {
    async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
        <Self as tokio::io::AsyncReadExt>::read(self, buf)
            .await
            .map_err(Into::into)
    }

    async fn writable(&self) -> Result<()> {
        self.writable().await.map_err(Into::into)
    }

    async fn write_all(&mut self, src: &[u8]) -> Result<()> {
        <Self as tokio::io::AsyncWriteExt>::write_all(self, src)
            .await
            .map_err(Into::into)
    }
}

impl TlsStreamImpl for tokio_rustls::TlsStream<tokio::net::TcpStream> {
    fn get_common_state(&self) -> &CommonState {
        self.get_ref().1
    }

    async fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
        <Self as tokio::io::AsyncReadExt>::read(self, buf)
            .await
            .map_err(Into::into)
    }

    async fn write_all(&mut self, src: &[u8]) -> Result<()> {
        <Self as tokio::io::AsyncWriteExt>::write_all(self, src)
            .await
            .map_err(Into::into)
    }
}

impl TcpListenerImpl<tokio::net::TcpStream> for tokio::net::TcpListener {
    async fn accept(&self) -> Result<tokio::net::TcpStream> {
        self.accept().await.map(|(r, _)| r).map_err(Into::into)
    }
}