use std::ffi::c_void;
use std::panic::{self, AssertUnwindSafe};
use futures_core::Stream;
use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender};
use tokio_stream::wrappers::UnboundedReceiverStream;
use windows::Win32::System::HostComputeSystem::{
HcsEventOptionNone, HcsSetComputeSystemCallback, HCS_EVENT, HCS_EVENT_OPTIONS, HCS_SYSTEM,
};
use crate::error::{HcsError, HcsResult};
#[derive(Debug, Clone)]
pub struct HcsEvent {
pub kind: HcsEventKind,
pub detail_json: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HcsEventKind {
SystemExited,
SystemCrashInitiated,
SystemCrashReport,
ServiceDisconnect,
Other(i32),
}
impl HcsEventKind {
fn from_raw(ty: i32) -> Self {
const SYSTEM_EXITED: i32 = 1;
const SYSTEM_CRASH_INITIATED: i32 = 2;
const SYSTEM_CRASH_REPORT: i32 = 3;
const SERVICE_DISCONNECT: i32 = 33_554_432;
match ty {
SYSTEM_EXITED => Self::SystemExited,
SYSTEM_CRASH_INITIATED => Self::SystemCrashInitiated,
SYSTEM_CRASH_REPORT => Self::SystemCrashReport,
SERVICE_DISCONNECT => Self::ServiceDisconnect,
other => Self::Other(other),
}
}
}
#[derive(Debug)]
pub struct EventSubscription {
_system: HCS_SYSTEM,
_sender_box_leak: *mut UnboundedSender<HcsEvent>,
}
unsafe impl Send for EventSubscription {}
unsafe impl Sync for EventSubscription {}
pub fn subscribe(
system: HCS_SYSTEM,
) -> HcsResult<(EventSubscription, impl Stream<Item = HcsEvent>)> {
let (tx, rx): (UnboundedSender<HcsEvent>, UnboundedReceiver<HcsEvent>) =
mpsc::unbounded_channel();
let tx_box: Box<UnboundedSender<HcsEvent>> = Box::new(tx);
let tx_ptr: *mut UnboundedSender<HcsEvent> = Box::into_raw(tx_box);
let register_result = unsafe {
HcsSetComputeSystemCallback(
system,
HCS_EVENT_OPTIONS(HcsEventOptionNone.0),
Some(tx_ptr.cast::<c_void>().cast_const()),
Some(event_trampoline),
)
};
if let Err(err) = register_result {
drop(unsafe { Box::from_raw(tx_ptr) });
return Err(HcsError::from_hresult(
err.code(),
"HcsSetComputeSystemCallback",
));
}
let sub = EventSubscription {
_system: system,
_sender_box_leak: tx_ptr,
};
Ok((sub, UnboundedReceiverStream::new(rx)))
}
unsafe extern "system" fn event_trampoline(event: *const HCS_EVENT, context: *const c_void) {
let result = panic::catch_unwind(AssertUnwindSafe(|| {
if context.is_null() || event.is_null() {
return;
}
let sender_ptr = context.cast::<UnboundedSender<HcsEvent>>().cast_mut();
let sender: &UnboundedSender<HcsEvent> = unsafe { &*sender_ptr };
let ev = unsafe { &*event };
let detail = if ev.EventData.is_null() {
String::new()
} else {
unsafe { ev.EventData.to_string() }.unwrap_or_default()
};
let kind = HcsEventKind::from_raw(ev.Type.0);
let _ = sender.send(HcsEvent {
kind,
detail_json: detail,
});
}));
if let Err(payload) = result {
let msg = if let Some(s) = payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic payload".to_string()
};
tracing::error!(panic = %msg, "HCS event trampoline panicked");
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classifies_known_event_types() {
assert_eq!(HcsEventKind::from_raw(1), HcsEventKind::SystemExited);
assert_eq!(
HcsEventKind::from_raw(2),
HcsEventKind::SystemCrashInitiated
);
assert_eq!(HcsEventKind::from_raw(3), HcsEventKind::SystemCrashReport);
assert_eq!(
HcsEventKind::from_raw(33_554_432),
HcsEventKind::ServiceDisconnect
);
}
#[test]
fn unknown_event_type_falls_through_to_other() {
assert_eq!(HcsEventKind::from_raw(4), HcsEventKind::Other(4));
assert_eq!(
HcsEventKind::from_raw(0x7FFF_FFFF),
HcsEventKind::Other(0x7FFF_FFFF)
);
}
}