use std::io;
use std::net::SocketAddr;
use socket2::{Domain, Protocol, Socket, Type};
use tokio::net::UdpSocket;
pub(crate) async fn bind_udp_socket(
addr: SocketAddr,
recv_buffer_size: Option<usize>,
send_buffer_size: Option<usize>,
reuse_address: bool,
) -> io::Result<UdpSocket> {
let domain = if addr.is_ipv6() {
Domain::IPV6
} else {
Domain::IPV4
};
let socket = Socket::new(domain, Type::DGRAM, Some(Protocol::UDP))?;
if addr.is_ipv6() {
socket.set_only_v6(false)?;
}
if reuse_address {
socket.set_reuse_address(true)?;
}
if let Some(size) = recv_buffer_size {
let _ = socket.set_recv_buffer_size(size);
}
if let Some(size) = send_buffer_size {
let _ = socket.set_send_buffer_size(size);
}
socket.set_nonblocking(true)?;
socket.bind(&addr.into())?;
UdpSocket::from_std(socket.into())
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_bind_udp_socket_ipv4() {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let socket = bind_udp_socket(addr, None, None, true).await.unwrap();
let local = socket.local_addr().unwrap();
assert!(local.is_ipv4());
assert_ne!(local.port(), 0);
}
#[tokio::test]
async fn test_bind_udp_socket_ipv6() {
let addr: SocketAddr = "[::1]:0".parse().unwrap();
let socket = bind_udp_socket(addr, None, None, true).await.unwrap();
let local = socket.local_addr().unwrap();
assert!(local.is_ipv6());
assert_ne!(local.port(), 0);
}
#[tokio::test]
async fn test_bind_udp_socket_with_buffer_size() {
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let socket = bind_udp_socket(addr, Some(1024 * 1024), None, true)
.await
.unwrap();
let local = socket.local_addr().unwrap();
assert!(local.is_ipv4());
assert_ne!(local.port(), 0);
}
}