#![cfg(target_os = "windows")]
#![allow(unsafe_code)]
use std::mem::size_of;
use std::sync::{Arc, Once};
use windows::Win32::Networking::WinSock::{
accept, bind, closesocket, connect, listen, recv, send, shutdown, WSAGetLastError, WSASocketW,
WSAStartup, INVALID_SOCKET, SD_BOTH, SEND_RECV_FLAGS, SOCKADDR, SOCKET, SOCKET_ERROR,
SOCK_STREAM, WINSOCK_SOCKET_TYPE, WSADATA, WSA_FLAG_OVERLAPPED,
};
use crate::error::{GcsError, GcsResult};
pub const GCS_SERVICE_GUID: windows::core::GUID =
windows::core::GUID::from_u128(0xacef_5661_84a1_4e44_856b_6245_e69f_4620);
pub const HV_GUID_LOOPBACK: windows::core::GUID =
windows::core::GUID::from_u128(0xe0e1_6197_dd56_4a10_9195_5ee7_a155_a838);
pub const HV_GUID_WILDCARD: windows::core::GUID =
windows::core::GUID::from_u128(0x0000_0000_0000_0000_0000_0000_0000_0000);
pub const WINDOWS_GCS_HV_HOST_ID: windows::core::GUID =
windows::core::GUID::from_u128(0x894c_c2d6_9d79_424f_93fe_4296_9ae6_d8d1);
pub const WINDOWS_LOGGING_HVSOCK_SERVICE_ID: windows::core::GUID =
windows::core::GUID::from_u128(0x172d_ad59_976d_45f2_8b6c_6d1b_13f2_ac4d);
const AF_HYPERV: u16 = 34;
const HV_PROTOCOL_RAW: i32 = 1;
const LISTEN_BACKLOG: i32 = 8;
#[repr(C)]
#[derive(Clone, Copy, Debug)]
pub struct SockAddrHv {
pub family: u16,
pub reserved: u16,
pub vm_id: windows::core::GUID,
pub service_id: windows::core::GUID,
}
impl SockAddrHv {
#[must_use]
pub const fn new(vm_id: windows::core::GUID, service_id: windows::core::GUID) -> Self {
Self {
family: AF_HYPERV,
reserved: 0,
vm_id,
service_id,
}
}
}
fn ensure_winsock_started() -> GcsResult<()> {
static INIT: Once = Once::new();
static mut STARTUP_ERR: i32 = 0;
INIT.call_once(|| {
let mut data = WSADATA::default();
let rc = unsafe { WSAStartup(0x0202, &raw mut data) };
if rc != 0 {
unsafe {
STARTUP_ERR = rc;
}
}
});
let rc = unsafe { STARTUP_ERR };
if rc == 0 {
Ok(())
} else {
Err(GcsError::Hvsock(format!("WSAStartup failed: rc={rc}")))
}
}
fn wsa_err(ctx: &str) -> GcsError {
let code = unsafe { WSAGetLastError() };
GcsError::Hvsock(format!("WSA error {code:?}: {ctx}", code = code.0))
}
#[derive(Debug)]
struct HvSocketInner {
raw: usize,
}
impl HvSocketInner {
const fn from_socket(s: SOCKET) -> Self {
Self { raw: s.0 }
}
const fn socket(&self) -> SOCKET {
SOCKET(self.raw)
}
}
impl Drop for HvSocketInner {
fn drop(&mut self) {
if self.raw != INVALID_SOCKET.0 {
let _ = unsafe { shutdown(self.socket(), SD_BOTH) };
let _ = unsafe { closesocket(self.socket()) };
}
}
}
fn new_hvsock() -> GcsResult<SOCKET> {
ensure_winsock_started()?;
let socket = unsafe {
WSASocketW(
i32::from(AF_HYPERV),
SOCK_STREAM.0,
HV_PROTOCOL_RAW,
None,
0,
WSA_FLAG_OVERLAPPED,
)
};
socket.map_err(|e| GcsError::Hvsock(format!("WSASocketW: {e}")))
}
#[derive(Debug, Clone)]
pub struct HvSockStream {
inner: Arc<HvSocketInner>,
}
impl HvSockStream {
pub async fn connect(
vm_id: windows::core::GUID,
service_id: windows::core::GUID,
) -> GcsResult<Self> {
let socket = new_hvsock()?;
let inner = Arc::new(HvSocketInner::from_socket(socket));
let inner_for_blocking = inner.clone();
let join = tokio::task::spawn_blocking(move || -> GcsResult<()> {
let addr = SockAddrHv::new(vm_id, service_id);
let addr_ptr: *const SOCKADDR = std::ptr::from_ref(&addr).cast();
let addr_len = i32::try_from(size_of::<SockAddrHv>())
.map_err(|e| GcsError::Hvsock(format!("addr size overflow: {e}")))?;
let rc = unsafe { connect(inner_for_blocking.socket(), addr_ptr, addr_len) };
if rc == SOCKET_ERROR {
Err(wsa_err("connect"))
} else {
Ok(())
}
})
.await
.map_err(|e| GcsError::Hvsock(format!("connect join: {e}")))?;
join?;
Ok(Self { inner })
}
pub async fn connect_loopback(service_id: windows::core::GUID) -> GcsResult<Self> {
Self::connect(HV_GUID_LOOPBACK, service_id).await
}
pub async fn read_exact(&self, buf: &mut [u8]) -> GcsResult<()> {
if buf.is_empty() {
return Ok(());
}
let inner = self.inner.clone();
let len = buf.len();
let filled = tokio::task::spawn_blocking(move || -> GcsResult<Vec<u8>> {
let mut out = vec![0u8; len];
let mut filled = 0usize;
while filled < len {
let n = unsafe { recv(inner.socket(), &mut out[filled..], SEND_RECV_FLAGS(0)) };
if n == SOCKET_ERROR {
return Err(wsa_err("recv"));
}
if n == 0 {
return Err(GcsError::Closed);
}
let advanced =
usize::try_from(n).map_err(|e| GcsError::Hvsock(format!("recv count: {e}")))?;
filled += advanced;
}
Ok(out)
})
.await
.map_err(|e| GcsError::Hvsock(format!("read_exact join: {e}")))??;
buf.copy_from_slice(&filled);
Ok(())
}
pub async fn read_some(&self, max: usize) -> GcsResult<Vec<u8>> {
if max == 0 {
return Ok(Vec::new());
}
let inner = self.inner.clone();
let bytes = tokio::task::spawn_blocking(move || -> GcsResult<Vec<u8>> {
let mut out = vec![0u8; max];
let n = unsafe { recv(inner.socket(), &mut out, SEND_RECV_FLAGS(0)) };
if n == SOCKET_ERROR {
return Err(wsa_err("recv"));
}
let got =
usize::try_from(n).map_err(|e| GcsError::Hvsock(format!("recv count: {e}")))?;
out.truncate(got);
Ok(out)
})
.await
.map_err(|e| GcsError::Hvsock(format!("read_some join: {e}")))??;
Ok(bytes)
}
pub async fn write_all(&self, buf: &[u8]) -> GcsResult<()> {
if buf.is_empty() {
return Ok(());
}
let inner = self.inner.clone();
let owned = buf.to_vec();
tokio::task::spawn_blocking(move || -> GcsResult<()> {
let mut sent = 0usize;
while sent < owned.len() {
let n = unsafe { send(inner.socket(), &owned[sent..], SEND_RECV_FLAGS(0)) };
if n == SOCKET_ERROR {
return Err(wsa_err("send"));
}
if n == 0 {
return Err(GcsError::Closed);
}
let advanced =
usize::try_from(n).map_err(|e| GcsError::Hvsock(format!("send count: {e}")))?;
sent += advanced;
}
Ok(())
})
.await
.map_err(|e| GcsError::Hvsock(format!("write_all join: {e}")))??;
Ok(())
}
pub async fn shutdown(self) -> GcsResult<()> {
let inner = self.inner.clone();
tokio::task::spawn_blocking(move || -> GcsResult<()> {
let rc = unsafe { shutdown(inner.socket(), SD_BOTH) };
if rc == SOCKET_ERROR {
Err(wsa_err("shutdown"))
} else {
Ok(())
}
})
.await
.map_err(|e| GcsError::Hvsock(format!("shutdown join: {e}")))??;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct HvSockListener {
inner: Arc<HvSocketInner>,
}
impl HvSockListener {
pub async fn bind(
vm_id: windows::core::GUID,
service_id: windows::core::GUID,
) -> GcsResult<Self> {
let socket = new_hvsock()?;
let inner = Arc::new(HvSocketInner::from_socket(socket));
let inner_for_blocking = inner.clone();
tokio::task::spawn_blocking(move || -> GcsResult<()> {
let addr = SockAddrHv::new(vm_id, service_id);
let addr_ptr: *const SOCKADDR = std::ptr::from_ref(&addr).cast();
let addr_len = i32::try_from(size_of::<SockAddrHv>())
.map_err(|e| GcsError::Hvsock(format!("addr size overflow: {e}")))?;
let bind_rc = unsafe { bind(inner_for_blocking.socket(), addr_ptr, addr_len) };
if bind_rc == SOCKET_ERROR {
return Err(wsa_err("bind"));
}
let listen_rc = unsafe { listen(inner_for_blocking.socket(), LISTEN_BACKLOG) };
if listen_rc == SOCKET_ERROR {
return Err(wsa_err("listen"));
}
Ok(())
})
.await
.map_err(|e| GcsError::Hvsock(format!("bind join: {e}")))??;
Ok(Self { inner })
}
pub async fn accept(&self) -> GcsResult<HvSockStream> {
let raw_socket = self.inner.socket();
let accepted = tokio::task::spawn_blocking(move || -> GcsResult<SOCKET> {
let s = unsafe { accept(raw_socket, None, None) }
.map_err(|e| GcsError::Hvsock(format!("accept: {e}")))?;
Ok(s)
})
.await
.map_err(|e| GcsError::Hvsock(format!("accept join: {e}")))??;
Ok(HvSockStream {
inner: Arc::new(HvSocketInner::from_socket(accepted)),
})
}
}
#[allow(dead_code)]
const _: WINSOCK_SOCKET_TYPE = SOCK_STREAM;
#[cfg(test)]
mod tests {
use super::{SockAddrHv, AF_HYPERV, GCS_SERVICE_GUID, HV_GUID_LOOPBACK, HV_GUID_WILDCARD};
#[test]
fn sockaddr_hv_layout() {
let addr = SockAddrHv::new(HV_GUID_LOOPBACK, GCS_SERVICE_GUID);
assert_eq!(addr.family, AF_HYPERV);
assert_eq!(addr.reserved, 0);
assert_eq!(addr.vm_id, HV_GUID_LOOPBACK);
assert_eq!(addr.service_id, GCS_SERVICE_GUID);
assert_eq!(std::mem::size_of::<SockAddrHv>(), 36);
assert_eq!(std::mem::align_of::<SockAddrHv>(), 4);
}
#[test]
fn wildcard_guid_is_zero() {
assert_eq!(HV_GUID_WILDCARD.to_u128(), 0);
}
#[tokio::test]
async fn accept_timeout_does_not_wedge_runtime() {
use std::time::{Duration, Instant};
let svc = windows::core::GUID::from_u128(
0xdead_beef_0000_4000_8000_0000_0000_0001_u128 ^ u128::from(std::process::id()),
);
let Ok(listener) = super::HvSockListener::bind(HV_GUID_WILDCARD, svc).await else {
return;
};
let started = Instant::now();
let res = tokio::time::timeout(Duration::from_millis(200), listener.accept()).await;
assert!(
res.is_err(),
"accept should hit the timeout (no peer dials)"
);
drop(listener);
let total = started.elapsed();
assert!(
total < Duration::from_secs(5),
"accept_timeout test took {total:?}; regression: blocking pool wedged"
);
}
}