kdeconnect-proto 0.2.0

A pure Rust modular implementation of the KDE Connect protocol
Documentation
//! MDNS discovery mechanism.
use core::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};

#[cfg(feature = "std")]
use std::sync::Arc;

#[cfg(not(feature = "std"))]
use alloc::{format, vec, string::ToString, sync::Arc};

use low_dns::{DnsQuestion, DnsRecord, Header, HeaderKind, Name, Packet, Rdata};

use crate::{
    device::Device,
    io::{IoImpl, TcpListenerImpl, TcpStreamImpl, TlsStreamImpl, UdpSocketImpl},
};

const MDNS_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
const MDNS_PORT: u16 = 5353;

const MDNS_BUFFER_SIZE: usize = 512;

const MDNS_SERVICE_TYPE: &str = "_kdeconnect._udp.local";

/// Start an UDP listener on the MDNS address (224.0.0.251).
///
/// As a library user, you should ignore this function as it's only useful to develop other IO
/// backends.
///
/// This function sets up a listener for the `_kdeconnect._udp.local` MDNS service and regularly
/// broadcasts an MDNS question for the same service to discover other devices.
///
/// If an MDNS answer is received, an UDP identity packet is sent to the other device.
#[allow(clippy::too_many_lines)]
pub async fn setup_mdns<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    device: Arc<Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>>,
) {
    let Ok(mut socket) = device
        .io_impl
        .bind_udp_reuse_multicast_v4(
            SocketAddr::new(IpAddr::V4(MDNS_ADDR), MDNS_PORT),
            (MDNS_ADDR, Ipv4Addr::UNSPECIFIED),
        )
        .await
    else {
        return;
    };

    let mut question_packet_buf = Packet::max_buffer();
    let mut answer_packet_buf = Packet::max_buffer();

    let (my_ipv4, my_ipv6) = device.io_impl.get_host_addresses().await;

    let kdeconnect_question = prepare_kdeconnect_question(&mut question_packet_buf);
    prepare_kdeconnect_answer(
        &mut answer_packet_buf,
        my_ipv4.expect("failed to get Ipv4 address of the host, is it connected to the internet?"),
        my_ipv6,
        &device,
    );

    // Send question
    let _ = socket
        .send_to(
            &kdeconnect_question,
            SocketAddr::new(IpAddr::V4(MDNS_ADDR), MDNS_PORT),
        )
        .await;

    // Send and receive answers
    let mut buf = [0u8; MDNS_BUFFER_SIZE];
    let mut name_buf = Name::max_buffer();

    loop {
        let Ok((bytes_read, _)) = socket.recv_from(&mut buf).await else {
            log::warn!("MDNS packet is too large, it's a bug, you should report it");
            continue;
        };
        let packet = Packet::parse(&buf[..bytes_read]);

        if let Ok(dns_packet) = packet {
            for question in dns_packet.questions() {
                if question
                    .name()
                    .as_str(&mut name_buf)
                    .unwrap()
                    .ends_with(MDNS_SERVICE_TYPE)
                {
                    let kdeconnect_answer = Packet::parse_and_builder(&mut answer_packet_buf)
                        .expect("prepared answer is a correct DNS packet")
                        .header(
                            Header::builder()
                                .id(dns_packet.header().id())
                                .kind(HeaderKind::Response)
                                .build(),
                        )
                        .build();

                    let _ = socket
                        .send_to(
                            &kdeconnect_answer,
                            SocketAddr::new(IpAddr::V4(MDNS_ADDR), MDNS_PORT),
                        )
                        .await;
                }
            }

            let mut ips = vec![];
            let mut name = None;
            let mut id = None;

            let answers_for_kdeconnect_service = dns_packet
                .answers()
                .any(|answer| answer.name() == MDNS_SERVICE_TYPE);

            if !answers_for_kdeconnect_service {
                continue;
            }

            for record in dns_packet.additional_records() {
                match record.rdata() {
                    Rdata::A { ip: new_ip } => {
                        ips.push(IpAddr::V4(*new_ip));
                    }
                    Rdata::AAAA { ip: new_ip } => {
                        ips.push(IpAddr::V6(*new_ip));
                    }
                    Rdata::TXT { text } => {
                        for item in text.clone().filter_map(Result::ok) {
                            if let Some(s) = item.strip_prefix("id=") {
                                id = Some(s.to_string());
                            } else if let Some(s) = item.strip_prefix("name=") {
                                name = Some(s.to_string());
                            }
                        }
                    }
                    _ => {}
                }
            }

            if id.is_none() {
                // If the ID is not set by the TXT record, fallback to the PTR answer
                for record in dns_packet.answers() {
                    if let Rdata::PTR { name } = record.rdata()
                        && let Some(new_id) = name
                            .to_string()
                            .strip_suffix(MDNS_SERVICE_TYPE)
                            .and_then(|s| s.strip_suffix("."))
                    {
                        id = Some(new_id.to_string());
                    }
                }
            }

            if id.as_ref().is_some_and(|id| *id == device.host_device_id) {
                continue;
            }

            if let (Some(id), Some(name)) = (id, name) {
                log::debug!(
                    "Discovered {name} (with ips {ips:?} and id {id}) using MDNS, sending an identity message using UDP"
                );

                if device.links.lock().await.contains_key(&id) {
                    log::debug!(
                        "Device {id} has already established connection, ignore the MDNS request"
                    );
                    continue;
                }

                let my_identity_packet = device.get_identity_packet();
                let serialized_my_identity_packet =
                    serde_json::to_string(&my_identity_packet).unwrap();

                for ip in ips {
                    let bound_addr = match ip {
                        IpAddr::V4(_) => IpAddr::V4(Ipv4Addr::UNSPECIFIED),
                        IpAddr::V6(_) => IpAddr::V6(Ipv6Addr::UNSPECIFIED),
                    };
                    if let Ok(mut socket) = device
                        .io_impl
                        .bind_udp(SocketAddr::new(bound_addr, 0))
                        .await
                    {
                        let _ = socket
                            .send_to(
                                serialized_my_identity_packet.as_bytes(),
                                SocketAddr::new(ip, crate::config::UDP_PORT),
                            )
                            .await;
                    }
                }
            }
        }
    }
}

fn prepare_kdeconnect_question(packet_buf: &mut [u8]) -> Packet<'_> {
    let mut service_type_name_buf = [0; MDNS_SERVICE_TYPE.len() + 2];

    Packet::builder(packet_buf)
        .question(DnsQuestion::ptr(Name::from_str_into_buf(
            MDNS_SERVICE_TYPE,
            &mut service_type_name_buf,
        )))
        .build()
}

fn prepare_kdeconnect_answer<
    Io: IoImpl<UdpSocket, TcpStream, TcpListener, TlsStream> + Unpin + 'static,
    UdpSocket: UdpSocketImpl + Unpin + 'static,
    TcpStream: TcpStreamImpl + Unpin + 'static,
    TcpListener: TcpListenerImpl<TcpStream> + Unpin + 'static,
    TlsStream: TlsStreamImpl + Unpin + 'static,
>(
    packet_buf: &mut [u8],
    my_ipv4_addr: Ipv4Addr,
    my_ipv6_addr: Option<Ipv6Addr>,
    device: &Device<Io, UdpSocket, TcpStream, TcpListener, TlsStream>,
) {
    const FAKE_HOST_NAME: &str = "host.local";

    let service_id = format!("{}.{}", device.host_device_id, MDNS_SERVICE_TYPE);
    let text_payload = &[
        "protocol=8",
        &format!("id={}", device.host_device_id),
        &format!("name={}", device.config.name),
        &format!("type={}", device.config.device_type),
    ];

    let mut service_type_name_buf = [0; MDNS_SERVICE_TYPE.len() + 2];
    let mut service_id_name_buf = Name::max_buffer();
    let mut host_name_buf = [0; FAKE_HOST_NAME.len() + 2];
    let mut text_record_buf = [0; 512];

    let service_type_name = Name::from_str_into_buf(MDNS_SERVICE_TYPE, &mut service_type_name_buf);
    let service_id_name = Name::from_str_into_buf(&service_id, &mut service_id_name_buf);
    let host_name = Name::from_str_into_buf(FAKE_HOST_NAME, &mut host_name_buf);

    let mut answer = Packet::builder(packet_buf)
        .answer(DnsRecord::ptr(
            service_type_name.clone(),
            service_id_name.clone(),
        ))
        .additional_record(DnsRecord::srv(
            service_id_name.clone(),
            0,
            0,
            crate::config::UDP_PORT,
            host_name.clone(),
        ))
        .additional_record(DnsRecord::a(host_name.clone(), my_ipv4_addr))
        .additional_record(DnsRecord::txt(
            service_id_name,
            *text_payload,
            &mut text_record_buf,
        ));

    if let Some(my_ipv6_addr) = my_ipv6_addr {
        answer = answer.additional_record(DnsRecord::aaaa(host_name, my_ipv6_addr));
    }

    // Packet::builder directly builds the packet in-place in the provided buffer so it's not
    // needed to return the answer
    let _ = answer;
}