librqbit_dualstack_sockets/
multicast.rs

1#[cfg(test)]
2mod tests;
3
4use std::{
5    future::poll_fn,
6    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7    task::Poll,
8};
9
10use network_interface::{NetworkInterface, NetworkInterfaceConfig};
11use socket2::SockRef;
12use tracing::{debug, trace};
13
14use crate::{
15    BindOpts, Error, UdpSocket,
16    addr::{Ipv6AddrExt, WithScopeId},
17};
18
19/// An IPv6 + IPv4 multicast socket that sends payloads generated by user callbacks to all
20/// interfaces.
21pub struct MulticastUdpSocket {
22    sock: UdpSocket,
23    ipv4_addr: SocketAddrV4,
24    ipv6_site_local: SocketAddrV6,
25    ipv6_link_local: Option<SocketAddrV6>,
26    nics: Vec<NetworkInterface>,
27}
28
29impl MulticastUdpSocket {
30    pub async fn new(
31        bind_addr: SocketAddr,
32        ipv4_mcast_addr: SocketAddrV4,
33        ipv6_site_local_addr: SocketAddrV6,
34        ipv6_link_local_addr: Option<SocketAddrV6>,
35    ) -> crate::Result<Self> {
36        if let Some(ll) = ipv6_link_local_addr {
37            if !ll.ip().is_link_local_mcast() {
38                return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
39            }
40        }
41        if !ipv6_site_local_addr.ip().is_site_local_mcast() {
42            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
43        }
44        let nics = network_interface::NetworkInterface::show()
45            .into_iter()
46            .flatten()
47            .collect::<Vec<_>>();
48        if nics.is_empty() {
49            return Err(Error::NoNics);
50        }
51        let opts = BindOpts {
52            request_dualstack: true,
53            reuseport: true,
54            device: None,
55        };
56        let sock = UdpSocket::bind_udp(bind_addr, opts)?;
57        let sock = Self {
58            sock,
59            ipv4_addr: ipv4_mcast_addr,
60            ipv6_link_local: ipv6_link_local_addr,
61            ipv6_site_local: ipv6_site_local_addr,
62            nics,
63        };
64        sock.bind_multicast().await?;
65        Ok(sock)
66    }
67
68    pub fn nics(&self) -> &[NetworkInterface] {
69        &self.nics
70    }
71
72    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
73        self.sock.recv_from(buf).await
74    }
75
76    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
77        // Ensure the multicast option is erased before sending
78        poll_fn(|cx| {
79            let sref = SockRef::from(self.sock.socket());
80            if self.sock.bind_addr().is_ipv6() {
81                if let Err(e) = sref.set_multicast_if_v6(0) {
82                    trace!("error calling set_multicast_if_v6(0): {e:#}")
83                }
84            } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
85                trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
86            }
87
88            self.sock.poll_send_to(cx, buf, addr)
89        })
90        .await
91    }
92
93    async fn bind_multicast(&self) -> crate::Result<()> {
94        let mut joined = false;
95        if self.sock.bind_addr().is_ipv4() {
96            joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
97        }
98
99        for nic in self.nics.iter() {
100            let mut has_link_local = false;
101            let mut has_site_local = false;
102
103            for addr in nic.addr.iter() {
104                match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
105                    (IpAddr::V4(iface_addr), is_ipv6)
106                        if iface_addr.is_private() && !iface_addr.is_loopback() =>
107                    {
108                        let is_linux_or_windows =
109                            cfg!(any(target_os = "linux", target_os = "windows"));
110                        if !is_ipv6 || is_linux_or_windows {
111                            joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
112                        } else {
113                            joined |= try_join_v6(
114                                &self.sock,
115                                self.ipv4_addr.ip().to_ipv6_mapped(),
116                                nic.index,
117                            )
118                        }
119                    }
120                    (IpAddr::V6(addr), true) => {
121                        if addr.is_loopback() {
122                            continue;
123                        }
124                        if addr.is_unicast_link_local() {
125                            has_link_local = true;
126                        } else {
127                            has_site_local = true;
128                        }
129                    }
130                    _ => continue,
131                }
132            }
133
134            if has_site_local {
135                joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
136            }
137
138            if let Some(ll) = self.ipv6_link_local {
139                if has_link_local {
140                    joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
141                }
142            }
143        }
144
145        if !joined {
146            return Err(Error::MulticastJoinFail);
147        }
148
149        self.sock
150            .socket()
151            .writable()
152            .await
153            .map_err(Error::Writeable)?;
154
155        Ok(())
156    }
157
158    pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
159        self.nics()
160            .iter()
161            .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
162            .find_map(|(nic, naddr)| {
163                let nm = naddr.netmask();
164                let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
165                    // For link-local addresses, we reply back to the nic from scope_id, if there's a multicast link-local
166                    // address
167                    (SocketAddr::V6(addr), _, _, Some(mlocal))
168                        if addr.ip().is_unicast_link_local() =>
169                    {
170                        if nic.index != addr.scope_id() {
171                            return None;
172                        }
173                        mlocal.with_scope_id(nic.index).into()
174                    }
175
176                    // For ULAs, multicast to site-local address if in the same netmask
177                    (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
178                        if addr.ip().is_unique_local()
179                            && addr.ip().to_bits() & mask.to_bits()
180                                == naddr.to_bits() & mask.to_bits() =>
181                    {
182                        self.ipv6_site_local.into()
183                    }
184
185                    // For IPv4, if the mask matches, determine the interface.
186                    (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
187                        if addr.ip().to_bits() & mask.to_bits()
188                            == naddr.to_bits() & mask.to_bits() =>
189                    {
190                        self.ipv4_addr.into()
191                    }
192                    _ => return None,
193                };
194                Some(MulticastOpts {
195                    interface_id: nic.index,
196                    interface_addr: naddr.ip(),
197                    mcast_addr,
198                })
199            })
200    }
201
202    pub async fn send_multicast_msg(
203        &self,
204        buf: &[u8],
205        opts: &MulticastOpts,
206    ) -> crate::Result<usize> {
207        // This is .poll_fn() so that we call .set_multicast_if_*() immediately before sending a packet.
208        // If it's repolled it'll get called again just before the send.
209        poll_fn(|cx| {
210            let sref = SockRef::from(self.sock.socket());
211            let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
212            let is_linux = cfg!(target_os = "linux");
213
214            // send ipv4 if either (is_linux && target=ipv4) or (!is_linux && bind_addr=ipv4)
215
216            match (opts.mcast_addr(), opts.iface_ip(), bind_is_ipv6, is_linux) {
217                // on linux, v4 multicast messages are sent with IP_MULTICAST_IF
218                // on other platforms' dualstack sockets they are sent with IPV6_MULTICAST_IF
219                (SocketAddr::V4(_), IpAddr::V4(addr), _, true)
220                | (SocketAddr::V4(_), IpAddr::V4(addr), false, false) => {
221                    sref.set_multicast_if_v4(&addr)
222                        .map_err(Error::SetMulticastIpv4)?;
223                }
224                (SocketAddr::V6(_), IpAddr::V6(_), _, _)
225                | (SocketAddr::V4(_), IpAddr::V4(_), _, _) => {
226                    sref.set_multicast_if_v6(opts.interface_id)
227                        .map_err(Error::SetMulticastIpv6)?;
228                }
229                _ => return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch)),
230            }
231
232            self.sock
233                .poll_send_to(cx, buf, opts.mcast_addr)
234                .map_err(Error::Send)
235        })
236        .await
237    }
238
239    pub async fn try_send_mcast_everywhere(
240        &self,
241        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
242    ) {
243        let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
244
245        let mut send_specs = self
246            .nics
247            .iter()
248            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
249            .filter_map(|(ifidx, ifaddr)| {
250                let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
251                    (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
252                        self.ipv4_addr.into()
253                    }
254                    (true, IpAddr::V6(a), Some(mlocal))
255                        if !a.is_loopback() && a.is_unicast_link_local() =>
256                    {
257                        mlocal.with_scope_id(ifidx).into()
258                    }
259                    (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
260                        self.ipv6_site_local.into()
261                    }
262                    _ => {
263                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
264                        return None;
265                    }
266                };
267                Some(MulticastOpts {
268                    interface_id: ifidx,
269                    interface_addr: ifaddr,
270                    mcast_addr,
271                })
272            })
273            .collect::<Vec<_>>();
274
275        send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
276        send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
277
278        let futs = send_specs.into_iter().filter_map(|opts| {
279            let payload = get_payload(&opts)?;
280            let fut = async move {
281                match self.send_multicast_msg(payload.as_bytes(), &opts).await {
282                    Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
283                    Err(e) => {
284                        debug!(?opts, payload=?payload, "error sending: {e:#}")
285                    }
286                };
287            };
288            Some(fut)
289        });
290
291        futures::future::join_all(futs).await;
292    }
293}
294
295fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
296    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
297    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
298        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
299        return false;
300    }
301    true
302}
303
304fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
305    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
306    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
307        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
308        return false;
309    }
310    true
311}
312
313#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
314pub struct MulticastOpts {
315    pub interface_id: u32,
316    pub interface_addr: IpAddr,
317    pub mcast_addr: SocketAddr,
318}
319
320impl MulticastOpts {
321    pub fn iface_ip(&self) -> IpAddr {
322        self.interface_addr
323    }
324
325    pub fn mcast_addr(&self) -> SocketAddr {
326        self.mcast_addr
327    }
328
329    fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
330        if bind_addr_is_ipv6 {
331            (Some(self.interface_id), None, self.mcast_addr)
332        } else {
333            (None, Some(self.interface_addr), self.mcast_addr)
334        }
335    }
336}