kdeconnect-proto 0.1.0

A pure Rust modular implementation of the KDE Connect protocol
Documentation
//! MDNS discovery mechanism.
use core::net::{IpAddr, Ipv4Addr, SocketAddr};

#[cfg(feature = "std")]
use std::sync::Arc;

#[cfg(not(feature = "std"))]
use alloc::{format, string::ToString, sync::Arc};

use low_dns::{DnsQuestion, DnsRecord, Header, HeaderKind, Name, Packet, Rdata};

use crate::{device::Device, io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl}};

const MDNS_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
const MDNS_PORT: u16 = 5353;

const MDNS_BUFFER_SIZE: usize = 512;

const MDNS_SERVICE_TYPE: &str = "_kdeconnect._udp.local";

/// Start an UDP listener on the MDNS address (224.0.0.251).
///
/// As a library user, you should ignore this function as it's only useful to develop other IO
/// backends.
///
/// This function sets up a listener for the `_kdeconnect._udp.local` MDNS service and regularly
/// broadcasts an MDNS question for the same service to discover other devices.
///
/// If an MDNS answer is received, an UDP identity packet is sent to the other device.
#[allow(clippy::too_many_lines)]
pub async fn setup_mdns<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
    let Ok(mut socket) = device
        .io_impl
        .bind_udp_reuse_multicast_v4(
            SocketAddr::new(IpAddr::V4(MDNS_ADDR), MDNS_PORT),
            (MDNS_ADDR, Ipv4Addr::UNSPECIFIED),
        )
        .await
    else {
        return;
    };

    let mut question_packet_buf = Packet::max_buffer();
    let mut answer_packet_buf = Packet::max_buffer();

    let (my_ipv4, _my_ipv6) = device.io_impl.get_host_addresses().await;

    // TODO: support IPv6

    let kdeconnect_question = prepare_kdeconnect_question(&mut question_packet_buf);
    prepare_kdeconnect_answer(
        &mut answer_packet_buf,
        my_ipv4.expect("failed to get Ipv4 address of the host, is it connected to the internet?"),
        &device,
    );

    // Send question
    // TODO: send question at specific intervals using a poll_fn instead of recv_from (e.g 4000 ms)
    socket
        .send_to(
            &kdeconnect_question,
            SocketAddr::new(IpAddr::V4(MDNS_ADDR), MDNS_PORT),
        )
        .await
        .unwrap();

    // Send and receive answers
    let mut buf = [0u8; MDNS_BUFFER_SIZE];
    let mut name_buf = Name::max_buffer();

    loop {
        let _ = socket.recv_from(&mut buf).await.unwrap();

        let packet = Packet::parse(&buf);

        if let Ok(dns_packet) = packet {
            for question in dns_packet.questions() {
                if question
                    .name()
                    .as_str(&mut name_buf)
                    .unwrap()
                    .ends_with(MDNS_SERVICE_TYPE)
                {
                    let kdeconnect_answer = Packet::parse_and_builder(&mut answer_packet_buf)
                        .expect("prepared answer is a correct DNS packet")
                        .header(
                            Header::builder()
                                .id(dns_packet.header().id())
                                .kind(HeaderKind::Response)
                                .build(),
                        )
                        .build();

                    socket
                        .send_to(
                            &kdeconnect_answer,
                            SocketAddr::new(IpAddr::V4(MDNS_ADDR), MDNS_PORT),
                        )
                        .await
                        .unwrap();
                }
            }

            let mut ip = None;
            let mut name = None;
            let mut id = None;

            let answers_for_kdeconnect_service = dns_packet
                .answers()
                .any(|answer| answer.name() == MDNS_SERVICE_TYPE);

            if !answers_for_kdeconnect_service {
                continue;
            }

            for record in dns_packet.additional_records() {
                match record.rdata() {
                    Rdata::A { ip: new_ip } => {
                        ip = Some(*new_ip);
                    }
                    // TODO: AAAA
                    Rdata::TXT { text } => {
                        for item in text.clone().filter_map(Result::ok) {
                            if let Some(s) = item.strip_prefix("id=") {
                                id = Some(s.to_string());
                            } else if let Some(s) = item.strip_prefix("name=") {
                                name = Some(s.to_string());
                            }
                        }
                    }
                    _ => {}
                }
            }

            if id.is_none() {
                // If the ID is not set by the TXT record, fallback to the PTR answer
                for record in dns_packet.answers() {
                    if let Rdata::PTR { name } = record.rdata()
                        && let Some(new_id) = name
                            .to_string()
                            .strip_suffix(MDNS_SERVICE_TYPE)
                            .and_then(|s| s.strip_suffix("."))
                    {
                        id = Some(new_id.to_string());
                    }
                }
            }

            if id.as_ref().is_some_and(|id| *id == device.host_device_id) {
                continue;
            }

            if let (Some(ip), Some(id), Some(name)) = (ip, id, name) {
                log::debug!(
                    "Discovered {name} (with ip {ip:?} and id {id}) using MDNS, sending an identity message using UDP"
                );

                if device.links.lock().await.contains_key(&id) {
                    log::debug!(
                        "Device {id} has already established connection, ignore the MDNS request"
                    );
                    continue;
                }

                let my_identity_packet = device.get_identity_packet();
                let serialized_my_identity_packet =
                    serde_json::to_string(&my_identity_packet).unwrap();

                // TODO: select the local-bound port IP version based on `ip` version
                if let Ok(mut socket) = device
                    .io_impl
                    .bind_udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), 0))
                    .await
                {
                    socket
                        .send_to(
                            serialized_my_identity_packet.as_bytes(),
                            SocketAddr::new(IpAddr::V4(ip), crate::config::UDP_PORT),
                        )
                        .await
                        .unwrap();
                }
            }
        }
    }
}

fn prepare_kdeconnect_question(packet_buf: &mut [u8]) -> Packet<'_> {
    let mut service_type_name_buf = [0; MDNS_SERVICE_TYPE.len() + 2];

    Packet::builder(packet_buf)
        .question(DnsQuestion::ptr(Name::from_str_into_buf(
            MDNS_SERVICE_TYPE,
            &mut service_type_name_buf,
        )))
        .build()
}

fn prepare_kdeconnect_answer<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    packet_buf: &mut [u8],
    my_ipv4_addr: Ipv4Addr,
    device: &Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>,
) {
    const FAKE_HOST_NAME: &str = "host.local";

    // TODO: use constcat
    let service_id = format!("{}.{}", device.host_device_id, MDNS_SERVICE_TYPE);
    let text_payload = &[
        "protocol=8",
        &format!("id={}", device.host_device_id),
        &format!("name={}", device.config.name),
        &format!("type={}", device.config.device_type),
    ];

    let mut service_type_name_buf = [0; MDNS_SERVICE_TYPE.len() + 2];
    let mut service_id_name_buf = Name::max_buffer();
    let mut host_name_buf = [0; FAKE_HOST_NAME.len() + 2];
    let mut text_record_buf = [0; 512];

    let service_type_name = Name::from_str_into_buf(MDNS_SERVICE_TYPE, &mut service_type_name_buf);
    let service_id_name = Name::from_str_into_buf(&service_id, &mut service_id_name_buf);
    let host_name = Name::from_str_into_buf(FAKE_HOST_NAME, &mut host_name_buf);

    Packet::builder(packet_buf)
        .answer(DnsRecord::ptr(
            service_type_name.clone(),
            service_id_name.clone(),
        ))
        .additional_record(DnsRecord::srv(
            service_id_name.clone(),
            0,
            0,
            crate::config::UDP_PORT,
            host_name.clone(),
        ))
        .additional_record(DnsRecord::a(host_name, my_ipv4_addr))
        // TODO: maybe AAAA record with IPv6 address
        .additional_record(DnsRecord::txt(
            service_id_name,
            *text_payload,
            &mut text_record_buf,
        ));
}