librqbit_dualstack_sockets/
multicast.rs1#[cfg(test)]
2mod tests;
3
4use std::{
5 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6 task::Poll,
7};
8
9use network_interface::{NetworkInterface, NetworkInterfaceConfig};
10use socket2::SockRef;
11use tracing::{debug, trace};
12
13use crate::{
14 BindOpts, Error, UdpSocket,
15 addr::{Ipv6AddrExt, WithScopeId},
16};
17
18pub struct MulticastUdpSocket {
21 sock: UdpSocket,
22 ipv4_addr: SocketAddrV4,
23 ipv6_site_local: SocketAddrV6,
24 ipv6_link_local: Option<SocketAddrV6>,
25 nics: Vec<NetworkInterface>,
26}
27
28impl MulticastUdpSocket {
29 pub async fn new(
30 bind_addr: SocketAddr,
31 ipv4_mcast_addr: SocketAddrV4,
32 ipv6_site_local_addr: SocketAddrV6,
33 ipv6_link_local_addr: Option<SocketAddrV6>,
34 ) -> crate::Result<Self> {
35 if let Some(ll) = ipv6_link_local_addr {
36 if !ll.ip().is_link_local_mcast() {
37 return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
38 }
39 }
40 if !ipv6_site_local_addr.ip().is_site_local_mcast() {
41 return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
42 }
43 let nics = network_interface::NetworkInterface::show()
44 .into_iter()
45 .flatten()
46 .collect::<Vec<_>>();
47 if nics.is_empty() {
48 return Err(Error::NoNics);
49 }
50 let opts = BindOpts {
51 request_dualstack: true,
52 reuseport: true,
53 };
54 let sock = UdpSocket::bind_udp(bind_addr, opts)?;
55 let sock = Self {
56 sock,
57 ipv4_addr: ipv4_mcast_addr,
58 ipv6_link_local: ipv6_link_local_addr,
59 ipv6_site_local: ipv6_site_local_addr,
60 nics,
61 };
62 sock.bind_multicast().await?;
63 Ok(sock)
64 }
65
66 pub fn nics(&self) -> &[NetworkInterface] {
67 &self.nics
68 }
69
70 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
71 self.sock.recv_from(buf).await
72 }
73
74 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
75 std::future::poll_fn(|cx| {
77 let sref = SockRef::from(self.sock.socket());
78 if self.sock.bind_addr().is_ipv6() {
79 if let Err(e) = sref.set_multicast_if_v6(0) {
80 trace!("error calling set_multicast_if_v6(0): {e:#}")
81 }
82 } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
83 trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
84 }
85
86 self.sock.poll_send_to(cx, buf, addr)
87 })
88 .await
89 }
90
91 async fn bind_multicast(&self) -> crate::Result<()> {
92 let mut joined = false;
93 if self.sock.bind_addr().is_ipv4() {
94 joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
95 }
96
97 for nic in self.nics.iter() {
98 let mut has_link_local = false;
99 let mut has_site_local = false;
100
101 for addr in nic.addr.iter() {
102 match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
103 (IpAddr::V4(iface_addr), is_ipv6)
104 if iface_addr.is_private() && !iface_addr.is_loopback() =>
105 {
106 let is_linux = cfg!(target_os = "linux");
107 if !is_ipv6 || is_linux {
108 joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
109 } else {
110 joined |= try_join_v6(
111 &self.sock,
112 self.ipv4_addr.ip().to_ipv6_mapped(),
113 nic.index,
114 )
115 }
116 }
117 (IpAddr::V6(addr), true) => {
118 if addr.is_loopback() {
119 continue;
120 }
121 if addr.is_unicast_link_local() {
122 has_link_local = true;
123 } else {
124 has_site_local = true;
125 }
126 }
127 _ => continue,
128 }
129 }
130
131 if has_site_local {
132 joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
133 }
134
135 if let Some(ll) = self.ipv6_link_local {
136 if has_link_local {
137 joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
138 }
139 }
140 }
141
142 if !joined {
143 return Err(Error::MulticastJoinFail);
144 }
145
146 self.sock
147 .socket()
148 .writable()
149 .await
150 .map_err(Error::Writeable)?;
151
152 Ok(())
153 }
154
155 pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
156 self.nics()
157 .iter()
158 .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
159 .find_map(|(nic, naddr)| {
160 let nm = naddr.netmask();
161 let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
162 (SocketAddr::V6(addr), _, _, Some(mlocal))
165 if addr.ip().is_unicast_link_local() =>
166 {
167 if nic.index != addr.scope_id() {
168 return None;
169 }
170 mlocal.with_scope_id(nic.index).into()
171 }
172
173 (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
175 if addr.ip().is_unique_local()
176 && addr.ip().to_bits() & mask.to_bits()
177 == naddr.to_bits() & mask.to_bits() =>
178 {
179 self.ipv6_site_local.into()
180 }
181
182 (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
184 if addr.ip().to_bits() & mask.to_bits()
185 == naddr.to_bits() & mask.to_bits() =>
186 {
187 self.ipv4_addr.into()
188 }
189 _ => return None,
190 };
191 Some(MulticastOpts {
192 interface_id: nic.index,
193 interface_addr: naddr.ip(),
194 mcast_addr,
195 })
196 })
197 }
198
199 pub async fn send_multicast_msg(
200 &self,
201 buf: &[u8],
202 opts: &MulticastOpts,
203 ) -> crate::Result<usize> {
204 std::future::poll_fn(|cx| {
208 let sref = SockRef::from(self.sock.socket());
209 if self.sock.bind_addr().is_ipv6() {
210 if let Err(e) = sref.set_multicast_if_v6(opts.interface_id) {
211 return Poll::Ready(Err(Error::SetMulticastIpv6(e)));
212 }
213 } else {
214 let ifaddr = match opts.interface_addr {
215 IpAddr::V4(ipv4_addr) => ipv4_addr,
216 IpAddr::V6(_) => {
217 return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch));
218 }
219 };
220 if let Err(e) = sref.set_multicast_if_v4(&ifaddr) {
221 return Poll::Ready(Err(Error::SetMulticastIpv4(e)));
222 }
223 };
224
225 self.sock
226 .poll_send_to(cx, buf, opts.mcast_addr)
227 .map_err(Error::Send)
228 })
229 .await
230 }
231
232 pub async fn try_send_mcast_everywhere(
233 &self,
234 get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
235 ) {
236 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
237
238 let mut send_specs = self
239 .nics
240 .iter()
241 .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
242 .filter_map(|(ifidx, ifaddr)| {
243 let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
244 (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
245 self.ipv4_addr.into()
246 }
247 (true, IpAddr::V6(a), Some(mlocal))
248 if !a.is_loopback() && a.is_unicast_link_local() =>
249 {
250 mlocal.with_scope_id(ifidx).into()
251 }
252 (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
253 self.ipv6_site_local.into()
254 }
255 _ => {
256 trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
257 return None;
258 }
259 };
260 Some(MulticastOpts {
261 interface_id: ifidx,
262 interface_addr: ifaddr,
263 mcast_addr,
264 })
265 })
266 .collect::<Vec<_>>();
267
268 send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
269 send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
270
271 let futs = send_specs.into_iter().filter_map(|opts| {
272 let payload = get_payload(&opts)?;
273 let fut = async move {
274 match self.send_multicast_msg(payload.as_bytes(), &opts).await {
275 Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
276 Err(e) => {
277 debug!(?opts, payload=?payload, "error sending: {e:#}")
278 }
279 };
280 };
281 Some(fut)
282 });
283
284 futures::future::join_all(futs).await;
285 }
286}
287
288fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
289 trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
290 if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
291 debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
292 return false;
293 }
294 true
295}
296
297fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
298 trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
299 if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
300 debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
301 return false;
302 }
303 true
304}
305
306#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
307pub struct MulticastOpts {
308 pub interface_id: u32,
309 pub interface_addr: IpAddr,
310 pub mcast_addr: SocketAddr,
311}
312
313impl MulticastOpts {
314 pub fn iface_ip(&self) -> IpAddr {
315 self.interface_addr
316 }
317
318 pub fn mcast_addr(&self) -> SocketAddr {
319 self.mcast_addr
320 }
321
322 fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
323 if bind_addr_is_ipv6 {
324 (Some(self.interface_id), None, self.mcast_addr)
325 } else {
326 (None, Some(self.interface_addr), self.mcast_addr)
327 }
328 }
329}