librqbit_dualstack_sockets/
multicast.rs

1#[cfg(test)]
2mod tests;
3
4use std::{
5    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6    task::Poll,
7};
8
9use network_interface::{NetworkInterface, NetworkInterfaceConfig};
10use socket2::SockRef;
11use tracing::{debug, trace};
12
13use crate::{
14    BindOpts, Error, UdpSocket,
15    addr::{Ipv6AddrExt, WithScopeId},
16};
17
18/// An IPv6 + IPv4 multicast socket that sends payloads generated by user callbacks to all
19/// interfaces.
20pub struct MulticastUdpSocket {
21    sock: UdpSocket,
22    ipv4_addr: SocketAddrV4,
23    ipv6_site_local: SocketAddrV6,
24    ipv6_link_local: Option<SocketAddrV6>,
25    nics: Vec<NetworkInterface>,
26}
27
28impl MulticastUdpSocket {
29    pub async fn new(
30        bind_addr: SocketAddr,
31        ipv4_mcast_addr: SocketAddrV4,
32        ipv6_site_local_addr: SocketAddrV6,
33        ipv6_link_local_addr: Option<SocketAddrV6>,
34    ) -> crate::Result<Self> {
35        if let Some(ll) = ipv6_link_local_addr {
36            if !ll.ip().is_link_local_mcast() {
37                return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
38            }
39        }
40        if !ipv6_site_local_addr.ip().is_site_local_mcast() {
41            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
42        }
43        let nics = network_interface::NetworkInterface::show()
44            .into_iter()
45            .flatten()
46            .collect::<Vec<_>>();
47        if nics.is_empty() {
48            return Err(Error::NoNics);
49        }
50        let opts = BindOpts {
51            request_dualstack: true,
52            reuseport: true,
53        };
54        let sock = UdpSocket::bind_udp(bind_addr, opts)?;
55        let sock = Self {
56            sock,
57            ipv4_addr: ipv4_mcast_addr,
58            ipv6_link_local: ipv6_link_local_addr,
59            ipv6_site_local: ipv6_site_local_addr,
60            nics,
61        };
62        sock.bind_multicast().await?;
63        Ok(sock)
64    }
65
66    pub fn nics(&self) -> &[NetworkInterface] {
67        &self.nics
68    }
69
70    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
71        self.sock.recv_from(buf).await
72    }
73
74    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
75        // Ensure the multicast option is erased before sending
76        std::future::poll_fn(|cx| {
77            let sref = SockRef::from(self.sock.socket());
78            if self.sock.bind_addr().is_ipv6() {
79                if let Err(e) = sref.set_multicast_if_v6(0) {
80                    trace!("error calling set_multicast_if_v6(0): {e:#}")
81                }
82            } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
83                trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
84            }
85
86            self.sock.poll_send_to(cx, buf, addr)
87        })
88        .await
89    }
90
91    async fn bind_multicast(&self) -> crate::Result<()> {
92        let mut joined = false;
93        if self.sock.bind_addr().is_ipv4() {
94            joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
95        }
96
97        for nic in self.nics.iter() {
98            let mut has_link_local = false;
99            let mut has_site_local = false;
100
101            for addr in nic.addr.iter() {
102                match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
103                    (IpAddr::V4(iface_addr), is_ipv6)
104                        if iface_addr.is_private() && !iface_addr.is_loopback() =>
105                    {
106                        let is_linux = cfg!(target_os = "linux");
107                        if !is_ipv6 || is_linux {
108                            joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
109                        } else {
110                            joined |= try_join_v6(
111                                &self.sock,
112                                self.ipv4_addr.ip().to_ipv6_mapped(),
113                                nic.index,
114                            )
115                        }
116                    }
117                    (IpAddr::V6(addr), true) => {
118                        if addr.is_loopback() {
119                            continue;
120                        }
121                        if addr.is_unicast_link_local() {
122                            has_link_local = true;
123                        } else {
124                            has_site_local = true;
125                        }
126                    }
127                    _ => continue,
128                }
129            }
130
131            if has_site_local {
132                joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
133            }
134
135            if let Some(ll) = self.ipv6_link_local {
136                if has_link_local {
137                    joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
138                }
139            }
140        }
141
142        if !joined {
143            return Err(Error::MulticastJoinFail);
144        }
145
146        self.sock
147            .socket()
148            .writable()
149            .await
150            .map_err(Error::Writeable)?;
151
152        Ok(())
153    }
154
155    pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
156        self.nics()
157            .iter()
158            .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
159            .find_map(|(nic, naddr)| {
160                let nm = naddr.netmask();
161                let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
162                    // For link-local addresses, we reply back to the nic from scope_id, if there's a multicast link-local
163                    // address
164                    (SocketAddr::V6(addr), _, _, Some(mlocal))
165                        if addr.ip().is_unicast_link_local() =>
166                    {
167                        if nic.index != addr.scope_id() {
168                            return None;
169                        }
170                        mlocal.with_scope_id(nic.index).into()
171                    }
172
173                    // For ULAs, multicast to site-local address if in the same netmask
174                    (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
175                        if addr.ip().is_unique_local()
176                            && addr.ip().to_bits() & mask.to_bits()
177                                == naddr.to_bits() & mask.to_bits() =>
178                    {
179                        self.ipv6_site_local.into()
180                    }
181
182                    // For IPv4, if the mask matches, determine the interface.
183                    (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
184                        if addr.ip().to_bits() & mask.to_bits()
185                            == naddr.to_bits() & mask.to_bits() =>
186                    {
187                        self.ipv4_addr.into()
188                    }
189                    _ => return None,
190                };
191                Some(MulticastOpts {
192                    interface_id: nic.index,
193                    interface_addr: naddr.ip(),
194                    mcast_addr,
195                })
196            })
197    }
198
199    pub async fn send_multicast_msg(
200        &self,
201        buf: &[u8],
202        opts: &MulticastOpts,
203    ) -> crate::Result<usize> {
204        // This is .poll_fn() so that we call .set_multicast() immediately before sending a packet.
205        // If it's repolled it'll get called again just before the send.
206
207        std::future::poll_fn(|cx| {
208            let sref = SockRef::from(self.sock.socket());
209            if self.sock.bind_addr().is_ipv6() {
210                if let Err(e) = sref.set_multicast_if_v6(opts.interface_id) {
211                    return Poll::Ready(Err(Error::SetMulticastIpv6(e)));
212                }
213            } else {
214                let ifaddr = match opts.interface_addr {
215                    IpAddr::V4(ipv4_addr) => ipv4_addr,
216                    IpAddr::V6(_) => {
217                        return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch));
218                    }
219                };
220                if let Err(e) = sref.set_multicast_if_v4(&ifaddr) {
221                    return Poll::Ready(Err(Error::SetMulticastIpv4(e)));
222                }
223            };
224
225            self.sock
226                .poll_send_to(cx, buf, opts.mcast_addr)
227                .map_err(Error::Send)
228        })
229        .await
230    }
231
232    pub async fn try_send_mcast_everywhere(
233        &self,
234        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
235    ) {
236        let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
237
238        let mut send_specs = self
239            .nics
240            .iter()
241            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
242            .filter_map(|(ifidx, ifaddr)| {
243                let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
244                    (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
245                        self.ipv4_addr.into()
246                    }
247                    (true, IpAddr::V6(a), Some(mlocal))
248                        if !a.is_loopback() && a.is_unicast_link_local() =>
249                    {
250                        mlocal.with_scope_id(ifidx).into()
251                    }
252                    (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
253                        self.ipv6_site_local.into()
254                    }
255                    _ => {
256                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
257                        return None;
258                    }
259                };
260                Some(MulticastOpts {
261                    interface_id: ifidx,
262                    interface_addr: ifaddr,
263                    mcast_addr,
264                })
265            })
266            .collect::<Vec<_>>();
267
268        send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
269        send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
270
271        let futs = send_specs.into_iter().filter_map(|opts| {
272            let payload = get_payload(&opts)?;
273            let fut = async move {
274                match self.send_multicast_msg(payload.as_bytes(), &opts).await {
275                    Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
276                    Err(e) => {
277                        debug!(?opts, payload=?payload, "error sending: {e:#}")
278                    }
279                };
280            };
281            Some(fut)
282        });
283
284        futures::future::join_all(futs).await;
285    }
286}
287
288fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
289    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
290    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
291        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
292        return false;
293    }
294    true
295}
296
297fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
298    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
299    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
300        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
301        return false;
302    }
303    true
304}
305
306#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
307pub struct MulticastOpts {
308    pub interface_id: u32,
309    pub interface_addr: IpAddr,
310    pub mcast_addr: SocketAddr,
311}
312
313impl MulticastOpts {
314    pub fn iface_ip(&self) -> IpAddr {
315        self.interface_addr
316    }
317
318    pub fn mcast_addr(&self) -> SocketAddr {
319        self.mcast_addr
320    }
321
322    fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
323        if bind_addr_is_ipv6 {
324            (Some(self.interface_id), None, self.mcast_addr)
325        } else {
326            (None, Some(self.interface_addr), self.mcast_addr)
327        }
328    }
329}