1#[cfg(test)]
2mod tests;
3
4use std::{
5 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
6 task::Poll,
7};
8
9use network_interface::{NetworkInterface, NetworkInterfaceConfig};
10use socket2::SockRef;
11use tracing::{debug, trace};
12
13use crate::{BindOpts, Error, UdpSocket, addr::WithScopeId};
14
15pub struct MulticastUdpSocket {
18 sock: UdpSocket,
19 ipv4_addr: SocketAddrV4,
20 ipv6_site_local: SocketAddrV6,
21 ipv6_link_local: Option<SocketAddrV6>,
22 nics: Vec<NetworkInterface>,
23}
24
25impl MulticastUdpSocket {
26 pub async fn new(
27 bind_addr: SocketAddr,
28 ipv4_mcast_addr: SocketAddrV4,
29 ipv6_site_local_addr: SocketAddrV6,
30 ipv6_link_local_addr: Option<SocketAddrV6>,
31 ) -> crate::Result<Self> {
32 if let Some(ll) = ipv6_link_local_addr {
33 if !ipv6_is_link_local_mcast(ll.ip()) {
34 return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
35 }
36 }
37 if !ipv6_is_site_local_mcast(ipv6_site_local_addr.ip()) {
38 return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
39 }
40 let nics = network_interface::NetworkInterface::show()
41 .into_iter()
42 .flatten()
43 .collect::<Vec<_>>();
44 if nics.is_empty() {
45 return Err(Error::NoNics);
46 }
47 let opts = BindOpts {
48 request_dualstack: true,
49 reuseport: true,
50 };
51 let sock = UdpSocket::bind_udp(bind_addr, opts)?;
52 let sock = Self {
53 sock,
54 ipv4_addr: ipv4_mcast_addr,
55 ipv6_link_local: ipv6_link_local_addr,
56 ipv6_site_local: ipv6_site_local_addr,
57 nics,
58 };
59 sock.bind_multicast().await?;
60 Ok(sock)
61 }
62
63 pub fn nics(&self) -> &[NetworkInterface] {
64 &self.nics
65 }
66
67 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
68 self.sock.recv_from(buf).await
69 }
70
71 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
72 std::future::poll_fn(|cx| {
74 let sref = SockRef::from(self.sock.socket());
75 if self.sock.bind_addr().is_ipv6() {
76 if let Err(e) = sref.set_multicast_if_v6(0) {
77 trace!("error calling set_multicast_if_v6(0): {e:#}")
78 }
79 } else if let Err(e) = sref.set_multicast_if_v4(&Ipv4Addr::UNSPECIFIED) {
80 trace!("error calling set_multicast_if_v4(0.0.0.0): {e:#}")
81 }
82
83 self.sock.poll_send_to(cx, buf, addr)
84 })
85 .await
86 }
87
88 async fn bind_multicast(&self) -> crate::Result<()> {
89 let mut joined = false;
90 if self.sock.bind_addr().is_ipv4() {
91 joined = try_join_v4(&self.sock, *self.ipv4_addr.ip(), Ipv4Addr::UNSPECIFIED);
92 }
93
94 for nic in self.nics.iter() {
95 let mut has_link_local = false;
96 let mut has_site_local = false;
97
98 for addr in nic.addr.iter() {
99 match (addr.ip(), self.sock.bind_addr().is_ipv6()) {
100 (IpAddr::V4(iface_addr), is_ipv6)
101 if iface_addr.is_private() && !iface_addr.is_loopback() =>
102 {
103 if !is_ipv6 {
104 joined |= try_join_v4(&self.sock, *self.ipv4_addr.ip(), iface_addr);
105 } else {
106 joined |= try_join_v6(
107 &self.sock,
108 self.ipv4_addr.ip().to_ipv6_mapped(),
109 nic.index,
110 )
111 }
112 }
113 (IpAddr::V6(addr), true) => {
114 if addr.is_loopback() {
115 continue;
116 }
117 if ipv6_is_link_local(&addr) {
118 has_link_local = true;
119 } else {
120 has_site_local = true;
121 }
122 }
123 _ => continue,
124 }
125 }
126
127 if has_site_local {
128 joined |= try_join_v6(&self.sock, *self.ipv6_site_local.ip(), nic.index);
129 }
130
131 if let Some(ll) = self.ipv6_link_local {
132 if has_link_local {
133 joined |= try_join_v6(&self.sock, *ll.ip(), nic.index);
134 }
135 }
136 }
137
138 if !joined {
139 return Err(Error::MulticastJoinFail);
140 }
141
142 self.sock
143 .socket()
144 .writable()
145 .await
146 .map_err(Error::Writeable)?;
147
148 Ok(())
149 }
150
151 pub fn find_mcast_opts_for_replying_to(&self, addr: &SocketAddr) -> Option<MulticastOpts> {
152 self.nics()
153 .iter()
154 .flat_map(|nic| nic.addr.iter().map(move |addr| (nic, addr)))
155 .find_map(|(nic, naddr)| {
156 let nm = naddr.netmask();
157 let mcast_addr: SocketAddr = match (addr, naddr.ip(), nm, self.ipv6_link_local) {
158 (SocketAddr::V6(addr), _, _, Some(mlocal)) if ipv6_is_link_local(addr.ip()) => {
161 if nic.index != addr.scope_id() {
162 return None;
163 }
164 mlocal.with_scope_id(nic.index).into()
165 }
166
167 (SocketAddr::V6(addr), IpAddr::V6(naddr), Some(IpAddr::V6(mask)), _)
169 if ipv6_is_unique_local_address(addr.ip())
170 && addr.ip().to_bits() & mask.to_bits()
171 == naddr.to_bits() & mask.to_bits() =>
172 {
173 self.ipv6_site_local.into()
174 }
175
176 (SocketAddr::V4(addr), IpAddr::V4(naddr), Some(IpAddr::V4(mask)), _)
178 if addr.ip().to_bits() & mask.to_bits()
179 == naddr.to_bits() & mask.to_bits() =>
180 {
181 self.ipv4_addr.into()
182 }
183 _ => return None,
184 };
185 Some(MulticastOpts {
186 interface_id: nic.index,
187 interface_addr: naddr.ip(),
188 mcast_addr,
189 })
190 })
191 }
192
193 pub async fn send_multicast_msg(
194 &self,
195 buf: &[u8],
196 opts: &MulticastOpts,
197 ) -> crate::Result<usize> {
198 std::future::poll_fn(|cx| {
202 let sref = SockRef::from(self.sock.socket());
203 if self.sock.bind_addr().is_ipv6() {
204 if let Err(e) = sref.set_multicast_if_v6(opts.interface_id) {
205 return Poll::Ready(Err(Error::SetMulticastIpv6(e)));
206 }
207 } else {
208 let ifaddr = match opts.interface_addr {
209 IpAddr::V4(ipv4_addr) => ipv4_addr,
210 IpAddr::V6(_) => {
211 return Poll::Ready(Err(Error::SendMulticastMsgProtocolMismatch));
212 }
213 };
214 if let Err(e) = sref.set_multicast_if_v4(&ifaddr) {
215 return Poll::Ready(Err(Error::SetMulticastIpv4(e)));
216 }
217 };
218
219 self.sock
220 .poll_send_to(cx, buf, opts.mcast_addr)
221 .map_err(Error::Send)
222 })
223 .await
224 }
225
226 pub async fn try_send_mcast_everywhere(
227 &self,
228 get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
229 ) {
230 let bind_is_ipv6 = self.sock.bind_addr().is_ipv6();
231
232 let mut send_specs = self
233 .nics
234 .iter()
235 .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
236 .filter_map(|(ifidx, ifaddr)| {
237 let mcast_addr: SocketAddr = match (bind_is_ipv6, ifaddr, self.ipv6_link_local) {
238 (_, IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => {
239 self.ipv4_addr.into()
240 }
241 (true, IpAddr::V6(a), Some(mlocal))
242 if !a.is_loopback() && ipv6_is_link_local(&a) =>
243 {
244 mlocal.with_scope_id(ifidx).into()
245 }
246 (true, IpAddr::V6(a), _)
247 if !a.is_loopback() && ipv6_is_unique_local_address(&a) =>
248 {
249 self.ipv6_site_local.into()
250 }
251 _ => {
252 trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
253 return None;
254 }
255 };
256 Some(MulticastOpts {
257 interface_id: ifidx,
258 interface_addr: ifaddr,
259 mcast_addr,
260 })
261 })
262 .collect::<Vec<_>>();
263
264 send_specs.sort_by_key(|s| s.uniq_key(bind_is_ipv6));
265 send_specs.dedup_by_key(|s| s.uniq_key(bind_is_ipv6));
266
267 let futs = send_specs.into_iter().filter_map(|opts| {
268 let payload = get_payload(&opts)?;
269 let fut = async move {
270 match self.send_multicast_msg(payload.as_bytes(), &opts).await {
271 Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
272 Err(e) => {
273 debug!(?opts, payload=?payload, "error sending: {e:#}")
274 }
275 };
276 };
277 Some(fut)
278 });
279
280 futures::future::join_all(futs).await;
281 }
282}
283
284fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
285 trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
286 if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
287 debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
288 return false;
289 }
290 true
291}
292
293fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
294 trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
295 if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
296 debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
297 return false;
298 }
299 true
300}
301
302fn ipv6_is_link_local(ip: &Ipv6Addr) -> bool {
303 const LL: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0);
304 const MASK: Ipv6Addr = Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
305
306 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
307}
308
309fn ipv6_is_unique_local_address(ip: &Ipv6Addr) -> bool {
310 const LL: Ipv6Addr = Ipv6Addr::new(0xfc00, 0, 0, 0, 0, 0, 0, 0);
311 const MASK: Ipv6Addr = Ipv6Addr::new(0b1111111000000000, 0, 0, 0, 0, 0, 0, 0);
312
313 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
314}
315
316fn ipv6_is_link_local_mcast(ip: &Ipv6Addr) -> bool {
317 const LL: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0);
318 const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
319
320 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
321}
322
323fn ipv6_is_site_local_mcast(ip: &Ipv6Addr) -> bool {
324 const LL: Ipv6Addr = Ipv6Addr::new(0xff05, 0, 0, 0, 0, 0, 0, 0);
325 const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
326
327 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
328}
329
330#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
331pub struct MulticastOpts {
332 pub interface_id: u32,
333 pub interface_addr: IpAddr,
334 pub mcast_addr: SocketAddr,
335}
336
337impl MulticastOpts {
338 pub fn iface_ip(&self) -> IpAddr {
339 self.interface_addr
340 }
341
342 pub fn mcast_addr(&self) -> SocketAddr {
343 self.mcast_addr
344 }
345
346 fn uniq_key(&self, bind_addr_is_ipv6: bool) -> (Option<u32>, Option<IpAddr>, SocketAddr) {
347 if bind_addr_is_ipv6 {
348 (Some(self.interface_id), None, self.mcast_addr)
349 } else {
350 (None, Some(self.interface_addr), self.mcast_addr)
351 }
352 }
353}