use socket2::{Domain, Protocol, Socket, Type as SockType};
use std::io;
use std::net::{SocketAddr, TcpListener, TcpStream};
use std::time::Duration;
use crate::tcp::TcpConfig;
#[cfg(unix)]
use std::os::fd::AsRawFd;
#[cfg(unix)]
use nix::poll::{PollFd, PollFlags, PollTimeout, poll};
#[derive(Debug)]
pub struct TcpSocket {
socket: Socket,
nonblocking: bool,
}
impl TcpSocket {
pub fn from_config(config: &TcpConfig) -> io::Result<Self> {
config.validate()?;
let socket = Socket::new(
config.socket_family.to_domain(),
config.socket_type.to_sock_type(),
Some(Protocol::TCP),
)?;
socket.set_nonblocking(config.nonblocking)?;
if let Some(flag) = config.reuseaddr {
socket.set_reuse_address(flag)?;
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos"
))]
if let Some(flag) = config.reuseport {
socket.set_reuse_port(flag)?;
}
if let Some(flag) = config.nodelay {
socket.set_nodelay(flag)?;
}
if let Some(dur) = config.linger {
socket.set_linger(Some(dur))?;
}
if let Some(ttl) = config.ttl {
socket.set_ttl(ttl)?;
}
if let Some(hoplimit) = config.hoplimit {
socket.set_unicast_hops_v6(hoplimit)?;
}
if let Some(keepalive) = config.keepalive {
socket.set_keepalive(keepalive)?;
}
if let Some(timeout) = config.read_timeout {
socket.set_read_timeout(Some(timeout))?;
}
if let Some(timeout) = config.write_timeout {
socket.set_write_timeout(Some(timeout))?;
}
if let Some(size) = config.recv_buffer_size {
socket.set_recv_buffer_size(size)?;
}
if let Some(size) = config.send_buffer_size {
socket.set_send_buffer_size(size)?;
}
if let Some(tos) = config.tos {
socket.set_tos(tos)?;
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos"
))]
if let Some(tclass) = config.tclass_v6 {
socket.set_tclass_v6(tclass)?;
}
if let Some(only_v6) = config.only_v6 {
socket.set_only_v6(only_v6)?;
}
#[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
if let Some(iface) = &config.bind_device {
socket.bind_device(Some(iface.as_bytes()))?;
}
if let Some(addr) = config.bind_addr {
socket.bind(&addr.into())?;
}
Ok(Self {
socket,
nonblocking: config.nonblocking,
})
}
pub fn new(domain: Domain, sock_type: SockType) -> io::Result<Self> {
let socket = Socket::new(domain, sock_type, Some(Protocol::TCP))?;
socket.set_nonblocking(false)?;
Ok(Self {
socket,
nonblocking: false,
})
}
pub fn v4_stream() -> io::Result<Self> {
Self::new(Domain::IPV4, SockType::STREAM)
}
pub fn v6_stream() -> io::Result<Self> {
Self::new(Domain::IPV6, SockType::STREAM)
}
pub fn raw_v4() -> io::Result<Self> {
Self::new(Domain::IPV4, SockType::RAW)
}
pub fn raw_v6() -> io::Result<Self> {
Self::new(Domain::IPV6, SockType::RAW)
}
pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
self.socket.bind(&addr.into())
}
pub fn connect(&self, addr: SocketAddr) -> io::Result<()> {
self.socket.connect(&addr.into())
}
#[cfg(unix)]
pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
let socket = self.socket.try_clone()?;
socket.set_nonblocking(true)?;
let raw_fd = socket.as_raw_fd();
match socket.connect(&target.into()) {
Ok(_) => { }
Err(err)
if err.kind() == io::ErrorKind::WouldBlock
|| err.raw_os_error() == Some(libc::EINPROGRESS) =>
{
}
Err(e) => return Err(e),
}
use std::os::unix::io::BorrowedFd;
let mut fds = [PollFd::new(
unsafe { BorrowedFd::borrow_raw(raw_fd) },
PollFlags::POLLOUT,
)];
let poll_timeout = PollTimeout::try_from(timeout).unwrap_or(PollTimeout::MAX);
let n = poll(&mut fds, poll_timeout)?;
if n == 0 {
return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
}
let err: i32 = socket
.take_error()?
.map(|e| e.raw_os_error().unwrap_or(0))
.unwrap_or(0);
if err != 0 {
return Err(io::Error::from_raw_os_error(err));
}
socket.set_nonblocking(self.nonblocking)?;
match socket.try_clone() {
Ok(cloned_socket) => {
let std_stream: TcpStream = cloned_socket.into();
Ok(std_stream)
}
Err(e) => Err(e),
}
}
#[cfg(windows)]
pub fn connect_timeout(&self, target: SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
use std::mem::size_of;
use std::os::windows::io::AsRawSocket;
use windows_sys::Win32::Networking::WinSock::{
POLLWRNORM, SO_ERROR, SOCKET, SOCKET_ERROR, SOL_SOCKET, WSAPOLLFD, WSAPoll, getsockopt,
};
let socket = self.socket.try_clone()?;
socket.set_nonblocking(true)?;
let sock = socket.as_raw_socket() as SOCKET;
match socket.connect(&target.into()) {
Ok(_) => { }
Err(e) if e.kind() == io::ErrorKind::WouldBlock || e.raw_os_error() == Some(10035) => {}
Err(e) => return Err(e),
}
let mut fds = [WSAPOLLFD {
fd: sock,
events: POLLWRNORM,
revents: 0,
}];
let timeout_ms = timeout.as_millis().clamp(0, i32::MAX as u128) as i32;
let result = unsafe { WSAPoll(fds.as_mut_ptr(), fds.len() as u32, timeout_ms) };
if result == SOCKET_ERROR {
return Err(io::Error::last_os_error());
} else if result == 0 {
return Err(io::Error::new(io::ErrorKind::TimedOut, "connect timed out"));
}
let mut so_error: i32 = 0;
let mut optlen = size_of::<i32>() as i32;
let ret = unsafe {
getsockopt(
sock,
SOL_SOCKET as i32,
SO_ERROR as i32,
&mut so_error as *mut _ as *mut _,
&mut optlen,
)
};
if ret == SOCKET_ERROR || so_error != 0 {
return Err(io::Error::from_raw_os_error(so_error));
}
socket.set_nonblocking(self.nonblocking)?;
let std_stream: TcpStream = socket.into();
Ok(std_stream)
}
pub fn listen(&self, backlog: i32) -> io::Result<()> {
self.socket.listen(backlog)
}
pub fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
let (stream, addr) = self.socket.accept()?;
Ok((stream.into(), addr.as_socket().unwrap()))
}
pub fn to_tcp_stream(self) -> io::Result<TcpStream> {
Ok(self.socket.into())
}
pub fn to_tcp_listener(self) -> io::Result<TcpListener> {
Ok(self.socket.into())
}
pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
self.socket.send_to(buf, &target.into())
}
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
let buf_maybe = unsafe {
std::slice::from_raw_parts_mut(
buf.as_mut_ptr() as *mut std::mem::MaybeUninit<u8>,
buf.len(),
)
};
let (n, addr) = self.socket.recv_from(buf_maybe)?;
let addr = addr
.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidData, "invalid address format"))?;
Ok((n, addr))
}
pub fn shutdown(&self, how: std::net::Shutdown) -> io::Result<()> {
self.socket.shutdown(how)
}
pub fn set_reuseaddr(&self, on: bool) -> io::Result<()> {
self.socket.set_reuse_address(on)
}
pub fn reuseaddr(&self) -> io::Result<bool> {
self.socket.reuse_address()
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos"
))]
pub fn set_reuseport(&self, on: bool) -> io::Result<()> {
self.socket.set_reuse_port(on)
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos"
))]
pub fn reuseport(&self) -> io::Result<bool> {
self.socket.reuse_port()
}
pub fn set_nodelay(&self, on: bool) -> io::Result<()> {
self.socket.set_nodelay(on)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.socket.nodelay()
}
pub fn set_linger(&self, dur: Option<Duration>) -> io::Result<()> {
self.socket.set_linger(dur)
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.socket.set_ttl(ttl)
}
pub fn ttl(&self) -> io::Result<u32> {
self.socket.ttl()
}
pub fn set_hoplimit(&self, hops: u32) -> io::Result<()> {
self.socket.set_unicast_hops_v6(hops)
}
pub fn hoplimit(&self) -> io::Result<u32> {
self.socket.unicast_hops_v6()
}
pub fn set_keepalive(&self, on: bool) -> io::Result<()> {
self.socket.set_keepalive(on)
}
pub fn keepalive(&self) -> io::Result<bool> {
self.socket.keepalive()
}
pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
self.socket.set_recv_buffer_size(size)
}
pub fn recv_buffer_size(&self) -> io::Result<usize> {
self.socket.recv_buffer_size()
}
pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
self.socket.set_send_buffer_size(size)
}
pub fn send_buffer_size(&self) -> io::Result<usize> {
self.socket.send_buffer_size()
}
pub fn set_tos(&self, tos: u32) -> io::Result<()> {
self.socket.set_tos(tos)
}
pub fn tos(&self) -> io::Result<u32> {
self.socket.tos()
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos"
))]
pub fn set_tclass_v6(&self, tclass: u32) -> io::Result<()> {
self.socket.set_tclass_v6(tclass)
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "fuchsia",
target_os = "ios",
target_os = "linux",
target_os = "macos",
target_os = "netbsd",
target_os = "openbsd",
target_os = "tvos",
target_os = "visionos",
target_os = "watchos"
))]
pub fn tclass_v6(&self) -> io::Result<u32> {
self.socket.tclass_v6()
}
pub fn set_only_v6(&self, only_v6: bool) -> io::Result<()> {
self.socket.set_only_v6(only_v6)
}
pub fn only_v6(&self) -> io::Result<bool> {
self.socket.only_v6()
}
pub fn set_bind_device(&self, iface: &str) -> io::Result<()> {
#[cfg(any(target_os = "linux", target_os = "android", target_os = "fuchsia"))]
return self.socket.bind_device(Some(iface.as_bytes()));
#[cfg(not(any(target_os = "linux", target_os = "android", target_os = "fuchsia")))]
{
let _ = iface;
Err(io::Error::new(
io::ErrorKind::Unsupported,
"bind_device is not supported on this platform",
))
}
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.socket
.local_addr()?
.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "failed to retrieve local address"))
}
#[cfg(unix)]
pub fn as_raw_fd(&self) -> std::os::unix::io::RawFd {
use std::os::fd::AsRawFd;
self.socket.as_raw_fd()
}
#[cfg(windows)]
pub fn as_raw_socket(&self) -> std::os::windows::io::RawSocket {
use std::os::windows::io::AsRawSocket;
self.socket.as_raw_socket()
}
pub fn from_socket(socket: Socket) -> Self {
Self {
socket,
nonblocking: false,
}
}
pub fn socket(&self) -> &Socket {
&self.socket
}
pub fn into_socket(self) -> Socket {
self.socket
}
}
#[cfg(test)]
mod tests {
#[cfg(unix)]
use super::*;
#[cfg(unix)]
use libc::{F_GETFL, O_NONBLOCK, fcntl};
#[cfg(unix)]
use std::net::TcpListener as StdTcpListener;
#[cfg(unix)]
fn socket_is_nonblocking(socket: &Socket) -> bool {
let flags = unsafe { fcntl(socket.as_raw_fd(), F_GETFL) };
assert!(flags >= 0, "F_GETFL failed: {}", io::Error::last_os_error());
(flags & O_NONBLOCK) != 0
}
#[cfg(unix)]
#[test]
fn connect_timeout_does_not_mutate_original_nonblocking_state_after_invalid_input() {
let sock = TcpSocket::v4_stream().expect("socket");
sock.socket.set_nonblocking(true).expect("set nonblocking");
let result = sock.connect_timeout("[::1]:80".parse().unwrap(), Duration::from_secs(1));
assert!(result.is_err());
assert!(socket_is_nonblocking(&sock.socket));
}
#[cfg(unix)]
#[test]
fn connect_timeout_does_not_mutate_original_blocking_state_after_success() {
let listener = StdTcpListener::bind("127.0.0.1:0").expect("listener");
let addr = listener.local_addr().expect("local addr");
let handle = std::thread::spawn(move || listener.accept().expect("accept"));
let sock = TcpSocket::v4_stream().expect("socket");
sock.socket.set_nonblocking(false).expect("set blocking");
let _stream = sock
.connect_timeout(addr, Duration::from_secs(1))
.expect("connect");
assert!(!socket_is_nonblocking(&sock.socket));
let _ = handle.join();
}
}