use std::io;
use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4};
use std::sync::Arc;
use socket2::{Domain, Protocol, Socket, Type};
use tokio::net::UdpSocket;
use super::iface_map::{IfaceInfo, IfaceMap};
#[derive(Debug, Clone, Copy)]
pub struct RecvMeta {
pub n: usize,
pub src: SocketAddr,
pub dst_ip: Option<Ipv4Addr>,
pub ifindex: Option<u32>,
pub iface_ip: Ipv4Addr,
}
pub struct NicSocket {
pub sock: Arc<UdpSocket>,
pub iface_ip: Ipv4Addr,
pub ifindex: u32,
pub netmask: Ipv4Addr,
pub broadcast: Option<Ipv4Addr>,
pub is_loopback: bool,
pub rx_only_bcast: bool,
}
impl std::fmt::Debug for NicSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NicSocket")
.field("iface_ip", &self.iface_ip)
.field("ifindex", &self.ifindex)
.field("netmask", &self.netmask)
.field("broadcast", &self.broadcast)
.field("is_loopback", &self.is_loopback)
.field("rx_only_bcast", &self.rx_only_bcast)
.field(
"local_addr",
&self
.sock
.local_addr()
.ok()
.unwrap_or_else(|| SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0))),
)
.finish()
}
}
pub struct AsyncUdpV4 {
sockets: Vec<NicSocket>,
}
impl AsyncUdpV4 {
pub fn bind(port: u16, broadcast: bool) -> io::Result<Self> {
Self::bind_with_map(&IfaceMap::new(), port, broadcast)
}
pub fn bind_non_loopback(port: u16, broadcast: bool) -> io::Result<Self> {
Self::bind_with_map_filtered(&IfaceMap::new(), port, broadcast, true)
}
pub fn bind_with_map(map: &IfaceMap, port: u16, broadcast: bool) -> io::Result<Self> {
Self::bind_with_map_filtered(map, port, broadcast, false)
}
fn bind_with_map_filtered(
map: &IfaceMap,
port: u16,
broadcast: bool,
skip_loopback: bool,
) -> io::Result<Self> {
let ifaces = map.all();
let mut sockets = Vec::with_capacity(ifaces.len() * 2);
for info in ifaces {
if skip_loopback && info.ip.is_loopback() {
continue;
}
match bind_one(&info, port, broadcast) {
Ok(nic) => sockets.push(nic),
Err(e) => {
tracing::debug!(
target: "epics_base_rs::net",
iface = %info.ip,
port,
error = %e,
"skipping NIC: bind failed"
);
}
}
#[cfg(not(target_os = "windows"))]
if let Some(bcast) = info.broadcast {
if !info.ip.is_loopback() && !bcast.is_unspecified() {
match bind_one_at(&info, bcast, port, broadcast, true) {
Ok(nic) => sockets.push(nic),
Err(e) => {
tracing::debug!(
target: "epics_base_rs::net",
iface = %info.ip,
bcast = %bcast,
port,
error = %e,
"skipping NIC bcast bind"
);
}
}
}
}
}
if sockets.is_empty() {
return Err(io::Error::new(
io::ErrorKind::AddrNotAvailable,
"AsyncUdpV4: failed to bind any interface",
));
}
Ok(Self { sockets })
}
pub fn bind_ephemeral_same_port(broadcast: bool) -> io::Result<Self> {
Self::bind_ephemeral_same_port_with_map(&IfaceMap::new(), broadcast)
}
pub fn bind_ephemeral_same_port_with_map(map: &IfaceMap, broadcast: bool) -> io::Result<Self> {
let ifaces = map.all();
let mut up_first: Vec<IfaceInfo> = Vec::with_capacity(ifaces.len());
for i in &ifaces {
if i.up_non_loopback {
up_first.push(i.clone());
}
}
for i in &ifaces {
if !i.up_non_loopback {
up_first.push(i.clone());
}
}
let mut iter = up_first.into_iter();
let first_info = iter
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::AddrNotAvailable, "no IPv4 NICs"))?;
let first = bind_one(&first_info, 0, broadcast)?;
let chosen_port = first
.sock
.local_addr()
.ok()
.map(|sa| sa.port())
.ok_or_else(|| {
io::Error::new(io::ErrorKind::Other, "could not read chosen UDP port")
})?;
let mut sockets = vec![first];
for info in iter {
match bind_one(&info, chosen_port, broadcast) {
Ok(nic) => sockets.push(nic),
Err(e) => {
tracing::debug!(
target: "epics_base_rs::net",
iface = %info.ip,
port = chosen_port,
error = %e,
"skipping NIC: same-port bind failed"
);
}
}
}
Ok(Self { sockets })
}
pub fn bind_single(iface_ip: Ipv4Addr, port: u16, broadcast: bool) -> io::Result<Self> {
let map = IfaceMap::new();
let info = map
.all()
.into_iter()
.find(|i| i.ip == iface_ip)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::AddrNotAvailable,
format!("AsyncUdpV4: iface {iface_ip} not found"),
)
})?;
let nic = bind_one(&info, port, broadcast)?;
Ok(Self { sockets: vec![nic] })
}
pub fn ifaces(&self) -> &[NicSocket] {
&self.sockets
}
pub fn local_addrs(&self) -> Vec<SocketAddr> {
self.sockets
.iter()
.filter_map(|n| n.sock.local_addr().ok())
.collect()
}
pub async fn send_to(&self, buf: &[u8], dest: SocketAddr) -> io::Result<usize> {
let v4 = match dest {
SocketAddr::V4(v) => v,
SocketAddr::V6(_) => {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"AsyncUdpV4 is IPv4-only",
));
}
};
let nic = self.pick_nic(*v4.ip())?;
nic.sock.send_to(buf, dest).await
}
pub async fn send_via(
&self,
buf: &[u8],
dest: SocketAddr,
iface_ip: Ipv4Addr,
) -> io::Result<usize> {
let nic = self
.sockets
.iter()
.find(|n| n.iface_ip == iface_ip && !n.rx_only_bcast)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::AddrNotAvailable,
format!("AsyncUdpV4: no socket bound to {iface_ip}"),
)
})?;
nic.sock.send_to(buf, dest).await
}
pub async fn send_via_ifindex(
&self,
buf: &[u8],
dest: SocketAddr,
ifindex: u32,
) -> io::Result<usize> {
let nic = self
.sockets
.iter()
.find(|n| n.ifindex == ifindex && n.ifindex != 0 && !n.rx_only_bcast)
.ok_or_else(|| {
io::Error::new(
io::ErrorKind::AddrNotAvailable,
format!("AsyncUdpV4: no socket with ifindex {ifindex}"),
)
})?;
nic.sock.send_to(buf, dest).await
}
pub async fn fanout_to(&self, buf: &[u8], dest: SocketAddr) -> io::Result<usize> {
let mut ok_count = 0usize;
let mut last_err: Option<io::Error> = None;
for nic in &self.sockets {
if nic.is_loopback || nic.rx_only_bcast {
continue;
}
match nic.sock.send_to(buf, dest).await {
Ok(_) => ok_count += 1,
Err(e) => {
tracing::debug!(
target: "epics_base_rs::net",
iface_ip = %nic.iface_ip,
%dest,
error = %e,
"fanout send failed"
);
last_err = Some(e);
}
}
}
if ok_count == 0 {
return Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"AsyncUdpV4: fanout had no eligible NICs",
)
}));
}
Ok(ok_count)
}
pub async fn recv_with_meta(&self, buf: &mut [u8]) -> io::Result<RecvMeta> {
let mut futures = Vec::with_capacity(self.sockets.len());
for nic in &self.sockets {
let sock = nic.sock.clone();
let info = (nic.iface_ip, nic.ifindex);
futures.push(Box::pin(async move {
let mut local = vec![0u8; 65535];
let r = sock.recv_from(&mut local).await;
(r, info, local)
}));
}
if futures.is_empty() {
return Err(io::Error::new(
io::ErrorKind::NotConnected,
"AsyncUdpV4: no NIC sockets",
));
}
let ((res, info, local), _idx, _rest) = select_all_owned(futures).await;
let (n, src) = res?;
let copy_len = n.min(buf.len());
buf[..copy_len].copy_from_slice(&local[..copy_len]);
let (iface_ip, ifindex) = info;
Ok(RecvMeta {
n: copy_len,
src,
dst_ip: Some(iface_ip),
ifindex: if ifindex == 0 { None } else { Some(ifindex) },
iface_ip,
})
}
pub async fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
let m = self.recv_with_meta(buf).await?;
Ok((m.n, m.src))
}
pub fn pick_nic(&self, dest: Ipv4Addr) -> io::Result<&NicSocket> {
let send_eligible = || self.sockets.iter().filter(|n| !n.rx_only_bcast);
for nic in send_eligible() {
if subnet_contains(nic.iface_ip, nic.netmask, dest) {
return Ok(nic);
}
}
for nic in send_eligible() {
if Some(dest) == nic.broadcast {
return Ok(nic);
}
}
if dest.is_loopback() {
if let Some(nic) = send_eligible().find(|n| n.is_loopback) {
return Ok(nic);
}
}
if let Some(nic) = send_eligible().find(|n| !n.is_loopback) {
return Ok(nic);
}
send_eligible().next().ok_or_else(|| {
io::Error::new(
io::ErrorKind::AddrNotAvailable,
"AsyncUdpV4: no NIC sockets",
)
})
}
pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
let mut ok = 0usize;
let mut last_err: Option<io::Error> = None;
for nic in &self.sockets {
let sref = socket_ref(&nic.sock);
match sref.set_recv_buffer_size(size) {
Ok(()) => ok += 1,
Err(e) => {
tracing::debug!(
target: "epics_base_rs::net",
iface_ip = %nic.iface_ip,
size,
error = %e,
"set_recv_buffer_size failed"
);
last_err = Some(e);
}
}
}
if ok == 0 {
return Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"AsyncUdpV4: set_recv_buffer_size had no eligible NICs",
)
}));
}
Ok(())
}
pub fn join_multicast_v4(&self, group: Ipv4Addr) -> io::Result<()> {
let mut ok = 0usize;
let mut last_err: Option<io::Error> = None;
for nic in &self.sockets {
if nic.is_loopback || nic.rx_only_bcast {
continue;
}
match nic.sock.join_multicast_v4(group, nic.iface_ip) {
Ok(()) => ok += 1,
Err(e) => {
tracing::debug!(
target: "epics_base_rs::net",
iface_ip = %nic.iface_ip,
%group,
error = %e,
"join_multicast_v4 failed"
);
last_err = Some(e);
}
}
}
if ok == 0 {
return Err(last_err.unwrap_or_else(|| {
io::Error::new(
io::ErrorKind::Other,
"AsyncUdpV4: join_multicast_v4 had no eligible NICs",
)
}));
}
Ok(())
}
}
fn socket_ref(sock: &UdpSocket) -> socket2::SockRef<'_> {
socket2::SockRef::from(sock)
}
fn bind_one(info: &IfaceInfo, port: u16, broadcast: bool) -> io::Result<NicSocket> {
bind_one_at(info, info.ip, port, broadcast, false)
}
fn bind_one_at(
info: &IfaceInfo,
bind_ip: Ipv4Addr,
port: u16,
broadcast: bool,
rx_only_bcast: bool,
) -> io::Result<NicSocket> {
let sock = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP))?;
#[cfg(not(windows))]
sock.set_reuse_address(true)?;
#[cfg(unix)]
sock.set_reuse_port(true)?;
if broadcast {
sock.set_broadcast(true)?;
}
#[cfg(target_os = "linux")]
{
let _ = sock.set_multicast_all_v4(false);
}
sock.set_nonblocking(true)?;
let bind_addr: SocketAddr = SocketAddr::V4(SocketAddrV4::new(bind_ip, port));
sock.bind(&bind_addr.into())?;
let std_sock: std::net::UdpSocket = sock.into();
let tokio_sock = UdpSocket::from_std(std_sock)?;
Ok(NicSocket {
sock: Arc::new(tokio_sock),
iface_ip: info.ip,
ifindex: info.index,
netmask: info.netmask,
broadcast: info.broadcast,
is_loopback: info.ip.is_loopback(),
rx_only_bcast,
})
}
fn subnet_contains(ip: Ipv4Addr, mask: Ipv4Addr, candidate: Ipv4Addr) -> bool {
let m = u32::from(mask);
if m == 0 {
return false;
}
(u32::from(ip) & m) == (u32::from(candidate) & m)
}
async fn select_all_owned<F, T>(
mut futures: Vec<std::pin::Pin<Box<F>>>,
) -> (T, usize, Vec<std::pin::Pin<Box<F>>>)
where
F: std::future::Future<Output = T> + ?Sized,
{
use std::future::poll_fn;
use std::task::Poll;
let (out, idx) = poll_fn(|cx| {
for (i, fut) in futures.iter_mut().enumerate() {
if let Poll::Ready(v) = fut.as_mut().poll(cx) {
return Poll::Ready((v, i));
}
}
Poll::Pending
})
.await;
let _completed = futures.swap_remove(idx);
(out, idx, futures)
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn loopback_send_and_recv() {
let sender = AsyncUdpV4::bind(0, false).expect("sender bind");
let receiver = AsyncUdpV4::bind(0, false).expect("receiver bind");
let lo_addr = receiver
.ifaces()
.iter()
.find(|n| n.is_loopback)
.map(|n| n.sock.local_addr().unwrap())
.expect("loopback NIC must exist");
let payload = b"libca-fanout";
let _n = sender.send_to(payload, lo_addr).await.expect("send to lo");
let mut buf = [0u8; 64];
let meta = tokio::time::timeout(
std::time::Duration::from_secs(2),
receiver.recv_with_meta(&mut buf),
)
.await
.expect("recv timeout")
.expect("recv ok");
assert_eq!(meta.n, payload.len());
assert_eq!(&buf[..meta.n], payload);
assert!(
meta.iface_ip.is_loopback(),
"expected loopback iface_ip, got {:?}",
meta.iface_ip
);
}
#[tokio::test]
async fn send_via_loopback_iface_ip() {
let sock = AsyncUdpV4::bind(0, false).expect("bind");
let lo_iface = sock
.ifaces()
.iter()
.find(|n| n.is_loopback)
.expect("loopback NIC must exist")
.iface_ip;
let receiver = AsyncUdpV4::bind(0, false).expect("recv bind");
let dest = receiver
.ifaces()
.iter()
.find(|n| n.is_loopback)
.map(|n| n.sock.local_addr().unwrap())
.unwrap();
let n = sock.send_via(b"x", dest, lo_iface).await.expect("send_via");
assert_eq!(n, 1);
}
#[tokio::test]
async fn bind_ephemeral_same_port_uses_one_port_across_nics() {
let sock = AsyncUdpV4::bind_ephemeral_same_port(false).expect("bind same-port");
let ports: Vec<u16> = sock
.ifaces()
.iter()
.filter_map(|n| n.sock.local_addr().ok().map(|sa| sa.port()))
.collect();
assert!(!ports.is_empty(), "at least one bound port");
let first = ports[0];
for p in &ports {
assert_eq!(*p, first, "all NIC sockets must share one port");
}
assert!(first != 0, "ephemeral port must be non-zero");
}
#[tokio::test]
async fn send_via_unknown_iface_returns_addr_not_available() {
let sock = AsyncUdpV4::bind(0, false).expect("bind");
let bogus = Ipv4Addr::new(203, 0, 113, 99);
let dest = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 9999));
let err = sock
.send_via(b"x", dest, bogus)
.await
.expect_err("unknown iface must fail");
assert_eq!(err.kind(), io::ErrorKind::AddrNotAvailable);
}
#[tokio::test]
async fn pick_nic_loopback() {
let sock = AsyncUdpV4::bind(0, false).expect("bind");
let nic = sock.pick_nic(Ipv4Addr::LOCALHOST).expect("pick");
assert!(nic.is_loopback || nic.iface_ip.is_loopback());
}
#[test]
fn subnet_contains_basic() {
let ip = Ipv4Addr::new(10, 0, 0, 5);
let mask = Ipv4Addr::new(255, 255, 255, 0);
assert!(subnet_contains(ip, mask, Ipv4Addr::new(10, 0, 0, 99)));
assert!(!subnet_contains(ip, mask, Ipv4Addr::new(10, 0, 1, 1)));
assert!(!subnet_contains(
Ipv4Addr::UNSPECIFIED,
Ipv4Addr::UNSPECIFIED,
Ipv4Addr::new(8, 8, 8, 8)
));
}
}