librqbit_dualstack_sockets/
multicast.rs

1#[cfg(test)]
2mod tests;
3
4use std::{
5    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6    task::Poll,
7};
8
9use network_interface::{NetworkInterface, NetworkInterfaceConfig};
10use socket2::SockRef;
11use tracing::{debug, trace};
12
13use crate::{
14    BindOpts, Error, UdpSocket,
15    addr::{Ipv6AddrExt, WithScopeId},
16};
17
18/// An IPv6 + IPv4 multicast socket that sends payloads generated by user callbacks to all
19/// interfaces.
20pub struct MulticastUdpSocket {
21    sock: UdpSocket,
22    ipv4_addr: SocketAddrV4,
23    ipv6_site_local: SocketAddrV6,
24    ipv6_link_local: Option<SocketAddrV6>,
25    nics: Vec<NetworkInterface>,
26}
27
28impl MulticastUdpSocket {
29    pub async fn new(
30        bind_addr: SocketAddr,
31        ipv4_mcast_addr: SocketAddrV4,
32        ipv6_site_local_addr: SocketAddrV6,
33        ipv6_link_local_addr: Option<SocketAddrV6>,
34    ) -> crate::Result<Self> {
35        if let Some(ll) = ipv6_link_local_addr {
36            if !ll.ip().is_link_local_mcast() {
37                return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
38            }
39        }
40        if !ipv6_site_local_addr.ip().is_site_local_mcast() {
41            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
42        }
43        let nics = network_interface::NetworkInterface::show()
44            .into_iter()
45            .flatten()
46            .collect::<Vec<_>>();
47        if nics.is_empty() {
48            return Err(Error::NoNics);
49        }
50        let opts = BindOpts {
51            request_dualstack: true,
52            reuseport: true,
53        };
54        let sock = UdpSocket::bind_udp(bind_addr, opts)?;
55        let sock = Self {
56            sock,
57            ipv4_addr: ipv4_mcast_addr,
58            ipv6_link_local: ipv6_link_local_addr,
59            ipv6_site_local: ipv6_site_local_addr,
60            nics,
61        };
62        sock.bind_multicast().await?;
63        Ok(sock)
64    }
65
66    pub fn nics(&self) -> &[NetworkInterface] {
67        &self.nics
68    }
69
70    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
71        self.sock.recv_from(buf).await
72    }
73
74    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
75        // Ensure the multicast option is erased before sending
76        std::future::poll_fn(|cx| {
77            let sref = SockRef::from(self.sock.socket());
78            if self.sock.bind_addr().is_ipv6() {
79                if let Err(e) = sref.set_multicast_if_v6(0) {
80                    trace!("error calling set_multicast_if_v6(0): {e:#}")
81                }
82            } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
83                trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
84            }
85
86            self.sock.poll_send_to(cx, buf, addr)
87        })
88        .await
89    }
90
91    async fn bind_multicast(&self) -> crate::Result<()> {
92        let mut joined = false;
93        if self.sock.bind_addr().is_ipv4() {
94            joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
95        }
96
97        for nic in self.nics.iter() {
98            let mut has_link_local = false;
99            let mut has_site_local = false;
100
101            for addr in nic.addr.iter() {
102                match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
103                    (IpAddr::V4(iface_addr), is_ipv6)
104                        if iface_addr.is_private() && !iface_addr.is_loopback() =>
105                    {
106                        if !is_ipv6 {
107                            joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
108                        } else {
109                            joined |= try_join_v6(
110                                &self.sock,
111                                self.ipv4_addr.ip().to_ipv6_mapped(),
112                                nic.index,
113                            )
114                        }
115                    }
116                    (IpAddr::V6(addr), true) => {
117                        if addr.is_loopback() {
118                            continue;
119                        }
120                        if addr.is_unicast_link_local() {
121                            has_link_local = true;
122                        } else {
123                            has_site_local = true;
124                        }
125                    }
126                    _ => continue,
127                }
128            }
129
130            if has_site_local {
131                joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
132            }
133
134            if let Some(ll) = self.ipv6_link_local {
135                if has_link_local {
136                    joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
137                }
138            }
139        }
140
141        if !joined {
142            return Err(Error::MulticastJoinFail);
143        }
144
145        self.sock
146            .socket()
147            .writable()
148            .await
149            .map_err(Error::Writeable)?;
150
151        Ok(())
152    }
153
154    pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
155        self.nics()
156            .iter()
157            .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
158            .find_map(|(nic, naddr)| {
159                let nm = naddr.netmask();
160                let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
161                    // For link-local addresses, we reply back to the nic from scope_id, if there's a multicast link-local
162                    // address
163                    (SocketAddr::V6(addr), _, _, Some(mlocal))
164                        if addr.ip().is_unicast_link_local() =>
165                    {
166                        if nic.index != addr.scope_id() {
167                            return None;
168                        }
169                        mlocal.with_scope_id(nic.index).into()
170                    }
171
172                    // For ULAs, multicast to site-local address if in the same netmask
173                    (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
174                        if addr.ip().is_unique_local()
175                            && addr.ip().to_bits() & mask.to_bits()
176                                == naddr.to_bits() & mask.to_bits() =>
177                    {
178                        self.ipv6_site_local.into()
179                    }
180
181                    // For IPv4, if the mask matches, determine the interface.
182                    (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
183                        if addr.ip().to_bits() & mask.to_bits()
184                            == naddr.to_bits() & mask.to_bits() =>
185                    {
186                        self.ipv4_addr.into()
187                    }
188                    _ => return None,
189                };
190                Some(MulticastOpts {
191                    interface_id: nic.index,
192                    interface_addr: naddr.ip(),
193                    mcast_addr,
194                })
195            })
196    }
197
198    pub async fn send_multicast_msg(
199        &self,
200        buf: &[u8],
201        opts: &MulticastOpts,
202    ) -> crate::Result<usize> {
203        // This is .poll_fn() so that we call .set_multicast() immediately before sending a packet.
204        // If it's repolled it'll get called again just before the send.
205
206        std::future::poll_fn(|cx| {
207            let sref = SockRef::from(self.sock.socket());
208            if self.sock.bind_addr().is_ipv6() {
209                if let Err(e) = sref.set_multicast_if_v6(opts.interface_id) {
210                    return Poll::Ready(Err(Error::SetMulticastIpv6(e)));
211                }
212            } else {
213                let ifaddr = match opts.interface_addr {
214                    IpAddr::V4(ipv4_addr) => ipv4_addr,
215                    IpAddr::V6(_) => {
216                        return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch));
217                    }
218                };
219                if let Err(e) = sref.set_multicast_if_v4(&ifaddr) {
220                    return Poll::Ready(Err(Error::SetMulticastIpv4(e)));
221                }
222            };
223
224            self.sock
225                .poll_send_to(cx, buf, opts.mcast_addr)
226                .map_err(Error::Send)
227        })
228        .await
229    }
230
231    pub async fn try_send_mcast_everywhere(
232        &self,
233        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
234    ) {
235        let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
236
237        let mut send_specs = self
238            .nics
239            .iter()
240            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
241            .filter_map(|(ifidx, ifaddr)| {
242                let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
243                    (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
244                        self.ipv4_addr.into()
245                    }
246                    (true, IpAddr::V6(a), Some(mlocal))
247                        if !a.is_loopback() && a.is_unicast_link_local() =>
248                    {
249                        mlocal.with_scope_id(ifidx).into()
250                    }
251                    (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
252                        self.ipv6_site_local.into()
253                    }
254                    _ => {
255                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
256                        return None;
257                    }
258                };
259                Some(MulticastOpts {
260                    interface_id: ifidx,
261                    interface_addr: ifaddr,
262                    mcast_addr,
263                })
264            })
265            .collect::<Vec<_>>();
266
267        send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
268        send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
269
270        let futs = send_specs.into_iter().filter_map(|opts| {
271            let payload = get_payload(&opts)?;
272            let fut = async move {
273                match self.send_multicast_msg(payload.as_bytes(), &opts).await {
274                    Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
275                    Err(e) => {
276                        debug!(?opts, payload=?payload, "error sending: {e:#}")
277                    }
278                };
279            };
280            Some(fut)
281        });
282
283        futures::future::join_all(futs).await;
284    }
285}
286
287fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
288    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
289    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
290        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
291        return false;
292    }
293    true
294}
295
296fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
297    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
298    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
299        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
300        return false;
301    }
302    true
303}
304
305#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
306pub struct MulticastOpts {
307    pub interface_id: u32,
308    pub interface_addr: IpAddr,
309    pub mcast_addr: SocketAddr,
310}
311
312impl MulticastOpts {
313    pub fn iface_ip(&self) -> IpAddr {
314        self.interface_addr
315    }
316
317    pub fn mcast_addr(&self) -> SocketAddr {
318        self.mcast_addr
319    }
320
321    fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
322        if bind_addr_is_ipv6 {
323            (Some(self.interface_id), None, self.mcast_addr)
324        } else {
325            (None, Some(self.interface_addr), self.mcast_addr)
326        }
327    }
328}