use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket};
#[cfg(feature = "network-discovery")]
use super::port::buffer_defaults;
use super::port::{
BoundSocket, EndpointConfigError, EndpointPortConfig, IpMode, PortBinding, PortConfigResult,
PortRetryBehavior, SocketOptions,
};
fn validate_port(port: u16) -> PortConfigResult<()> {
if port < 1024 {
return Err(EndpointConfigError::PermissionDenied(port));
}
Ok(())
}
fn validate_port_range(start: u16, end: u16) -> PortConfigResult<()> {
if start >= end {
return Err(EndpointConfigError::InvalidConfig(format!(
"Invalid port range: start ({}) must be less than end ({})",
start, end
)));
}
if start < 1024 {
return Err(EndpointConfigError::PermissionDenied(start));
}
Ok(())
}
#[cfg(feature = "network-discovery")]
mod socket2_impl {
use super::*;
fn try_set_send_buffer(socket: &socket2::Socket, requested: usize) -> std::io::Result<usize> {
let mut size = requested;
while size >= buffer_defaults::MIN_BUFFER_SIZE {
if socket.set_send_buffer_size(size).is_ok() {
return socket.send_buffer_size();
}
size /= 2;
tracing::debug!(
"Send buffer size {} rejected, trying {} bytes",
size * 2,
size
);
}
if socket
.set_send_buffer_size(buffer_defaults::MIN_BUFFER_SIZE)
.is_ok()
{
return socket.send_buffer_size();
}
socket.send_buffer_size()
}
fn try_set_recv_buffer(socket: &socket2::Socket, requested: usize) -> std::io::Result<usize> {
let mut size = requested;
while size >= buffer_defaults::MIN_BUFFER_SIZE {
if socket.set_recv_buffer_size(size).is_ok() {
return socket.recv_buffer_size();
}
size /= 2;
tracing::debug!(
"Recv buffer size {} rejected, trying {} bytes",
size * 2,
size
);
}
if socket
.set_recv_buffer_size(buffer_defaults::MIN_BUFFER_SIZE)
.is_ok()
{
return socket.recv_buffer_size();
}
socket.recv_buffer_size()
}
pub fn create_dual_stack_socket(
port: u16,
opts: &SocketOptions,
) -> PortConfigResult<UdpSocket> {
use std::net::{Ipv6Addr, SocketAddrV6};
let socket = socket2::Socket::new(
socket2::Domain::IPV6,
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
socket.set_only_v6(false).map_err(|e| {
EndpointConfigError::BindFailed(format!("Failed to enable dual-stack: {e}"))
})?;
socket
.set_nonblocking(true)
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
if opts.reuse_address {
socket
.set_reuse_address(true)
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
}
if let Some(size) = opts.send_buffer_size {
if let Err(e) = try_set_send_buffer(&socket, size) {
tracing::warn!(
"Failed to set send buffer to {} bytes: {}. Using OS default.",
size,
e
);
}
}
if let Some(size) = opts.recv_buffer_size {
if let Err(e) = try_set_recv_buffer(&socket, size) {
tracing::warn!(
"Failed to set recv buffer to {} bytes: {}. Using OS default.",
size,
e
);
}
}
let addr = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, port, 0, 0);
socket.bind(&socket2::SockAddr::from(addr)).map_err(|e| {
if e.kind() == std::io::ErrorKind::AddrInUse {
EndpointConfigError::PortInUse(port)
} else if e.kind() == std::io::ErrorKind::PermissionDenied {
EndpointConfigError::PermissionDenied(port)
} else {
EndpointConfigError::BindFailed(e.to_string())
}
})?;
let std_socket: UdpSocket = socket.into();
Ok(std_socket)
}
pub fn create_socket(addr: &SocketAddr, opts: &SocketOptions) -> PortConfigResult<UdpSocket> {
let socket = socket2::Socket::new(
if addr.is_ipv4() {
socket2::Domain::IPV4
} else {
socket2::Domain::IPV6
},
socket2::Type::DGRAM,
Some(socket2::Protocol::UDP),
)
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
socket
.set_nonblocking(true)
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
if opts.reuse_address {
socket
.set_reuse_address(true)
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
}
#[allow(clippy::collapsible_if)]
if opts.reuse_port {
#[cfg(all(unix, not(target_os = "solaris"), not(target_os = "illumos")))]
{
tracing::debug!("SO_REUSEPORT requested but skipped for compatibility");
}
}
if let Some(size) = opts.send_buffer_size {
if let Err(e) = try_set_send_buffer(&socket, size) {
tracing::warn!(
"Failed to set send buffer to {} bytes: {}. Using OS default.",
size,
e
);
}
}
if let Some(size) = opts.recv_buffer_size {
if let Err(e) = try_set_recv_buffer(&socket, size) {
tracing::warn!(
"Failed to set recv buffer to {} bytes: {}. Using OS default.",
size,
e
);
}
}
socket.bind(&socket2::SockAddr::from(*addr)).map_err(|e| {
if e.kind() == std::io::ErrorKind::AddrInUse {
EndpointConfigError::PortInUse(addr.port())
} else if e.kind() == std::io::ErrorKind::PermissionDenied {
EndpointConfigError::PermissionDenied(addr.port())
} else {
EndpointConfigError::BindFailed(e.to_string())
}
})?;
let std_socket: UdpSocket = socket.into();
Ok(std_socket)
}
}
#[cfg(not(feature = "network-discovery"))]
mod std_impl {
use super::*;
pub fn create_socket(addr: &SocketAddr, opts: &SocketOptions) -> PortConfigResult<UdpSocket> {
let _ = opts;
let socket = UdpSocket::bind(addr).map_err(|e| {
if e.kind() == std::io::ErrorKind::AddrInUse {
EndpointConfigError::PortInUse(addr.port())
} else if e.kind() == std::io::ErrorKind::PermissionDenied {
EndpointConfigError::PermissionDenied(addr.port())
} else {
EndpointConfigError::BindFailed(e.to_string())
}
})?;
socket
.set_nonblocking(true)
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
Ok(socket)
}
}
fn create_socket(addr: &SocketAddr, opts: &SocketOptions) -> PortConfigResult<UdpSocket> {
#[cfg(feature = "network-discovery")]
{
socket2_impl::create_socket(addr, opts)
}
#[cfg(not(feature = "network-discovery"))]
{
std_impl::create_socket(addr, opts)
}
}
fn bind_single_socket(
port: u16,
ip_mode: &IpMode,
socket_opts: &SocketOptions,
) -> PortConfigResult<Vec<SocketAddr>> {
match ip_mode {
IpMode::IPv4Only => {
let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port);
let socket = create_socket(&addr, socket_opts)?;
let local_addr = socket
.local_addr()
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
std::mem::forget(socket);
Ok(vec![local_addr])
}
IpMode::IPv6Only => {
let addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
let socket = create_socket(&addr, socket_opts)?;
let local_addr = socket
.local_addr()
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
std::mem::forget(socket);
Ok(vec![local_addr])
}
IpMode::DualStack => {
#[cfg(feature = "network-discovery")]
{
match socket2_impl::create_dual_stack_socket(port, socket_opts) {
Ok(socket) => {
let local_addr = socket
.local_addr()
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
tracing::info!(
"Created true dual-stack socket on {} (accepts IPv4 and IPv6)",
local_addr
);
std::mem::forget(socket);
return Ok(vec![local_addr]);
}
Err(e) => {
tracing::debug!(
"True dual-stack socket failed: {:?}, falling back to separate sockets",
e
);
}
}
}
let v4_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::UNSPECIFIED), port);
let v6_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port);
let v4_socket = create_socket(&v4_addr, socket_opts)?;
let v4_local = v4_socket
.local_addr()
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
match create_socket(&v6_addr, socket_opts) {
Ok(v6_socket) => {
let v6_local = v6_socket
.local_addr()
.map_err(|e| EndpointConfigError::BindFailed(e.to_string()))?;
tracing::info!(
"Created separate IPv4 ({}) and IPv6 ({}) sockets (fallback mode)",
v4_local,
v6_local
);
std::mem::forget(v4_socket);
std::mem::forget(v6_socket);
Ok(vec![v4_local, v6_local])
}
Err(e) => {
tracing::debug!(
"IPv6 socket creation failed ({:?}), using IPv4-only mode",
e
);
tracing::info!(
"Created IPv4-only socket on {} (IPv6 not available on this system)",
v4_local
);
std::mem::forget(v4_socket);
Ok(vec![v4_local])
}
}
}
IpMode::DualStackSeparate {
ipv4_port,
ipv6_port,
} => {
let mut addrs = Vec::new();
let v4_addrs = bind_with_port_binding(ipv4_port, &IpMode::IPv4Only, socket_opts)?;
addrs.extend(v4_addrs);
let v6_addrs = bind_with_port_binding(ipv6_port, &IpMode::IPv6Only, socket_opts)?;
addrs.extend(v6_addrs);
Ok(addrs)
}
}
}
fn bind_with_port_binding(
port_binding: &PortBinding,
ip_mode: &IpMode,
socket_opts: &SocketOptions,
) -> PortConfigResult<Vec<SocketAddr>> {
match port_binding {
PortBinding::OsAssigned => bind_single_socket(0, ip_mode, socket_opts),
PortBinding::Explicit(port) => {
validate_port(*port)?;
bind_single_socket(*port, ip_mode, socket_opts)
}
PortBinding::Range(start, end) => {
validate_port_range(*start, *end)?;
for port in *start..=*end {
match bind_single_socket(port, ip_mode, socket_opts) {
Ok(addrs) => return Ok(addrs),
Err(EndpointConfigError::PortInUse(_)) => continue,
Err(e) => return Err(e),
}
}
Err(EndpointConfigError::NoPortInRange(*start, *end))
}
}
}
pub fn bind_endpoint(config: &EndpointPortConfig) -> PortConfigResult<BoundSocket> {
let addrs = match &config.port {
PortBinding::OsAssigned => bind_single_socket(0, &config.ip_mode, &config.socket_options)?,
PortBinding::Explicit(port) => {
validate_port(*port)?;
match bind_single_socket(*port, &config.ip_mode, &config.socket_options) {
Ok(addrs) => addrs,
Err(EndpointConfigError::PortInUse(_)) => match config.retry_behavior {
PortRetryBehavior::FailFast => {
return Err(EndpointConfigError::PortInUse(*port));
}
PortRetryBehavior::FallbackToOsAssigned => {
tracing::warn!("Port {} in use, falling back to OS-assigned", port);
bind_single_socket(0, &config.ip_mode, &config.socket_options)?
}
PortRetryBehavior::TryNext => {
return Err(EndpointConfigError::PortInUse(*port));
}
},
Err(e) => return Err(e),
}
}
PortBinding::Range(start, end) => {
validate_port_range(*start, *end)?;
bind_with_port_binding(&config.port, &config.ip_mode, &config.socket_options)?
}
};
Ok(BoundSocket {
addrs,
config: config.clone(),
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_port_privileged() {
assert!(matches!(
validate_port(80),
Err(EndpointConfigError::PermissionDenied(80))
));
assert!(matches!(
validate_port(443),
Err(EndpointConfigError::PermissionDenied(443))
));
assert!(matches!(
validate_port(1023),
Err(EndpointConfigError::PermissionDenied(1023))
));
}
#[test]
fn test_validate_port_valid() {
assert!(validate_port(1024).is_ok());
assert!(validate_port(9000).is_ok());
assert!(validate_port(65535).is_ok());
}
#[test]
fn test_validate_port_range_invalid() {
assert!(validate_port_range(9000, 9000).is_err());
assert!(validate_port_range(9010, 9000).is_err());
assert!(validate_port_range(80, 90).is_err());
}
#[test]
fn test_validate_port_range_valid() {
assert!(validate_port_range(9000, 9010).is_ok());
assert!(validate_port_range(1024, 2048).is_ok());
}
#[test]
fn test_bind_os_assigned_ipv4() {
let config = EndpointPortConfig {
port: PortBinding::OsAssigned,
ip_mode: IpMode::IPv4Only,
..Default::default()
};
let result = bind_endpoint(&config);
assert!(result.is_ok());
let bound = result.expect("bind_endpoint should succeed");
assert_eq!(bound.addrs.len(), 1);
assert!(bound.addrs[0].is_ipv4());
assert_ne!(bound.addrs[0].port(), 0); }
#[test]
fn test_bind_explicit_port() {
let config = EndpointPortConfig {
port: PortBinding::Explicit(12345),
ip_mode: IpMode::IPv4Only,
..Default::default()
};
let result = bind_endpoint(&config);
assert!(result.is_ok());
let bound = result.expect("bind_endpoint should succeed");
assert_eq!(bound.addrs.len(), 1);
assert_eq!(bound.addrs[0].port(), 12345);
}
#[test]
fn test_bind_privileged_port_fails() {
let config = EndpointPortConfig {
port: PortBinding::Explicit(80),
ip_mode: IpMode::IPv4Only,
..Default::default()
};
let result = bind_endpoint(&config);
assert!(matches!(
result,
Err(EndpointConfigError::PermissionDenied(80))
));
}
#[test]
fn test_bind_port_conflict() {
let config1 = EndpointPortConfig {
port: PortBinding::Explicit(23456),
ip_mode: IpMode::IPv4Only,
retry_behavior: PortRetryBehavior::FailFast,
..Default::default()
};
let _bound1 = bind_endpoint(&config1).expect("First bind should succeed");
let config2 = EndpointPortConfig {
port: PortBinding::Explicit(23456),
ip_mode: IpMode::IPv4Only,
retry_behavior: PortRetryBehavior::FailFast,
..Default::default()
};
let result2 = bind_endpoint(&config2);
assert!(matches!(
result2,
Err(EndpointConfigError::PortInUse(23456))
));
}
#[test]
fn test_bind_fallback_to_os_assigned() {
let config1 = EndpointPortConfig {
port: PortBinding::Explicit(34567),
ip_mode: IpMode::IPv4Only,
..Default::default()
};
let _bound1 = bind_endpoint(&config1).expect("First bind should succeed");
let config2 = EndpointPortConfig {
port: PortBinding::Explicit(34567),
ip_mode: IpMode::IPv4Only,
retry_behavior: PortRetryBehavior::FallbackToOsAssigned,
..Default::default()
};
let result2 = bind_endpoint(&config2);
assert!(result2.is_ok());
let bound2 = result2.expect("bind_endpoint with fallback should succeed");
assert_ne!(bound2.addrs[0].port(), 34567); }
#[test]
fn test_bind_port_range() {
let config = EndpointPortConfig {
port: PortBinding::Range(45000, 45010),
ip_mode: IpMode::IPv4Only,
..Default::default()
};
let result = bind_endpoint(&config);
assert!(result.is_ok());
let bound = result.expect("bind_endpoint should succeed");
let port = bound.addrs[0].port();
assert!((45000..=45010).contains(&port));
}
#[test]
fn test_bound_socket_primary_addr() {
let config = EndpointPortConfig::default();
let bound = bind_endpoint(&config).expect("bind_endpoint should succeed");
assert!(bound.primary_addr().is_some());
assert_eq!(bound.primary_addr(), bound.addrs.first().copied());
}
#[test]
fn test_bound_socket_all_addrs() {
let config = EndpointPortConfig::default();
let bound = bind_endpoint(&config).expect("bind_endpoint should succeed");
assert!(!bound.all_addrs().is_empty());
assert_eq!(bound.all_addrs(), &bound.addrs[..]);
}
}