zeroconf 0.18.0

cross-platform library that wraps ZeroConf/mDNS implementations like Bonjour or Avahi
Documentation
//! Bonjour implementation for cross-platform browser

use super::service_ref::{
    BrowseServicesParams, GetAddressInfoParams, ManagedDNSServiceRef, ServiceResolveParams,
};
use super::txt_record_ref::ManagedTXTRecordRef;
use super::{bonjour_util, constants};
use crate::ffi::{AsRaw, FromRaw, c_str};
use crate::prelude::*;
use crate::{BrowserEvent, ServiceBrowserCallback, ServiceDiscovery, ServiceRemoval};
use crate::{EventLoop, NetworkInterface, Result, ServiceType, TxtRecord};
#[cfg(target_vendor = "pc")]
use bonjour_sys::sockaddr_in;
use bonjour_sys::{DNSServiceErrorType, DNSServiceFlags, DNSServiceRef};
#[cfg(any(target_vendor = "apple", target_os = "freebsd"))]
use libc::sockaddr_in;
use libc::{c_char, c_uchar, c_void};
use std::any::Any;
use std::ffi::CString;
use std::fmt::{self, Formatter};
use std::net::IpAddr;
use std::ptr;
use std::sync::{Arc, Mutex};

#[derive(Debug)]
pub struct BonjourMdnsBrowser {
    service: Arc<Mutex<ManagedDNSServiceRef>>,
    kind: CString,
    interface_index: u32,
    context: Box<BonjourBrowserContext>,
}

unsafe impl Send for BonjourMdnsBrowser {}
unsafe impl Sync for BonjourMdnsBrowser {}

impl TMdnsBrowser for BonjourMdnsBrowser {
    fn new(service_type: ServiceType) -> Self {
        Self {
            service: Arc::default(),
            kind: bonjour_util::format_regtype(&service_type),
            interface_index: constants::BONJOUR_IF_UNSPEC,
            context: Box::default(),
        }
    }

    fn set_network_interface(&mut self, interface: NetworkInterface) {
        self.interface_index = bonjour_util::interface_index(interface);
    }

    fn network_interface(&self) -> NetworkInterface {
        bonjour_util::interface_from_index(self.interface_index)
    }

    fn set_service_callback(&mut self, service_discovered_callback: Box<ServiceBrowserCallback>) {
        self.context.service_discovered_callback = Some(service_discovered_callback);
    }

    fn set_context(&mut self, context: Box<dyn Any + Send + Sync>) {
        self.context.user_context = Some(Arc::from(context));
    }

    fn context(&self) -> Option<&(dyn Any + Send + Sync)> {
        self.context.user_context.as_ref().map(|c| c.as_ref())
    }

    fn browse_services(&mut self) -> Result<EventLoop> {
        debug!("Browsing services: {:?}", self);

        let mut service_lock = self
            .service
            .lock()
            .expect("should have been able to obtain lock on service ref");

        let browse_params = BrowseServicesParams::builder()
            .flags(0)
            .interface_index(self.interface_index)
            .regtype(self.kind.as_ptr())
            .domain(ptr::null_mut())
            .callback(Some(browse_callback))
            .context(self.context.as_raw())
            .build()?;

        unsafe { service_lock.browse_services(browse_params)? };

        Ok(EventLoop::new(self.service.clone()))
    }
}

#[derive(Default, FromRaw, AsRaw)]
struct BonjourBrowserContext {
    service_discovered_callback: Option<Box<ServiceBrowserCallback>>,
    resolved_name: Option<String>,
    resolved_kind: Option<String>,
    resolved_domain: Option<String>,
    resolved_port: u16,
    resolved_txt: Option<TxtRecord>,
    user_context: Option<Arc<dyn Any + Send + Sync>>,
}

impl BonjourBrowserContext {
    fn invoke_callback(&self, result: Result<BrowserEvent>) {
        if let Some(f) = &self.service_discovered_callback {
            f(result, self.user_context.clone());
        } else {
            warn!("attempted to invoke callback but none was set");
        }
    }
}

impl fmt::Debug for BonjourBrowserContext {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        f.debug_struct("BonjourResolverContext")
            .field("resolved_name", &self.resolved_name)
            .field("resolved_kind", &self.resolved_kind)
            .field("resolved_domain", &self.resolved_domain)
            .field("resolved_port", &self.resolved_port)
            .finish()
    }
}

unsafe impl Send for BonjourBrowserContext {}
unsafe impl Sync for BonjourBrowserContext {}

unsafe extern "system" fn browse_callback(
    _sd_ref: DNSServiceRef,
    flags: DNSServiceFlags,
    interface_index: u32,
    error: DNSServiceErrorType,
    name: *const c_char,
    regtype: *const c_char,
    domain: *const c_char,
    context: *mut c_void,
) {
    let ctx = unsafe { BonjourBrowserContext::from_raw(context) };

    if error != 0 {
        ctx.invoke_callback(Err(format!(
            "browse_callback() reported error (code: {})",
            error
        )
        .into()));
        return;
    }

    if flags & bonjour_sys::kDNSServiceFlagsAdd != 0 {
        if let Err(e) = unsafe { handle_browse_add(ctx, name, regtype, domain, interface_index) } {
            ctx.invoke_callback(Err(e));
        }
    } else {
        unsafe { handle_browse_remove(ctx, name, regtype, domain) };
    }
}

unsafe fn handle_browse_add(
    ctx: &mut BonjourBrowserContext,
    name: *const c_char,
    regtype: *const c_char,
    domain: *const c_char,
    interface_index: u32,
) -> Result<()> {
    ctx.resolved_name = Some(unsafe { c_str::copy_raw(name) });
    ctx.resolved_kind = Some(unsafe { c_str::copy_raw(regtype) });
    ctx.resolved_domain = Some(unsafe { c_str::copy_raw(domain) });

    unsafe {
        ManagedDNSServiceRef::default().resolve_service(
            ServiceResolveParams::builder()
                .flags(bonjour_sys::kDNSServiceFlagsForceMulticast)
                .interface_index(interface_index)
                .name(name)
                .regtype(regtype)
                .domain(domain)
                .callback(Some(resolve_callback))
                .context(ctx.as_raw())
                .build()?,
        )
    }
}

unsafe fn handle_browse_remove(
    ctx: &mut BonjourBrowserContext,
    name: *const c_char,
    regtype: *const c_char,
    domain: *const c_char,
) {
    let name = unsafe { c_str::raw_to_str(name) };
    let regtype = unsafe { c_str::raw_to_str(regtype) };
    let domain = unsafe { c_str::raw_to_str(domain) };

    // Remove the "." suffix to be consistent with the Avahi implementation.
    let regtype = regtype.strip_suffix(".").unwrap_or(domain);
    let domain = domain.strip_suffix(".").unwrap_or(domain);

    ctx.invoke_callback(Ok(BrowserEvent::Remove(
        ServiceRemoval::builder()
            .name(name.to_string())
            .kind(regtype.to_string())
            .domain(domain.to_string())
            .build()
            .expect("could not build ServiceRemoval"),
    )));
}

unsafe extern "system" fn resolve_callback(
    _sd_ref: DNSServiceRef,
    _flags: DNSServiceFlags,
    interface_index: u32,
    error: DNSServiceErrorType,
    _fullname: *const c_char,
    host_target: *const c_char,
    port: u16,
    txt_len: u16,
    txt_record: *const c_uchar,
    context: *mut c_void,
) {
    let ctx = unsafe { BonjourBrowserContext::from_raw(context) };

    let result = unsafe {
        handle_resolve(
            ctx,
            error,
            port,
            interface_index,
            host_target,
            txt_len,
            txt_record,
        )
    };

    if let Err(e) = result {
        ctx.invoke_callback(Err(e));
    }
}

unsafe fn handle_resolve(
    ctx: &mut BonjourBrowserContext,
    error: DNSServiceErrorType,
    port: u16,
    interface_index: u32,
    host_target: *const c_char,
    txt_len: u16,
    txt_record: *const c_uchar,
) -> Result<()> {
    if error != 0 {
        return Err(format!("error reported by resolve_callback: (code: {})", error).into());
    }

    ctx.resolved_port = port;

    ctx.resolved_txt = if txt_len > 1 {
        Some(TxtRecord::from(unsafe {
            ManagedTXTRecordRef::clone_raw(txt_record, txt_len)?
        }))
    } else {
        None
    };

    unsafe {
        ManagedDNSServiceRef::default().get_address_info(
            GetAddressInfoParams::builder()
                .flags(bonjour_sys::kDNSServiceFlagsForceMulticast)
                .interface_index(interface_index)
                .protocol(0)
                .hostname(host_target)
                .callback(Some(get_address_info_callback))
                .context(ctx.as_raw())
                .build()?,
        )
    }
}

unsafe extern "system" fn get_address_info_callback(
    _sd_ref: DNSServiceRef,
    _flags: DNSServiceFlags,
    _interface_index: u32,
    error: DNSServiceErrorType,
    hostname: *const c_char,
    address: *const bonjour_sys::sockaddr,
    _ttl: u32,
    context: *mut c_void,
) {
    let ctx = unsafe { BonjourBrowserContext::from_raw(context) };
    if let Err(e) = unsafe { handle_get_address_info(ctx, error, address, hostname) } {
        ctx.invoke_callback(Err(e));
    }
}

unsafe fn handle_get_address_info(
    ctx: &mut BonjourBrowserContext,
    error: DNSServiceErrorType,
    address: *const bonjour_sys::sockaddr,
    hostname: *const c_char,
) -> Result<()> {
    // this callback runs multiple times for some reason
    if ctx.resolved_name.is_none() {
        return Ok(());
    }

    if error != 0 {
        return Err(format!(
            "get_address_info_callback() reported error (code: {})",
            error
        )
        .into());
    }

    // on macOS the bytes are swapped for the port
    let port: u16 = ctx.resolved_port.to_be();

    // on macOS the bytes are swapped for the ip
    #[cfg(any(target_vendor = "apple", target_os = "freebsd"))]
    let ip = {
        let address = address as *const sockaddr_in;
        assert_not_null!(address);
        let s_addr = unsafe { (*address).sin_addr.s_addr.to_le_bytes() };
        IpAddr::from(s_addr).to_string()
    };

    #[cfg(target_vendor = "pc")]
    let ip = {
        let address = address as *const sockaddr_in;
        assert_not_null!(address);
        let s_un = unsafe { (*address).sin_addr.S_un.S_un_b };
        let s_addr = [s_un.s_b1, s_un.s_b2, s_un.s_b3, s_un.s_b4];
        IpAddr::from(s_addr).to_string()
    };

    let hostname = unsafe { c_str::copy_raw(hostname) };

    let domain = bonjour_util::normalize_domain(
        &ctx.resolved_domain
            .take()
            .ok_or("could not get domain from BonjourBrowserContext")?,
    );

    let kind = bonjour_util::normalize_domain(
        &ctx.resolved_kind
            .take()
            .ok_or("could not get kind from BonjourBrowserContext")?,
    );

    let name = ctx
        .resolved_name
        .take()
        .ok_or("could not get name from BonjourBrowserContext")?;

    let result = ServiceDiscovery::builder()
        .name(name)
        .service_type(bonjour_util::parse_regtype(&kind)?)
        .domain(domain)
        .host_name(hostname)
        .address(ip)
        .port(port)
        .txt(ctx.resolved_txt.take())
        .build()
        .expect("could not build ServiceResolution");

    ctx.invoke_callback(Ok(BrowserEvent::Add(result)));

    Ok(())
}