use std::io::{self, Read, Write};
use std::mem::{self, MaybeUninit};
use std::net::TcpListener;
use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket};
use std::sync::Once;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use windows_sys::Win32::Networking::WinSock::{
INVALID_SOCKET, SOCKET, WSA_FLAG_OVERLAPPED, WSADATA, WSADuplicateSocketW, WSAGetLastError,
WSAPROTOCOL_INFOW, WSASocketW, WSAStartup,
};
fn ensure_winsock_initialized() {
static INIT: Once = Once::new();
INIT.call_once(|| {
let mut wsa_data: MaybeUninit<WSADATA> = MaybeUninit::uninit();
let result = unsafe { WSAStartup(0x0202, wsa_data.as_mut_ptr()) };
if result != 0 {
panic!("WSAStartup failed with error: {}", result);
}
});
}
const PROTOCOL_INFO_SIZE: usize = mem::size_of::<WSAPROTOCOL_INFOW>();
pub async fn send_tcp_listener<S>(
stream: &mut S,
listener: &TcpListener,
target_pid: u32,
) -> io::Result<()>
where
S: AsyncWriteExt + Unpin,
{
let socket = listener.as_raw_socket() as SOCKET;
let protocol_info = duplicate_socket(socket, target_pid)?;
let bytes = unsafe {
std::slice::from_raw_parts(&protocol_info as *const _ as *const u8, PROTOCOL_INFO_SIZE)
};
stream.write_all(bytes).await?;
Ok(())
}
pub async fn recv_tcp_listener<S>(stream: &mut S) -> io::Result<TcpListener>
where
S: AsyncReadExt + Unpin,
{
let mut buf = [0u8; PROTOCOL_INFO_SIZE];
stream.read_exact(&mut buf).await?;
let protocol_info: WSAPROTOCOL_INFOW = unsafe { std::ptr::read(buf.as_ptr() as *const _) };
let socket = create_socket_from_info(&protocol_info)?;
Ok(unsafe { TcpListener::from_raw_socket(socket as RawSocket) })
}
pub fn send_tcp_listener_sync<S>(
stream: &mut S,
listener: &TcpListener,
target_pid: u32,
) -> io::Result<()>
where
S: Write,
{
let socket = listener.as_raw_socket() as SOCKET;
let protocol_info = duplicate_socket(socket, target_pid)?;
let bytes = unsafe {
std::slice::from_raw_parts(&protocol_info as *const _ as *const u8, PROTOCOL_INFO_SIZE)
};
stream.write_all(bytes)?;
Ok(())
}
pub fn recv_tcp_listener_sync<S>(stream: &mut S) -> io::Result<TcpListener>
where
S: Read,
{
let mut buf = [0u8; PROTOCOL_INFO_SIZE];
stream.read_exact(&mut buf)?;
let protocol_info: WSAPROTOCOL_INFOW = unsafe { std::ptr::read(buf.as_ptr() as *const _) };
let socket = create_socket_from_info(&protocol_info)?;
Ok(unsafe { TcpListener::from_raw_socket(socket as RawSocket) })
}
fn duplicate_socket(socket: SOCKET, target_pid: u32) -> io::Result<WSAPROTOCOL_INFOW> {
ensure_winsock_initialized();
let mut protocol_info: MaybeUninit<WSAPROTOCOL_INFOW> = MaybeUninit::uninit();
let result = unsafe { WSADuplicateSocketW(socket, target_pid, protocol_info.as_mut_ptr()) };
if result != 0 {
let err = unsafe { WSAGetLastError() };
return Err(io::Error::from_raw_os_error(err));
}
Ok(unsafe { protocol_info.assume_init() })
}
fn create_socket_from_info(protocol_info: &WSAPROTOCOL_INFOW) -> io::Result<SOCKET> {
ensure_winsock_initialized();
let socket = unsafe {
WSASocketW(
protocol_info.iAddressFamily,
protocol_info.iSocketType,
protocol_info.iProtocol,
protocol_info as *const _ as *mut _,
0,
WSA_FLAG_OVERLAPPED,
)
};
if socket == INVALID_SOCKET {
let err = unsafe { WSAGetLastError() };
return Err(io::Error::from_raw_os_error(err));
}
Ok(socket)
}
pub fn current_pid() -> u32 {
unsafe { windows_sys::Win32::System::Threading::GetCurrentProcessId() }
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_protocol_info_size() {
assert!(PROTOCOL_INFO_SIZE > 0);
assert!(PROTOCOL_INFO_SIZE < 1024); }
#[test]
fn test_current_pid() {
let pid = current_pid();
assert!(pid > 0);
}
#[test]
fn test_roundtrip_same_process() {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = listener.local_addr().expect("local_addr");
let mut buffer = Vec::new();
let pid = current_pid();
send_tcp_listener_sync(&mut buffer, &listener, pid).expect("send");
let mut cursor = Cursor::new(buffer);
let received = recv_tcp_listener_sync(&mut cursor).expect("recv");
let received_addr = received.local_addr().expect("received local_addr");
assert_eq!(addr, received_addr);
}
}