librqbit_dualstack_sockets/
multicast.rs1#[cfg(test)]
2mod tests;
3
4use std::{
5 future::poll_fn,
6 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7 task::Poll,
8};
9
10use network_interface::{NetworkInterface, NetworkInterfaceConfig};
11use socket2::SockRef;
12use tracing::{debug, trace};
13
14use crate::{
15 BindDevice, BindOpts, Error, UdpSocket,
16 addr::{Ipv6AddrExt, WithScopeId},
17};
18
19pub struct MulticastUdpSocket {
22 sock: UdpSocket,
23 ipv4_addr: SocketAddrV4,
24 ipv6_site_local: SocketAddrV6,
25 ipv6_link_local: Option<SocketAddrV6>,
26 nics: Vec<NetworkInterface>,
27}
28
29impl MulticastUdpSocket {
30 pub async fn new(
31 bind_addr: SocketAddr,
32 ipv4_mcast_addr: SocketAddrV4,
33 ipv6_site_local_addr: SocketAddrV6,
34 ipv6_link_local_addr: Option<SocketAddrV6>,
35 bind_device: Option<&BindDevice>,
36 ) -> crate::Result<Self> {
37 if let Some(ll) = ipv6_link_local_addr
38 && !ll.ip().is_link_local_mcast()
39 {
40 return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
41 }
42 if !ipv6_site_local_addr.ip().is_site_local_mcast() {
43 return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
44 }
45 let nics = network_interface::NetworkInterface::show()
46 .into_iter()
47 .flatten()
48 .filter(|nic| bind_device.is_none_or(|bd| bd.index().get() == nic.index))
49 .collect::<Vec<_>>();
50 if nics.is_empty() {
51 return Err(Error::NoNics);
52 }
53 let opts = BindOpts {
54 request_dualstack: true,
55 reuseport: true,
56 device: bind_device,
57 };
58 let sock = UdpSocket::bind_udp(bind_addr, opts)?;
59 let sock = Self {
60 sock,
61 ipv4_addr: ipv4_mcast_addr,
62 ipv6_link_local: ipv6_link_local_addr,
63 ipv6_site_local: ipv6_site_local_addr,
64 nics,
65 };
66 sock.bind_multicast().await?;
67 Ok(sock)
68 }
69
70 pub fn nics(&self) -> &[NetworkInterface] {
71 &self.nics
72 }
73
74 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
75 self.sock.recv_from(buf).await
76 }
77
78 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
79 poll_fn(|cx| {
81 let sref = SockRef::from(self.sock.socket());
82 if self.sock.bind_addr().is_ipv6() {
83 if let Err(e) = sref.set_multicast_if_v6(0) {
84 trace!("error calling set_multicast_if_v6(0): {e:#}")
85 }
86 } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
87 trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
88 }
89
90 self.sock.poll_send_to(cx, buf, addr)
91 })
92 .await
93 }
94
95 async fn bind_multicast(&self) -> crate::Result<()> {
96 let mut joined = false;
97 if self.sock.bind_addr().is_ipv4() {
98 joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
99 }
100
101 for nic in self.nics.iter() {
102 let mut has_link_local = false;
103 let mut has_site_local = false;
104
105 for addr in nic.addr.iter() {
106 match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
107 (IpAddr::V4(iface_addr), is_ipv6)
108 if iface_addr.is_private() || iface_addr.is_loopback() =>
109 {
110 let is_linux_or_windows =
111 cfg!(any(target_os = "linux", target_os = "windows"));
112 if !is_ipv6 || is_linux_or_windows {
113 joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
114 } else {
115 joined |= try_join_v6(
116 &self.sock,
117 self.ipv4_addr.ip().to_ipv6_mapped(),
118 nic.index,
119 )
120 }
121 }
122 (IpAddr::V6(addr), true) => {
123 if addr.is_unicast_link_local() {
124 has_link_local = true;
125 } else {
126 has_site_local = true;
127 }
128 }
129 _ => continue,
130 }
131 }
132
133 if has_site_local {
134 joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
135 }
136
137 if let Some(ll) = self.ipv6_link_local
138 && has_link_local
139 {
140 joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
141 }
142 }
143
144 if !joined {
145 return Err(Error::MulticastJoinFail);
146 }
147
148 self.sock
149 .socket()
150 .writable()
151 .await
152 .map_err(Error::Writeable)?;
153
154 Ok(())
155 }
156
157 pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
158 self.nics()
159 .iter()
160 .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
161 .find_map(|(nic, naddr)| {
162 let nm = naddr.netmask();
163 let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
164 (SocketAddr::V6(addr), _, _, Some(mlocal))
167 if addr.ip().is_unicast_link_local() =>
168 {
169 if nic.index != addr.scope_id() {
170 return None;
171 }
172 mlocal.with_scope_id(nic.index).into()
173 }
174
175 (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
177 if addr.ip().is_unique_local()
178 && addr.ip().to_bits() & mask.to_bits()
179 == naddr.to_bits() & mask.to_bits() =>
180 {
181 self.ipv6_site_local.into()
182 }
183
184 (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
186 if addr.ip().to_bits() & mask.to_bits()
187 == naddr.to_bits() & mask.to_bits() =>
188 {
189 self.ipv4_addr.into()
190 }
191 _ => return None,
192 };
193 Some(MulticastOpts {
194 interface_id: nic.index,
195 interface_addr: naddr.ip(),
196 mcast_addr,
197 })
198 })
199 }
200
201 pub async fn send_multicast_msg(
202 &self,
203 buf: &[u8],
204 opts: &MulticastOpts,
205 ) -> crate::Result<usize> {
206 poll_fn(|cx| {
209 let sref = SockRef::from(self.sock.socket());
210 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
211 let is_linux = cfg!(target_os = "linux");
212
213 match (opts.mcast_addr(), opts.iface_ip(), bind_is_ipv6, is_linux) {
216 (SocketAddr::V4(_), IpAddr::V4(addr), _, true)
219 | (SocketAddr::V4(_), IpAddr::V4(addr), false, false) => {
220 sref.set_multicast_if_v4(&addr)
221 .map_err(Error::SetMulticastIpv4)?;
222 }
223 (SocketAddr::V6(_), IpAddr::V6(_), _, _)
224 | (SocketAddr::V4(_), IpAddr::V4(_), _, _) => {
225 sref.set_multicast_if_v6(opts.interface_id)
226 .map_err(Error::SetMulticastIpv6)?;
227 }
228 _ => return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch)),
229 }
230
231 self.sock
232 .poll_send_to(cx, buf, opts.mcast_addr)
233 .map_err(Error::Send)
234 })
235 .await
236 }
237
238 pub async fn try_send_mcast_everywhere(
239 &self,
240 get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
241 ) {
242 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
243
244 let mut send_specs = self
245 .nics
246 .iter()
247 .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
248 .filter_map(|(ifidx, ifaddr)| {
249 let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
250 (_, IpAddr::V4(a), _) if a.is_private() || a.is_loopback() => {
251 self.ipv4_addr.into()
252 }
253 (true, IpAddr::V6(a), Some(mlocal)) if a.is_unicast_link_local() => {
254 mlocal.with_scope_id(ifidx).into()
255 }
256 (true, IpAddr::V6(a), _) if a.is_unique_local() => self.ipv6_site_local.into(),
257 _ => {
258 trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
259 return None;
260 }
261 };
262 Some(MulticastOpts {
263 interface_id: ifidx,
264 interface_addr: ifaddr,
265 mcast_addr,
266 })
267 })
268 .collect::<Vec<_>>();
269
270 send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
271 send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
272
273 let futs = send_specs.into_iter().filter_map(|opts| {
274 let payload = get_payload(&opts)?;
275 let fut = async move {
276 match self.send_multicast_msg(payload.as_bytes(), &opts).await {
277 Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
278 Err(e) => {
279 debug!(?opts, payload=?payload, "error sending: {e:#}")
280 }
281 };
282 };
283 Some(fut)
284 });
285
286 futures::future::join_all(futs).await;
287 }
288}
289
290fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
291 trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
292 if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
293 debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
294 return false;
295 }
296 true
297}
298
299fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
300 trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
301 if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
302 debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
303 return false;
304 }
305 true
306}
307
308#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
309pub struct MulticastOpts {
310 pub interface_id: u32,
311 pub interface_addr: IpAddr,
312 pub mcast_addr: SocketAddr,
313}
314
315impl MulticastOpts {
316 pub fn iface_ip(&self) -> IpAddr {
317 self.interface_addr
318 }
319
320 pub fn mcast_addr(&self) -> SocketAddr {
321 self.mcast_addr
322 }
323
324 fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
325 if bind_addr_is_ipv6 {
326 (Some(self.interface_id), None, self.mcast_addr)
327 } else {
328 (None, Some(self.interface_addr), self.mcast_addr)
329 }
330 }
331}