1#[cfg(test)]
2mod tests;
3
4use std::{
5 collections::HashSet,
6 net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
7 sync::{Arc, Mutex},
8 task::Poll,
9};
10
11use network_interface::{NetworkInterface, NetworkInterfaceConfig};
12use parking_lot::RwLock;
13use socket2::SockRef;
14use tracing::{debug, trace};
15
16use crate::{BindOpts, Error, UdpSocket};
17
18pub struct MulticastUdpSocket {
21 sock_v4: UdpSocket,
24 sock_v6: UdpSocket,
25 ipv4_addr: Ipv4Addr,
26 ipv6_site_local: Ipv6Addr,
27 ipv6_link_local: Option<Ipv6Addr>,
28 nics: Vec<NetworkInterface>,
29}
30
31impl MulticastUdpSocket {
32 pub fn new(
33 port: u16,
34 ipv4_addr: Ipv4Addr,
35 ipv6_site_local: Ipv6Addr,
36 ipv6_link_local: Option<Ipv6Addr>,
37 ) -> crate::Result<Self> {
38 if let Some(ll) = ipv6_link_local {
39 if !ipv6_is_link_local_mcast(ll) {
40 return Err(Error::ProvidedLinkLocalAddrIsntLinkLocal);
41 }
42 }
43 if !ipv6_is_site_local_mcast(ipv6_site_local) {
44 return Err(Error::ProvidedSiteLocalAddrIsNotSiteLocal);
45 }
46 let nics = network_interface::NetworkInterface::show()
47 .into_iter()
48 .flatten()
49 .collect::<Vec<_>>();
50 if nics.is_empty() {
51 return Err(Error::NoNics);
52 }
53 let opts = BindOpts {
54 request_dualstack: false,
55 reuseport: true,
56 };
57 let sock_v4 = UdpSocket::bind_udp((Ipv4Addr::UNSPECIFIED, port).into(), opts)?;
58 let sock_v6 = UdpSocket::bind_udp((Ipv6Addr::UNSPECIFIED, port).into(), opts)?;
59 let sock = Self {
60 sock_v4,
61 sock_v6,
62 ipv4_addr,
63 ipv6_link_local,
64 ipv6_site_local,
65 nics,
66 };
67 sock.bind_multicast()?;
68 Ok(sock)
69 }
70
71 pub fn nics(&self) -> &[NetworkInterface] {
72 &self.nics
73 }
74
75 pub async fn recv_from(&self, buf: &mut [u8]) -> std::io::Result<(usize, SocketAddr)> {
76 std::future::poll_fn(|cx| {
77 let mut buf = tokio::io::ReadBuf::new(buf);
78 if let Poll::Ready(res) = self.sock_v4.socket().poll_recv_from(cx, &mut buf) {
79 return Poll::Ready(res.map(|addr| (buf.filled().len(), addr)));
80 }
81 if let Poll::Ready(res) = self.sock_v6.socket().poll_recv_from(cx, &mut buf) {
82 return Poll::Ready(res.map(|addr| (buf.filled().len(), addr)));
83 }
84 Poll::Pending
85 })
86 .await
87 }
88
89 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
90 let sock = if addr.is_ipv6() {
91 &self.sock_v6
92 } else {
93 &self.sock_v4
94 };
95 sock.send_to(buf, addr).await
96 }
97
98 fn bind_multicast(&self) -> crate::Result<()> {
99 let mut joined = try_join_v4(&self.sock_v4, self.ipv4_addr, Ipv4Addr::UNSPECIFIED);
100
101 for nic in self.nics.iter() {
102 let mut has_link_local = false;
103 let mut has_site_local = false;
104
105 for addr in nic.addr.iter() {
106 match addr.ip() {
107 IpAddr::V4(iface_addr)
108 if iface_addr.is_private() && !iface_addr.is_loopback() =>
109 {
110 joined |= try_join_v4(&self.sock_v4, self.ipv4_addr, iface_addr);
111 }
112 IpAddr::V6(addr) => {
113 if addr.is_loopback() {
114 continue;
115 }
116 if ipv6_is_link_local(addr) {
117 has_link_local = true;
118 } else {
119 has_site_local = true;
120 }
121 }
122 _ => continue,
123 }
124 }
125
126 if has_site_local {
127 joined |= try_join_v6(&self.sock_v6, self.ipv6_site_local, nic.index);
128 }
129
130 if let Some(ll) = self.ipv6_link_local {
131 if has_link_local {
132 joined |= try_join_v6(&self.sock_v6, ll, nic.index);
133 }
134 }
135 }
136
137 if !joined {
138 return Err(Error::MulticastJoinFail);
139 }
140
141 Ok(())
142 }
143
144 async fn send_to_once(&self, buf: &[u8], opts: &MulticastOpts) -> std::io::Result<usize> {
145 std::future::poll_fn(|cx| {
149 let sock;
150 let mcast_addr_s: SocketAddr;
151
152 match opts {
153 MulticastOpts::V4 {
154 interface_addr,
155 mcast_addr,
156 } => {
157 sock = &self.sock_v4;
158 mcast_addr_s = (*mcast_addr).into();
159 if let Err(e) = SockRef::from(sock.socket()).set_multicast_if_v4(interface_addr)
160 {
161 debug!(addr=%interface_addr, "error calling set_multicast_if_v4: {e:#}");
162 return Poll::Ready(Err(e));
163 }
164 }
165 MulticastOpts::V6 {
166 interface_id,
167 mcast_addr,
168 ..
169 } => {
170 sock = &self.sock_v6;
171 mcast_addr_s = (*mcast_addr).into();
172 if let Err(e) = SockRef::from(sock.socket()).set_multicast_if_v6(*interface_id)
173 {
174 debug!(
175 oif_id = interface_id,
176 "error calling set_multicast_if_v6: {e:#}"
177 );
178 return Poll::Ready(Err(e));
179 }
180 }
181 }
182
183 sock.poll_send_to(cx, buf, mcast_addr_s)
184 })
185 .await
186 }
187
188 pub async fn try_send_mcast_everywhere(
189 &self,
190 get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
191 ) {
192 let _ = self.sock_v6.socket().writable().await;
198
199 let sent = Mutex::new(HashSet::new());
200 let sent = &sent;
201
202 let port = self.sock_v4.bind_addr().port();
203
204 let futs = self
205 .nics
206 .iter()
207 .flat_map(|ni| ni.addr.iter().map(move |a| (ni.index, a.ip())))
208 .filter_map(|(ifidx, ifaddr)| {
209 let ipv6_link_local = self
210 .ipv6_link_local
211 .filter(|_| matches!(ifaddr, IpAddr::V6(v6) if ipv6_is_link_local(v6)));
212 let opts = match (ifaddr, ipv6_link_local) {
213 (IpAddr::V4(a), _) if !a.is_loopback() && a.is_private() => MulticastOpts::V4 {
214 interface_addr: a,
215 mcast_addr: SocketAddrV4::new(self.ipv4_addr, port),
216 },
217 (IpAddr::V6(a), Some(mlocal)) if !a.is_loopback() => MulticastOpts::V6 {
218 interface_id: ifidx,
219 interface_addr: a,
220 mcast_addr: SocketAddrV6::new(mlocal, port, 0, ifidx),
221 },
222 (IpAddr::V6(a), None) if !a.is_loopback() => MulticastOpts::V6 {
223 interface_id: ifidx,
224 interface_addr: a,
225 mcast_addr: SocketAddrV6::new(self.ipv6_site_local, port, 0, ifidx),
226 },
227 _ => {
228 trace!(oif_id=ifidx, addr=%ifaddr, "ignoring address");
229 return None;
230 }
231 };
232 Some(opts)
233 })
234 .filter_map(|opts| {
235 let payload = get_payload(&opts)?;
236 let fut = async move {
237 if !sent
238 .lock()
239 .unwrap()
240 .insert((payload.clone(), opts.uniq_key()))
241 {
242 trace!(?opts, "not sending duplicate payload");
243 return;
244 }
245
246 match self.send_to_once(payload.as_bytes(), &opts).await {
247 Ok(sz) => trace!(?opts, size=sz, payload=?payload, "sent"),
248 Err(e) => {
249 debug!(?opts, payload=?payload, "error sending: {e:#}")
250 }
251 };
252 };
253 Some(fut)
254 });
255
256 futures::future::join_all(futs).await;
257 }
258}
259
260fn try_join_v4(sock: &UdpSocket, addr: Ipv4Addr, iface: Ipv4Addr) -> bool {
261 trace!(multiaddr=?addr, interface=?iface, "joining multicast v4 group");
262 if let Err(e) = sock.socket().join_multicast_v4(addr, iface) {
263 debug!(multiaddr=?addr, interface=?iface, "error joining multicast v4 group: {e:#}");
264 return false;
265 }
266 true
267}
268
269fn try_join_v6(sock: &UdpSocket, addr: Ipv6Addr, ifindex: u32) -> bool {
270 trace!(multiaddr=?addr, interface=?ifindex, "joining multicast v6 group");
271 if let Err(e) = sock.socket().join_multicast_v6(&addr, ifindex) {
272 debug!(multiaddr=?addr, interface=?ifindex, "error joining multicast v6 group: {e:#}");
273 return false;
274 }
275 true
276}
277
278fn ipv6_is_link_local(ip: Ipv6Addr) -> bool {
279 const LL: Ipv6Addr = Ipv6Addr::new(0xfe80, 0, 0, 0, 0, 0, 0, 0);
280 const MASK: Ipv6Addr = Ipv6Addr::new(0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
281
282 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
283}
284
285fn ipv6_is_link_local_mcast(ip: Ipv6Addr) -> bool {
286 const LL: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0);
287 const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
288
289 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
290}
291
292fn ipv6_is_site_local_mcast(ip: Ipv6Addr) -> bool {
293 const LL: Ipv6Addr = Ipv6Addr::new(0xff05, 0, 0, 0, 0, 0, 0, 0);
294 const MASK: Ipv6Addr = Ipv6Addr::new(0xff0f, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0);
295
296 ip.to_bits() & MASK.to_bits() == LL.to_bits() & MASK.to_bits()
297}
298
299#[derive(Debug, Hash, Clone, Copy, PartialEq, Eq)]
300pub enum MulticastOpts {
301 V4 {
302 interface_addr: Ipv4Addr,
303 mcast_addr: SocketAddrV4,
304 },
305 V6 {
306 interface_id: u32,
307 interface_addr: Ipv6Addr,
308 mcast_addr: SocketAddrV6,
309 },
310}
311
312impl MulticastOpts {
313 pub fn iface_ip(&self) -> IpAddr {
314 match self {
315 MulticastOpts::V4 { interface_addr, .. } => (*interface_addr).into(),
316 MulticastOpts::V6 { interface_addr, .. } => (*interface_addr).into(),
317 }
318 }
319
320 pub fn mcast_addr(&self) -> SocketAddr {
321 match self {
322 MulticastOpts::V4 { mcast_addr, .. } => (*mcast_addr).into(),
323 MulticastOpts::V6 { mcast_addr, .. } => (*mcast_addr).into(),
324 }
325 }
326
327 fn uniq_key(&self) -> (Option<u32>, Option<Ipv4Addr>, SocketAddr) {
328 match self {
329 MulticastOpts::V4 {
330 interface_addr,
331 mcast_addr,
332 } => (None, Some(*interface_addr), (*mcast_addr).into()),
333 MulticastOpts::V6 {
334 interface_id,
335 mcast_addr,
336 ..
337 } => (Some(*interface_id), None, (*mcast_addr).into()),
338 }
339 }
340}
341
342pub type HandlerFn = dyn Fn(&[u8], SocketAddr) + Send + Sync + 'static;
343pub type Handler = Box<HandlerFn>;
344
345pub struct SharedMulticastUdpSocket {
348 sock: MulticastUdpSocket,
349 handlers: RwLock<Vec<Handler>>,
350}
351
352impl SharedMulticastUdpSocket {
353 pub fn new(sock: MulticastUdpSocket) -> crate::Result<Arc<Self>> {
354 let sock = Arc::new(Self {
355 sock,
356 handlers: Default::default(),
357 });
358 Ok(sock)
359 }
360
361 pub fn add_handler(self: &Arc<Self>, handler: Handler) {
362 self.handlers.write().push(handler);
363 }
364
365 pub async fn send_to(&self, buf: &[u8], addr: SocketAddr) -> std::io::Result<usize> {
366 self.sock.send_to(buf, addr).await
367 }
368
369 pub async fn task_listen_forever(self: Arc<Self>) -> std::io::Result<()> {
370 let mut buf = [0u8; 4096];
371 loop {
372 let (sz, addr) = self.sock.recv_from(&mut buf).await?;
373 for handler in self.handlers.read().iter() {
374 handler(&buf[..sz], addr);
375 }
376 }
377 }
378
379 pub async fn try_send_mcast_everywhere(
380 &self,
381 get_payload: &impl Fn(&MulticastOpts) -> Option<String>,
382 ) {
383 self.sock.try_send_mcast_everywhere(get_payload).await
384 }
385}