use std::{
io::{self, ErrorKind},
net::{Ipv4Addr, Ipv6Addr, SocketAddr},
sync::LazyLock,
};
use cfg_if::cfg_if;
use log::{debug, warn};
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use tokio::net::TcpSocket;
use super::ConnectOpts;
cfg_if! {
if #[cfg(unix)] {
mod unix;
pub use self::unix::*;
} else if #[cfg(windows)] {
mod windows;
pub use self::windows::*;
}
}
fn set_common_sockopt_for_connect(addr: SocketAddr, socket: &TcpSocket, opts: &ConnectOpts) -> io::Result<()> {
if let Some(baddr) = opts.bind_local_addr {
match (baddr, addr) {
(SocketAddr::V4(..), SocketAddr::V4(..)) => {
socket.bind(baddr)?;
}
(SocketAddr::V4(v4baddr), SocketAddr::V6(..)) => {
socket.bind(SocketAddr::new(v4baddr.ip().to_ipv6_mapped().into(), v4baddr.port()))?;
}
(SocketAddr::V6(..), SocketAddr::V6(..)) => {
socket.bind(baddr)?;
}
(SocketAddr::V6(v6baddr), SocketAddr::V4(..)) => match v6baddr.ip().to_ipv4_mapped() {
Some(v4baddr) => socket.bind(SocketAddr::new(v4baddr.into(), v6baddr.port()))?,
None => {
return Err(io::Error::new(
ErrorKind::InvalidInput,
"bind_local_addr is not a valid IPv4-mapped IPv6 address",
));
}
},
}
}
if let Some(buf_size) = opts.tcp.send_buffer_size {
socket.set_send_buffer_size(buf_size)?;
}
if let Some(buf_size) = opts.tcp.recv_buffer_size {
socket.set_recv_buffer_size(buf_size)?;
}
Ok(())
}
#[cfg(all(not(windows), not(unix)))]
#[inline]
fn set_common_sockopt_after_connect_sys(_: &tokio::net::TcpStream, _: &ConnectOpts) -> io::Result<()> {
Ok(())
}
#[cfg(unix)]
pub fn socket_bind_dual_stack<S>(socket: &S, addr: &SocketAddr, ipv6_only: bool) -> io::Result<()>
where
S: std::os::unix::io::AsRawFd,
{
use std::os::unix::prelude::{FromRawFd, IntoRawFd};
let fd = socket.as_raw_fd();
let sock = unsafe { Socket::from_raw_fd(fd) };
let result = socket_bind_dual_stack_inner(&sock, addr, ipv6_only);
let _ = sock.into_raw_fd();
result
}
#[cfg(windows)]
pub fn socket_bind_dual_stack<S>(socket: &S, addr: &SocketAddr, ipv6_only: bool) -> io::Result<()>
where
S: std::os::windows::io::AsRawSocket,
{
use std::os::windows::prelude::{FromRawSocket, IntoRawSocket};
let handle = socket.as_raw_socket();
let sock = unsafe { Socket::from_raw_socket(handle) };
let result = socket_bind_dual_stack_inner(&sock, addr, ipv6_only);
let _ = sock.into_raw_socket();
result
}
fn socket_bind_dual_stack_inner(socket: &Socket, addr: &SocketAddr, ipv6_only: bool) -> io::Result<()> {
let saddr = SockAddr::from(*addr);
if ipv6_only {
socket.set_only_v6(true)?;
socket.bind(&saddr)?;
} else {
if let Err(err) = socket.set_only_v6(false) {
warn!("failed to set IPV6_V6ONLY: false for socket, error: {}", err);
}
match socket.bind(&saddr) {
Ok(..) => {}
Err(ref err) if err.kind() == ErrorKind::AddrInUse => {
debug!(
"0.0.0.0:{} may have already been occupied, retry with IPV6_V6ONLY",
addr.port()
);
if let Err(err) = socket.set_only_v6(true) {
warn!("failed to set IPV6_V6ONLY: true for socket, error: {}", err);
}
socket.bind(&saddr)?;
}
Err(err) => return Err(err),
}
}
Ok(())
}
#[derive(Debug, Clone, Copy, Default)]
pub struct IpStackCapabilities {
pub support_ipv4: bool,
pub support_ipv6: bool,
pub support_ipv4_mapped_ipv6: bool,
}
static IP_STACK_CAPABILITIES: LazyLock<IpStackCapabilities> = LazyLock::new(|| {
let mut caps = IpStackCapabilities {
support_ipv4: false,
support_ipv6: false,
support_ipv4_mapped_ipv6: false,
};
if Socket::new(Domain::IPV4, Type::STREAM, Some(Protocol::TCP)).is_ok() {
caps.support_ipv4 = true;
debug!("IpStackCapability support_ipv4=true");
}
if let Ok(ipv6_socket) = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))
&& ipv6_socket.set_only_v6(true).is_ok() {
let local_host = SockAddr::from(SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 0));
if ipv6_socket.bind(&local_host).is_ok() {
caps.support_ipv6 = true;
debug!("IpStackCapability support_ipv6=true");
}
}
if check_ipv4_mapped_ipv6_capability().is_ok() {
caps.support_ipv4_mapped_ipv6 = true;
debug!("IpStackCapability support_ipv4_mapped_ipv6=true");
}
caps
});
fn check_ipv4_mapped_ipv6_capability() -> io::Result<()> {
let socket = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP))?;
socket.set_only_v6(false)?;
let local_host = SockAddr::from(SocketAddr::new(Ipv4Addr::LOCALHOST.to_ipv6_mapped().into(), 0));
socket.bind(&local_host)?;
Ok(())
}
pub fn get_ip_stack_capabilities() -> &'static IpStackCapabilities {
&IP_STACK_CAPABILITIES
}