librqbit_dualstack_sockets/
multicast.rs

1#[cfg(test)]
2mod tests;
3
4use std::{
5    future::poll_fn,
6    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7    task::Poll,
8};
9
10use network_interface::{NetworkInterface, NetworkInterfaceConfig};
11use socket2::SockRef;
12use tracing::{debug, trace};
13
14use crate::{
15    BindOpts, Error, UdpSocket,
16    addr::{Ipv6AddrExt, WithScopeId},
17};
18
19/// An IPv6 + IPv4 multicast socket that sends payloads generated by user callbacks to all
20/// interfaces.
21pub struct MulticastUdpSocket {
22    sock: UdpSocket,
23    ipv4_addr: SocketAddrV4,
24    ipv6_site_local: SocketAddrV6,
25    ipv6_link_local: Option<SocketAddrV6>,
26    nics: Vec<NetworkInterface>,
27}
28
29impl MulticastUdpSocket {
30    pub async fn new(
31        bind_addr: SocketAddr,
32        ipv4_mcast_addr: SocketAddrV4,
33        ipv6_site_local_addr: SocketAddrV6,
34        ipv6_link_local_addr: Option<SocketAddrV6>,
35    ) -> crate::Result<Self> {
36        if let Some(ll) = ipv6_link_local_addr {
37            if !ll.ip().is_link_local_mcast() {
38                return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
39            }
40        }
41        if !ipv6_site_local_addr.ip().is_site_local_mcast() {
42            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
43        }
44        let nics = network_interface::NetworkInterface::show()
45            .into_iter()
46            .flatten()
47            .collect::<Vec<_>>();
48        if nics.is_empty() {
49            return Err(Error::NoNics);
50        }
51        let opts = BindOpts {
52            request_dualstack: true,
53            reuseport: true,
54        };
55        let sock = UdpSocket::bind_udp(bind_addr, opts)?;
56        let sock = Self {
57            sock,
58            ipv4_addr: ipv4_mcast_addr,
59            ipv6_link_local: ipv6_link_local_addr,
60            ipv6_site_local: ipv6_site_local_addr,
61            nics,
62        };
63        sock.bind_multicast().await?;
64        Ok(sock)
65    }
66
67    pub fn nics(&self) -> &[NetworkInterface] {
68        &self.nics
69    }
70
71    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
72        self.sock.recv_from(buf).await
73    }
74
75    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
76        // Ensure the multicast option is erased before sending
77        poll_fn(|cx| {
78            let sref = SockRef::from(self.sock.socket());
79            if self.sock.bind_addr().is_ipv6() {
80                if let Err(e) = sref.set_multicast_if_v6(0) {
81                    trace!("error calling set_multicast_if_v6(0): {e:#}")
82                }
83            } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
84                trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
85            }
86
87            self.sock.poll_send_to(cx, buf, addr)
88        })
89        .await
90    }
91
92    async fn bind_multicast(&self) -> crate::Result<()> {
93        let mut joined = false;
94        if self.sock.bind_addr().is_ipv4() {
95            joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
96        }
97
98        for nic in self.nics.iter() {
99            let mut has_link_local = false;
100            let mut has_site_local = false;
101
102            for addr in nic.addr.iter() {
103                match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
104                    (IpAddr::V4(iface_addr), is_ipv6)
105                        if iface_addr.is_private() && !iface_addr.is_loopback() =>
106                    {
107                        let is_linux = cfg!(target_os = "linux");
108                        if !is_ipv6 || is_linux {
109                            joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
110                        } else {
111                            joined |= try_join_v6(
112                                &self.sock,
113                                self.ipv4_addr.ip().to_ipv6_mapped(),
114                                nic.index,
115                            )
116                        }
117                    }
118                    (IpAddr::V6(addr), true) => {
119                        if addr.is_loopback() {
120                            continue;
121                        }
122                        if addr.is_unicast_link_local() {
123                            has_link_local = true;
124                        } else {
125                            has_site_local = true;
126                        }
127                    }
128                    _ => continue,
129                }
130            }
131
132            if has_site_local {
133                joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
134            }
135
136            if let Some(ll) = self.ipv6_link_local {
137                if has_link_local {
138                    joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
139                }
140            }
141        }
142
143        if !joined {
144            return Err(Error::MulticastJoinFail);
145        }
146
147        self.sock
148            .socket()
149            .writable()
150            .await
151            .map_err(Error::Writeable)?;
152
153        Ok(())
154    }
155
156    pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
157        self.nics()
158            .iter()
159            .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
160            .find_map(|(nic, naddr)| {
161                let nm = naddr.netmask();
162                let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
163                    // For link-local addresses, we reply back to the nic from scope_id, if there's a multicast link-local
164                    // address
165                    (SocketAddr::V6(addr), _, _, Some(mlocal))
166                        if addr.ip().is_unicast_link_local() =>
167                    {
168                        if nic.index != addr.scope_id() {
169                            return None;
170                        }
171                        mlocal.with_scope_id(nic.index).into()
172                    }
173
174                    // For ULAs, multicast to site-local address if in the same netmask
175                    (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
176                        if addr.ip().is_unique_local()
177                            && addr.ip().to_bits() & mask.to_bits()
178                                == naddr.to_bits() & mask.to_bits() =>
179                    {
180                        self.ipv6_site_local.into()
181                    }
182
183                    // For IPv4, if the mask matches, determine the interface.
184                    (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
185                        if addr.ip().to_bits() & mask.to_bits()
186                            == naddr.to_bits() & mask.to_bits() =>
187                    {
188                        self.ipv4_addr.into()
189                    }
190                    _ => return None,
191                };
192                Some(MulticastOpts {
193                    interface_id: nic.index,
194                    interface_addr: naddr.ip(),
195                    mcast_addr,
196                })
197            })
198    }
199
200    pub async fn send_multicast_msg(
201        &self,
202        buf: &[u8],
203        opts: &MulticastOpts,
204    ) -> crate::Result<usize> {
205        // This is .poll_fn() so that we call .set_multicast_if_*() immediately before sending a packet.
206        // If it's repolled it'll get called again just before the send.
207        poll_fn(|cx| {
208            let sref = SockRef::from(self.sock.socket());
209            let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
210            let is_linux = cfg!(target_os = "linux");
211
212            // send ipv4 if either (is_linux && target=ipv4) or (!is_linux && bind_addr=ipv4)
213
214            match (opts.mcast_addr(), opts.iface_ip(), bind_is_ipv6, is_linux) {
215                // on linux, v4 multicast messages are sent with IP_MULTICAST_IF
216                // on other platforms' dualstack sockets they are sent with IPV6_MULTICAST_IF
217                (SocketAddr::V4(_), IpAddr::V4(addr), _, true)
218                | (SocketAddr::V4(_), IpAddr::V4(addr), false, false) => {
219                    sref.set_multicast_if_v4(&addr)
220                        .map_err(Error::SetMulticastIpv4)?;
221                }
222                (SocketAddr::V6(_), IpAddr::V6(_), _, _) => {
223                    sref.set_multicast_if_v6(opts.interface_id)
224                        .map_err(Error::SetMulticastIpv6)?;
225                }
226                _ => return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch)),
227            }
228
229            self.sock
230                .poll_send_to(cx, buf, opts.mcast_addr)
231                .map_err(Error::Send)
232        })
233        .await
234    }
235
236    pub async fn try_send_mcast_everywhere(
237        &self,
238        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
239    ) {
240        let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
241
242        let mut send_specs = self
243            .nics
244            .iter()
245            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
246            .filter_map(|(ifidx, ifaddr)| {
247                let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
248                    (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
249                        self.ipv4_addr.into()
250                    }
251                    (true, IpAddr::V6(a), Some(mlocal))
252                        if !a.is_loopback() && a.is_unicast_link_local() =>
253                    {
254                        mlocal.with_scope_id(ifidx).into()
255                    }
256                    (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
257                        self.ipv6_site_local.into()
258                    }
259                    _ => {
260                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
261                        return None;
262                    }
263                };
264                Some(MulticastOpts {
265                    interface_id: ifidx,
266                    interface_addr: ifaddr,
267                    mcast_addr,
268                })
269            })
270            .collect::<Vec<_>>();
271
272        send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
273        send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
274
275        let futs = send_specs.into_iter().filter_map(|opts| {
276            let payload = get_payload(&opts)?;
277            let fut = async move {
278                match self.send_multicast_msg(payload.as_bytes(), &opts).await {
279                    Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
280                    Err(e) => {
281                        debug!(?opts, payload=?payload, "error sending: {e:#}")
282                    }
283                };
284            };
285            Some(fut)
286        });
287
288        futures::future::join_all(futs).await;
289    }
290}
291
292fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
293    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
294    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
295        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
296        return false;
297    }
298    true
299}
300
301fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
302    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
303    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
304        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
305        return false;
306    }
307    true
308}
309
310#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
311pub struct MulticastOpts {
312    pub interface_id: u32,
313    pub interface_addr: IpAddr,
314    pub mcast_addr: SocketAddr,
315}
316
317impl MulticastOpts {
318    pub fn iface_ip(&self) -> IpAddr {
319        self.interface_addr
320    }
321
322    pub fn mcast_addr(&self) -> SocketAddr {
323        self.mcast_addr
324    }
325
326    fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
327        if bind_addr_is_ipv6 {
328            (Some(self.interface_id), None, self.mcast_addr)
329        } else {
330            (None, Some(self.interface_addr), self.mcast_addr)
331        }
332    }
333}