librqbit_dualstack_sockets/
multicast.rs

1#[cfg(test)]
2mod tests;
3
4use std::{
5    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6    task::Poll,
7};
8
9use network_interface::{NetworkInterface, NetworkInterfaceConfig};
10use socket2::SockRef;
11use tracing::{debug, trace};
12
13use crate::{BindOpts, Error, UdpSocket, addr::WithScopeId};
14
15/// An IPv6 + IPv4 multicast socket that sends payloads generated by user callbacks to all
16/// interfaces.
17pub struct MulticastUdpSocket {
18    sock: UdpSocket,
19    ipv4_addr: SocketAddrV4,
20    ipv6_site_local: SocketAddrV6,
21    ipv6_link_local: Option<SocketAddrV6>,
22    nics: Vec<NetworkInterface>,
23}
24
25impl MulticastUdpSocket {
26    pub async fn new(
27        bind_addr: SocketAddr,
28        ipv4_mcast_addr: SocketAddrV4,
29        ipv6_site_local_addr: SocketAddrV6,
30        ipv6_link_local_addr: Option<SocketAddrV6>,
31    ) -> crate::Result<Self> {
32        if let Some(ll) = ipv6_link_local_addr {
33            if !ipv6_is_link_local_mcast(ll.ip()) {
34                return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
35            }
36        }
37        if !ipv6_is_site_local_mcast(ipv6_site_local_addr.ip()) {
38            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
39        }
40        let nics = network_interface::NetworkInterface::show()
41            .into_iter()
42            .flatten()
43            .collect::<Vec<_>>();
44        if nics.is_empty() {
45            return Err(Error::NoNics);
46        }
47        let opts = BindOpts {
48            request_dualstack: true,
49            reuseport: true,
50        };
51        let sock = UdpSocket::bind_udp(bind_addr, opts)?;
52        let sock = Self {
53            sock,
54            ipv4_addr: ipv4_mcast_addr,
55            ipv6_link_local: ipv6_link_local_addr,
56            ipv6_site_local: ipv6_site_local_addr,
57            nics,
58        };
59        sock.bind_multicast().await?;
60        Ok(sock)
61    }
62
63    pub fn nics(&self) -> &[NetworkInterface] {
64        &self.nics
65    }
66
67    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
68        self.sock.recv_from(buf).await
69    }
70
71    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
72        // Ensure the multicast option is erased before sending
73        std::future::poll_fn(|cx| {
74            let sref = SockRef::from(self.sock.socket());
75            if self.sock.bind_addr().is_ipv6() {
76                if let Err(e) = sref.set_multicast_if_v6(0) {
77                    trace!("error calling set_multicast_if_v6(0): {e:#}")
78                }
79            } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
80                trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
81            }
82
83            self.sock.poll_send_to(cx, buf, addr)
84        })
85        .await
86    }
87
88    async fn bind_multicast(&self) -> crate::Result<()> {
89        let mut joined = false;
90        if self.sock.bind_addr().is_ipv4() {
91            joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
92        }
93
94        for nic in self.nics.iter() {
95            let mut has_link_local = false;
96            let mut has_site_local = false;
97
98            for addr in nic.addr.iter() {
99                match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
100                    (IpAddr::V4(iface_addr), is_ipv6)
101                        if iface_addr.is_private() && !iface_addr.is_loopback() =>
102                    {
103                        if !is_ipv6 {
104                            joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
105                        } else {
106                            joined |= try_join_v6(
107                                &self.sock,
108                                self.ipv4_addr.ip().to_ipv6_mapped(),
109                                nic.index,
110                            )
111                        }
112                    }
113                    (IpAddr::V6(addr), true) => {
114                        if addr.is_loopback() {
115                            continue;
116                        }
117                        if ipv6_is_link_local(&addr) {
118                            has_link_local = true;
119                        } else {
120                            has_site_local = true;
121                        }
122                    }
123                    _ => continue,
124                }
125            }
126
127            if has_site_local {
128                joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
129            }
130
131            if let Some(ll) = self.ipv6_link_local {
132                if has_link_local {
133                    joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
134                }
135            }
136        }
137
138        if !joined {
139            return Err(Error::MulticastJoinFail);
140        }
141
142        self.sock
143            .socket()
144            .writable()
145            .await
146            .map_err(Error::Writeable)?;
147
148        Ok(())
149    }
150
151    pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
152        self.nics()
153            .iter()
154            .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
155            .find_map(|(nic, naddr)| {
156                let nm = naddr.netmask();
157                let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
158                    // For link-local addresses, we reply back to the nic from scope_id, if there's a multicast link-local
159                    // address
160                    (SocketAddr::V6(addr), _, _, Some(mlocal)) if ipv6_is_link_local(addr.ip()) => {
161                        if nic.index != addr.scope_id() {
162                            return None;
163                        }
164                        mlocal.with_scope_id(nic.index).into()
165                    }
166
167                    // For ULAs, multicast to site-local address if in the same netmask
168                    (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
169                        if ipv6_is_unique_local_address(addr.ip())
170                            && addr.ip().to_bits() & mask.to_bits()
171                                == naddr.to_bits() & mask.to_bits() =>
172                    {
173                        self.ipv6_site_local.into()
174                    }
175
176                    // For IPv4, if the mask matches, determine the interface.
177                    (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
178                        if addr.ip().to_bits() & mask.to_bits()
179                            == naddr.to_bits() & mask.to_bits() =>
180                    {
181                        self.ipv4_addr.into()
182                    }
183                    _ => return None,
184                };
185                Some(MulticastOpts {
186                    interface_id: nic.index,
187                    interface_addr: naddr.ip(),
188                    mcast_addr,
189                })
190            })
191    }
192
193    pub async fn send_multicast_msg(
194        &self,
195        buf: &[u8],
196        opts: &MulticastOpts,
197    ) -> crate::Result<usize> {
198        // This is .poll_fn() so that we call .set_multicast() immediately before sending a packet.
199        // If it's repolled it'll get called again just before the send.
200
201        std::future::poll_fn(|cx| {
202            let sref = SockRef::from(self.sock.socket());
203            if self.sock.bind_addr().is_ipv6() {
204                if let Err(e) = sref.set_multicast_if_v6(opts.interface_id) {
205                    return Poll::Ready(Err(Error::SetMulticastIpv6(e)));
206                }
207            } else {
208                let ifaddr = match opts.interface_addr {
209                    IpAddr::V4(ipv4_addr) => ipv4_addr,
210                    IpAddr::V6(_) => {
211                        return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch));
212                    }
213                };
214                if let Err(e) = sref.set_multicast_if_v4(&ifaddr) {
215                    return Poll::Ready(Err(Error::SetMulticastIpv4(e)));
216                }
217            };
218
219            self.sock
220                .poll_send_to(cx, buf, opts.mcast_addr)
221                .map_err(Error::Send)
222        })
223        .await
224    }
225
226    pub async fn try_send_mcast_everywhere(
227        &self,
228        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
229    ) {
230        let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
231
232        let mut send_specs = self
233            .nics
234            .iter()
235            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
236            .filter_map(|(ifidx, ifaddr)| {
237                let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
238                    (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
239                        self.ipv4_addr.into()
240                    }
241                    (true, IpAddr::V6(a), Some(mlocal))
242                        if !a.is_loopback() && ipv6_is_link_local(&a) =>
243                    {
244                        mlocal.with_scope_id(ifidx).into()
245                    }
246                    (true, IpAddr::V6(a), _)
247                        if !a.is_loopback() && ipv6_is_unique_local_address(&a) =>
248                    {
249                        self.ipv6_site_local.into()
250                    }
251                    _ => {
252                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
253                        return None;
254                    }
255                };
256                Some(MulticastOpts {
257                    interface_id: ifidx,
258                    interface_addr: ifaddr,
259                    mcast_addr,
260                })
261            })
262            .collect::<Vec<_>>();
263
264        send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
265        send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
266
267        let futs = send_specs.into_iter().filter_map(|opts| {
268            let payload = get_payload(&opts)?;
269            let fut = async move {
270                match self.send_multicast_msg(payload.as_bytes(), &opts).await {
271                    Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
272                    Err(e) => {
273                        debug!(?opts, payload=?payload, "error sending: {e:#}")
274                    }
275                };
276            };
277            Some(fut)
278        });
279
280        futures::future::join_all(futs).await;
281    }
282}
283
284fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
285    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
286    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
287        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
288        return false;
289    }
290    true
291}
292
293fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
294    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
295    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
296        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
297        return false;
298    }
299    true
300}
301
302fn ipv6_is_link_local(ip: &Ipv6Addr) -> bool {
303    const LL: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0);
304    const MASK: Ipv6Addr = Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
305
306    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
307}
308
309fn ipv6_is_unique_local_address(ip: &Ipv6Addr) -> bool {
310    const LL: Ipv6Addr = Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 0);
311    const MASK: Ipv6Addr = Ipv6Addr::new(0b1111111000000000, 0, 0, 0, 0, 0, 0, 0);
312
313    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
314}
315
316fn ipv6_is_link_local_mcast(ip: &Ipv6Addr) -> bool {
317    const LL: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0);
318    const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
319
320    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
321}
322
323fn ipv6_is_site_local_mcast(ip: &Ipv6Addr) -> bool {
324    const LL: Ipv6Addr = Ipv6Addr::new(0xff05, 0, 0, 0, 0, 0, 0, 0);
325    const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
326
327    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
328}
329
330#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
331pub struct MulticastOpts {
332    pub interface_id: u32,
333    pub interface_addr: IpAddr,
334    pub mcast_addr: SocketAddr,
335}
336
337impl MulticastOpts {
338    pub fn iface_ip(&self) -> IpAddr {
339        self.interface_addr
340    }
341
342    pub fn mcast_addr(&self) -> SocketAddr {
343        self.mcast_addr
344    }
345
346    fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
347        if bind_addr_is_ipv6 {
348            (Some(self.interface_id), None, self.mcast_addr)
349        } else {
350            (None, Some(self.interface_addr), self.mcast_addr)
351        }
352    }
353}