use std::ffi::c_void;
use std::panic::{self, AssertUnwindSafe};
use windows::core::PWSTR;
use windows::Win32::System::HostComputeSystem::{
HcsCreateOperation, HcsGetOperationResult, HCS_OPERATION,
};
use crate::error::{HcsError, HcsResult};
use crate::handle::{OwnedOperation, SendHandle};
type CompletionSender = tokio::sync::oneshot::Sender<HcsResult<String>>;
unsafe fn pwstr_to_string(pwstr: PWSTR) -> String {
if pwstr.is_null() {
return String::new();
}
if let Ok(s) = unsafe { pwstr.to_string() } {
s
} else {
let wide = unsafe { pwstr.as_wide() };
String::from_utf16_lossy(wide)
}
}
unsafe extern "system" fn completion_trampoline(operation: HCS_OPERATION, context: *const c_void) {
let sender_box: Box<CompletionSender> =
unsafe { Box::from_raw(context.cast::<CompletionSender>().cast_mut()) };
let sender = *sender_box;
let payload = panic::catch_unwind(AssertUnwindSafe(|| {
let mut result_doc: PWSTR = PWSTR::null();
let hr_result = unsafe { HcsGetOperationResult(operation, Some(&raw mut result_doc)) };
let json = unsafe { pwstr_to_string(result_doc) };
match hr_result {
Ok(()) => Ok(json),
Err(err) => {
let hr = err.code();
let message = if json.is_empty() { err.message() } else { json };
Err(HcsError::from_hresult(hr, message))
}
}
}));
let outcome = match payload {
Ok(outcome) => outcome,
Err(panic_payload) => {
let msg = if let Some(s) = panic_payload.downcast_ref::<&'static str>() {
(*s).to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic payload".to_string()
};
tracing::error!(panic = %msg, "HCS completion callback panicked");
Err(HcsError::CallbackPanic(msg))
}
};
let _ = sender.send(outcome);
}
pub async fn run_operation<F>(f: F) -> HcsResult<String>
where
F: FnOnce(HCS_OPERATION) -> windows::core::HRESULT,
{
let (tx, rx) = tokio::sync::oneshot::channel::<HcsResult<String>>();
let tx_box: Box<CompletionSender> = Box::new(tx);
let tx_ptr: *mut CompletionSender = Box::into_raw(tx_box);
let op_raw = SendHandle(unsafe {
HcsCreateOperation(
Some(tx_ptr.cast::<c_void>().cast_const()),
Some(completion_trampoline),
)
});
if op_raw.0.is_invalid() {
drop(unsafe { Box::from_raw(tx_ptr) });
return Err(HcsError::Other {
hresult: 0,
message: "HcsCreateOperation returned an invalid handle".to_string(),
});
}
let op = unsafe { OwnedOperation::from_raw(op_raw.0) };
let kickoff_hr = f(*op.as_raw());
if kickoff_hr.is_err() {
drop(unsafe { Box::from_raw(tx_ptr) });
return Err(HcsError::from_hresult(
kickoff_hr,
"HCS call failed to start",
));
}
match rx.await {
Ok(payload) => payload,
Err(_closed) => Err(HcsError::Other {
hresult: 0,
message: "HCS callback dropped the completion sender without sending".to_string(),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn dropped_sender_surfaces_other_error() {
let (tx, rx) = tokio::sync::oneshot::channel::<HcsResult<String>>();
drop(tx);
let result = match rx.await {
Ok(payload) => payload,
Err(_closed) => Err(HcsError::Other {
hresult: 0,
message: "HCS callback dropped the completion sender without sending".to_string(),
}),
};
match result {
Err(HcsError::Other { hresult, message }) => {
assert_eq!(hresult, 0);
assert!(message.contains("dropped"));
}
other => panic!("expected Other error, got {other:?}"),
}
}
}