use std::{
fmt::Debug,
mem::MaybeUninit,
net::{
Ipv4Addr,
Ipv6Addr,
SocketAddr,
SocketAddrV4,
SocketAddrV6,
},
pin::Pin,
rc::Rc,
time::Duration,
};
use windows::{
core::PSTR,
Win32::{
Foundation::{
BOOL,
ERROR_NOT_FOUND,
FALSE,
HANDLE,
TRUE,
},
Networking::WinSock::{
bind,
closesocket,
listen,
shutdown,
tcp_keepalive,
WSAGetLastError,
WSARecvFrom,
WSASendTo,
FROM_PROTOCOL_INFO,
INVALID_SOCKET,
IPPROTO_TCP,
LINGER,
SD_BOTH,
SIO_KEEPALIVE_VALS,
SOCKADDR,
SOCKADDR_IN,
SOCKADDR_IN6,
SOCKADDR_INET,
SOCKADDR_STORAGE,
SOCKET,
SOCKET_ERROR,
SOL_SOCKET,
SO_KEEPALIVE,
SO_LINGER,
SO_PROTOCOL_INFOW,
SO_UPDATE_ACCEPT_CONTEXT,
SO_UPDATE_CONNECT_CONTEXT,
TCP_NODELAY,
WSABUF,
WSAEINVAL,
WSAPROTOCOL_INFOW,
WSA_FLAG_OVERLAPPED,
},
System::IO::{
CancelIoEx,
OVERLAPPED,
},
},
};
use crate::{
catnap::transport::{
error::{
expect_last_wsa_error,
get_overlapped_api_result,
},
overlapped::{
IoCompletionPort,
OverlappedResult,
},
winsock::{
SocketExtensions,
WinsockRuntime,
},
},
runtime::{
fail::Fail,
memory::DemiBuffer,
network::socket::option::TcpSocketOptions,
},
};
const SOCKADDR_BUF_SIZE: usize = std::mem::size_of::<SOCKADDR_STORAGE>() + 16;
const ACCEPT_BUFFER_LEN: usize = (SOCKADDR_BUF_SIZE) * 2;
pub struct Socket {
s: SOCKET,
extensions: Rc<SocketExtensions>,
}
pub struct AcceptState {
new_socket: Option<Socket>,
buffer: [u8; ACCEPT_BUFFER_LEN],
}
pub struct PopState {
buffer: DemiBuffer,
address: MaybeUninit<SOCKADDR_STORAGE>,
addr_len: i32,
}
pub enum SocketOpState {
Accept(AcceptState),
Connect,
Pop(PopState),
Push(DemiBuffer),
Close,
}
impl AcceptState {
pub fn new() -> Self {
Self {
new_socket: None,
buffer: [0u8; ACCEPT_BUFFER_LEN],
}
}
}
impl PopState {
pub fn new(buffer: DemiBuffer) -> PopState {
Self {
buffer,
address: MaybeUninit::zeroed(),
addr_len: 0,
}
}
}
impl Socket {
pub(super) fn new(
s: SOCKET,
protocol: libc::c_int,
options: &TcpSocketOptions,
extensions: Rc<SocketExtensions>,
iocp: &IoCompletionPort<SocketOpState>,
) -> Result<Socket, Fail> {
let s: Socket = Socket { s, extensions };
s.setup_socket(protocol, options)?;
iocp.associate_socket(s.s, 0)?;
Ok(s)
}
fn translate_address(addr: SocketAddr) -> (SOCKADDR_INET, i32) {
match addr {
SocketAddr::V4(addr) => (addr.into(), std::mem::size_of::<SOCKADDR_IN>() as i32),
SocketAddr::V6(addr) => (addr.into(), std::mem::size_of::<SOCKADDR_IN6>() as i32),
}
}
fn setup_socket(&self, protocol: libc::c_int, options: &TcpSocketOptions) -> Result<(), Fail> {
self.set_linger(options.get_linger())?;
if protocol == IPPROTO_TCP.0 {
self.set_tcp_keepalive(&options.get_keepalive())?;
self.set_nagle(options.get_nodelay())?;
}
Ok(())
}
pub fn set_linger(&self, linger_time: Option<Duration>) -> Result<(), Fail> {
let l: LINGER = LINGER {
l_onoff: if linger_time.is_some() { 1 } else { 0 },
l_linger: linger_time.unwrap_or(Duration::ZERO).as_secs() as u16,
};
unsafe { WinsockRuntime::do_setsockopt(self.s, SOL_SOCKET, SO_LINGER, Some(&l)) }?;
Ok(())
}
pub fn get_linger(&self) -> Result<Option<Duration>, Fail> {
let l: LINGER = unsafe { WinsockRuntime::do_getsockopt(self.s, SOL_SOCKET, SO_LINGER) }?;
match l.l_onoff {
0 => Ok(None),
_ => Ok(Some(Duration::from_secs(l.l_linger.into()))),
}
}
pub fn getpeername(&self) -> Result<SocketAddrV4, Fail> {
let addr: Result<SocketAddrV4, Fail> = WinsockRuntime::getpeername(self.s);
addr
}
pub fn set_tcp_keepalive(&self, keepalive_params: &tcp_keepalive) -> Result<(), Fail> {
unsafe { WinsockRuntime::do_setsockopt(self.s, SOL_SOCKET, SO_KEEPALIVE, Some(&keepalive_params.onoff)) }?;
if keepalive_params.onoff != 0 {
unsafe {
WinsockRuntime::do_ioctl::<tcp_keepalive, ()>(self.s, SIO_KEEPALIVE_VALS, Some(&keepalive_params), None)
}?;
}
Ok(())
}
pub fn get_tcp_keepalive(&self) -> Result<tcp_keepalive, Fail> {
unsafe { WinsockRuntime::do_getsockopt(self.s, SOL_SOCKET, SO_KEEPALIVE) }
}
pub fn set_nagle(&self, enabled: bool) -> Result<(), Fail> {
let value: BOOL = if enabled { FALSE } else { TRUE };
unsafe { WinsockRuntime::do_setsockopt(self.s, IPPROTO_TCP.0, TCP_NODELAY, Some(&value)) }?;
Ok(())
}
pub fn get_nagle(&self) -> Result<bool, Fail> {
match unsafe { WinsockRuntime::do_getsockopt(self.s, IPPROTO_TCP.0, TCP_NODELAY) }? {
FALSE => Ok(false),
_ => Ok(true),
}
}
pub fn new_like(template: &Socket) -> Result<Socket, Fail> {
let protocol: WSAPROTOCOL_INFOW =
unsafe { WinsockRuntime::do_getsockopt(template.s, SOL_SOCKET, SO_PROTOCOL_INFOW) }?;
let extensions: Rc<SocketExtensions> = template.extensions.clone();
let s: SOCKET = unsafe {
WinsockRuntime::raw_socket(
FROM_PROTOCOL_INFO,
FROM_PROTOCOL_INFO,
FROM_PROTOCOL_INFO,
Some(&protocol),
WSA_FLAG_OVERLAPPED,
)
}?;
Ok(Socket { s, extensions })
}
pub fn start_disconnect(&self, overlapped: *mut OVERLAPPED) -> Result<(), Fail> {
let result: bool = unsafe { self.extensions.disconnectex.unwrap()(self.s, overlapped, 0, 0).as_bool() };
get_overlapped_api_result(result)
}
pub fn finish_disconnect(&self, result: OverlappedResult) -> Result<(), Fail> {
self.shutdown().and(result.ok())
}
pub fn shutdown(&self) -> Result<(), Fail> {
if unsafe { shutdown(self.s, SD_BOTH) } == 0 {
Ok(())
} else {
Err(expect_last_wsa_error().into())
}
}
pub fn bind(&self, local: SocketAddr) -> Result<(), Fail> {
let sockaddr: socket2::SockAddr = local.into();
let result: i32 = unsafe { bind(self.s, sockaddr.as_ptr().cast(), sockaddr.len()) };
if result == 0 {
Ok(())
} else {
Err(expect_last_wsa_error().into())
}
}
pub fn listen(&self, backlog: usize) -> Result<(), Fail> {
let backlog: i32 = i32::try_from(backlog).unwrap_or(i32::MAX);
if unsafe { listen(self.s, backlog) } == 0 {
Ok(())
} else {
Err(expect_last_wsa_error().into())
}
}
pub fn cancel_io(&self, overlapped: *mut OVERLAPPED) -> Result<(), Fail> {
unsafe { CancelIoEx(HANDLE(self.s.0 as isize), Some(overlapped)) }.map_err(|win_err| {
if win_err.code() == ERROR_NOT_FOUND.into() {
Fail::new(libc::EINPROGRESS, "cannot cancel this operation")
} else {
win_err.into()
}
})
}
pub fn start_accept(&self, state: Pin<&mut SocketOpState>, overlapped: *mut OVERLAPPED) -> Result<(), Fail> {
let accept_result: &mut AcceptState = match state.get_mut() {
SocketOpState::Accept(ref mut accept_result) => accept_result,
_ => unreachable!("must be an accept operation"),
};
let new_socket: Socket = Socket::new_like(self)?;
let buf_ptr: *mut u8 = accept_result.buffer.as_mut_ptr();
let mut bytes_out: u32 = 0;
let success: bool = unsafe {
self.extensions.acceptex.unwrap()(
self.s,
new_socket.s,
buf_ptr.cast(),
0,
SOCKADDR_BUF_SIZE as u32,
SOCKADDR_BUF_SIZE as u32,
&mut bytes_out,
overlapped,
)
}
.as_bool();
get_overlapped_api_result(success).and_then(|_| {
accept_result.new_socket = Some(new_socket);
Ok(())
})
}
pub fn finish_accept(
&self,
state: Pin<&mut SocketOpState>,
iocp: &IoCompletionPort<SocketOpState>,
result: OverlappedResult,
) -> Result<(Socket, SocketAddr, SocketAddr), Fail> {
if let Err(err) = result.ok() {
return Err(err);
}
let accept_result: &mut AcceptState = match state.get_mut() {
SocketOpState::Accept(ref mut accept_result) => accept_result,
_ => unreachable!("must be an accept operation"),
};
let new_socket = accept_result
.new_socket
.take()
.ok_or_else(|| Fail::new(libc::EINVAL, "invalid state"))?;
unsafe { WinsockRuntime::do_setsockopt(new_socket.s, SOL_SOCKET, SO_UPDATE_ACCEPT_CONTEXT, Some(&self.s)) }?;
iocp.associate_socket(new_socket.s, 0)?;
let (local_addr, remote_addr) = unsafe {
let mut localsockaddr: MaybeUninit<*mut windows_sys::Win32::Networking::WinSock::SOCKADDR_STORAGE> =
MaybeUninit::zeroed();
let mut localsockaddrlength: i32 = 0;
let mut remotesockaddr: MaybeUninit<*mut windows_sys::Win32::Networking::WinSock::SOCKADDR_STORAGE> =
MaybeUninit::zeroed();
let mut remotesockaddrlength: i32 = 0;
self.extensions.get_acceptex_sockaddrs.unwrap()(
accept_result.buffer.as_ptr().cast(),
0,
SOCKADDR_BUF_SIZE as u32,
SOCKADDR_BUF_SIZE as u32,
localsockaddr.as_mut_ptr().cast(),
&mut localsockaddrlength,
remotesockaddr.as_mut_ptr().cast(),
&mut remotesockaddrlength,
);
assert!(!localsockaddr.assume_init_ref().is_null() && !remotesockaddr.assume_init_ref().is_null());
(
socket2::SockAddr::new(*localsockaddr.assume_init(), localsockaddrlength),
socket2::SockAddr::new(*remotesockaddr.assume_init(), remotesockaddrlength),
)
};
let local_addr: SocketAddr = local_addr
.as_socket()
.ok_or_else(|| Fail::new(libc::EAFNOSUPPORT, "bad local socket address from accept"))?;
let remote_addr: SocketAddr = remote_addr
.as_socket()
.ok_or_else(|| Fail::new(libc::EAFNOSUPPORT, "bad remote socket address from accept"))?;
Ok((new_socket, local_addr, remote_addr))
}
pub fn start_connect(&self, remote: SocketAddr, overlapped: *mut OVERLAPPED) -> Result<(), Fail> {
const IN6ADDR_ANY: SocketAddrV6 = SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, 0, 0, 0);
const INADDR_ANY: SocketAddrV4 = SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, 0);
let (remote_addr, remote_len): (SOCKADDR_INET, i32) = Self::translate_address(remote);
let success: bool = unsafe {
let (localaddr, locallen): (SOCKADDR_INET, i32) = match remote {
SocketAddr::V4(_) => Self::translate_address(INADDR_ANY.into()),
SocketAddr::V6(_) => Self::translate_address(IN6ADDR_ANY.into()),
};
if bind(self.s, (&localaddr as *const SOCKADDR_INET).cast(), locallen) != 0 {
if WSAGetLastError() != WSAEINVAL {
return Err(expect_last_wsa_error());
}
}
self.extensions.connectex.unwrap()(
self.s,
(&remote_addr as *const SOCKADDR_INET).cast(),
remote_len,
std::ptr::null(), 0, std::ptr::null_mut(), overlapped,
)
.as_bool()
};
get_overlapped_api_result(success)
}
pub fn finish_connect(&self, result: OverlappedResult) -> Result<(), Fail> {
if let Err(err) = result.ok() {
return Err(err);
}
unsafe { WinsockRuntime::do_setsockopt::<()>(self.s, SOL_SOCKET, SO_UPDATE_CONNECT_CONTEXT, None) }
}
pub fn start_pop(&self, state: Pin<&mut SocketOpState>, overlapped: *mut OVERLAPPED) -> Result<(), Fail> {
let pop_state: &mut PopState = match state.get_mut() {
SocketOpState::Pop(ref mut pop_state) => pop_state,
_ => unreachable!("must be an accept operation"),
};
let mut bytes_transferred: u32 = 0;
let mut flags: u32 = 0;
let success: bool = unsafe {
let wsa_buffer: WSABUF = WSABUF {
len: pop_state.buffer.len() as u32,
buf: PSTR::from_raw(pop_state.buffer.as_mut_ptr()),
};
let result: i32 = WSARecvFrom(
self.s,
std::slice::from_ref(&wsa_buffer),
Some(&mut bytes_transferred),
&mut flags,
Some(pop_state.address.as_mut_ptr() as *mut SOCKADDR),
Some(&mut pop_state.addr_len),
Some(overlapped),
None,
);
result != SOCKET_ERROR
};
get_overlapped_api_result(success)
}
pub fn finish_pop(
&self,
state: Pin<&mut SocketOpState>,
result: OverlappedResult,
) -> Result<(usize, Option<SocketAddr>), Fail> {
if let Err(err) = result.ok() {
return Err(err);
}
let pop_state: &mut PopState = match state.get_mut() {
SocketOpState::Pop(ref mut pop_state) => pop_state,
_ => unreachable!("must be an accept operation"),
};
let addr: Option<SocketAddr> = if pop_state.addr_len > 0 {
unsafe {
socket2::SockAddr::new(
std::mem::transmute(std::mem::take(pop_state.address.assume_init_mut())),
pop_state.addr_len,
)
}
.as_socket()
} else {
None
};
Ok((result.bytes_transferred as usize, addr))
}
pub fn start_push(
&self,
state: Pin<&mut SocketOpState>,
addr: Option<SocketAddr>,
overlapped: *mut OVERLAPPED,
) -> Result<(), Fail> {
let buffer: &mut DemiBuffer = match state.get_mut() {
SocketOpState::Push(ref mut buffer) => buffer,
_ => unreachable!("must be an accept operation"),
};
let mut bytes_transferred: u32 = 0;
let success: bool = unsafe {
let wsa_buffer: WSABUF = WSABUF {
len: buffer.len() as u32,
buf: PSTR::from_raw(buffer.as_mut_ptr()),
};
let addr: Option<socket2::SockAddr> = addr.map(socket2::SockAddr::from);
let result: i32 = WSASendTo(
self.s,
std::slice::from_ref(&wsa_buffer),
Some(&mut bytes_transferred),
0,
addr.as_ref()
.map(|addr: &socket2::SockAddr| -> *const SOCKADDR { addr.as_ptr().cast() }),
addr.as_ref()
.map(|addr: &socket2::SockAddr| -> i32 { addr.len() })
.unwrap_or(0),
Some(overlapped),
None,
);
result != SOCKET_ERROR
};
get_overlapped_api_result(success)
}
pub fn finish_push(&self, result: OverlappedResult) -> Result<usize, Fail> {
result.ok().and(Ok(result.bytes_transferred as usize))
}
}
impl Drop for Socket {
fn drop(&mut self) {
unsafe { closesocket(self.s) };
self.s = INVALID_SOCKET;
}
}
impl Debug for Socket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Socket").field("s", &self.s).finish()
}
}