Skip to main content

librqbit_dualstack_sockets/
multicast.rs

1#[cfg(test)]
2mod tests;
3
4use std::{
5    future::poll_fn,
6    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7    task::Poll,
8};
9
10use network_interface::{NetworkInterface, NetworkInterfaceConfig};
11use socket2::SockRef;
12use tracing::{debug, trace};
13
14use crate::{
15    BindDevice, BindOpts, Error, UdpSocket,
16    addr::{Ipv6AddrExt, WithScopeId},
17};
18
19/// An IPv6 + IPv4 multicast socket that sends payloads generated by user callbacks to all
20/// interfaces.
21pub struct MulticastUdpSocket {
22    sock: UdpSocket,
23    ipv4_addr: SocketAddrV4,
24    ipv6_site_local: SocketAddrV6,
25    ipv6_link_local: Option<SocketAddrV6>,
26    nics: Vec<NetworkInterface>,
27}
28
29impl MulticastUdpSocket {
30    pub async fn new(
31        bind_addr: SocketAddr,
32        ipv4_mcast_addr: SocketAddrV4,
33        ipv6_site_local_addr: SocketAddrV6,
34        ipv6_link_local_addr: Option<SocketAddrV6>,
35        bind_device: Option<&BindDevice>,
36    ) -> crate::Result<Self> {
37        if let Some(ll) = ipv6_link_local_addr
38            && !ll.ip().is_link_local_mcast()
39        {
40            return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
41        }
42        if !ipv6_site_local_addr.ip().is_site_local_mcast() {
43            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
44        }
45        let nics = network_interface::NetworkInterface::show()
46            .into_iter()
47            .flatten()
48            .filter(|nic| bind_device.is_none_or(|bd| bd.index().get() == nic.index))
49            .collect::<Vec<_>>();
50        if nics.is_empty() {
51            return Err(Error::NoNics);
52        }
53        let opts = BindOpts {
54            request_dualstack: true,
55            reuseport: true,
56            device: bind_device,
57        };
58        let sock = UdpSocket::bind_udp(bind_addr, opts)?;
59        let sock = Self {
60            sock,
61            ipv4_addr: ipv4_mcast_addr,
62            ipv6_link_local: ipv6_link_local_addr,
63            ipv6_site_local: ipv6_site_local_addr,
64            nics,
65        };
66        sock.bind_multicast().await?;
67        Ok(sock)
68    }
69
70    pub fn nics(&self) -> &[NetworkInterface] {
71        &self.nics
72    }
73
74    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
75        self.sock.recv_from(buf).await
76    }
77
78    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
79        // Ensure the multicast option is erased before sending
80        poll_fn(|cx| {
81            let sref = SockRef::from(self.sock.socket());
82            if self.sock.bind_addr().is_ipv6() {
83                if let Err(e) = sref.set_multicast_if_v6(0) {
84                    trace!("error calling set_multicast_if_v6(0): {e:#}")
85                }
86            } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
87                trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
88            }
89
90            self.sock.poll_send_to(cx, buf, addr)
91        })
92        .await
93    }
94
95    async fn bind_multicast(&self) -> crate::Result<()> {
96        let mut joined = false;
97        if self.sock.bind_addr().is_ipv4() {
98            joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
99        }
100
101        for nic in self.nics.iter() {
102            let mut has_link_local = false;
103            let mut has_site_local = false;
104
105            for addr in nic.addr.iter() {
106                match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
107                    (IpAddr::V4(iface_addr), is_ipv6)
108                        if iface_addr.is_private() || iface_addr.is_loopback() =>
109                    {
110                        let is_linux_or_windows =
111                            cfg!(any(target_os = "linux", target_os = "windows"));
112                        if !is_ipv6 || is_linux_or_windows {
113                            joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
114                        } else {
115                            joined |= try_join_v6(
116                                &self.sock,
117                                self.ipv4_addr.ip().to_ipv6_mapped(),
118                                nic.index,
119                            )
120                        }
121                    }
122                    (IpAddr::V6(addr), true) => {
123                        if addr.is_unicast_link_local() {
124                            has_link_local = true;
125                        } else {
126                            has_site_local = true;
127                        }
128                    }
129                    _ => continue,
130                }
131            }
132
133            if has_site_local {
134                joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
135            }
136
137            if let Some(ll) = self.ipv6_link_local
138                && has_link_local
139            {
140                joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
141            }
142        }
143
144        if !joined {
145            return Err(Error::MulticastJoinFail);
146        }
147
148        self.sock
149            .socket()
150            .writable()
151            .await
152            .map_err(Error::Writeable)?;
153
154        Ok(())
155    }
156
157    pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
158        self.nics()
159            .iter()
160            .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
161            .find_map(|(nic, naddr)| {
162                let nm = naddr.netmask();
163                let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
164                    // For link-local addresses, we reply back to the nic from scope_id, if there's a multicast link-local
165                    // address
166                    (SocketAddr::V6(addr), _, _, Some(mlocal))
167                        if addr.ip().is_unicast_link_local() =>
168                    {
169                        if nic.index != addr.scope_id() {
170                            return None;
171                        }
172                        mlocal.with_scope_id(nic.index).into()
173                    }
174
175                    // For ULAs, multicast to site-local address if in the same netmask
176                    (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
177                        if addr.ip().is_unique_local()
178                            && addr.ip().to_bits() & mask.to_bits()
179                                == naddr.to_bits() & mask.to_bits() =>
180                    {
181                        self.ipv6_site_local.into()
182                    }
183
184                    // For IPv4, if the mask matches, determine the interface.
185                    (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
186                        if addr.ip().to_bits() & mask.to_bits()
187                            == naddr.to_bits() & mask.to_bits() =>
188                    {
189                        self.ipv4_addr.into()
190                    }
191                    _ => return None,
192                };
193                Some(MulticastOpts {
194                    interface_id: nic.index,
195                    interface_addr: naddr.ip(),
196                    mcast_addr,
197                })
198            })
199    }
200
201    pub async fn send_multicast_msg(
202        &self,
203        buf: &[u8],
204        opts: &MulticastOpts,
205    ) -> crate::Result<usize> {
206        // This is .poll_fn() so that we call .set_multicast_if_*() immediately before sending a packet.
207        // If it's repolled it'll get called again just before the send.
208        poll_fn(|cx| {
209            let sref = SockRef::from(self.sock.socket());
210            let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
211            let is_linux = cfg!(target_os = "linux");
212
213            // send ipv4 if either (is_linux && target=ipv4) or (!is_linux && bind_addr=ipv4)
214
215            match (opts.mcast_addr(), opts.iface_ip(), bind_is_ipv6, is_linux) {
216                // on linux, v4 multicast messages are sent with IP_MULTICAST_IF
217                // on other platforms' dualstack sockets they are sent with IPV6_MULTICAST_IF
218                (SocketAddr::V4(_), IpAddr::V4(addr), _, true)
219                | (SocketAddr::V4(_), IpAddr::V4(addr), false, false) => {
220                    sref.set_multicast_if_v4(&addr)
221                        .map_err(Error::SetMulticastIpv4)?;
222                }
223                (SocketAddr::V6(_), IpAddr::V6(_), _, _)
224                | (SocketAddr::V4(_), IpAddr::V4(_), _, _) => {
225                    sref.set_multicast_if_v6(opts.interface_id)
226                        .map_err(Error::SetMulticastIpv6)?;
227                }
228                _ => return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch)),
229            }
230
231            self.sock
232                .poll_send_to(cx, buf, opts.mcast_addr)
233                .map_err(Error::Send)
234        })
235        .await
236    }
237
238    pub async fn try_send_mcast_everywhere(
239        &self,
240        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
241    ) {
242        let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
243
244        let mut send_specs = self
245            .nics
246            .iter()
247            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
248            .filter_map(|(ifidx, ifaddr)| {
249                let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
250                    (_, IpAddr::V4(a), _) if a.is_private() || a.is_loopback() => {
251                        self.ipv4_addr.into()
252                    }
253                    (true, IpAddr::V6(a), Some(mlocal)) if a.is_unicast_link_local() => {
254                        mlocal.with_scope_id(ifidx).into()
255                    }
256                    (true, IpAddr::V6(a), _) if a.is_unique_local() => self.ipv6_site_local.into(),
257                    _ => {
258                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
259                        return None;
260                    }
261                };
262                Some(MulticastOpts {
263                    interface_id: ifidx,
264                    interface_addr: ifaddr,
265                    mcast_addr,
266                })
267            })
268            .collect::<Vec<_>>();
269
270        send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
271        send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
272
273        let futs = send_specs.into_iter().filter_map(|opts| {
274            let payload = get_payload(&opts)?;
275            let fut = async move {
276                match self.send_multicast_msg(payload.as_bytes(), &opts).await {
277                    Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
278                    Err(e) => {
279                        debug!(?opts, payload=?payload, "error sending: {e:#}")
280                    }
281                };
282            };
283            Some(fut)
284        });
285
286        futures::future::join_all(futs).await;
287    }
288}
289
290fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
291    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
292    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
293        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
294        return false;
295    }
296    true
297}
298
299fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
300    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
301    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
302        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
303        return false;
304    }
305    true
306}
307
308#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
309pub struct MulticastOpts {
310    pub interface_id: u32,
311    pub interface_addr: IpAddr,
312    pub mcast_addr: SocketAddr,
313}
314
315impl MulticastOpts {
316    pub fn iface_ip(&self) -> IpAddr {
317        self.interface_addr
318    }
319
320    pub fn mcast_addr(&self) -> SocketAddr {
321        self.mcast_addr
322    }
323
324    fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
325        if bind_addr_is_ipv6 {
326            (Some(self.interface_id), None, self.mcast_addr)
327        } else {
328            (None, Some(self.interface_addr), self.mcast_addr)
329        }
330    }
331}