librqbit_dualstack_sockets/
multicast.rs

1#[cfg(test)]
2mod tests;
3
4use std::{
5    collections::HashSet,
6    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7    sync::{Arc, Mutex},
8    task::Poll,
9};
10
11use network_interface::{NetworkInterface, NetworkInterfaceConfig};
12use parking_lot::RwLock;
13use socket2::SockRef;
14use tracing::{debug, trace};
15
16use crate::{BindOpts, Error, UdpSocket};
17
18/// An IPv6 + IPv4 multicast socket that sends payloads generated by user callbacks to all
19/// interfaces.
20pub struct MulticastUdpSocket {
21    // At least on OSX, it multicast doesn't seem to work on dualstack sockets, so we need
22    // to create 2 of them.
23    sock_v4: UdpSocket,
24    sock_v6: UdpSocket,
25    ipv4_addr: Ipv4Addr,
26    ipv6_site_local: Ipv6Addr,
27    ipv6_link_local: Option<Ipv6Addr>,
28    nics: Vec<NetworkInterface>,
29}
30
31impl MulticastUdpSocket {
32    pub fn new(
33        port: u16,
34        ipv4_addr: Ipv4Addr,
35        ipv6_site_local: Ipv6Addr,
36        ipv6_link_local: Option<Ipv6Addr>,
37    ) -> crate::Result<Self> {
38        if let Some(ll) = ipv6_link_local {
39            if !ipv6_is_link_local_mcast(ll) {
40                return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
41            }
42        }
43        if !ipv6_is_site_local_mcast(ipv6_site_local) {
44            return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
45        }
46        let nics = network_interface::NetworkInterface::show()
47            .into_iter()
48            .flatten()
49            .collect::<Vec<_>>();
50        if nics.is_empty() {
51            return Err(Error::NoNics);
52        }
53        let opts = BindOpts {
54            request_dualstack: false,
55            reuseport: true,
56        };
57        let sock_v4 = UdpSocket::bind_udp((Ipv4Addr::UNSPECIFIED, port).into(), opts)?;
58        let sock_v6 = UdpSocket::bind_udp((Ipv6Addr::UNSPECIFIED, port).into(), opts)?;
59        let sock = Self {
60            sock_v4,
61            sock_v6,
62            ipv4_addr,
63            ipv6_link_local,
64            ipv6_site_local,
65            nics,
66        };
67        sock.bind_multicast()?;
68        Ok(sock)
69    }
70
71    pub fn nics(&self) -> &[NetworkInterface] {
72        &self.nics
73    }
74
75    pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
76        std::future::poll_fn(|cx| {
77            let mut buf = tokio::io::ReadBuf::new(buf);
78            if let Poll::Ready(res) = self.sock_v4.socket().poll_recv_from(cx, &mut buf) {
79                return Poll::Ready(res.map(|addr| (buf.filled().len(), addr)));
80            }
81            if let Poll::Ready(res) = self.sock_v6.socket().poll_recv_from(cx, &mut buf) {
82                return Poll::Ready(res.map(|addr| (buf.filled().len(), addr)));
83            }
84            Poll::Pending
85        })
86        .await
87    }
88
89    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
90        let sock = if addr.is_ipv6() {
91            &self.sock_v6
92        } else {
93            &self.sock_v4
94        };
95        sock.send_to(buf, addr).await
96    }
97
98    fn bind_multicast(&self) -> crate::Result<()> {
99        let mut joined = try_join_v4(&self.sock_v4, self.ipv4_addr, Ipv4Addr::UNSPECIFIED);
100
101        for nic in self.nics.iter() {
102            let mut has_link_local = false;
103            let mut has_site_local = false;
104
105            for addr in nic.addr.iter() {
106                match addr.ip() {
107                    IpAddr::V4(iface_addr)
108                        if iface_addr.is_private() && !iface_addr.is_loopback() =>
109                    {
110                        joined |= try_join_v4(&self.sock_v4, self.ipv4_addr, iface_addr);
111                    }
112                    IpAddr::V6(addr) => {
113                        if addr.is_loopback() {
114                            continue;
115                        }
116                        if ipv6_is_link_local(addr) {
117                            has_link_local = true;
118                        } else {
119                            has_site_local = true;
120                        }
121                    }
122                    _ => continue,
123                }
124            }
125
126            if has_site_local {
127                joined |= try_join_v6(&self.sock_v6, self.ipv6_site_local, nic.index);
128            }
129
130            if let Some(ll) = self.ipv6_link_local {
131                if has_link_local {
132                    joined |= try_join_v6(&self.sock_v6, ll, nic.index);
133                }
134            }
135        }
136
137        if !joined {
138            return Err(Error::MulticastJoinFail);
139        }
140
141        Ok(())
142    }
143
144    async fn send_to_once(&self, buf: &[u8], opts: &MulticastOpts) -> std::io::Result<usize> {
145        // This is .poll_fn() so that we call .set_multicast() immediately before sending a packet.
146        // If it's repolled it'll get called again just before the send.
147
148        std::future::poll_fn(|cx| {
149            let sock;
150            let mcast_addr_s: SocketAddr;
151
152            match opts {
153                MulticastOpts::V4 {
154                    interface_addr,
155                    mcast_addr,
156                } => {
157                    sock = &self.sock_v4;
158                    mcast_addr_s = (*mcast_addr).into();
159                    if let Err(e) = SockRef::from(sock.socket()).set_multicast_if_v4(interface_addr)
160                    {
161                        debug!(addr=%interface_addr, "error calling set_multicast_if_v4: {e:#}");
162                        return Poll::Ready(Err(e));
163                    }
164                }
165                MulticastOpts::V6 {
166                    interface_id,
167                    mcast_addr,
168                    ..
169                } => {
170                    sock = &self.sock_v6;
171                    mcast_addr_s = (*mcast_addr).into();
172                    if let Err(e) = SockRef::from(sock.socket()).set_multicast_if_v6(*interface_id)
173                    {
174                        debug!(
175                            oif_id = interface_id,
176                            "error calling set_multicast_if_v6: {e:#}"
177                        );
178                        return Poll::Ready(Err(e));
179                    }
180                }
181            }
182
183            sock.poll_send_to(cx, buf, mcast_addr_s)
184        })
185        .await
186    }
187
188    pub async fn try_send_mcast_everywhere(
189        &self,
190        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
191    ) {
192        // Without this it blocks for some reason. Maybe we need to do it once in new(), so that all multicast joining
193        // messages are actually sent?
194        //
195        // It also works if we call .send_to() vs .poll_send_to() underneath. Maybe a bug in tokio/mio or I'm just
196        // misusing it.
197        let _ = self.sock_v6.socket().writable().await;
198
199        let sent = Mutex::new(HashSet::new());
200        let sent = &sent;
201
202        let port = self.sock_v4.bind_addr().port();
203
204        let futs = self
205            .nics
206            .iter()
207            .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
208            .filter_map(|(ifidx, ifaddr)| {
209                let ipv6_link_local = self
210                    .ipv6_link_local
211                    .filter(|_| matches!(ifaddr, IpAddr::V6(v6) if ipv6_is_link_local(v6)));
212                let opts = match (ifaddr, ipv6_link_local) {
213                    (IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => MulticastOpts::V4 {
214                        interface_addr: a,
215                        mcast_addr: SocketAddrV4::new(self.ipv4_addr, port),
216                    },
217                    (IpAddr::V6(a), Some(mlocal)) if !a.is_loopback() => MulticastOpts::V6 {
218                        interface_id: ifidx,
219                        interface_addr: a,
220                        mcast_addr: SocketAddrV6::new(mlocal, port, 0, ifidx),
221                    },
222                    (IpAddr::V6(a), None) if !a.is_loopback() => MulticastOpts::V6 {
223                        interface_id: ifidx,
224                        interface_addr: a,
225                        mcast_addr: SocketAddrV6::new(self.ipv6_site_local, port, 0, ifidx),
226                    },
227                    _ => {
228                        trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
229                        return None;
230                    }
231                };
232                Some(opts)
233            })
234            .filter_map(|opts| {
235                let payload = get_payload(&opts)?;
236                let fut = async move {
237                    if !sent
238                        .lock()
239                        .unwrap()
240                        .insert((payload.clone(), opts.uniq_key()))
241                    {
242                        trace!(?opts, "not sending duplicate payload");
243                        return;
244                    }
245
246                    match self.send_to_once(payload.as_bytes(), &opts).await {
247                        Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
248                        Err(e) => {
249                            debug!(?opts, payload=?payload, "error sending: {e:#}")
250                        }
251                    };
252                };
253                Some(fut)
254            });
255
256        futures::future::join_all(futs).await;
257    }
258}
259
260fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
261    trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
262    if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
263        debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
264        return false;
265    }
266    true
267}
268
269fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
270    trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
271    if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
272        debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
273        return false;
274    }
275    true
276}
277
278fn ipv6_is_link_local(ip: Ipv6Addr) -> bool {
279    const LL: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0);
280    const MASK: Ipv6Addr = Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
281
282    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
283}
284
285fn ipv6_is_link_local_mcast(ip: Ipv6Addr) -> bool {
286    const LL: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0);
287    const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
288
289    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
290}
291
292fn ipv6_is_site_local_mcast(ip: Ipv6Addr) -> bool {
293    const LL: Ipv6Addr = Ipv6Addr::new(0xff05, 0, 0, 0, 0, 0, 0, 0);
294    const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
295
296    ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
297}
298
299#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
300pub enum MulticastOpts {
301    V4 {
302        interface_addr: Ipv4Addr,
303        mcast_addr: SocketAddrV4,
304    },
305    V6 {
306        interface_id: u32,
307        interface_addr: Ipv6Addr,
308        mcast_addr: SocketAddrV6,
309    },
310}
311
312impl MulticastOpts {
313    pub fn iface_ip(&self) -> IpAddr {
314        match self {
315            MulticastOpts::V4 { interface_addr, .. } => (*interface_addr).into(),
316            MulticastOpts::V6 { interface_addr, .. } => (*interface_addr).into(),
317        }
318    }
319
320    pub fn mcast_addr(&self) -> SocketAddr {
321        match self {
322            MulticastOpts::V4 { mcast_addr, .. } => (*mcast_addr).into(),
323            MulticastOpts::V6 { mcast_addr, .. } => (*mcast_addr).into(),
324        }
325    }
326
327    fn uniq_key(&self) -> (Option<u32>, Option<Ipv4Addr>, SocketAddr) {
328        match self {
329            MulticastOpts::V4 {
330                interface_addr,
331                mcast_addr,
332            } => (None, Some(*interface_addr), (*mcast_addr).into()),
333            MulticastOpts::V6 {
334                interface_id,
335                mcast_addr,
336                ..
337            } => (Some(*interface_id), None, (*mcast_addr).into()),
338        }
339    }
340}
341
342pub type HandlerFn = dyn Fn(&[u8], SocketAddr) + Send + Sync + 'static;
343pub type Handler = Box<HandlerFn>;
344
345/// A multicast socket with shared .recv_from(). It'll call all subscribed handlers.
346/// You MUST spawn task_listen_forever().
347pub struct SharedMulticastUdpSocket {
348    sock: MulticastUdpSocket,
349    handlers: RwLock<Vec<Handler>>,
350}
351
352impl SharedMulticastUdpSocket {
353    pub fn new(sock: MulticastUdpSocket) -> crate::Result<Arc<Self>> {
354        let sock = Arc::new(Self {
355            sock,
356            handlers: Default::default(),
357        });
358        Ok(sock)
359    }
360
361    pub fn add_handler(self: &Arc<Self>, handler: Handler) {
362        self.handlers.write().push(handler);
363    }
364
365    pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
366        self.sock.send_to(buf, addr).await
367    }
368
369    pub async fn task_listen_forever(self: Arc<Self>) -> std::io::Result<()> {
370        let mut buf = [0u8; 4096];
371        loop {
372            let (sz, addr) = self.sock.recv_from(&mut buf).await?;
373            for handler in self.handlers.read().iter() {
374                handler(&buf[..sz], addr);
375            }
376        }
377    }
378
379    pub async fn try_send_mcast_everywhere(
380        &self,
381        get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
382    ) {
383        self.sock.try_send_mcast_everywhere(get_payload).await
384    }
385}