#[cfg(test)]
mod tests;
use std::{
future::poll_fn,
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
task::Poll,
};
use network_interface::{NetworkInterface, NetworkInterfaceConfig};
use socket2::SockRef;
use tracing::{debug, trace};
use crate::{
BindDevice, BindOpts, Error, UdpSocket,
addr::{Ipv6AddrExt, WithScopeId},
};
pub struct MulticastUdpSocket {
sock: UdpSocket,
ipv4_addr: SocketAddrV4,
ipv6_site_local: SocketAddrV6,
ipv6_link_local: Option<SocketAddrV6>,
nics: Vec<NetworkInterface>,
}
impl MulticastUdpSocket {
pub async fn new(
bind_addr: SocketAddr,
ipv4_mcast_addr: SocketAddrV4,
ipv6_site_local_addr: SocketAddrV6,
ipv6_link_local_addr: Option<SocketAddrV6>,
bind_device: Option<&BindDevice>,
) -> crate::Result<Self> {
if let Some(ll) = ipv6_link_local_addr {
if !ll.ip().is_link_local_mcast() {
return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
}
}
if !ipv6_site_local_addr.ip().is_site_local_mcast() {
return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
}
let nics = network_interface::NetworkInterface::show()
.into_iter()
.flatten()
.filter(|nic| bind_device.is_none_or(|bd| bd.index().get() == nic.index))
.collect::<Vec<_>>();
if nics.is_empty() {
return Err(Error::NoNics);
}
let opts = BindOpts {
request_dualstack: true,
reuseport: true,
device: bind_device,
};
let sock = UdpSocket::bind_udp(bind_addr, opts)?;
let sock = Self {
sock,
ipv4_addr: ipv4_mcast_addr,
ipv6_link_local: ipv6_link_local_addr,
ipv6_site_local: ipv6_site_local_addr,
nics,
};
sock.bind_multicast().await?;
Ok(sock)
}
pub fn nics(&self) -> &[NetworkInterface] {
&self.nics
}
pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
self.sock.recv_from(buf).await
}
pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
poll_fn(|cx| {
let sref = SockRef::from(self.sock.socket());
if self.sock.bind_addr().is_ipv6() {
if let Err(e) = sref.set_multicast_if_v6(0) {
trace!("error calling set_multicast_if_v6(0): {e:#}")
}
} else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
}
self.sock.poll_send_to(cx, buf, addr)
})
.await
}
async fn bind_multicast(&self) -> crate::Result<()> {
let mut joined = false;
if self.sock.bind_addr().is_ipv4() {
joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
}
for nic in self.nics.iter() {
let mut has_link_local = false;
let mut has_site_local = false;
for addr in nic.addr.iter() {
match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
(IpAddr::V4(iface_addr), is_ipv6)
if iface_addr.is_private() || iface_addr.is_loopback() =>
{
let is_linux_or_windows =
cfg!(any(target_os = "linux", target_os = "windows"));
if !is_ipv6 || is_linux_or_windows {
joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
} else {
joined |= try_join_v6(
&self.sock,
self.ipv4_addr.ip().to_ipv6_mapped(),
nic.index,
)
}
}
(IpAddr::V6(addr), true) => {
if addr.is_unicast_link_local() {
has_link_local = true;
} else {
has_site_local = true;
}
}
_ => continue,
}
}
if has_site_local {
joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
}
if let Some(ll) = self.ipv6_link_local {
if has_link_local {
joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
}
}
}
if !joined {
return Err(Error::MulticastJoinFail);
}
self.sock
.socket()
.writable()
.await
.map_err(Error::Writeable)?;
Ok(())
}
pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
self.nics()
.iter()
.flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
.find_map(|(nic, naddr)| {
let nm = naddr.netmask();
let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
(SocketAddr::V6(addr), _, _, Some(mlocal))
if addr.ip().is_unicast_link_local() =>
{
if nic.index != addr.scope_id() {
return None;
}
mlocal.with_scope_id(nic.index).into()
}
(SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
if addr.ip().is_unique_local()
&& addr.ip().to_bits() & mask.to_bits()
== naddr.to_bits() & mask.to_bits() =>
{
self.ipv6_site_local.into()
}
(SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
if addr.ip().to_bits() & mask.to_bits()
== naddr.to_bits() & mask.to_bits() =>
{
self.ipv4_addr.into()
}
_ => return None,
};
Some(MulticastOpts {
interface_id: nic.index,
interface_addr: naddr.ip(),
mcast_addr,
})
})
}
pub async fn send_multicast_msg(
&self,
buf: &[u8],
opts: &MulticastOpts,
) -> crate::Result<usize> {
poll_fn(|cx| {
let sref = SockRef::from(self.sock.socket());
let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
let is_linux = cfg!(target_os = "linux");
match (opts.mcast_addr(), opts.iface_ip(), bind_is_ipv6, is_linux) {
(SocketAddr::V4(_), IpAddr::V4(addr), _, true)
| (SocketAddr::V4(_), IpAddr::V4(addr), false, false) => {
sref.set_multicast_if_v4(&addr)
.map_err(Error::SetMulticastIpv4)?;
}
(SocketAddr::V6(_), IpAddr::V6(_), _, _)
| (SocketAddr::V4(_), IpAddr::V4(_), _, _) => {
sref.set_multicast_if_v6(opts.interface_id)
.map_err(Error::SetMulticastIpv6)?;
}
_ => return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch)),
}
self.sock
.poll_send_to(cx, buf, opts.mcast_addr)
.map_err(Error::Send)
})
.await
}
pub async fn try_send_mcast_everywhere(
&self,
get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
) {
let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
let mut send_specs = self
.nics
.iter()
.flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
.filter_map(|(ifidx, ifaddr)| {
let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
(_, IpAddr::V4(a), _) if a.is_private() || a.is_loopback() => {
self.ipv4_addr.into()
}
(true, IpAddr::V6(a), Some(mlocal)) if a.is_unicast_link_local() => {
mlocal.with_scope_id(ifidx).into()
}
(true, IpAddr::V6(a), _) if a.is_unique_local() => self.ipv6_site_local.into(),
_ => {
trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
return None;
}
};
Some(MulticastOpts {
interface_id: ifidx,
interface_addr: ifaddr,
mcast_addr,
})
})
.collect::<Vec<_>>();
send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
let futs = send_specs.into_iter().filter_map(|opts| {
let payload = get_payload(&opts)?;
let fut = async move {
match self.send_multicast_msg(payload.as_bytes(), &opts).await {
Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
Err(e) => {
debug!(?opts, payload=?payload, "error sending: {e:#}")
}
};
};
Some(fut)
});
futures::future::join_all(futs).await;
}
}
fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
return false;
}
true
}
fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
return false;
}
true
}
#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
pub struct MulticastOpts {
pub interface_id: u32,
pub interface_addr: IpAddr,
pub mcast_addr: SocketAddr,
}
impl MulticastOpts {
pub fn iface_ip(&self) -> IpAddr {
self.interface_addr
}
pub fn mcast_addr(&self) -> SocketAddr {
self.mcast_addr
}
fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
if bind_addr_is_ipv6 {
(Some(self.interface_id), None, self.mcast_addr)
} else {
(None, Some(self.interface_addr), self.mcast_addr)
}
}
}