1#[cfg(test)]
2mod tests;
3
4use std::{
5 collections::HashSet,
6 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7 sync::Mutex,
8 task::Poll,
9};
10
11use network_interface::{NetworkInterface, NetworkInterfaceConfig};
12use socket2::SockRef;
13use tracing::{debug, trace};
14
15use crate::{BindOpts, Error, UdpSocket};
16
17pub struct MulticastUdpSocket {
18 sock_v4: UdpSocket,
21 sock_v6: UdpSocket,
22 ipv4_addr: Ipv4Addr,
23 ipv6_site_local: Ipv6Addr,
24 ipv6_link_local: Option<Ipv6Addr>,
25 nics: Vec<NetworkInterface>,
26}
27
28impl MulticastUdpSocket {
29 pub fn new(
30 port: u16,
31 ipv4_addr: Ipv4Addr,
32 ipv6_site_local: Ipv6Addr,
33 ipv6_link_local: Option<Ipv6Addr>,
34 ) -> crate::Result<Self> {
35 if let Some(ll) = ipv6_link_local {
36 if !ipv6_is_link_local_mcast(ll) {
37 return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
38 }
39 }
40 if !ipv6_is_site_local_mcast(ipv6_site_local) {
41 return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
42 }
43 let nics = network_interface::NetworkInterface::show()
44 .into_iter()
45 .flatten()
46 .collect::<Vec<_>>();
47 if nics.is_empty() {
48 return Err(Error::NoNics);
49 }
50 let opts = BindOpts {
51 request_dualstack: false,
52 reuseport: true,
53 };
54 let sock_v4 = UdpSocket::bind_udp((Ipv4Addr::UNSPECIFIED, port).into(), opts)?;
55 let sock_v6 = UdpSocket::bind_udp((Ipv6Addr::UNSPECIFIED, port).into(), opts)?;
56 let sock = Self {
57 sock_v4,
58 sock_v6,
59 ipv4_addr,
60 ipv6_link_local,
61 ipv6_site_local,
62 nics,
63 };
64 sock.bind_multicast()?;
65 Ok(sock)
66 }
67
68 pub fn nics(&self) -> &[NetworkInterface] {
69 &self.nics
70 }
71
72 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
73 std::future::poll_fn(|cx| {
74 let mut buf = tokio::io::ReadBuf::new(buf);
75 if let Poll::Ready(res) = self.sock_v4.socket().poll_recv_from(cx, &mut buf) {
76 return Poll::Ready(res.map(|addr| (buf.filled().len(), addr)));
77 }
78 if let Poll::Ready(res) = self.sock_v6.socket().poll_recv_from(cx, &mut buf) {
79 return Poll::Ready(res.map(|addr| (buf.filled().len(), addr)));
80 }
81 Poll::Pending
82 })
83 .await
84 }
85
86 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
87 let sock = if addr.is_ipv6() {
88 &self.sock_v6
89 } else {
90 &self.sock_v4
91 };
92 sock.send_to(buf, addr).await
93 }
94
95 fn bind_multicast(&self) -> crate::Result<()> {
96 let mut joined = try_join_v4(&self.sock_v4, self.ipv4_addr, Ipv4Addr::UNSPECIFIED);
97
98 for nic in self.nics.iter() {
99 let mut has_link_local = false;
100 let mut has_site_local = false;
101
102 for addr in nic.addr.iter() {
103 match addr.ip() {
104 IpAddr::V4(iface_addr)
105 if iface_addr.is_private() && !iface_addr.is_loopback() =>
106 {
107 joined |= try_join_v4(&self.sock_v4, self.ipv4_addr, iface_addr);
108 }
109 IpAddr::V6(addr) => {
110 if addr.is_loopback() {
111 continue;
112 }
113 if ipv6_is_link_local(addr) {
114 has_link_local = true;
115 } else {
116 has_site_local = true;
117 }
118 }
119 _ => continue,
120 }
121 }
122
123 if has_site_local {
124 joined |= try_join_v6(&self.sock_v6, self.ipv6_site_local, nic.index);
125 }
126
127 if let Some(ll) = self.ipv6_link_local {
128 if has_link_local {
129 joined |= try_join_v6(&self.sock_v6, ll, nic.index);
130 }
131 }
132 }
133
134 if !joined {
135 return Err(Error::MulticastJoinFail);
136 }
137
138 Ok(())
139 }
140
141 async fn send_to_once(&self, buf: &[u8], opts: &MulticastOpts) -> std::io::Result<usize> {
142 std::future::poll_fn(|cx| {
146 let sock;
147 let mcast_addr_s: SocketAddr;
148
149 match opts {
150 MulticastOpts::V4 {
151 interface_addr,
152 mcast_addr,
153 } => {
154 sock = &self.sock_v4;
155 mcast_addr_s = (*mcast_addr).into();
156 if let Err(e) = SockRef::from(sock.socket()).set_multicast_if_v4(interface_addr)
157 {
158 debug!(addr=%interface_addr, "error calling set_multicast_if_v4: {e:#}");
159 return Poll::Ready(Err(e));
160 }
161 }
162 MulticastOpts::V6 {
163 interface_id,
164 mcast_addr,
165 ..
166 } => {
167 sock = &self.sock_v6;
168 mcast_addr_s = (*mcast_addr).into();
169 if let Err(e) = SockRef::from(sock.socket()).set_multicast_if_v6(*interface_id)
170 {
171 debug!(
172 oif_id = interface_id,
173 "error calling set_multicast_if_v6: {e:#}"
174 );
175 return Poll::Ready(Err(e));
176 }
177 }
178 }
179
180 sock.poll_send_to(cx, buf, mcast_addr_s)
181 })
182 .await
183 }
184
185 pub async fn try_send_mcast_everywhere(
186 &self,
187 get_payload: &impl Fn(&MulticastOpts) -> bstr::BString,
188 ) {
189 let _ = self.sock_v6.socket().writable().await;
195
196 let sent = Mutex::new(HashSet::new());
197 let sent = &sent;
198
199 let port = self.sock_v4.bind_addr().port();
200
201 let futs = self
202 .nics
203 .iter()
204 .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
205 .filter_map(|(ifidx, ifaddr)| {
206 let ipv6_link_local = self
207 .ipv6_link_local
208 .filter(|_| matches!(ifaddr, IpAddr::V6(v6) if ipv6_is_link_local(v6)));
209 let opts = match (ifaddr, ipv6_link_local) {
210 (IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => MulticastOpts::V4 {
211 interface_addr: a,
212 mcast_addr: SocketAddrV4::new(self.ipv4_addr, port),
213 },
214 (IpAddr::V6(a), Some(mlocal)) if !a.is_loopback() => MulticastOpts::V6 {
215 interface_id: ifidx,
216 interface_addr: a,
217 mcast_addr: SocketAddrV6::new(mlocal, port, 0, ifidx),
218 },
219 (IpAddr::V6(a), None) if !a.is_loopback() => MulticastOpts::V6 {
220 interface_id: ifidx,
221 interface_addr: a,
222 mcast_addr: SocketAddrV6::new(self.ipv6_site_local, port, 0, ifidx),
223 },
224 _ => {
225 trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
226 return None;
227 }
228 };
229 Some(opts)
230 })
231 .map(|opts| async move {
232 let payload = get_payload(&opts);
233 if !sent
234 .lock()
235 .unwrap()
236 .insert((payload.clone(), opts.uniq_key()))
237 {
238 trace!(?opts, "not sending duplicate payload");
239 return;
240 }
241
242 match self.send_to_once(payload.as_slice(), &opts).await {
243 Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
244 Err(e) => {
245 debug!(?opts, payload=?payload, "error sending: {e:#}")
246 }
247 }
248 });
249
250 futures::future::join_all(futs).await;
251 }
252}
253
254fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
255 trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
256 if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
257 debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
258 return false;
259 }
260 true
261}
262
263fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
264 trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
265 if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
266 debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
267 return false;
268 }
269 true
270}
271
272fn ipv6_is_link_local(ip: Ipv6Addr) -> bool {
273 const LL: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0);
274 const MASK: Ipv6Addr = Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
275
276 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
277}
278
279fn ipv6_is_link_local_mcast(ip: Ipv6Addr) -> bool {
280 const LL: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0);
281 const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
282
283 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
284}
285
286fn ipv6_is_site_local_mcast(ip: Ipv6Addr) -> bool {
287 const LL: Ipv6Addr = Ipv6Addr::new(0xff05, 0, 0, 0, 0, 0, 0, 0);
288 const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
289
290 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
291}
292
293#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
294pub enum MulticastOpts {
295 V4 {
296 interface_addr: Ipv4Addr,
297 mcast_addr: SocketAddrV4,
298 },
299 V6 {
300 interface_id: u32,
301 interface_addr: Ipv6Addr,
302 mcast_addr: SocketAddrV6,
303 },
304}
305
306impl MulticastOpts {
307 pub fn iface_ip(&self) -> IpAddr {
308 match self {
309 MulticastOpts::V4 { interface_addr, .. } => (*interface_addr).into(),
310 MulticastOpts::V6 { interface_addr, .. } => (*interface_addr).into(),
311 }
312 }
313
314 pub fn mcast_addr(&self) -> SocketAddr {
315 match self {
316 MulticastOpts::V4 { mcast_addr, .. } => (*mcast_addr).into(),
317 MulticastOpts::V6 { mcast_addr, .. } => (*mcast_addr).into(),
318 }
319 }
320
321 fn uniq_key(&self) -> (Option<u32>, Option<Ipv4Addr>, SocketAddr) {
322 match self {
323 MulticastOpts::V4 {
324 interface_addr,
325 mcast_addr,
326 } => (None, Some(*interface_addr), (*mcast_addr).into()),
327 MulticastOpts::V6 {
328 interface_id,
329 mcast_addr,
330 ..
331 } => (Some(*interface_id), None, (*mcast_addr).into()),
332 }
333 }
334}