use std::{
net::{
Ipv4Addr,
SocketAddrV4,
},
collections::HashMap,
mem,
mem::MaybeUninit,
rc::{
Rc,
Weak,
},
};
use crate::{
catnap::transport::{
error::expect_last_wsa_error,
overlapped::IoCompletionPort,
socket::{
Socket,
SocketOpState,
},
},
runtime::{
fail::Fail,
network::socket::option::TcpSocketOptions,
},
};
use windows::{
core::{
GUID,
PSTR,
},
Win32::Networking::WinSock::{
closesocket,
getsockopt,
setsockopt,
getpeername,
WSACleanup,
WSAIoctl,
WSASocketW,
WSAStartup,
INVALID_SOCKET,
LPFN_ACCEPTEX,
LPFN_CONNECTEX,
LPFN_DISCONNECTEX,
LPFN_GETACCEPTEXSOCKADDRS,
RIO_EXTENSION_FUNCTION_TABLE,
SIO_GET_EXTENSION_FUNCTION_POINTER,
SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
SOCKET,
SOCKADDR,
SOCKADDR_IN,
IN_ADDR_0_0,
SOL_SOCKET,
SO_PROTOCOL_INFOW,
WSADATA,
WSAID_ACCEPTEX,
WSAID_CONNECTEX,
WSAID_DISCONNECTEX,
WSAID_GETACCEPTEXSOCKADDRS,
WSAPROTOCOL_INFOW,
WSA_FLAG_OVERLAPPED,
},
};
const WSAID_MULTIPLE_RIO: ::windows::core::GUID =
::windows::core::GUID::from_u128(0x8509e081_96dd_4005_b165_9e2ee8c79e3f);
#[derive(Default, Clone, Copy)]
pub(super) struct SocketExtensions {
pub acceptex: LPFN_ACCEPTEX,
pub get_acceptex_sockaddrs: LPFN_GETACCEPTEXSOCKADDRS,
pub connectex: LPFN_CONNECTEX,
pub disconnectex: LPFN_DISCONNECTEX,
#[allow(unused)]
pub rio_fns: RIO_EXTENSION_FUNCTION_TABLE,
}
pub struct WinsockRuntime {
extensions_by_provider: HashMap<GUID, Weak<SocketExtensions>>,
}
impl SocketExtensions {
pub fn new(s: SOCKET) -> Result<Rc<SocketExtensions>, Fail> {
Ok(Rc::new(SocketExtensions {
acceptex: Self::lookup_single_fn(s, &WSAID_ACCEPTEX)?,
get_acceptex_sockaddrs: Self::lookup_single_fn(s, &WSAID_GETACCEPTEXSOCKADDRS)?,
connectex: Self::lookup_single_fn(s, &WSAID_CONNECTEX)?,
disconnectex: Self::lookup_single_fn(s, &WSAID_DISCONNECTEX)?,
rio_fns: Self::resolve_rio_fn_table(s)?,
}))
}
fn resolve_rio_fn_table(s: SOCKET) -> Result<RIO_EXTENSION_FUNCTION_TABLE, Fail> {
let mut result: RIO_EXTENSION_FUNCTION_TABLE = RIO_EXTENSION_FUNCTION_TABLE::default();
result.cbSize = std::mem::size_of::<RIO_EXTENSION_FUNCTION_TABLE>() as u32;
unsafe {
WinsockRuntime::do_ioctl(
s,
SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
Some(&WSAID_MULTIPLE_RIO),
Some(&mut result),
)
}?;
if result.cbSize != std::mem::size_of::<RIO_EXTENSION_FUNCTION_TABLE>() as u32 {
Err(Fail::new(
libc::EFAULT,
"Winsock did not return enough data for RIO_EXTENSION_FUNCTION_TABLE",
))
} else {
Ok(result)
}
}
fn lookup_single_fn<T>(s: SOCKET, guid: &GUID) -> Result<Option<T>, Fail> {
let mut fn_ptr: Option<T> = None;
unsafe {
WinsockRuntime::do_ioctl(s, SIO_GET_EXTENSION_FUNCTION_POINTER, Some(guid), Some(&mut fn_ptr)).map_err(
|err| {
let msg: String = format!("{} for function lookup {:?}", err.cause, guid);
Fail::new(err.errno, msg.as_str())
},
)?;
}
if fn_ptr.is_some() {
Ok(fn_ptr)
} else {
Err(Fail::new(libc::ENOTSUP, "Winsock extension not supported"))
}
}
}
impl WinsockRuntime {
pub fn new() -> Result<Self, Fail> {
let mut data: WSADATA = WSADATA::default();
if unsafe { WSAStartup(0x202u16, &mut data as *mut WSADATA) } != 0 {
return Err(expect_last_wsa_error());
}
Ok(WinsockRuntime {
extensions_by_provider: HashMap::new(),
})
}
pub(super) unsafe fn do_ioctl<T, U>(
s: SOCKET,
control_code: u32,
input: Option<&T>,
output: Option<&mut U>,
) -> Result<(), Fail> {
let input: Option<*const libc::c_void> = input.map(|t: &T| -> *const libc::c_void { (t as *const T).cast() });
let input_size: usize = input.map(|_| std::mem::size_of::<T>()).unwrap_or(0);
let output: Option<*mut libc::c_void> = output.map(|u: &mut U| -> *mut libc::c_void { (u as *mut U).cast() });
let output_size: usize = output.map(|_| std::mem::size_of::<U>()).unwrap_or(0);
if input_size > u32::MAX as usize {
return Err(Fail::new(
libc::E2BIG,
"\"input_size\" parameter to WSAIoctl parameter is too big",
));
}
if output_size > u32::MAX as usize {
return Err(Fail::new(
libc::E2BIG,
"\"output_size\" parameter to WSAIoctl parameter is too big",
));
}
let mut bytes_returned: u32 = 0;
let ret: i32 = unsafe {
WSAIoctl(
s,
control_code,
input,
input_size as u32,
output,
output_size as u32,
&mut bytes_returned,
None,
None,
)
};
if ret == 0 {
if bytes_returned == output_size as u32 {
Ok(())
} else {
let s: String = format!("WSAIoctl returned {} bytes; expected {}", bytes_returned, output_size);
Err(Fail::new(libc::EFAULT, s.as_str()))
}
} else {
Err(expect_last_wsa_error())
}
}
#[allow(unused)]
pub unsafe fn ioctl<T, U>(
&self,
s: SOCKET,
control_code: u32,
input: Option<&T>,
output: Option<&mut U>,
) -> Result<(), Fail> {
Self::do_ioctl(s, control_code, input, output)
}
pub(super) unsafe fn do_setsockopt<'a, T>(s: SOCKET, level: i32, opt: i32, val: Option<&'a T>) -> Result<(), Fail> {
let val: Option<&'a [u8]> = match val {
Some(val) => {
Some(unsafe { std::slice::from_raw_parts((val as *const T).cast(), std::mem::size_of::<T>()) })
},
None => None,
};
if unsafe { setsockopt(s, level, opt, val) } == 0 {
Ok(())
} else {
Err(expect_last_wsa_error())
}
}
#[allow(unused)]
pub unsafe fn setsockopt<'a, T>(&self, s: SOCKET, level: i32, opt: i32, val: Option<&'a T>) -> Result<(), Fail> {
Self::do_setsockopt(s, level, opt, val)
}
pub(super) unsafe fn do_getsockopt<T>(s: SOCKET, level: i32, optname: i32) -> Result<T, Fail> {
let mut out: MaybeUninit<T> = MaybeUninit::zeroed();
let optval: PSTR = PSTR::from_raw(out.as_mut_ptr().cast());
let mut optlen: i32 =
i32::try_from(std::mem::size_of::<T>()).map_err(|_| Fail::new(libc::E2BIG, "option type too large"))?;
if unsafe { getsockopt(s, level, optname, optval, &mut optlen) } == 0 {
Ok(unsafe { out.assume_init() })
} else {
Err(expect_last_wsa_error())
}
}
pub unsafe fn getsockopt<T>(&self, s: SOCKET, level: i32, optname: i32) -> Result<T, Fail> {
Self::do_getsockopt(s, level, optname)
}
pub fn getpeername(s: SOCKET) -> Result<SocketAddrV4, Fail> {
let mut sockaddr_in: SOCKADDR_IN = SOCKADDR_IN::default();
let sockaddr_ptr: &mut SOCKADDR = &mut unsafe { mem::transmute::<SOCKADDR_IN, SOCKADDR>(sockaddr_in) };
let mut namelen: i32 = std::mem::size_of::<SOCKADDR>() as i32;
if unsafe { getpeername(s, sockaddr_ptr, &mut namelen) } == 0 {
sockaddr_in = unsafe { mem::transmute::<SOCKADDR, SOCKADDR_IN>(*sockaddr_ptr) };
let port: u16 = sockaddr_in.sin_port;
let addr: IN_ADDR_0_0 = unsafe { sockaddr_in.sin_addr.S_un.S_un_b };
let addrv4: SocketAddrV4 = SocketAddrV4::new(
Ipv4Addr::new(addr.s_b1, addr.s_b2, addr.s_b3, addr.s_b4),
port);
Ok(addrv4)
} else {
Err(expect_last_wsa_error())
}
}
fn get_or_init_extensions(&mut self, s: SOCKET) -> Result<Rc<SocketExtensions>, Fail> {
let protocol: WSAPROTOCOL_INFOW = unsafe { self.getsockopt(s, SOL_SOCKET, SO_PROTOCOL_INFOW) }?;
let extensions: &mut Weak<SocketExtensions> = self
.extensions_by_provider
.entry(protocol.ProviderId)
.or_insert(Weak::default());
if let Some(extensions) = extensions.upgrade() {
return Ok(extensions);
}
let new_extensions: Rc<SocketExtensions> = SocketExtensions::new(s)?;
*extensions = Rc::downgrade(&new_extensions);
Ok(new_extensions)
}
pub(super) unsafe fn raw_socket(
domain: libc::c_int,
typ: libc::c_int,
protocol: libc::c_int,
protocol_info: Option<&WSAPROTOCOL_INFOW>,
flags: u32,
) -> Result<SOCKET, Fail> {
let protocol_info: Option<*const WSAPROTOCOL_INFOW> = protocol_info.map(|i| i as *const WSAPROTOCOL_INFOW);
match unsafe { WSASocketW(domain, typ, protocol, protocol_info, 0, flags) } {
INVALID_SOCKET => Err(expect_last_wsa_error()),
socket => Ok(socket),
}
}
pub fn socket(
&mut self,
domain: libc::c_int,
typ: libc::c_int,
protocol: libc::c_int,
options: &TcpSocketOptions,
iocp: &IoCompletionPort<SocketOpState>,
) -> Result<Socket, Fail> {
let s: SOCKET = unsafe { Self::raw_socket(domain, typ, protocol, None, WSA_FLAG_OVERLAPPED) }?;
self.get_or_init_extensions(s)
.and_then(|extensions: Rc<SocketExtensions>| Socket::new(s, protocol, options, extensions, iocp))
.or_else(|err: Fail| {
unsafe { closesocket(s) };
Err(err)
})
}
}
impl Drop for WinsockRuntime {
fn drop(&mut self) {
self.extensions_by_provider.clear();
unsafe { WSACleanup() };
}
}