pub mod sock_protocol;
pub mod sock_domain;
pub mod sock_type;
pub mod sock_msg_flags;
pub mod sock_opts;
use std::io::{ErrorKind, IoSlice};
use std::mem::MaybeUninit;
use std::net::{Ipv4Addr, Shutdown};
use std::ptr::null_mut;
use std::time::{Duration, Instant};
use std::{io, mem, process, ptr};
use std::os::windows::io::{AsRawSocket, FromRawSocket, OwnedSocket, RawSocket};
use std::sync::LazyLock;
pub(crate) use sock_opts::{getsockopt, setsockopt};
pub use sock_domain::*;
pub use sock_msg_flags::*;
pub use sock_protocol::*;
use windows_sys::Win32::Foundation::{HANDLE, HANDLE_FLAG_INHERIT, SetHandleInformation};
use windows_sys::Win32::Networking::WinSock::
{
FIONBIO, INVALID_SOCKET, MSG_PEEK, POLLERR, POLLHUP, POLLRDNORM, POLLWRNORM, SD_BOTH, SD_RECEIVE, SD_SEND,
SOCKET_ERROR, WSACleanup, WSADATA, WSADuplicateSocketW, WSAEMSGSIZE, WSAESHUTDOWN, WSAPOLLFD, WSAPROTOCOL_INFOW,
WSAPoll, WSARecv, WSARecvFrom, WSASend, WSASendTo, WSASocketW, WSAStartup, accept, bind, connect, getpeername,
getsockname, ioctlsocket, listen, recv, recvfrom, send, sendto, shutdown
};
pub use windows_sys::Win32::Networking::WinSock::
{
TIMEVAL as timeval, LINGER as linger,
SOCKADDR_STORAGE as sockaddr_storage, socklen_t, SOCKADDR as sockaddr, ADDRESS_FAMILY as sa_family_t,
SOCKADDR_UN as sockaddr_un, CMSGHDR as msghdr, SOCKADDR_IN as sockaddr_in, SOCKADDR_IN6 as sockaddr_in6,
IN_ADDR as in_addr, IN6_ADDR as in6_addr, IP_MREQ_SOURCE as ip_mreq_source, IP_MREQ as ip_mreqn,
IPV6_MREQ as ipv6_mreq, AF_UNIX, AF_INET, AF_INET6
};
pub use windows_sys::Win32::Networking::WinSock::
{
TCP_KEEPCNT, TCP_NODELAY, TCP_KEEPINTVL, TCP_KEEPIDLE, IPPROTO_TCP, POLLIN, SO_SNDTIMEO, SOCK_STREAM, SOCK_DGRAM,
SOCK_RAW, IPPROTO_ICMP, IPPROTO_ICMPV6, IPPROTO_UDP, IPPROTO_SCTP, IPPROTO_IPV6, IP_HDRINCL, IPV6_ADD_MEMBERSHIP,
IPV6_DROP_MEMBERSHIP, IPV6_MULTICAST_HOPS, IPV6_MULTICAST_IF, IPV6_MULTICAST_LOOP, IPV6_UNICAST_HOPS, IPV6_V6ONLY,
IPV6_RECVTCLASS, IPPROTO_IP, IP_TTL, IP_TOS, IP_RECVTOS, IP_MULTICAST_TTL, IP_MULTICAST_LOOP, IP_MULTICAST_IF,
IP_ADD_SOURCE_MEMBERSHIP, IP_DROP_SOURCE_MEMBERSHIP, IP_ADD_MEMBERSHIP, IP_DROP_MEMBERSHIP, SOL_IP,
SO_ORIGINAL_DST, SO_TYPE, SOL_SOCKET, IP6T_SO_ORIGINAL_DST, SOL_IPV6, SO_ACCEPTCONN, SO_BROADCAST, SO_DONTROUTE,
SO_ERROR, SO_KEEPALIVE, SO_LINGER, SO_RCVBUF, SO_RCVLOWAT, SO_RCVTIMEO, SO_REUSEADDR, SO_SNDBUF, SO_SNDLOWAT,
SO_OOBINLINE
};
#[allow(non_camel_case_types)]
pub type time_t = i32;
#[allow(non_camel_case_types)]
pub type suseconds_t = i32;
use crate::{LocalFrom, SockOptMarker, SockTypeFromRaw, Socket9ExtSo, SocketTypeImps};
use crate::address::{So9AddrIntoRaw, So9SocketAddr};
pub use sock_domain::So9DomainRange;
pub use sock_type::{So9SockDwFlags, So9SockType};
#[derive(Debug, Clone, Copy)]
pub enum So9IfInder
{
Address(Ipv4Addr),
IfInder(u32)
}
impl So9IfInder
{
pub(crate)
fn into_ip_mreqn(&self, multiaddr: &Ipv4Addr) -> ip_mreqn
{
match self
{
Self::Address(ipv4_addr) =>
ip_mreqn
{
imr_multiaddr: <in_addr as LocalFrom<&Ipv4Addr>>::from(multiaddr),
imr_interface: <in_addr as LocalFrom<&Ipv4Addr>>::from(ipv4_addr),
},
Self::IfInder(idx) =>
ip_mreqn
{
imr_multiaddr: <in_addr as LocalFrom<&Ipv4Addr>>::from(multiaddr),
imr_interface: <in_addr as LocalFrom<&Ipv4Addr>>::from(&Ipv4Addr::UNSPECIFIED),
},
}
}
}
#[derive(Debug)]
struct WsaStartup;
impl Drop for WsaStartup
{
fn drop(&mut self)
{
unsafe
{
WSACleanup();
}
}
}
static WSA_STARTUP: LazyLock<WsaStartup> =
LazyLock::new(
||
{
let mut data: WSADATA = unsafe{ mem::zeroed() };
let res =
unsafe
{
WSAStartup(0x2020, &mut data)
};
assert_eq!(res, 0);
WsaStartup
}
);
impl Socket9ExtSo<Self> for OwnedSocket {}
impl SockTypeFromRaw for OwnedSocket {}
impl SocketTypeImps for OwnedSocket
{
type RawType = RawSocket;
fn get_raw(&self) -> Self::RawType
{
return self.as_raw_socket();
}
fn new_sock(so_dom: So9SockDomain, so_type: So9SockType, so_proto: So9SockProtocol, dwflags: So9SockDwFlags) -> io::Result<Self>
where
Self: Sized
{
let _ = &*WSA_STARTUP;
let socket =
unsafe
{
WSASocketW(
so_dom.0 as i32,
so_type.0,
so_proto.0,
ptr::null_mut(),
0,
dwflags.bits()
)
};
if socket == INVALID_SOCKET
{
return Err(io::Error::last_os_error());
}
let owsock = unsafe { Self::from_raw_socket(socket as u64) };
return Ok(owsock);
}
fn bind_sock<A: So9AddrIntoRaw>(&self, addr: &A) -> io::Result<()>
{
let addr_raw = addr.into_raw_addr();
let (addr_ptr, addr_sz) = addr_raw.get_sockaddr_ptr();
let res =
unsafe
{
bind(self.as_raw_socket() as usize, addr_ptr, addr_sz)
};
if res == 0
{
return Ok(());
}
return Err(io::Error::last_os_error());
}
fn listen_sock(&self, backlog: i32) -> io::Result<()>
{
let res = unsafe{ listen(self.as_raw_socket() as usize, backlog) };
if res == 0
{
return Ok(());
}
return Err(io::Error::last_os_error());
}
fn accept_with_flags<A, CAST>(&self, _flags: i32) -> io::Result<(CAST, A)>
where
Self: Sized,
A: TryFrom<So9SocketAddr, Error = io::Error>,
CAST: SocketTypeImps
{
let mut sa = So9SocketAddr::default();
let res =
unsafe
{
accept(self.as_raw_socket() as usize, sa.get_sockaddr_ptr(), sa.get_capacity_mut())
};
if res != INVALID_SOCKET
{
let a = unsafe{ CAST::from_raw_socket(res as u64) };
return Ok(( a, sa.try_into()? ));
}
return Err(io::Error::last_os_error());
}
fn connect_sock<A>(&self, addr: &A) -> io::Result<()>
where
A: So9AddrIntoRaw
{
let sa = addr.into_raw_addr();
let sa_raw = sa.get_sockaddr_ptr();
let res =
unsafe
{
connect(self.as_raw_socket() as usize, sa_raw.0, sa_raw.1)
};
if res == 0
{
return Ok(());
}
return Err(io::Error::last_os_error());
}
fn poll_connect_sock(&self, timeout: Duration) -> io::Result<()>
{
self.poll_sock(POLLRDNORM | POLLWRNORM, timeout)
}
fn poll_sock(&self, ev: i16, timeout: Duration) -> io::Result<()>
{
let mut pollfd =
WSAPOLLFD
{
fd: self.as_raw_socket() as usize,
events: ev,
revents: 0,
};
let start = Instant::now();
while let elapsed = start.elapsed() && elapsed < timeout
{
let timeout =
(timeout - elapsed).as_millis().clamp(1, i32::MAX as u128) as i32;
let res = unsafe{ WSAPoll(&mut pollfd, 1, timeout) };
if res == SOCKET_ERROR
{
let err = io::Error::last_os_error();
if err.kind() == ErrorKind::Interrupted
{
continue;
}
return Err(err);
}
else if res == 0
{
return Err(
io::Error::new(ErrorKind::TimedOut, format!("poll timeout!"))
);
}
else
{
if (pollfd.revents & POLLERR as i16) > 0 || (pollfd.revents & POLLHUP as i16) > 0
{
return Err(
self
.get_so_error()
.unwrap_or(
io::Error::new(ErrorKind::TimedOut,
format!("poll failed with no error"))
)
);
}
return Ok(());
}
}
return Err(
io::Error::new(ErrorKind::TimedOut, format!("poll timeout!"))
);
}
fn get_socket_peer_addr<A>(&self) -> io::Result<A>
where
A: TryFrom<So9SocketAddr, Error = io::Error>
{
let mut sa = So9SocketAddr::default();
let res =
unsafe
{
getpeername(
self.as_raw_socket() as usize,
sa.get_sockaddr_storage_ptr().cast(),
sa.get_capacity_mut()
)
};
if res != SOCKET_ERROR
{
return A::try_from(sa);
}
else
{
return Err(std::io::Error::last_os_error());
}
}
fn get_socket_local_addr<A>(&self) -> io::Result<A>
where
A: TryFrom<So9SocketAddr, Error = io::Error>
{
let mut sa = So9SocketAddr::default();
let res =
unsafe
{
getsockname(
self.as_raw_socket() as usize,
sa.get_sockaddr_storage_ptr().cast(),
sa.get_capacity_mut()
)
};
if res != SOCKET_ERROR
{
return A::try_from(sa);
}
else
{
return Err(std::io::Error::last_os_error());
}
}
fn get_socket_addr_type(&self) -> io::Result<So9SockDomain>
{
let mut sa = So9SocketAddr::default();
let res =
unsafe
{
getsockname(
self.as_raw_socket() as usize,
sa.get_sockaddr_storage_ptr().cast(),
sa.get_capacity_mut()
)
};
if res != SOCKET_ERROR
{
return Ok(So9SockDomain::from(sa.get_sa_fam()));
}
else
{
return Err(std::io::Error::last_os_error());
}
}
fn get_socket_proto(&self) -> io::Result<Option<<crate::op_sol_socket::SoProtocol as crate::SockOptMarker>::DataType>>
{
self.get_so_protocol().map(|v| Some(v))
}
fn send_with_flags(&self, data: &[u8], flags: So9MsgFlags) -> Result<usize, io::Error>
{
let res =
unsafe
{
send(
self.as_raw_socket() as usize,
data.as_ptr().cast(),
usize::min(data.len(), i32::MAX as usize) as i32,
flags.bits(),
)
};
if res != SOCKET_ERROR
{
return Ok(res as usize)
}
else
{
return Err(std::io::Error::last_os_error());
}
}
fn recv_vect_from_flags<A>(&self, buffers: &mut[io::IoSliceMut], _recv_from: bool, mut rcv_flags: i32) -> io::Result<(usize, So9MsgFlags, Option<A>)>
where
A: TryFrom<So9SocketAddr, Error = io::Error>,
{
let mut lpnumberofbytesrecvd = 0;
let mut sa = So9SocketAddr::default();
let res =
unsafe
{
WSARecvFrom(
self.as_raw_socket() as usize,
buffers.as_mut_ptr().cast(),
usize::min(buffers.len(), u32::MAX as usize) as u32,
&mut lpnumberofbytesrecvd,
&mut (rcv_flags as u32),
sa.get_sockaddr_ptr(),
sa.get_capacity_mut(),
ptr::null_mut(),
None,
)
};
if res != SOCKET_ERROR
{
return Ok((lpnumberofbytesrecvd as usize, So9MsgFlags::empty(), Some(A::try_from(sa)?)));
}
else
{
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(WSAESHUTDOWN)
{
return Ok((0, So9MsgFlags::empty(), Some(A::try_from(sa)?)));
}
else if err.raw_os_error() == Some(WSAEMSGSIZE)
{
return Ok((0, So9MsgFlags::MSG_TRUNC, Some(A::try_from(sa)?)));
}
return Err(err);
}
}
fn recv_with_flags(&self, buf: &mut [u8], flags: i32) -> io::Result<usize>
{
let res =
unsafe
{
recv(
self.as_raw_socket() as usize,
buf.as_mut_ptr().cast(),
usize::min(buf.len(), i32::MAX as usize) as i32,
flags,
)
};
if res != SOCKET_ERROR
{
return Ok(res as usize);
}
else
{
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(WSAESHUTDOWN)
{
return Ok(0);
}
return Err(err);
}
}
fn recv_from<A>(&self, buf: &mut [u8], rcv_flags: i32) -> io::Result<(usize, A)>
where
A: TryFrom<So9SocketAddr, Error = io::Error>
{
let mut sa = So9SocketAddr::default();
let res =
unsafe
{
recvfrom(
self.as_raw_socket() as usize,
buf.as_mut_ptr().cast(),
usize::min(buf.len(), i32::MAX as usize) as i32,
rcv_flags,
sa.get_sockaddr_ptr(),
sa.get_capacity_mut()
)
};
if res != SOCKET_ERROR
{
return Ok((res as usize, A::try_from(sa)?));
}
else
{
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(WSAESHUTDOWN)
{
return Ok((0, A::try_from(sa)?));
}
return Err(err);
}
}
fn recv_addr_from<A>(&self) -> io::Result<A>
where
A: TryFrom<So9SocketAddr, Error = io::Error>
{
let mut sa = So9SocketAddr::default();
let res =
unsafe
{
recvfrom(
self.as_raw_socket() as usize,
ptr::null_mut(),
0,
MSG_PEEK,
sa.get_sockaddr_ptr(),
sa.get_capacity_mut()
)
};
if res != SOCKET_ERROR
{
return Ok(A::try_from(sa)?);
}
else
{
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(WSAEMSGSIZE) ||
err.raw_os_error() == Some(WSAESHUTDOWN)
{
return Ok(A::try_from(sa)?);
}
return Err(err);
}
}
fn recv_vectored_flags(&self, bufs: &mut [io::IoSliceMut<'_>], flags: i32) -> io::Result<(usize, So9MsgFlags)>
{
let mut lpnumberofbytesrecvd = 0;
let res =
unsafe
{
WSARecv(
self.as_raw_socket() as usize,
bufs.as_mut_ptr().cast(),
usize::min(bufs.len(), u32::MAX as usize) as u32,
&mut lpnumberofbytesrecvd,
&mut (flags as u32),
ptr::null_mut(),
None,
)
};
if res != SOCKET_ERROR
{
return Ok((res as usize, So9MsgFlags::empty()));
}
else
{
let err = io::Error::last_os_error();
if err.raw_os_error() == Some(WSAESHUTDOWN)
{
return Ok((0, So9MsgFlags::empty()));
}
else if err.raw_os_error() == Some(WSAEMSGSIZE)
{
return Ok((0, So9MsgFlags::MSG_TRUNC));
}
return Err(err);
}
}
fn send_vectored(&self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize>
{
self.send_vectored_flags(bufs, 0)
}
fn send_vectored_flags(&self, bufs: &[IoSlice<'_>], snd_flags: i32) -> io::Result<usize>
{
let mut lpnumberofbytessent = 0;
let res =
unsafe
{
WSASend(
self.as_raw_socket() as usize,
bufs.as_ptr().cast(),
usize::min(bufs.len(), u32::MAX as usize) as u32,
&mut lpnumberofbytessent,
snd_flags as u32,
null_mut(),
None
)
};
if res != SOCKET_ERROR
{
return Ok(lpnumberofbytessent as usize);
}
else
{
return Err(std::io::Error::last_os_error());
}
}
fn send_to_sock<A>(&self, addr: &A, buf: &[u8], flags: i32) -> io::Result<usize>
where
A: So9AddrIntoRaw
{
let sa = addr.into_raw_addr();
let sa_raw = sa.get_sockaddr_ptr();
let res =
unsafe
{
sendto(
self.as_raw_socket() as usize,
buf.as_ptr().cast(),
usize::min(buf.len(), i32::MAX as usize) as i32,
flags,
sa_raw.0,
sa_raw.1
)
};
if res != SOCKET_ERROR
{
return Ok(res as usize);
}
else
{
return Err(io::Error::last_os_error());
}
}
fn send_to_sock_vect<A>(&self, addr: &A, bufs: &[IoSlice<'_>], flags: i32) -> io::Result<usize>
where
A: So9AddrIntoRaw
{
let mut lpnumberofbytessent = 0;
let sa = addr.into_raw_addr();
let sa_raw = sa.get_sockaddr_ptr();
let res =
unsafe
{
WSASendTo(
self.as_raw_socket() as usize,
bufs.as_ptr() as *mut _,
bufs.len().min(u32::MAX as usize) as u32,
&mut lpnumberofbytessent,
flags as u32,
sa_raw.0,
sa_raw.1,
ptr::null_mut(),
None,
)
};
if res != SOCKET_ERROR
{
return Ok(lpnumberofbytessent as usize);
}
else
{
return Err(io::Error::last_os_error());
}
}
fn try_clone_sock(&self, dwflags: So9SockDwFlags) -> io::Result<Self>
where
Self: Sized
{
let mut info = MaybeUninit::<WSAPROTOCOL_INFOW>::zeroed();
let res =
unsafe
{
WSADuplicateSocketW(self.as_raw_socket() as usize, process::id(), info.as_mut_ptr())
};
if res == SOCKET_ERROR
{
return Err(io::Error::last_os_error());
}
let mut proto_info = unsafe{ info.assume_init() };
let res =
unsafe
{
WSASocketW(
proto_info.iAddressFamily,
proto_info.iSocketType,
proto_info.iProtocol,
&mut proto_info,
0,
dwflags.bits()
)
};
if res == INVALID_SOCKET
{
return Err(io::Error::last_os_error());
}
let dup_sock = unsafe{ OwnedSocket::from_raw_socket(res as u64) };
return Ok(dup_sock);
}
fn set_nonblocking_sock(&self, nonblk: bool) -> io::Result<()>
{
let mut blk_flag = nonblk as u32;
let res =
unsafe{ ioctlsocket(self.as_raw_socket() as usize, FIONBIO, &mut blk_flag) };
if res == SOCKET_ERROR
{
return Err(io::Error::last_os_error());
}
return Ok(());
}
fn set_cloexec_sock(&self, cloexec: bool) -> io::Result<()>
{
let res =
unsafe
{
SetHandleInformation(
self.as_raw_socket() as HANDLE,
HANDLE_FLAG_INHERIT,
(cloexec == false) as u32,
)
};
if res != 0
{
return Ok(());
}
return Err(io::Error::last_os_error());
}
fn shutdown_sock(&self, how: Shutdown) -> io::Result<()>
{
let res =
match how
{
Shutdown::Write =>
unsafe{ shutdown(self.as_raw_socket() as usize, SD_SEND) },
Shutdown::Read =>
unsafe{ shutdown(self.as_raw_socket() as usize, SD_RECEIVE) },
Shutdown::Both =>
unsafe{ shutdown(self.as_raw_socket() as usize, SD_BOTH) },
};
if res == -1
{
return Err(io::Error::last_os_error());
}
return Ok(());
}
}