use std::io;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use socket2::{Domain, Protocol, Socket, Type};
use tokio::net::UdpSocket;
pub const ORIGIN_TAG_MCAST_GROUP: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 128);
pub fn bind_loopback_mcast(port: u16) -> io::Result<UdpSocket> {
let sock = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
#[cfg(not(windows))]
sock.set_reuse_address(true)?;
#[cfg(unix)]
sock.set_reuse_port(true)?;
#[cfg(target_os = "linux")]
{
let _ = sock.set_multicast_all_v4(false);
}
sock.set_nonblocking(true)?;
let bind_addr: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, port));
sock.bind(&bind_addr.into())?;
sock.join_multicast_v4(&ORIGIN_TAG_MCAST_GROUP, &Ipv4Addr::LOCALHOST)?;
sock.set_multicast_loop_v4(true)?;
sock.set_multicast_ttl_v4(1)?;
sock.set_multicast_if_v4(&Ipv4Addr::LOCALHOST)?;
let std_sock: std::net::UdpSocket = sock.into();
UdpSocket::from_std(std_sock)
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[tokio::test]
async fn bind_succeeds_and_reports_wildcard_local_addr() {
let sock = bind_loopback_mcast(0).expect("bind");
let local = sock.local_addr().expect("local_addr");
match local {
SocketAddr::V4(v4) => {
assert!(v4.ip().is_unspecified(), "expected 0.0.0.0, got {v4}");
assert!(v4.port() != 0, "ephemeral port must be assigned");
}
_ => panic!("expected V4 local addr, got {local}"),
}
}
#[tokio::test]
async fn self_loopback_send_is_received() {
let sock = bind_loopback_mcast(0).expect("bind");
let port = sock.local_addr().unwrap().port();
let dest = SocketAddr::V4(SocketAddrV4::new(ORIGIN_TAG_MCAST_GROUP, port));
sock.send_to(b"origin-tag-loop", dest).await.expect("send");
let mut buf = [0u8; 64];
let (n, _src) = tokio::time::timeout(Duration::from_secs(2), sock.recv_from(&mut buf))
.await
.expect("recv timeout")
.expect("recv ok");
assert_eq!(&buf[..n], b"origin-tag-loop");
}
#[cfg(unix)]
#[tokio::test]
async fn two_listeners_both_receive() {
let a = bind_loopback_mcast(0).expect("bind a");
let port = a.local_addr().unwrap().port();
let b = bind_loopback_mcast(port).expect("bind b shares port");
let dest = SocketAddr::V4(SocketAddrV4::new(ORIGIN_TAG_MCAST_GROUP, port));
a.send_to(b"shared", dest).await.expect("send");
let mut buf_a = [0u8; 32];
let mut buf_b = [0u8; 32];
let (na, _) = tokio::time::timeout(Duration::from_secs(2), a.recv_from(&mut buf_a))
.await
.expect("a recv timeout")
.expect("a recv ok");
let (nb, _) = tokio::time::timeout(Duration::from_secs(2), b.recv_from(&mut buf_b))
.await
.expect("b recv timeout")
.expect("b recv ok");
assert_eq!(&buf_a[..na], b"shared");
assert_eq!(&buf_b[..nb], b"shared");
}
}