#![cfg(target_os = "windows")]
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{oneshot, Mutex};
use crate::error::{GcsError, GcsResult};
use crate::frame::{self, RpcMessageType, HEADER_LEN, MSG_TYPE_MASK, MSG_TYPE_RESPONSE};
use crate::protocol::{
NegotiateProtocolRequest, NegotiateProtocolResponse, ProtocolSupport, RequestBase, ResponseBase,
};
use crate::transport::{HvSockListener, HvSockStream, GCS_SERVICE_GUID};
const NULL_CONTAINER_ID: &str = "00000000-0000-0000-0000-000000000000";
macro_rules! gcs_debug {
($fmt:literal $(, $arg:expr)* $(,)?) => {
#[cfg(feature = "windows-debug")]
{
eprintln!(
concat!("[t=+{}us] ", $fmt),
$crate::diagnostics::ts_us(),
$($arg),*
);
}
};
}
pub const GCS_PROTOCOL_VERSION: u32 = 4;
fn describe_response_error(stage: &str, base: &ResponseBase) -> String {
let hresult = u32::from_ne_bytes(base.result.to_ne_bytes());
let records = if base.error_records.is_null() {
String::new()
} else {
serde_json::to_string(&base.error_records)
.map_or_else(|_| String::new(), |s| format!(" error_records={s}"))
};
format!(
"{stage} returned HRESULT {hresult:#x}: {}{records}",
base.error_message
)
}
fn stage_err(stage: &str, err: &GcsError) -> GcsError {
GcsError::Negotiation(format!("{stage}: {err}"))
}
#[derive(Debug)]
struct PendingState {
closed: bool,
waiters: HashMap<u64, oneshot::Sender<(u32, Vec<u8>)>>,
}
type PendingMap = Arc<Mutex<PendingState>>;
#[derive(Clone, Debug)]
pub struct GcsBridge {
stream: HvSockStream,
next_id: Arc<AtomicU64>,
pending: PendingMap,
}
impl GcsBridge {
pub async fn listen(vm_id: windows::core::GUID) -> GcsResult<PendingGcsBridge> {
let listener = HvSockListener::bind(vm_id, GCS_SERVICE_GUID).await?;
Ok(PendingGcsBridge { listener })
}
pub async fn negotiate_protocol(&self) -> GcsResult<ProtocolSupport> {
let req = NegotiateProtocolRequest {
base: RequestBase {
activity_id: uuid::Uuid::new_v4(),
container_id: String::new(),
},
minimum_version: GCS_PROTOCOL_VERSION,
maximum_version: GCS_PROTOCOL_VERSION,
};
let resp: NegotiateProtocolResponse = self
.send_rpc_json(RpcMessageType::NegotiateProtocol, &req)
.await
.map_err(|e| stage_err("negotiate", &e))?;
if resp.base.result != 0 {
return Err(GcsError::Negotiation(describe_response_error(
"negotiate",
&resp.base,
)));
}
if resp.version != GCS_PROTOCOL_VERSION {
return Err(GcsError::Negotiation(format!(
"guest chose version {} (host wanted {GCS_PROTOCOL_VERSION})",
resp.version
)));
}
Ok(resp.capabilities)
}
async fn cold_start_create_start(
&self,
caps: &ProtocolSupport,
host_tz: Option<&serde_json::Value>,
) -> GcsResult<()> {
if !caps.send_host_create_message {
return Ok(());
}
let settle_ms = std::env::var("ZLAYER_GCS_COLDSTART_DELAY_MS")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(0);
if settle_ms > 0 {
tokio::time::sleep(std::time::Duration::from_millis(settle_ms)).await;
}
let timezone_information = host_tz.map_or_else(
|| {
serde_json::json!({
"StandardName": "Coordinated Universal Time",
"DaylightName": "Coordinated Universal Time",
"StandardDate": {},
"DaylightDate": {},
})
},
Clone::clone,
);
let uvm_config_str = serde_json::to_string(&serde_json::json!({
"SystemType": "Container",
"TimeZoneInformation": timezone_information,
}))?;
let create_body = serde_json::json!({
"ActivityId": uuid::Uuid::new_v4().to_string(),
"ContainerId": NULL_CONTAINER_ID,
"ContainerConfig": uvm_config_str,
});
let create_resp: ResponseBase = self
.send_rpc_json(RpcMessageType::Create, &create_body)
.await
.map_err(|e| stage_err("cold-start Create", &e))?;
if create_resp.result != 0 {
return Err(GcsError::Negotiation(describe_response_error(
"cold-start Create",
&create_resp,
)));
}
if !caps.send_host_start_message {
return Ok(());
}
let start_body = serde_json::json!({
"ActivityId": uuid::Uuid::new_v4().to_string(),
"ContainerId": NULL_CONTAINER_ID,
});
let start_resp: ResponseBase = self
.send_rpc_json(RpcMessageType::Start, &start_body)
.await
.map_err(|e| stage_err("cold-start Start", &e))?;
if start_resp.result != 0 {
return Err(GcsError::Negotiation(describe_response_error(
"cold-start Start",
&start_resp,
)));
}
Ok(())
}
#[cfg(feature = "windows-debug")]
async fn start_log_forwarding(&self) -> GcsResult<()> {
let body = serde_json::json!({
"ActivityId": uuid::Uuid::new_v4().to_string(),
"ContainerId": NULL_CONTAINER_ID,
"PropertyType": "LogForwardService",
"Settings": {
"RPCType": "StartLogForwarding",
"Settings": "",
},
});
gcs_debug!(
"gcs-logfwd: issuing StartLogForwarding ModifyServiceSettings RPC: {}",
body,
);
let resp: ResponseBase = self
.send_rpc_json(RpcMessageType::ModifyServiceSettings, &body)
.await
.map_err(|e| stage_err("StartLogForwarding", &e))?;
if resp.result != 0 {
return Err(GcsError::Negotiation(describe_response_error(
"StartLogForwarding",
&resp,
)));
}
Ok(())
}
pub async fn send_rpc_json<Req, Resp>(&self, rpc: RpcMessageType, req: &Req) -> GcsResult<Resp>
where
Req: serde::Serialize + Sync,
Resp: serde::de::DeserializeOwned,
{
let message_id = self.next_id.fetch_add(1, Ordering::Relaxed);
let payload = serde_json::to_vec(req)?;
let mut frame = Vec::with_capacity(HEADER_LEN + payload.len());
frame::encode_frame(rpc.as_request_type(), message_id, &payload, &mut frame);
gcs_debug!(
"gcs-bridge-send: rpc={:?} msg_id={} frame_size={} payload_size={} payload={}",
rpc,
message_id,
frame.len(),
payload.len(),
std::str::from_utf8(&payload).unwrap_or("<non-utf8>"),
);
let (tx, rx) = oneshot::channel();
{
let mut guard = self.pending.lock().await;
if guard.closed {
return Err(GcsError::Closed);
}
guard.waiters.insert(message_id, tx);
}
if let Err(e) = self.stream.write_all(&frame).await {
self.pending.lock().await.waiters.remove(&message_id);
return Err(e);
}
let (resp_type, resp_payload) = rx.await.map_err(|_| GcsError::Closed)?;
let expected = rpc.as_response_type();
if resp_type != expected {
return Err(GcsError::Protocol(format!(
"unexpected response type {resp_type:#x} (expected {expected:#x}) for message {message_id}"
)));
}
let resp: Resp = serde_json::from_slice(&resp_payload)?;
Ok(resp)
}
fn spawn_reader(&self) {
let stream = self.stream.clone();
let pending = Arc::clone(&self.pending);
tokio::spawn(async move {
gcs_debug!("gcs-bridge-reader: started");
#[cfg(feature = "windows-debug")]
let mut frames_seen: u32 = 0;
loop {
let mut hdr_buf = [0u8; HEADER_LEN];
#[cfg_attr(not(feature = "windows-debug"), allow(unused_variables))]
if let Err(e) = stream.read_exact(&mut hdr_buf).await {
gcs_debug!(
"gcs-bridge-reader: header read failed after {} frame(s): {}",
frames_seen,
e,
);
break;
}
let hdr = match frame::decode_header(&hdr_buf) {
Ok(h) => h,
#[cfg_attr(not(feature = "windows-debug"), allow(unused_variables))]
Err(e) => {
gcs_debug!(
"gcs-bridge-reader: header decode failed (bytes={:02x?}): {}",
hdr_buf,
e,
);
break;
}
};
gcs_debug!(
"gcs-bridge-reader: frame#{} type=0x{:08x} size={} msg_id={}",
frames_seen,
hdr.r#type,
hdr.size,
hdr.message_id,
);
let body_len = (hdr.size as usize) - HEADER_LEN;
let mut body = vec![0u8; body_len];
if body_len > 0 {
#[cfg_attr(not(feature = "windows-debug"), allow(unused_variables))]
if let Err(e) = stream.read_exact(&mut body).await {
gcs_debug!(
"gcs-bridge-reader: body read failed (need {} bytes): {}",
body_len,
e,
);
break;
}
#[cfg(feature = "windows-debug")]
{
let cap = body.len().min(512);
gcs_debug!(
"gcs-bridge-reader: body[..{}]={:?}",
cap,
String::from_utf8_lossy(&body[..cap]),
);
}
}
#[cfg(feature = "windows-debug")]
{
frames_seen = frames_seen.saturating_add(1);
}
if hdr.r#type & MSG_TYPE_MASK != MSG_TYPE_RESPONSE {
continue;
}
let waiter = {
let mut guard = pending.lock().await;
guard.waiters.remove(&hdr.message_id)
};
if let Some(tx) = waiter {
let _ = tx.send((hdr.r#type, body));
}
}
{
let mut g = pending.lock().await;
gcs_debug!(
"gcs-bridge-reader: exiting; dropping {} pending waiters",
g.waiters.len(),
);
g.closed = true;
g.waiters.clear();
}
});
}
}
pub struct PendingGcsBridge {
listener: HvSockListener,
}
impl PendingGcsBridge {
pub async fn accept(
self,
timeout: Duration,
host_tz: Option<serde_json::Value>,
) -> GcsResult<GcsBridge> {
let stream = tokio::time::timeout(timeout, self.listener.accept())
.await
.map_err(|_| {
GcsError::Hvsock(format!(
"timed out after {timeout:?} waiting for in-guest GCS to connect"
))
})??;
let bridge = GcsBridge {
stream,
next_id: Arc::new(AtomicU64::new(1)),
pending: Arc::new(Mutex::new(PendingState {
closed: false,
waiters: HashMap::new(),
})),
};
bridge.spawn_reader();
let caps = bridge.negotiate_protocol().await?;
#[cfg(feature = "windows-debug")]
{
if caps.modify_service_settings_supported {
if let Err(e) = bridge.start_log_forwarding().await {
gcs_debug!(
"gcs-logfwd: StartLogForwarding RPC failed (continuing): {}",
e
);
}
} else {
gcs_debug!(
"gcs-logfwd: guest did not advertise ModifyServiceSettingsSupported; \
skipping StartLogForwarding RPC (unsupported on this guest)",
);
}
}
bridge
.cold_start_create_start(&caps, host_tz.as_ref())
.await?;
Ok(bridge)
}
}
#[cfg(test)]
mod tests {
use super::GcsBridge;
#[test]
fn bridge_is_clone_send_sync() {
const fn assert_clone_send_sync<T: Clone + Send + Sync>() {}
assert_clone_send_sync::<GcsBridge>();
}
}