Skip to main content

tokio_multicast/
lib.rs

1mod builder;
2mod config;
3mod diagnostics;
4mod error;
5mod interface;
6mod membership;
7mod packet;
8mod raw;
9mod socket;
10mod stream;
11mod sys;
12
13pub use builder::MulticastSocketBuilder;
14pub use config::MulticastConfig;
15pub use diagnostics::{
16    diagnose_multicast, diagnose_multicast_with_config, MulticastDiagnosticConfig,
17    MulticastDiagnostics, ProbeErrorKind, ProbeResult, ProbeStages,
18};
19pub use error::{MulticastError, Result};
20pub use interface::{Interface, InterfaceId};
21pub use membership::Membership;
22pub use packet::{Datagram, RecvMeta};
23pub use socket::MulticastSocket;
24pub use stream::MulticastReceiver;
25
26#[cfg(test)]
27mod tests {
28    use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket as StdUdpSocket};
29
30    use socket2::{Domain, Protocol, Socket, Type};
31    use tokio::time::{timeout, Duration};
32
33    use crate::{
34        diagnose_multicast, diagnose_multicast_with_config, Interface, Membership,
35        MulticastDiagnosticConfig, MulticastError, MulticastReceiver, MulticastSocket,
36    };
37
38    fn free_port() -> u16 {
39        StdUdpSocket::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))
40            .unwrap()
41            .local_addr()
42            .unwrap()
43            .port()
44    }
45
46    fn free_port_v6() -> Option<u16> {
47        StdUdpSocket::bind(SocketAddr::from((Ipv6Addr::LOCALHOST, 0)))
48            .ok()?
49            .local_addr()
50            .ok()
51            .map(|addr| addr.port())
52    }
53
54    #[cfg(target_os = "linux")]
55    fn loopback_ifindex_v6() -> Option<u32> {
56        crate::sys::loopback_interface_v6()
57    }
58
59    #[cfg(target_os = "macos")]
60    fn loopback_ifindex_v6() -> Option<u32> {
61        crate::sys::loopback_interface_v6()
62    }
63
64    #[cfg(target_os = "windows")]
65    fn loopback_ifindex_v6() -> Option<u32> {
66        crate::sys::loopback_interface_v6()
67    }
68
69    #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
70    fn loopback_ifindex_v6() -> Option<u32> {
71        crate::sys::loopback_interface_v6()
72    }
73
74    #[cfg(any(target_os = "linux", target_os = "macos", target_os = "windows"))]
75    #[test]
76    fn reuse_port_socket_option_smoke_test() {
77        let socket = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap();
78
79        crate::sys::set_reuse_port(&socket, true).unwrap();
80        crate::sys::set_reuse_port(&socket, false).unwrap();
81    }
82
83    #[tokio::test(flavor = "current_thread")]
84    async fn multicast_round_trip_ipv4() {
85        let port = free_port();
86        let socket = MulticastSocket::builder()
87            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, port)))
88            .join(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)))
89            .build()
90            .await
91            .unwrap();
92
93        socket.send_to_group(b"ping").await.unwrap();
94
95        let mut buf = [0_u8; 64];
96        let (size, _) = timeout(Duration::from_secs(2), socket.recv_from(&mut buf))
97            .await
98            .unwrap()
99            .unwrap();
100
101        assert_eq!(&buf[..size], b"ping");
102        assert!(socket.memberships().contains(&Membership::any_source(IpAddr::V4(
103            Ipv4Addr::new(224, 0, 0, 251),
104        ))));
105    }
106
107    #[tokio::test(flavor = "current_thread")]
108    async fn recv_datagram_includes_basic_metadata() {
109        let port = free_port();
110        let socket = MulticastSocket::builder()
111            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, port)))
112            .join(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 252)))
113            .build()
114            .await
115            .unwrap();
116
117        socket.send_to_group(b"meta").await.unwrap();
118        let datagram = timeout(Duration::from_secs(2), socket.recv_datagram(32))
119            .await
120            .unwrap()
121            .unwrap();
122
123        assert_eq!(datagram.payload.as_ref(), b"meta");
124        assert_eq!(datagram.meta.group, Some(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 252))));
125        assert!(datagram.meta.local_addr.is_some());
126        assert!(datagram.meta.timestamp.is_some());
127    }
128
129    #[tokio::test(flavor = "current_thread")]
130    async fn builder_requires_bind_port() {
131        let err = MulticastSocket::builder()
132            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, 0)))
133            .join(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)))
134            .build()
135            .await
136            .unwrap_err();
137
138        assert!(matches!(err, MulticastError::BindAddressRequired));
139    }
140
141    #[tokio::test(flavor = "current_thread")]
142    async fn builder_requires_membership() {
143        let err = MulticastSocket::builder()
144            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, free_port())))
145            .build()
146            .await
147            .unwrap_err();
148
149        assert!(matches!(err, MulticastError::NoMembershipsConfigured));
150    }
151
152    #[tokio::test(flavor = "current_thread")]
153    async fn builder_rejects_non_multicast_group() {
154        let err = MulticastSocket::builder()
155            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, free_port())))
156            .join(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)))
157            .build()
158            .await
159            .unwrap_err();
160
161        assert!(matches!(
162            err,
163            MulticastError::InvalidGroupAddress(IpAddr::V4(addr)) if addr == Ipv4Addr::new(127, 0, 0, 1)
164        ));
165    }
166
167    #[tokio::test(flavor = "current_thread")]
168    async fn duplicate_join_does_not_duplicate_membership_state() {
169        let group = IpAddr::V4(Ipv4Addr::new(224, 0, 0, 253));
170        let socket = MulticastSocket::builder()
171            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, free_port())))
172            .join(group)
173            .build()
174            .await
175            .unwrap();
176
177        socket.join(Membership::any_source(group)).await.unwrap();
178
179        let memberships = socket.memberships();
180        assert_eq!(memberships.len(), 1);
181        assert!(memberships.contains(&Membership::any_source(group)));
182    }
183
184    #[tokio::test(flavor = "current_thread")]
185    async fn leave_absent_membership_is_noop() {
186        let joined = IpAddr::V4(Ipv4Addr::new(224, 0, 0, 254));
187        let absent = Membership::any_source(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 200)));
188        let socket = MulticastSocket::builder()
189            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, free_port())))
190            .join(joined)
191            .build()
192            .await
193            .unwrap();
194
195        socket.leave(&absent).await.unwrap();
196
197        let memberships = socket.memberships();
198        assert_eq!(memberships.len(), 1);
199        assert!(memberships.contains(&Membership::any_source(joined)));
200    }
201
202    #[tokio::test(flavor = "current_thread")]
203    async fn dynamic_source_specific_join_is_rejected() {
204        let socket = MulticastSocket::builder()
205            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, free_port())))
206            .join(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)))
207            .build()
208            .await
209            .unwrap();
210
211        let err = socket
212            .join(Membership::source_specific(
213                IpAddr::V4(Ipv4Addr::new(232, 1, 1, 1)),
214                IpAddr::V4(Ipv4Addr::new(192, 168, 1, 10)),
215            ))
216            .await
217            .unwrap_err();
218
219        assert!(matches!(
220            err,
221            MulticastError::UnsupportedOption("dynamic source-specific membership")
222        ));
223    }
224
225    #[tokio::test(flavor = "current_thread")]
226    async fn leave_existing_membership_removes_it_from_state() {
227        let group = IpAddr::V4(Ipv4Addr::new(224, 0, 0, 155));
228        let membership = Membership::any_source(group);
229        let socket = MulticastSocket::builder()
230            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, free_port())))
231            .join(group)
232            .build()
233            .await
234            .unwrap();
235
236        socket.leave(&membership).await.unwrap();
237
238        let memberships = socket.memberships();
239        assert!(!memberships.contains(&membership));
240        assert!(memberships.is_empty());
241    }
242
243    #[tokio::test(flavor = "current_thread")]
244    async fn send_to_unicast_target_works() {
245        let receiver = tokio::net::UdpSocket::bind(SocketAddr::from((Ipv4Addr::LOCALHOST, 0)))
246            .await
247            .unwrap();
248        let target = receiver.local_addr().unwrap();
249
250        let socket = MulticastSocket::builder()
251            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, free_port())))
252            .join(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 156)))
253            .build()
254            .await
255            .unwrap();
256
257        socket.send_to(b"direct", target).await.unwrap();
258
259        let mut buf = [0_u8; 64];
260        let (size, from) = timeout(Duration::from_secs(2), receiver.recv_from(&mut buf))
261            .await
262            .unwrap()
263            .unwrap();
264
265        assert_eq!(&buf[..size], b"direct");
266        assert_eq!(from.ip(), IpAddr::V4(Ipv4Addr::LOCALHOST));
267    }
268
269    #[tokio::test(flavor = "current_thread")]
270    async fn receiver_wrapper_reads_datagram() {
271        let port = free_port();
272        let socket = MulticastSocket::builder()
273            .bind(SocketAddr::from((Ipv4Addr::UNSPECIFIED, port)))
274            .join(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 157)))
275            .build()
276            .await
277            .unwrap();
278
279        let receiver = MulticastReceiver::new(&socket, 64);
280        socket.send_to_group(b"stream").await.unwrap();
281
282        let datagram = timeout(Duration::from_secs(2), receiver.recv())
283            .await
284            .unwrap()
285            .unwrap();
286
287        assert_eq!(datagram.payload.as_ref(), b"stream");
288        assert_eq!(datagram.meta.group, Some(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 157))));
289    }
290
291    #[tokio::test(flavor = "current_thread")]
292    async fn config_and_local_addr_reflect_builder_values() {
293        let bind_addr = SocketAddr::from((Ipv4Addr::UNSPECIFIED, free_port()));
294        let socket = MulticastSocket::builder()
295            .bind(bind_addr)
296            .join(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 158)))
297            .reuse_addr(false)
298            .reuse_port(false)
299            .loopback(false)
300            .ttl(16)
301            .build()
302            .await
303            .unwrap();
304
305        let config = socket.config();
306        assert_eq!(config.bind_addr, bind_addr);
307        assert_eq!(config.reuse_addr, false);
308        assert_eq!(config.reuse_port, false);
309        assert_eq!(config.loopback, false);
310        assert_eq!(config.ttl, Some(16));
311        assert_eq!(socket.local_addr().unwrap().port(), bind_addr.port());
312    }
313
314    #[tokio::test(flavor = "current_thread")]
315    async fn ipv6_builder_path_smoke_test_if_supported() {
316        let Some(port) = free_port_v6() else {
317            return;
318        };
319        let Some(ifindex) = loopback_ifindex_v6() else {
320            return;
321        };
322        let group: Ipv6Addr = "ff02::114".parse().unwrap();
323
324        let result = MulticastSocket::builder()
325            .bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, port)))
326            .join(IpAddr::V6(group))
327            .inbound_interface(Interface::V6(ifindex))
328            .outbound_interface(Interface::V6(ifindex))
329            .build()
330            .await;
331
332        match result {
333            Ok(socket) => {
334                assert_eq!(socket.config().bind_addr, SocketAddr::from((Ipv6Addr::UNSPECIFIED, port)));
335                assert!(socket
336                    .memberships()
337                    .contains(&Membership::any_source(IpAddr::V6(group))));
338            }
339            Err(
340                MulticastError::UnsupportedOption(_)
341                | MulticastError::Io(_)
342                | MulticastError::BindFailed { .. },
343            ) => {
344                // Some environments expose IPv6 sockets but do not permit multicast join
345                // without a concrete interface index. Treat that as unsupported here.
346            }
347            Err(err) => panic!("unexpected IPv6 build result: {err}"),
348        }
349    }
350
351    #[tokio::test(flavor = "current_thread")]
352    async fn ipv6_round_trip_with_interface_index_if_supported() {
353        let Some(port) = free_port_v6() else {
354            return;
355        };
356        let Some(ifindex) = loopback_ifindex_v6() else {
357            return;
358        };
359        let group: Ipv6Addr = "ff02::114".parse().unwrap();
360
361        let result = MulticastSocket::builder()
362            .bind(SocketAddr::from((Ipv6Addr::UNSPECIFIED, port)))
363            .join(IpAddr::V6(group))
364            .inbound_interface(Interface::V6(ifindex))
365            .outbound_interface(Interface::V6(ifindex))
366            .build()
367            .await;
368
369        let socket = match result {
370            Ok(socket) => socket,
371            Err(
372                MulticastError::UnsupportedOption(_)
373                | MulticastError::Io(_)
374                | MulticastError::BindFailed { .. },
375            ) => return,
376            Err(err) => panic!("unexpected IPv6 round-trip build result: {err}"),
377        };
378
379        socket.send_to_group(b"ipv6").await.unwrap();
380
381        let mut buf = [0_u8; 64];
382        let (size, _) = timeout(Duration::from_secs(2), socket.recv_from(&mut buf))
383            .await
384            .unwrap()
385            .unwrap();
386
387        assert_eq!(&buf[..size], b"ipv6");
388    }
389
390    #[test]
391    fn diagnostics_returns_structured_report() {
392        let report = diagnose_multicast();
393
394        assert!(
395            !report.ipv4.label.is_empty() && !report.ipv6.label.is_empty(),
396            "probe labels should be populated"
397        );
398        assert!(
399            report.ipv4.details.is_some() || report.ipv4.error.is_some(),
400            "ipv4 probe should explain its outcome"
401        );
402    }
403
404    #[test]
405    fn diagnostics_accepts_custom_config() {
406        let config = MulticastDiagnosticConfig {
407            timeout: Duration::from_millis(200),
408            ..MulticastDiagnosticConfig::default()
409        };
410        let report = diagnose_multicast_with_config(&config);
411
412        assert!(
413            report.ipv4.details.as_deref().unwrap_or_default().contains("timeout")
414                || report.ipv4.error.is_some()
415        );
416    }
417}