librqbit-dualstack-sockets 0.6.7

utilities for creating dual-stack TCP and UDP tokio sockets
Documentation
#[cfg(test)]
mod tests;

use std::{
    future::poll_fn,
    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
    task::Poll,
};

use network_interface::{NetworkInterface, NetworkInterfaceConfig};
use socket2::SockRef;
use tracing::{debug, trace};

use crate::{
    BindOpts, Error, UdpSocket,
    addr::{Ipv6AddrExt, WithScopeId},
};

/// An IPv6 + IPv4 multicast socket that sends payloads generated by user callbacks to all
/// interfaces.
pub struct MulticastUdpSocket {
    sock: UdpSocket,
    ipv4_addr: SocketAddrV4,
    ipv6_site_local: SocketAddrV6,
    ipv6_link_local: Option<SocketAddrV6>,
    nics: Vec<NetworkInterface>,
}

impl MulticastUdpSocket {
    pub async fn new(
        bind_addr: SocketAddr,
        ipv4_mcast_addr: SocketAddrV4,
        ipv6_site_local_addr: SocketAddrV6,
        ipv6_link_local_addr: Option<SocketAddrV6>,
    ) -> crate::Result<Self> {
        if let Some(ll) = ipv6_link_local_addr {
            if !ll.ip().is_link_local_mcast() {
                return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
            }
        }
        if !ipv6_site_local_addr.ip().is_site_local_mcast() {
            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
        }
        let nics = network_interface::NetworkInterface::show()
            .into_iter()
            .flatten()
            .collect::<Vec<_>>();
        if nics.is_empty() {
            return Err(Error::NoNics);
        }
        let opts = BindOpts {
            request_dualstack: true,
            reuseport: true,
            device: None,
        };
        let sock = UdpSocket::bind_udp(bind_addr, opts)?;
        let sock = Self {
            sock,
            ipv4_addr: ipv4_mcast_addr,
            ipv6_link_local: ipv6_link_local_addr,
            ipv6_site_local: ipv6_site_local_addr,
            nics,
        };
        sock.bind_multicast().await?;
        Ok(sock)
    }

    pub fn nics(&self) -> &[NetworkInterface] {
        &self.nics
    }

    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
        self.sock.recv_from(buf).await
    }

    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
        // Ensure the multicast option is erased before sending
        poll_fn(|cx| {
            let sref = SockRef::from(self.sock.socket());
            if self.sock.bind_addr().is_ipv6() {
                if let Err(e) = sref.set_multicast_if_v6(0) {
                    trace!("error calling set_multicast_if_v6(0): {e:#}")
                }
            } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
                trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
            }

            self.sock.poll_send_to(cx, buf, addr)
        })
        .await
    }

    async fn bind_multicast(&self) -> crate::Result<()> {
        let mut joined = false;
        if self.sock.bind_addr().is_ipv4() {
            joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
        }

        for nic in self.nics.iter() {
            let mut has_link_local = false;
            let mut has_site_local = false;

            for addr in nic.addr.iter() {
                match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
                    (IpAddr::V4(iface_addr), is_ipv6)
                        if iface_addr.is_private() && !iface_addr.is_loopback() =>
                    {
                        let is_linux_or_windows =
                            cfg!(any(target_os = "linux", target_os = "windows"));
                        if !is_ipv6 || is_linux_or_windows {
                            joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
                        } else {
                            joined |= try_join_v6(
                                &self.sock,
                                self.ipv4_addr.ip().to_ipv6_mapped(),
                                nic.index,
                            )
                        }
                    }
                    (IpAddr::V6(addr), true) => {
                        if addr.is_loopback() {
                            continue;
                        }
                        if addr.is_unicast_link_local() {
                            has_link_local = true;
                        } else {
                            has_site_local = true;
                        }
                    }
                    _ => continue,
                }
            }

            if has_site_local {
                joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
            }

            if let Some(ll) = self.ipv6_link_local {
                if has_link_local {
                    joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
                }
            }
        }

        if !joined {
            return Err(Error::MulticastJoinFail);
        }

        self.sock
            .socket()
            .writable()
            .await
            .map_err(Error::Writeable)?;

        Ok(())
    }

    pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
        self.nics()
            .iter()
            .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
            .find_map(|(nic, naddr)| {
                let nm = naddr.netmask();
                let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
                    // For link-local addresses, we reply back to the nic from scope_id, if there's a multicast link-local
                    // address
                    (SocketAddr::V6(addr), _, _, Some(mlocal))
                        if addr.ip().is_unicast_link_local() =>
                    {
                        if nic.index != addr.scope_id() {
                            return None;
                        }
                        mlocal.with_scope_id(nic.index).into()
                    }

                    // For ULAs, multicast to site-local address if in the same netmask
                    (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
                        if addr.ip().is_unique_local()
                            && addr.ip().to_bits() & mask.to_bits()
                                == naddr.to_bits() & mask.to_bits() =>
                    {
                        self.ipv6_site_local.into()
                    }

                    // For IPv4, if the mask matches, determine the interface.
                    (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
                        if addr.ip().to_bits() & mask.to_bits()
                            == naddr.to_bits() & mask.to_bits() =>
                    {
                        self.ipv4_addr.into()
                    }
                    _ => return None,
                };
                Some(MulticastOpts {
                    interface_id: nic.index,
                    interface_addr: naddr.ip(),
                    mcast_addr,
                })
            })
    }

    pub async fn send_multicast_msg(
        &self,
        buf: &[u8],
        opts: &MulticastOpts,
    ) -> crate::Result<usize> {
        // This is .poll_fn() so that we call .set_multicast_if_*() immediately before sending a packet.
        // If it's repolled it'll get called again just before the send.
        poll_fn(|cx| {
            let sref = SockRef::from(self.sock.socket());
            let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
            let is_linux = cfg!(target_os = "linux");

            // send ipv4 if either (is_linux && target=ipv4) or (!is_linux && bind_addr=ipv4)

            match (opts.mcast_addr(), opts.iface_ip(), bind_is_ipv6, is_linux) {
                // on linux, v4 multicast messages are sent with IP_MULTICAST_IF
                // on other platforms' dualstack sockets they are sent with IPV6_MULTICAST_IF
                (SocketAddr::V4(_), IpAddr::V4(addr), _, true)
                | (SocketAddr::V4(_), IpAddr::V4(addr), false, false) => {
                    sref.set_multicast_if_v4(&addr)
                        .map_err(Error::SetMulticastIpv4)?;
                }
                (SocketAddr::V6(_), IpAddr::V6(_), _, _)
                | (SocketAddr::V4(_), IpAddr::V4(_), _, _) => {
                    sref.set_multicast_if_v6(opts.interface_id)
                        .map_err(Error::SetMulticastIpv6)?;
                }
                _ => return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch)),
            }

            self.sock
                .poll_send_to(cx, buf, opts.mcast_addr)
                .map_err(Error::Send)
        })
        .await
    }

    pub async fn try_send_mcast_everywhere(
        &self,
        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
    ) {
        let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();

        let mut send_specs = self
            .nics
            .iter()
            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
            .filter_map(|(ifidx, ifaddr)| {
                let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
                    (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
                        self.ipv4_addr.into()
                    }
                    (true, IpAddr::V6(a), Some(mlocal))
                        if !a.is_loopback() && a.is_unicast_link_local() =>
                    {
                        mlocal.with_scope_id(ifidx).into()
                    }
                    (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
                        self.ipv6_site_local.into()
                    }
                    _ => {
                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
                        return None;
                    }
                };
                Some(MulticastOpts {
                    interface_id: ifidx,
                    interface_addr: ifaddr,
                    mcast_addr,
                })
            })
            .collect::<Vec<_>>();

        send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
        send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));

        let futs = send_specs.into_iter().filter_map(|opts| {
            let payload = get_payload(&opts)?;
            let fut = async move {
                match self.send_multicast_msg(payload.as_bytes(), &opts).await {
                    Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
                    Err(e) => {
                        debug!(?opts, payload=?payload, "error sending: {e:#}")
                    }
                };
            };
            Some(fut)
        });

        futures::future::join_all(futs).await;
    }
}

fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
        return false;
    }
    true
}

fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
        return false;
    }
    true
}

#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
pub struct MulticastOpts {
    pub interface_id: u32,
    pub interface_addr: IpAddr,
    pub mcast_addr: SocketAddr,
}

impl MulticastOpts {
    pub fn iface_ip(&self) -> IpAddr {
        self.interface_addr
    }

    pub fn mcast_addr(&self) -> SocketAddr {
        self.mcast_addr
    }

    fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
        if bind_addr_is_ipv6 {
            (Some(self.interface_id), None, self.mcast_addr)
        } else {
            (None, Some(self.interface_addr), self.mcast_addr)
        }
    }
}