librqbit_dualstack_sockets/
multicast.rs1#[cfg(test)]
2mod tests;
3
4use std::{
5 future::poll_fn,
6 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7 task::Poll,
8};
9
10use network_interface::{NetworkInterface, NetworkInterfaceConfig};
11use socket2::SockRef;
12use tracing::{debug, trace};
13
14use crate::{
15 BindOpts, Error, UdpSocket,
16 addr::{Ipv6AddrExt, WithScopeId},
17};
18
19pub struct MulticastUdpSocket {
22 sock: UdpSocket,
23 ipv4_addr: SocketAddrV4,
24 ipv6_site_local: SocketAddrV6,
25 ipv6_link_local: Option<SocketAddrV6>,
26 nics: Vec<NetworkInterface>,
27}
28
29impl MulticastUdpSocket {
30 pub async fn new(
31 bind_addr: SocketAddr,
32 ipv4_mcast_addr: SocketAddrV4,
33 ipv6_site_local_addr: SocketAddrV6,
34 ipv6_link_local_addr: Option<SocketAddrV6>,
35 ) -> crate::Result<Self> {
36 if let Some(ll) = ipv6_link_local_addr {
37 if !ll.ip().is_link_local_mcast() {
38 return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
39 }
40 }
41 if !ipv6_site_local_addr.ip().is_site_local_mcast() {
42 return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
43 }
44 let nics = network_interface::NetworkInterface::show()
45 .into_iter()
46 .flatten()
47 .collect::<Vec<_>>();
48 if nics.is_empty() {
49 return Err(Error::NoNics);
50 }
51 let opts = BindOpts {
52 request_dualstack: true,
53 reuseport: true,
54 device: None,
55 };
56 let sock = UdpSocket::bind_udp(bind_addr, opts)?;
57 let sock = Self {
58 sock,
59 ipv4_addr: ipv4_mcast_addr,
60 ipv6_link_local: ipv6_link_local_addr,
61 ipv6_site_local: ipv6_site_local_addr,
62 nics,
63 };
64 sock.bind_multicast().await?;
65 Ok(sock)
66 }
67
68 pub fn nics(&self) -> &[NetworkInterface] {
69 &self.nics
70 }
71
72 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
73 self.sock.recv_from(buf).await
74 }
75
76 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
77 poll_fn(|cx| {
79 let sref = SockRef::from(self.sock.socket());
80 if self.sock.bind_addr().is_ipv6() {
81 if let Err(e) = sref.set_multicast_if_v6(0) {
82 trace!("error calling set_multicast_if_v6(0): {e:#}")
83 }
84 } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
85 trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
86 }
87
88 self.sock.poll_send_to(cx, buf, addr)
89 })
90 .await
91 }
92
93 async fn bind_multicast(&self) -> crate::Result<()> {
94 let mut joined = false;
95 if self.sock.bind_addr().is_ipv4() {
96 joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
97 }
98
99 for nic in self.nics.iter() {
100 let mut has_link_local = false;
101 let mut has_site_local = false;
102
103 for addr in nic.addr.iter() {
104 match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
105 (IpAddr::V4(iface_addr), is_ipv6)
106 if iface_addr.is_private() && !iface_addr.is_loopback() =>
107 {
108 let is_linux_or_windows =
109 cfg!(any(target_os = "linux", target_os = "windows"));
110 if !is_ipv6 || is_linux_or_windows {
111 joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
112 } else {
113 joined |= try_join_v6(
114 &self.sock,
115 self.ipv4_addr.ip().to_ipv6_mapped(),
116 nic.index,
117 )
118 }
119 }
120 (IpAddr::V6(addr), true) => {
121 if addr.is_loopback() {
122 continue;
123 }
124 if addr.is_unicast_link_local() {
125 has_link_local = true;
126 } else {
127 has_site_local = true;
128 }
129 }
130 _ => continue,
131 }
132 }
133
134 if has_site_local {
135 joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
136 }
137
138 if let Some(ll) = self.ipv6_link_local {
139 if has_link_local {
140 joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
141 }
142 }
143 }
144
145 if !joined {
146 return Err(Error::MulticastJoinFail);
147 }
148
149 self.sock
150 .socket()
151 .writable()
152 .await
153 .map_err(Error::Writeable)?;
154
155 Ok(())
156 }
157
158 pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
159 self.nics()
160 .iter()
161 .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
162 .find_map(|(nic, naddr)| {
163 let nm = naddr.netmask();
164 let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
165 (SocketAddr::V6(addr), _, _, Some(mlocal))
168 if addr.ip().is_unicast_link_local() =>
169 {
170 if nic.index != addr.scope_id() {
171 return None;
172 }
173 mlocal.with_scope_id(nic.index).into()
174 }
175
176 (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
178 if addr.ip().is_unique_local()
179 && addr.ip().to_bits() & mask.to_bits()
180 == naddr.to_bits() & mask.to_bits() =>
181 {
182 self.ipv6_site_local.into()
183 }
184
185 (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
187 if addr.ip().to_bits() & mask.to_bits()
188 == naddr.to_bits() & mask.to_bits() =>
189 {
190 self.ipv4_addr.into()
191 }
192 _ => return None,
193 };
194 Some(MulticastOpts {
195 interface_id: nic.index,
196 interface_addr: naddr.ip(),
197 mcast_addr,
198 })
199 })
200 }
201
202 pub async fn send_multicast_msg(
203 &self,
204 buf: &[u8],
205 opts: &MulticastOpts,
206 ) -> crate::Result<usize> {
207 poll_fn(|cx| {
210 let sref = SockRef::from(self.sock.socket());
211 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
212 let is_linux = cfg!(target_os = "linux");
213
214 match (opts.mcast_addr(), opts.iface_ip(), bind_is_ipv6, is_linux) {
217 (SocketAddr::V4(_), IpAddr::V4(addr), _, true)
220 | (SocketAddr::V4(_), IpAddr::V4(addr), false, false) => {
221 sref.set_multicast_if_v4(&addr)
222 .map_err(Error::SetMulticastIpv4)?;
223 }
224 (SocketAddr::V6(_), IpAddr::V6(_), _, _)
225 | (SocketAddr::V4(_), IpAddr::V4(_), _, _) => {
226 sref.set_multicast_if_v6(opts.interface_id)
227 .map_err(Error::SetMulticastIpv6)?;
228 }
229 _ => return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch)),
230 }
231
232 self.sock
233 .poll_send_to(cx, buf, opts.mcast_addr)
234 .map_err(Error::Send)
235 })
236 .await
237 }
238
239 pub async fn try_send_mcast_everywhere(
240 &self,
241 get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
242 ) {
243 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
244
245 let mut send_specs = self
246 .nics
247 .iter()
248 .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
249 .filter_map(|(ifidx, ifaddr)| {
250 let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
251 (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
252 self.ipv4_addr.into()
253 }
254 (true, IpAddr::V6(a), Some(mlocal))
255 if !a.is_loopback() && a.is_unicast_link_local() =>
256 {
257 mlocal.with_scope_id(ifidx).into()
258 }
259 (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
260 self.ipv6_site_local.into()
261 }
262 _ => {
263 trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
264 return None;
265 }
266 };
267 Some(MulticastOpts {
268 interface_id: ifidx,
269 interface_addr: ifaddr,
270 mcast_addr,
271 })
272 })
273 .collect::<Vec<_>>();
274
275 send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
276 send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
277
278 let futs = send_specs.into_iter().filter_map(|opts| {
279 let payload = get_payload(&opts)?;
280 let fut = async move {
281 match self.send_multicast_msg(payload.as_bytes(), &opts).await {
282 Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
283 Err(e) => {
284 debug!(?opts, payload=?payload, "error sending: {e:#}")
285 }
286 };
287 };
288 Some(fut)
289 });
290
291 futures::future::join_all(futs).await;
292 }
293}
294
295fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
296 trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
297 if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
298 debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
299 return false;
300 }
301 true
302}
303
304fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
305 trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
306 if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
307 debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
308 return false;
309 }
310 true
311}
312
313#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
314pub struct MulticastOpts {
315 pub interface_id: u32,
316 pub interface_addr: IpAddr,
317 pub mcast_addr: SocketAddr,
318}
319
320impl MulticastOpts {
321 pub fn iface_ip(&self) -> IpAddr {
322 self.interface_addr
323 }
324
325 pub fn mcast_addr(&self) -> SocketAddr {
326 self.mcast_addr
327 }
328
329 fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
330 if bind_addr_is_ipv6 {
331 (Some(self.interface_id), None, self.mcast_addr)
332 } else {
333 (None, Some(self.interface_addr), self.mcast_addr)
334 }
335 }
336}