use std::ffi::c_void;
use std::sync::Arc;
use futures_util::future::BoxFuture;
use futures_util::FutureExt;
use tokio::sync::oneshot;
use crate::error::GpuError;
use super::CompletionStrategy;
#[derive(Clone, Default)]
pub struct HostFnCompletion;
impl HostFnCompletion {
pub fn new() -> Self {
Self
}
}
unsafe extern "C" fn wake_trampoline(data: *mut c_void) {
if data.is_null() {
return;
}
let slot: Box<oneshot::Sender<Result<(), GpuError>>> = Box::from_raw(data.cast());
let _ = slot.send(Ok(()));
}
impl CompletionStrategy for HostFnCompletion {
fn await_completion(
&self,
stream: &Arc<cudarc::driver::CudaStream>,
) -> BoxFuture<'static, Result<(), GpuError>> {
let stream = stream.clone();
let (tx, rx) = oneshot::channel::<Result<(), GpuError>>();
let boxed = Box::new(tx);
let arg = Box::into_raw(boxed) as *mut c_void;
let launch_res = unsafe {
cudarc::driver::result::stream::launch_host_function(
stream.cu_stream(),
wake_trampoline,
arg,
)
};
if let Err(e) = launch_res {
unsafe {
drop(Box::from_raw(
arg as *mut oneshot::Sender<Result<(), GpuError>>,
));
}
let msg = format!("cuLaunchHostFunc failed: {e}");
return async move { Err(GpuError::Driver(msg)) }.boxed();
}
async move {
match rx.await {
Ok(r) => r,
Err(_) => Err(GpuError::Driver(
"host-function callback dropped without firing".into(),
)),
}
}
.boxed()
}
}