use socket2::{Domain, Protocol, Socket, Type};
use std::{
net::{Ipv4Addr, SocketAddr},
os::windows::prelude::AsRawSocket,
ptr,
};
use windows::{
core::{PCSTR, PSTR},
Win32::Networking::WinSock::{
gethostbyname, gethostname, WSAData, WSAGetLastError, WSAIoctl, WSAStartup, IN_ADDR,
RCVALL_ON, SIO_RCVALL, SOCKET, SOCKET_ERROR, WSA_ERROR,
},
};
pub fn recv_all_socket() -> Result<Socket, SocketError> {
type E = SocketError;
const WINSOCK_VERSION: u16 = 2 << 8 | 2;
let mut wsa_data = WSAData::default();
unsafe { WSAStartup(WINSOCK_VERSION, &mut wsa_data as *mut _) };
let mut hostname = [0u8; 100];
if unsafe { gethostname(PSTR(hostname.as_mut_ptr()), hostname.len() as i32) } == SOCKET_ERROR {
return Err(E::win_sock("failed to get hostname"));
}
let local = unsafe { gethostbyname(PCSTR(hostname.as_mut_ptr())) };
if local.is_null() {
return Err(E::win_sock("failed to get local address"));
}
let h_addr = unsafe { *((*local).h_addr_list) };
if h_addr.is_null() {
return Err(E::win_sock("failed to find host"));
}
let ip_addr = unsafe { (*(h_addr as *const IN_ADDR)).S_un.S_addr.to_be() };
let addr = SocketAddr::new(Ipv4Addr::from(ip_addr).into(), 0);
let socket = Socket::new(Domain::IPV4, Type::RAW, None)?;
{
let socket0 = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::UDP))?;
socket0.bind(&addr.into())?;
}
socket.bind(&addr.into())?;
let win_sock = SOCKET(socket.as_raw_socket() as usize);
let mut in_ = 0u32;
let res = unsafe {
WSAIoctl(
win_sock,
SIO_RCVALL,
&RCVALL_ON.0 as *const _ as *const _,
4,
ptr::null_mut(),
0,
&mut in_ as *mut _ as *mut _,
ptr::null_mut(),
None,
)
};
if res == SOCKET_ERROR {
return Err(E::win_sock("failed to set socket option SIO_RCVALL"));
}
Ok(socket)
}
#[derive(thiserror::Error, Debug)]
pub enum SocketError {
#[error("Io Error {0}")]
Io(#[from] std::io::Error),
#[error("{0}. {1:?}")]
WinSock(&'static str, WSA_ERROR),
}
impl SocketError {
fn win_sock(msg: &'static str) -> Self {
Self::WinSock(msg, unsafe { WSAGetLastError() })
}
}