librqbit_dualstack_sockets/
multicast.rs

1#[cfg(test)]
2mod tests;
3
4use std::{
5    collections::HashSet,
6    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7    sync::Mutex,
8    task::Poll,
9};
10
11use network_interface::{NetworkInterface, NetworkInterfaceConfig};
12use socket2::SockRef;
13use tracing::{debug, trace};
14
15use crate::{BindOpts, Error, UdpSocket};
16
17pub struct MulticastUdpSocket {
18    // At least on OSX, it multicast doesn't seem to work on dualstack sockets, so we need
19    // to create 2 of them.
20    sock_v4: UdpSocket,
21    sock_v6: UdpSocket,
22    ipv4_addr: Ipv4Addr,
23    ipv6_site_local: Ipv6Addr,
24    ipv6_link_local: Option<Ipv6Addr>,
25    nics: Vec<NetworkInterface>,
26}
27
28impl MulticastUdpSocket {
29    pub fn new(
30        port: u16,
31        ipv4_addr: Ipv4Addr,
32        ipv6_site_local: Ipv6Addr,
33        ipv6_link_local: Option<Ipv6Addr>,
34    ) -> crate::Result<Self> {
35        if let Some(ll) = ipv6_link_local {
36            if !ipv6_is_link_local_mcast(ll) {
37                return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
38            }
39        }
40        if !ipv6_is_site_local_mcast(ipv6_site_local) {
41            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
42        }
43        let nics = network_interface::NetworkInterface::show()
44            .into_iter()
45            .flatten()
46            .collect::<Vec<_>>();
47        if nics.is_empty() {
48            return Err(Error::NoNics);
49        }
50        let opts = BindOpts {
51            request_dualstack: false,
52            reuseport: true,
53        };
54        let sock_v4 = UdpSocket::bind_udp((Ipv4Addr::UNSPECIFIED, port).into(), opts)?;
55        let sock_v6 = UdpSocket::bind_udp((Ipv6Addr::UNSPECIFIED, port).into(), opts)?;
56        let sock = Self {
57            sock_v4,
58            sock_v6,
59            ipv4_addr,
60            ipv6_link_local,
61            ipv6_site_local,
62            nics,
63        };
64        sock.bind_multicast()?;
65        Ok(sock)
66    }
67
68    pub fn nics(&self) -> &[NetworkInterface] {
69        &self.nics
70    }
71
72    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
73        std::future::poll_fn(|cx| {
74            let mut buf = tokio::io::ReadBuf::new(buf);
75            if let Poll::Ready(res) = self.sock_v4.socket().poll_recv_from(cx, &mut buf) {
76                return Poll::Ready(res.map(|addr| (buf.filled().len(), addr)));
77            }
78            if let Poll::Ready(res) = self.sock_v6.socket().poll_recv_from(cx, &mut buf) {
79                return Poll::Ready(res.map(|addr| (buf.filled().len(), addr)));
80            }
81            Poll::Pending
82        })
83        .await
84    }
85
86    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
87        let sock = if addr.is_ipv6() {
88            &self.sock_v6
89        } else {
90            &self.sock_v4
91        };
92        sock.send_to(buf, addr).await
93    }
94
95    fn bind_multicast(&self) -> crate::Result<()> {
96        let mut joined = try_join_v4(&self.sock_v4, self.ipv4_addr, Ipv4Addr::UNSPECIFIED);
97
98        for nic in self.nics.iter() {
99            let mut has_link_local = false;
100            let mut has_site_local = false;
101
102            for addr in nic.addr.iter() {
103                match addr.ip() {
104                    IpAddr::V4(iface_addr)
105                        if iface_addr.is_private() && !iface_addr.is_loopback() =>
106                    {
107                        joined |= try_join_v4(&self.sock_v4, self.ipv4_addr, iface_addr);
108                    }
109                    IpAddr::V6(addr) => {
110                        if addr.is_loopback() {
111                            continue;
112                        }
113                        if ipv6_is_link_local(addr) {
114                            has_link_local = true;
115                        } else {
116                            has_site_local = true;
117                        }
118                    }
119                    _ => continue,
120                }
121            }
122
123            if has_site_local {
124                joined |= try_join_v6(&self.sock_v6, self.ipv6_site_local, nic.index);
125            }
126
127            if let Some(ll) = self.ipv6_link_local {
128                if has_link_local {
129                    joined |= try_join_v6(&self.sock_v6, ll, nic.index);
130                }
131            }
132        }
133
134        if !joined {
135            return Err(Error::MulticastJoinFail);
136        }
137
138        Ok(())
139    }
140
141    async fn send_to_once(&self, buf: &[u8], opts: &MulticastOpts) -> std::io::Result<usize> {
142        // This is .poll_fn() so that we call .set_multicast() immediately before sending a packet.
143        // If it's repolled it'll get called again just before the send.
144
145        std::future::poll_fn(|cx| {
146            let sock;
147            let mcast_addr_s: SocketAddr;
148
149            match opts {
150                MulticastOpts::V4 {
151                    interface_addr,
152                    mcast_addr,
153                } => {
154                    sock = &self.sock_v4;
155                    mcast_addr_s = (*mcast_addr).into();
156                    if let Err(e) = SockRef::from(sock.socket()).set_multicast_if_v4(interface_addr)
157                    {
158                        debug!(addr=%interface_addr, "error calling set_multicast_if_v4: {e:#}");
159                        return Poll::Ready(Err(e));
160                    }
161                }
162                MulticastOpts::V6 {
163                    interface_id,
164                    mcast_addr,
165                    ..
166                } => {
167                    sock = &self.sock_v6;
168                    mcast_addr_s = (*mcast_addr).into();
169                    if let Err(e) = SockRef::from(sock.socket()).set_multicast_if_v6(*interface_id)
170                    {
171                        debug!(
172                            oif_id = interface_id,
173                            "error calling set_multicast_if_v6: {e:#}"
174                        );
175                        return Poll::Ready(Err(e));
176                    }
177                }
178            }
179
180            sock.poll_send_to(cx, buf, mcast_addr_s)
181        })
182        .await
183    }
184
185    pub async fn try_send_mcast_everywhere(
186        &self,
187        get_payload: &impl Fn(&MulticastOpts) -> bstr::BString,
188    ) {
189        // Without this it blocks for some reason. Maybe we need to do it once in new(), so that all multicast joining
190        // messages are actually sent?
191        //
192        // It also works if we call .send_to() vs .poll_send_to() underneath. Maybe a bug in tokio/mio or I'm just
193        // misusing it.
194        let _ = self.sock_v6.socket().writable().await;
195
196        let sent = Mutex::new(HashSet::new());
197        let sent = &sent;
198
199        let port = self.sock_v4.bind_addr().port();
200
201        let futs = self
202            .nics
203            .iter()
204            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
205            .filter_map(|(ifidx, ifaddr)| {
206                let ipv6_link_local = self
207                    .ipv6_link_local
208                    .filter(|_| matches!(ifaddr, IpAddr::V6(v6) if ipv6_is_link_local(v6)));
209                let opts = match (ifaddr, ipv6_link_local) {
210                    (IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => MulticastOpts::V4 {
211                        interface_addr: a,
212                        mcast_addr: SocketAddrV4::new(self.ipv4_addr, port),
213                    },
214                    (IpAddr::V6(a), Some(mlocal)) if !a.is_loopback() => MulticastOpts::V6 {
215                        interface_id: ifidx,
216                        interface_addr: a,
217                        mcast_addr: SocketAddrV6::new(mlocal, port, 0, ifidx),
218                    },
219                    (IpAddr::V6(a), None) if !a.is_loopback() => MulticastOpts::V6 {
220                        interface_id: ifidx,
221                        interface_addr: a,
222                        mcast_addr: SocketAddrV6::new(self.ipv6_site_local, port, 0, ifidx),
223                    },
224                    _ => {
225                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
226                        return None;
227                    }
228                };
229                Some(opts)
230            })
231            .map(|opts| async move {
232                let payload = get_payload(&opts);
233                if !sent
234                    .lock()
235                    .unwrap()
236                    .insert((payload.clone(), opts.uniq_key()))
237                {
238                    trace!(?opts, "not sending duplicate payload");
239                    return;
240                }
241
242                match self.send_to_once(payload.as_slice(), &opts).await {
243                    Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
244                    Err(e) => {
245                        debug!(?opts, payload=?payload, "error sending: {e:#}")
246                    }
247                }
248            });
249
250        futures::future::join_all(futs).await;
251    }
252}
253
254fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
255    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
256    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
257        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
258        return false;
259    }
260    true
261}
262
263fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
264    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
265    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
266        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
267        return false;
268    }
269    true
270}
271
272fn ipv6_is_link_local(ip: Ipv6Addr) -> bool {
273    const LL: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0);
274    const MASK: Ipv6Addr = Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
275
276    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
277}
278
279fn ipv6_is_link_local_mcast(ip: Ipv6Addr) -> bool {
280    const LL: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0);
281    const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
282
283    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
284}
285
286fn ipv6_is_site_local_mcast(ip: Ipv6Addr) -> bool {
287    const LL: Ipv6Addr = Ipv6Addr::new(0xff05, 0, 0, 0, 0, 0, 0, 0);
288    const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
289
290    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
291}
292
293#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
294pub enum MulticastOpts {
295    V4 {
296        interface_addr: Ipv4Addr,
297        mcast_addr: SocketAddrV4,
298    },
299    V6 {
300        interface_id: u32,
301        interface_addr: Ipv6Addr,
302        mcast_addr: SocketAddrV6,
303    },
304}
305
306impl MulticastOpts {
307    pub fn iface_ip(&self) -> IpAddr {
308        match self {
309            MulticastOpts::V4 { interface_addr, .. } => (*interface_addr).into(),
310            MulticastOpts::V6 { interface_addr, .. } => (*interface_addr).into(),
311        }
312    }
313
314    pub fn mcast_addr(&self) -> SocketAddr {
315        match self {
316            MulticastOpts::V4 { mcast_addr, .. } => (*mcast_addr).into(),
317            MulticastOpts::V6 { mcast_addr, .. } => (*mcast_addr).into(),
318        }
319    }
320
321    fn uniq_key(&self) -> (Option<u32>, Option<Ipv4Addr>, SocketAddr) {
322        match self {
323            MulticastOpts::V4 {
324                interface_addr,
325                mcast_addr,
326            } => (None, Some(*interface_addr), (*mcast_addr).into()),
327            MulticastOpts::V6 {
328                interface_id,
329                mcast_addr,
330                ..
331            } => (Some(*interface_id), None, (*mcast_addr).into()),
332        }
333    }
334}