librqbit_dualstack_sockets/
multicast.rs1#[cfg(test)]
2mod tests;
3
4use std::{
5 future::poll_fn,
6 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7 task::Poll,
8};
9
10use network_interface::{NetworkInterface, NetworkInterfaceConfig};
11use socket2::SockRef;
12use tracing::{debug, trace};
13
14use crate::{
15 BindOpts, Error, UdpSocket,
16 addr::{Ipv6AddrExt, WithScopeId},
17};
18
19pub struct MulticastUdpSocket {
22 sock: UdpSocket,
23 ipv4_addr: SocketAddrV4,
24 ipv6_site_local: SocketAddrV6,
25 ipv6_link_local: Option<SocketAddrV6>,
26 nics: Vec<NetworkInterface>,
27}
28
29impl MulticastUdpSocket {
30 pub async fn new(
31 bind_addr: SocketAddr,
32 ipv4_mcast_addr: SocketAddrV4,
33 ipv6_site_local_addr: SocketAddrV6,
34 ipv6_link_local_addr: Option<SocketAddrV6>,
35 ) -> crate::Result<Self> {
36 if let Some(ll) = ipv6_link_local_addr {
37 if !ll.ip().is_link_local_mcast() {
38 return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
39 }
40 }
41 if !ipv6_site_local_addr.ip().is_site_local_mcast() {
42 return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
43 }
44 let nics = network_interface::NetworkInterface::show()
45 .into_iter()
46 .flatten()
47 .collect::<Vec<_>>();
48 if nics.is_empty() {
49 return Err(Error::NoNics);
50 }
51 let opts = BindOpts {
52 request_dualstack: true,
53 reuseport: true,
54 };
55 let sock = UdpSocket::bind_udp(bind_addr, opts)?;
56 let sock = Self {
57 sock,
58 ipv4_addr: ipv4_mcast_addr,
59 ipv6_link_local: ipv6_link_local_addr,
60 ipv6_site_local: ipv6_site_local_addr,
61 nics,
62 };
63 sock.bind_multicast().await?;
64 Ok(sock)
65 }
66
67 pub fn nics(&self) -> &[NetworkInterface] {
68 &self.nics
69 }
70
71 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
72 self.sock.recv_from(buf).await
73 }
74
75 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
76 poll_fn(|cx| {
78 let sref = SockRef::from(self.sock.socket());
79 if self.sock.bind_addr().is_ipv6() {
80 if let Err(e) = sref.set_multicast_if_v6(0) {
81 trace!("error calling set_multicast_if_v6(0): {e:#}")
82 }
83 } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
84 trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
85 }
86
87 self.sock.poll_send_to(cx, buf, addr)
88 })
89 .await
90 }
91
92 async fn bind_multicast(&self) -> crate::Result<()> {
93 let mut joined = false;
94 if self.sock.bind_addr().is_ipv4() {
95 joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
96 }
97
98 for nic in self.nics.iter() {
99 let mut has_link_local = false;
100 let mut has_site_local = false;
101
102 for addr in nic.addr.iter() {
103 match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
104 (IpAddr::V4(iface_addr), is_ipv6)
105 if iface_addr.is_private() && !iface_addr.is_loopback() =>
106 {
107 let is_linux = cfg!(target_os = "linux");
108 if !is_ipv6 || is_linux {
109 joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
110 } else {
111 joined |= try_join_v6(
112 &self.sock,
113 self.ipv4_addr.ip().to_ipv6_mapped(),
114 nic.index,
115 )
116 }
117 }
118 (IpAddr::V6(addr), true) => {
119 if addr.is_loopback() {
120 continue;
121 }
122 if addr.is_unicast_link_local() {
123 has_link_local = true;
124 } else {
125 has_site_local = true;
126 }
127 }
128 _ => continue,
129 }
130 }
131
132 if has_site_local {
133 joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
134 }
135
136 if let Some(ll) = self.ipv6_link_local {
137 if has_link_local {
138 joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
139 }
140 }
141 }
142
143 if !joined {
144 return Err(Error::MulticastJoinFail);
145 }
146
147 self.sock
148 .socket()
149 .writable()
150 .await
151 .map_err(Error::Writeable)?;
152
153 Ok(())
154 }
155
156 pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
157 self.nics()
158 .iter()
159 .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
160 .find_map(|(nic, naddr)| {
161 let nm = naddr.netmask();
162 let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
163 (SocketAddr::V6(addr), _, _, Some(mlocal))
166 if addr.ip().is_unicast_link_local() =>
167 {
168 if nic.index != addr.scope_id() {
169 return None;
170 }
171 mlocal.with_scope_id(nic.index).into()
172 }
173
174 (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
176 if addr.ip().is_unique_local()
177 && addr.ip().to_bits() & mask.to_bits()
178 == naddr.to_bits() & mask.to_bits() =>
179 {
180 self.ipv6_site_local.into()
181 }
182
183 (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
185 if addr.ip().to_bits() & mask.to_bits()
186 == naddr.to_bits() & mask.to_bits() =>
187 {
188 self.ipv4_addr.into()
189 }
190 _ => return None,
191 };
192 Some(MulticastOpts {
193 interface_id: nic.index,
194 interface_addr: naddr.ip(),
195 mcast_addr,
196 })
197 })
198 }
199
200 pub async fn send_multicast_msg(
201 &self,
202 buf: &[u8],
203 opts: &MulticastOpts,
204 ) -> crate::Result<usize> {
205 poll_fn(|cx| {
208 let sref = SockRef::from(self.sock.socket());
209 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
210 let is_linux = cfg!(target_os = "linux");
211
212 match (opts.mcast_addr(), opts.iface_ip(), bind_is_ipv6, is_linux) {
215 (SocketAddr::V4(_), IpAddr::V4(addr), _, true)
218 | (SocketAddr::V4(_), IpAddr::V4(addr), false, false) => {
219 sref.set_multicast_if_v4(&addr)
220 .map_err(Error::SetMulticastIpv4)?;
221 }
222 (SocketAddr::V6(_), IpAddr::V6(_), _, _) => {
223 sref.set_multicast_if_v6(opts.interface_id)
224 .map_err(Error::SetMulticastIpv6)?;
225 }
226 _ => return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch)),
227 }
228
229 self.sock
230 .poll_send_to(cx, buf, opts.mcast_addr)
231 .map_err(Error::Send)
232 })
233 .await
234 }
235
236 pub async fn try_send_mcast_everywhere(
237 &self,
238 get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
239 ) {
240 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
241
242 let mut send_specs = self
243 .nics
244 .iter()
245 .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
246 .filter_map(|(ifidx, ifaddr)| {
247 let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
248 (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
249 self.ipv4_addr.into()
250 }
251 (true, IpAddr::V6(a), Some(mlocal))
252 if !a.is_loopback() && a.is_unicast_link_local() =>
253 {
254 mlocal.with_scope_id(ifidx).into()
255 }
256 (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
257 self.ipv6_site_local.into()
258 }
259 _ => {
260 trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
261 return None;
262 }
263 };
264 Some(MulticastOpts {
265 interface_id: ifidx,
266 interface_addr: ifaddr,
267 mcast_addr,
268 })
269 })
270 .collect::<Vec<_>>();
271
272 send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
273 send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
274
275 let futs = send_specs.into_iter().filter_map(|opts| {
276 let payload = get_payload(&opts)?;
277 let fut = async move {
278 match self.send_multicast_msg(payload.as_bytes(), &opts).await {
279 Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
280 Err(e) => {
281 debug!(?opts, payload=?payload, "error sending: {e:#}")
282 }
283 };
284 };
285 Some(fut)
286 });
287
288 futures::future::join_all(futs).await;
289 }
290}
291
292fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
293 trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
294 if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
295 debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
296 return false;
297 }
298 true
299}
300
301fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
302 trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
303 if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
304 debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
305 return false;
306 }
307 true
308}
309
310#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
311pub struct MulticastOpts {
312 pub interface_id: u32,
313 pub interface_addr: IpAddr,
314 pub mcast_addr: SocketAddr,
315}
316
317impl MulticastOpts {
318 pub fn iface_ip(&self) -> IpAddr {
319 self.interface_addr
320 }
321
322 pub fn mcast_addr(&self) -> SocketAddr {
323 self.mcast_addr
324 }
325
326 fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
327 if bind_addr_is_ipv6 {
328 (Some(self.interface_id), None, self.mcast_addr)
329 } else {
330 (None, Some(self.interface_addr), self.mcast_addr)
331 }
332 }
333}