use std::sync::Arc;
use cudarc::driver::CudaSlice;
use futures_util::FutureExt;
use tokio::sync::oneshot;
use tracing::warn;
use crate::completion::CompletionStrategy;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
pub fn access_all_2<A, B>(
a: &GpuRef<A>,
b: &GpuRef<B>,
) -> Result<(Arc<CudaSlice<A>>, Arc<CudaSlice<B>>), GpuError> {
let a_s = a.access()?.clone();
let b_s = b.access()?.clone();
Ok((a_s, b_s))
}
pub fn access_all_3<A, B, C>(
a: &GpuRef<A>,
b: &GpuRef<B>,
c: &GpuRef<C>,
) -> Result<(Arc<CudaSlice<A>>, Arc<CudaSlice<B>>, Arc<CudaSlice<C>>), GpuError> {
let a_s = a.access()?.clone();
let b_s = b.access()?.clone();
let c_s = c.access()?.clone();
Ok((a_s, b_s, c_s))
}
pub fn access_all_4<A, B, C, D>(
a: &GpuRef<A>,
b: &GpuRef<B>,
c: &GpuRef<C>,
d: &GpuRef<D>,
) -> Result<
(
Arc<CudaSlice<A>>,
Arc<CudaSlice<B>>,
Arc<CudaSlice<C>>,
Arc<CudaSlice<D>>,
),
GpuError,
> {
let a_s = a.access()?.clone();
let b_s = b.access()?.clone();
let c_s = c.access()?.clone();
let d_s = d.access()?.clone();
Ok((a_s, b_s, c_s, d_s))
}
pub fn run_kernel<O, KA, F>(
lib_tag: &'static str,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
output: O,
reply: oneshot::Sender<Result<O, GpuError>>,
enqueue: F,
) where
O: Send + 'static,
KA: Send + 'static,
F: FnOnce() -> Result<KA, GpuError>,
{
let keep_alive = match enqueue() {
Ok(ka) => ka,
Err(e) => {
let _ = reply.send(Err(annotate_error(e, lib_tag)));
return;
}
};
let fut = completion.await_completion(stream).boxed();
tokio::spawn(async move {
let result = fut.await;
match result {
Ok(()) => {
let _ = reply.send(Ok(output));
}
Err(e) => {
warn!(lib = lib_tag, error = %e, "kernel completion failed");
let _ = reply.send(Err(e));
}
}
drop(keep_alive);
});
}
fn annotate_error(e: GpuError, lib_tag: &'static str) -> GpuError {
match e {
GpuError::Driver(msg) => GpuError::LibraryError { lib: lib_tag, msg },
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
#[test]
fn annotate_error_tags_driver_failures() {
let e = annotate_error(GpuError::Driver("oops".into()), "cudnn");
match e {
GpuError::LibraryError { lib, msg } => {
assert_eq!(lib, "cudnn");
assert_eq!(msg, "oops");
}
other => panic!("expected LibraryError, got {other:?}"),
}
}
#[test]
fn annotate_error_passes_through_typed_variants() {
let e = annotate_error(GpuError::OutOfMemory("alloc".into()), "cudnn");
assert!(matches!(e, GpuError::OutOfMemory(_)));
let e = annotate_error(GpuError::GpuRefStale("stale"), "cudnn");
assert!(matches!(e, GpuError::GpuRefStale(_)));
}
#[test]
fn pre_enqueue_error_bypasses_completion() {
let (tx, rx) = oneshot::channel::<Result<u32, GpuError>>();
let mut bumped = AtomicU32::new(0);
let enqueue = || -> Result<(), GpuError> {
bumped.fetch_add(1, Ordering::Relaxed);
Err(GpuError::OutOfMemory("forced".into()))
};
let res = enqueue();
assert!(matches!(res, Err(GpuError::OutOfMemory(_))));
assert_eq!(*bumped.get_mut(), 1);
drop(tx);
drop(rx);
}
}