use crate::convert::SocketState;
use socket2::Socket;
use std::{
ffi::c_int,
fs,
mem,
net::{Ipv4Addr, UdpSocket},
io,
os::windows::{
fs::OpenOptionsExt,
io::{AsRawHandle, AsRawSocket},
},
path::Path,
sync::Once,
};
use windows_sys::Win32::{
Foundation::{
HANDLE_FLAG_INHERIT,
INVALID_HANDLE_VALUE,
SetHandleInformation,
},
Networking::WinSock::{
getsockopt,
SO_ACCEPTCONN,
SO_PROTOCOL_INFOW,
SOL_SOCKET,
WSAPROTOCOL_INFOW,
},
Storage::FileSystem::{
FILE_ATTRIBUTE_REPARSE_POINT,
FILE_FLAG_BACKUP_SEMANTICS,
FILE_FLAG_OPEN_REPARSE_POINT,
FILE_ATTRIBUTE_TAG_INFO,
FileAttributeTagInfo,
GetFileInformationByHandleEx,
},
System::Console::{GetStdHandle, STD_INPUT_HANDLE},
System::SystemServices::IO_REPARSE_TAG_AF_UNIX,
};
pub use std::os::windows::io::{
BorrowedSocket,
OwnedSocket,
RawSocket,
};
pub fn make_socket_inheritable(
socket: &Socket,
inheritable: bool,
) -> io::Result<RawSocket> {
let handle = socket.as_raw_socket();
let success = unsafe {
SetHandleInformation(
handle as _,
HANDLE_FLAG_INHERIT,
match inheritable {
true => HANDLE_FLAG_INHERIT,
false => 0,
},
)
};
if success == 0 {
Err(io::Error::last_os_error())
}
else {
Ok(handle)
}
}
pub fn is_unix_socket(path: &Path) -> io::Result<bool> {
let file: fs::File =
fs::OpenOptions::new()
.access_mode(0)
.custom_flags(FILE_FLAG_BACKUP_SEMANTICS | FILE_FLAG_OPEN_REPARSE_POINT)
.open(path)?;
let mut file_attrs: FILE_ATTRIBUTE_TAG_INFO = unsafe {
mem::zeroed()
};
let file_attrs_len = mem::size_of_val(&file_attrs).try_into().unwrap();
let get_result = unsafe {
GetFileInformationByHandleEx(
file.as_raw_handle() as _,
FileAttributeTagInfo,
&mut file_attrs as *mut FILE_ATTRIBUTE_TAG_INFO as *mut _,
file_attrs_len,
)
};
if get_result == 0 {
return Err(io::Error::last_os_error());
}
let is_unix_socket: bool =
(file_attrs.FileAttributes & FILE_ATTRIBUTE_REPARSE_POINT) != 0 &&
file_attrs.ReparseTag == IO_REPARSE_TAG_AF_UNIX;
Ok(is_unix_socket)
}
pub fn startup_socket_api() {
static ONCE: Once = Once::new();
ONCE.call_once(|| {
let _ = UdpSocket::bind((Ipv4Addr::LOCALHOST, 0u16));
});
}
pub fn get_stdin_as_socket() -> io::Result<RawSocket> {
let maybe_socket = unsafe {
GetStdHandle(STD_INPUT_HANDLE)
};
if maybe_socket == INVALID_HANDLE_VALUE {
return Err(io::Error::last_os_error());
}
Ok(maybe_socket as RawSocket)
}
pub(crate) fn get_socket_state(socket: &Socket) -> io::Result<SocketState> {
let mut protocol_info: WSAPROTOCOL_INFOW = unsafe {
mem::zeroed()
};
let mut protocol_info_len: c_int = mem::size_of_val(&protocol_info).try_into().unwrap();
let getsockopt_result = unsafe {
getsockopt(
socket.as_raw_socket() as _,
SOL_SOCKET,
SO_PROTOCOL_INFOW,
&mut protocol_info as *mut WSAPROTOCOL_INFOW as *mut _,
&mut protocol_info_len,
)
};
if getsockopt_result != 0 {
return Err(io::Error::last_os_error());
}
let r#type = socket2::Type::from(protocol_info.iSocketType);
let protocol = Some(socket2::Protocol::from(protocol_info.iProtocol));
let mut is_listening_dword: u32 = 0;
let mut is_listening_dword_len: c_int = mem::size_of_val(&is_listening_dword).try_into().unwrap();
let getsockopt_result = unsafe {
getsockopt(
socket.as_raw_socket() as _,
SOL_SOCKET,
SO_ACCEPTCONN,
&mut is_listening_dword as *mut u32 as *mut _,
&mut is_listening_dword_len,
)
};
if getsockopt_result != 0 {
return Err(io::Error::last_os_error());
}
let is_listening = Some(is_listening_dword != 0);
Ok(SocketState { r#type, protocol, is_listening })
}