librqbit_dualstack_sockets/
multicast.rs1#[cfg(test)]
2mod tests;
3
4use std::{
5 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6 task::Poll,
7};
8
9use network_interface::{NetworkInterface, NetworkInterfaceConfig};
10use socket2::SockRef;
11use tracing::{debug, trace};
12
13use crate::{
14 BindOpts, Error, UdpSocket,
15 addr::{Ipv6AddrExt, WithScopeId},
16};
17
18pub struct MulticastUdpSocket {
21 sock: UdpSocket,
22 ipv4_addr: SocketAddrV4,
23 ipv6_site_local: SocketAddrV6,
24 ipv6_link_local: Option<SocketAddrV6>,
25 nics: Vec<NetworkInterface>,
26}
27
28impl MulticastUdpSocket {
29 pub async fn new(
30 bind_addr: SocketAddr,
31 ipv4_mcast_addr: SocketAddrV4,
32 ipv6_site_local_addr: SocketAddrV6,
33 ipv6_link_local_addr: Option<SocketAddrV6>,
34 ) -> crate::Result<Self> {
35 if let Some(ll) = ipv6_link_local_addr {
36 if !ll.ip().is_link_local_mcast() {
37 return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
38 }
39 }
40 if !ipv6_site_local_addr.ip().is_site_local_mcast() {
41 return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
42 }
43 let nics = network_interface::NetworkInterface::show()
44 .into_iter()
45 .flatten()
46 .collect::<Vec<_>>();
47 if nics.is_empty() {
48 return Err(Error::NoNics);
49 }
50 let opts = BindOpts {
51 request_dualstack: true,
52 reuseport: true,
53 };
54 let sock = UdpSocket::bind_udp(bind_addr, opts)?;
55 let sock = Self {
56 sock,
57 ipv4_addr: ipv4_mcast_addr,
58 ipv6_link_local: ipv6_link_local_addr,
59 ipv6_site_local: ipv6_site_local_addr,
60 nics,
61 };
62 sock.bind_multicast().await?;
63 Ok(sock)
64 }
65
66 pub fn nics(&self) -> &[NetworkInterface] {
67 &self.nics
68 }
69
70 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
71 self.sock.recv_from(buf).await
72 }
73
74 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
75 std::future::poll_fn(|cx| {
77 let sref = SockRef::from(self.sock.socket());
78 if self.sock.bind_addr().is_ipv6() {
79 if let Err(e) = sref.set_multicast_if_v6(0) {
80 trace!("error calling set_multicast_if_v6(0): {e:#}")
81 }
82 } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
83 trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
84 }
85
86 self.sock.poll_send_to(cx, buf, addr)
87 })
88 .await
89 }
90
91 async fn bind_multicast(&self) -> crate::Result<()> {
92 let mut joined = false;
93 if self.sock.bind_addr().is_ipv4() {
94 joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
95 }
96
97 for nic in self.nics.iter() {
98 let mut has_link_local = false;
99 let mut has_site_local = false;
100
101 for addr in nic.addr.iter() {
102 match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
103 (IpAddr::V4(iface_addr), is_ipv6)
104 if iface_addr.is_private() && !iface_addr.is_loopback() =>
105 {
106 if !is_ipv6 {
107 joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
108 } else {
109 joined |= try_join_v6(
110 &self.sock,
111 self.ipv4_addr.ip().to_ipv6_mapped(),
112 nic.index,
113 )
114 }
115 }
116 (IpAddr::V6(addr), true) => {
117 if addr.is_loopback() {
118 continue;
119 }
120 if addr.is_unicast_link_local() {
121 has_link_local = true;
122 } else {
123 has_site_local = true;
124 }
125 }
126 _ => continue,
127 }
128 }
129
130 if has_site_local {
131 joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
132 }
133
134 if let Some(ll) = self.ipv6_link_local {
135 if has_link_local {
136 joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
137 }
138 }
139 }
140
141 if !joined {
142 return Err(Error::MulticastJoinFail);
143 }
144
145 self.sock
146 .socket()
147 .writable()
148 .await
149 .map_err(Error::Writeable)?;
150
151 Ok(())
152 }
153
154 pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
155 self.nics()
156 .iter()
157 .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
158 .find_map(|(nic, naddr)| {
159 let nm = naddr.netmask();
160 let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
161 (SocketAddr::V6(addr), _, _, Some(mlocal))
164 if addr.ip().is_unicast_link_local() =>
165 {
166 if nic.index != addr.scope_id() {
167 return None;
168 }
169 mlocal.with_scope_id(nic.index).into()
170 }
171
172 (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
174 if addr.ip().is_unique_local()
175 && addr.ip().to_bits() & mask.to_bits()
176 == naddr.to_bits() & mask.to_bits() =>
177 {
178 self.ipv6_site_local.into()
179 }
180
181 (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
183 if addr.ip().to_bits() & mask.to_bits()
184 == naddr.to_bits() & mask.to_bits() =>
185 {
186 self.ipv4_addr.into()
187 }
188 _ => return None,
189 };
190 Some(MulticastOpts {
191 interface_id: nic.index,
192 interface_addr: naddr.ip(),
193 mcast_addr,
194 })
195 })
196 }
197
198 pub async fn send_multicast_msg(
199 &self,
200 buf: &[u8],
201 opts: &MulticastOpts,
202 ) -> crate::Result<usize> {
203 std::future::poll_fn(|cx| {
207 let sref = SockRef::from(self.sock.socket());
208 if self.sock.bind_addr().is_ipv6() {
209 if let Err(e) = sref.set_multicast_if_v6(opts.interface_id) {
210 return Poll::Ready(Err(Error::SetMulticastIpv6(e)));
211 }
212 } else {
213 let ifaddr = match opts.interface_addr {
214 IpAddr::V4(ipv4_addr) => ipv4_addr,
215 IpAddr::V6(_) => {
216 return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch));
217 }
218 };
219 if let Err(e) = sref.set_multicast_if_v4(&ifaddr) {
220 return Poll::Ready(Err(Error::SetMulticastIpv4(e)));
221 }
222 };
223
224 self.sock
225 .poll_send_to(cx, buf, opts.mcast_addr)
226 .map_err(Error::Send)
227 })
228 .await
229 }
230
231 pub async fn try_send_mcast_everywhere(
232 &self,
233 get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
234 ) {
235 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
236
237 let mut send_specs = self
238 .nics
239 .iter()
240 .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
241 .filter_map(|(ifidx, ifaddr)| {
242 let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
243 (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
244 self.ipv4_addr.into()
245 }
246 (true, IpAddr::V6(a), Some(mlocal))
247 if !a.is_loopback() && a.is_unicast_link_local() =>
248 {
249 mlocal.with_scope_id(ifidx).into()
250 }
251 (true, IpAddr::V6(a), _) if !a.is_loopback() && a.is_unique_local() => {
252 self.ipv6_site_local.into()
253 }
254 _ => {
255 trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
256 return None;
257 }
258 };
259 Some(MulticastOpts {
260 interface_id: ifidx,
261 interface_addr: ifaddr,
262 mcast_addr,
263 })
264 })
265 .collect::<Vec<_>>();
266
267 send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
268 send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
269
270 let futs = send_specs.into_iter().filter_map(|opts| {
271 let payload = get_payload(&opts)?;
272 let fut = async move {
273 match self.send_multicast_msg(payload.as_bytes(), &opts).await {
274 Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
275 Err(e) => {
276 debug!(?opts, payload=?payload, "error sending: {e:#}")
277 }
278 };
279 };
280 Some(fut)
281 });
282
283 futures::future::join_all(futs).await;
284 }
285}
286
287fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
288 trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
289 if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
290 debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
291 return false;
292 }
293 true
294}
295
296fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
297 trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
298 if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
299 debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
300 return false;
301 }
302 true
303}
304
305#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
306pub struct MulticastOpts {
307 pub interface_id: u32,
308 pub interface_addr: IpAddr,
309 pub mcast_addr: SocketAddr,
310}
311
312impl MulticastOpts {
313 pub fn iface_ip(&self) -> IpAddr {
314 self.interface_addr
315 }
316
317 pub fn mcast_addr(&self) -> SocketAddr {
318 self.mcast_addr
319 }
320
321 fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
322 if bind_addr_is_ipv6 {
323 (Some(self.interface_id), None, self.mcast_addr)
324 } else {
325 (None, Some(self.interface_addr), self.mcast_addr)
326 }
327 }
328}