use std::sync::Arc;
use std::time::{Duration, Instant};
use cudarc::driver::CudaSlice;
use futures_util::FutureExt;
use tokio::sync::oneshot;
use tracing::warn;
use crate::completion::CompletionStrategy;
use crate::error::GpuError;
use crate::gpu_ref::GpuRef;
pub fn access_all_2<A, B>(
a: &GpuRef<A>,
b: &GpuRef<B>,
) -> Result<(Arc<CudaSlice<A>>, Arc<CudaSlice<B>>), GpuError> {
let a_s = a.access()?.clone();
let b_s = b.access()?.clone();
Ok((a_s, b_s))
}
pub fn access_all_3<A, B, C>(
a: &GpuRef<A>,
b: &GpuRef<B>,
c: &GpuRef<C>,
) -> Result<(Arc<CudaSlice<A>>, Arc<CudaSlice<B>>, Arc<CudaSlice<C>>), GpuError> {
let a_s = a.access()?.clone();
let b_s = b.access()?.clone();
let c_s = c.access()?.clone();
Ok((a_s, b_s, c_s))
}
pub fn access_all_4<A, B, C, D>(
a: &GpuRef<A>,
b: &GpuRef<B>,
c: &GpuRef<C>,
d: &GpuRef<D>,
) -> Result<
(
Arc<CudaSlice<A>>,
Arc<CudaSlice<B>>,
Arc<CudaSlice<C>>,
Arc<CudaSlice<D>>,
),
GpuError,
> {
let a_s = a.access()?.clone();
let b_s = b.access()?.clone();
let c_s = c.access()?.clone();
let d_s = d.access()?.clone();
Ok((a_s, b_s, c_s, d_s))
}
#[derive(Debug, Clone, Copy)]
pub struct KernelInfo<'a> {
pub op_name: &'a str,
pub library: &'a str,
pub stream_id: u64,
pub dtype: Option<&'a str>,
}
pub trait KernelTrace: Send + Sync + 'static {
fn before_enqueue(&self, info: &KernelInfo<'_>) {
let _ = info;
}
fn after_enqueue(&self, info: &KernelInfo<'_>, result: Result<(), &GpuError>) {
let _ = (info, result);
}
fn before_complete(&self, info: &KernelInfo<'_>) {
let _ = info;
}
fn after_complete(
&self,
info: &KernelInfo<'_>,
result: Result<(), &GpuError>,
latency: Duration,
) {
let _ = (info, result, latency);
}
}
#[derive(Clone)]
pub struct KernelEnvelope {
lib_tag: &'static str,
op_name: &'static str,
dtype: Option<&'static str>,
trace: Option<Arc<dyn KernelTrace>>,
nvtx_range_name: Option<&'static str>,
}
impl KernelEnvelope {
pub fn new(lib_tag: &'static str) -> Self {
Self {
lib_tag,
op_name: lib_tag,
dtype: None,
trace: None,
nvtx_range_name: None,
}
}
pub fn with_op_name(mut self, op_name: &'static str) -> Self {
self.op_name = op_name;
self
}
pub fn with_dtype(mut self, dtype: &'static str) -> Self {
self.dtype = Some(dtype);
self
}
pub fn with_trace(mut self, trace: Arc<dyn KernelTrace>) -> Self {
self.trace = Some(trace);
self
}
pub fn with_nvtx(mut self, name: &'static str) -> Self {
self.nvtx_range_name = Some(name);
self
}
fn info<'a>(&'a self, stream_id: u64) -> KernelInfo<'a> {
KernelInfo {
op_name: self.op_name,
library: self.lib_tag,
stream_id,
dtype: self.dtype,
}
}
pub fn run_kernel<O, KA, F>(
self,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
output: O,
reply: oneshot::Sender<Result<O, GpuError>>,
enqueue: F,
) where
O: Send + 'static,
KA: Send + 'static,
F: FnOnce() -> Result<KA, GpuError>,
{
let stream_id = stream.cu_stream() as usize as u64;
let info = self.info(stream_id);
if let Some(t) = self.trace.as_deref() {
t.before_enqueue(&info);
}
let enqueue_result = {
#[cfg(feature = "nvtx")]
let _nvtx_guard = self.nvtx_range_name.map(cudarc::nvtx::safe::scoped_range);
#[cfg(not(feature = "nvtx"))]
let _ = self.nvtx_range_name;
enqueue()
};
let keep_alive = match enqueue_result {
Ok(ka) => {
if let Some(t) = self.trace.as_deref() {
t.after_enqueue(&info, Ok(()));
}
ka
}
Err(e) => {
let annotated = annotate_error(e, self.lib_tag);
if let Some(t) = self.trace.as_deref() {
t.after_enqueue(&info, Err(&annotated));
}
let _ = reply.send(Err(annotated));
return;
}
};
let fut = completion.await_completion(stream).boxed();
let lib_tag = self.lib_tag;
let op_name = self.op_name;
let dtype = self.dtype;
let trace = self.trace.clone();
tokio::spawn(async move {
let info = KernelInfo {
op_name,
library: lib_tag,
stream_id,
dtype,
};
if let Some(t) = trace.as_deref() {
t.before_complete(&info);
}
let started = Instant::now();
let result = fut.await;
let latency = started.elapsed();
match result {
Ok(()) => {
if let Some(t) = trace.as_deref() {
t.after_complete(&info, Ok(()), latency);
}
let _ = reply.send(Ok(output));
}
Err(e) => {
warn!(lib = lib_tag, error = %e, "kernel completion failed");
if let Some(t) = trace.as_deref() {
t.after_complete(&info, Err(&e), latency);
}
let _ = reply.send(Err(e));
}
}
drop(keep_alive);
});
}
}
impl std::fmt::Debug for KernelEnvelope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("KernelEnvelope")
.field("lib_tag", &self.lib_tag)
.field("op_name", &self.op_name)
.field("dtype", &self.dtype)
.field("nvtx_range_name", &self.nvtx_range_name)
.field("trace", &self.trace.as_ref().map(|_| "<dyn KernelTrace>"))
.finish()
}
}
pub fn run_kernel<O, KA, F>(
lib_tag: &'static str,
stream: &Arc<cudarc::driver::CudaStream>,
completion: &Arc<dyn CompletionStrategy>,
output: O,
reply: oneshot::Sender<Result<O, GpuError>>,
enqueue: F,
) where
O: Send + 'static,
KA: Send + 'static,
F: FnOnce() -> Result<KA, GpuError>,
{
let keep_alive = match enqueue() {
Ok(ka) => ka,
Err(e) => {
let _ = reply.send(Err(annotate_error(e, lib_tag)));
return;
}
};
let fut = completion.await_completion(stream).boxed();
tokio::spawn(async move {
let result = fut.await;
match result {
Ok(()) => {
let _ = reply.send(Ok(output));
}
Err(e) => {
warn!(lib = lib_tag, error = %e, "kernel completion failed");
let _ = reply.send(Err(e));
}
}
drop(keep_alive);
});
}
fn annotate_error(e: GpuError, lib_tag: &'static str) -> GpuError {
match e {
GpuError::Driver(msg) => GpuError::LibraryError { lib: lib_tag, msg },
other => other,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Mutex;
#[test]
fn annotate_error_tags_driver_failures() {
let e = annotate_error(GpuError::Driver("oops".into()), "cudnn");
match e {
GpuError::LibraryError { lib, msg } => {
assert_eq!(lib, "cudnn");
assert_eq!(msg, "oops");
}
other => panic!("expected LibraryError, got {other:?}"),
}
}
#[test]
fn annotate_error_passes_through_typed_variants() {
let e = annotate_error(GpuError::OutOfMemory("alloc".into()), "cudnn");
assert!(matches!(e, GpuError::OutOfMemory(_)));
let e = annotate_error(GpuError::GpuRefStale("stale"), "cudnn");
assert!(matches!(e, GpuError::GpuRefStale(_)));
}
#[test]
fn pre_enqueue_error_bypasses_completion() {
let (tx, rx) = oneshot::channel::<Result<u32, GpuError>>();
let mut bumped = AtomicU32::new(0);
let enqueue = || -> Result<(), GpuError> {
bumped.fetch_add(1, Ordering::Relaxed);
Err(GpuError::OutOfMemory("forced".into()))
};
let res = enqueue();
assert!(matches!(res, Err(GpuError::OutOfMemory(_))));
assert_eq!(*bumped.get_mut(), 1);
drop(tx);
drop(rx);
}
#[derive(Default)]
struct RecordingTrace {
events: Mutex<Vec<&'static str>>,
last_dtype: Mutex<Option<String>>,
last_op: Mutex<Option<String>>,
last_lib: Mutex<Option<String>>,
enqueue_ok: AtomicU32,
enqueue_err: AtomicU32,
}
impl KernelTrace for RecordingTrace {
fn before_enqueue(&self, info: &KernelInfo<'_>) {
self.events.lock().unwrap().push("before_enqueue");
*self.last_op.lock().unwrap() = Some(info.op_name.to_string());
*self.last_lib.lock().unwrap() = Some(info.library.to_string());
*self.last_dtype.lock().unwrap() = info.dtype.map(str::to_string);
}
fn after_enqueue(&self, _info: &KernelInfo<'_>, result: Result<(), &GpuError>) {
self.events.lock().unwrap().push("after_enqueue");
match result {
Ok(()) => {
self.enqueue_ok.fetch_add(1, Ordering::Relaxed);
}
Err(_) => {
self.enqueue_err.fetch_add(1, Ordering::Relaxed);
}
}
}
fn before_complete(&self, _info: &KernelInfo<'_>) {
self.events.lock().unwrap().push("before_complete");
}
fn after_complete(
&self,
_info: &KernelInfo<'_>,
_result: Result<(), &GpuError>,
_latency: Duration,
) {
self.events.lock().unwrap().push("after_complete");
}
}
fn drive_envelope_trace<F>(
env: &KernelEnvelope,
enqueue: F,
) -> (Result<(), GpuError>, Result<(), GpuError>)
where
F: FnOnce() -> Result<(), GpuError>,
{
let info = env.info(0xDEAD_BEEF);
if let Some(t) = env.trace.as_deref() {
t.before_enqueue(&info);
}
let enqueue_result = enqueue();
let enqueue_report = match &enqueue_result {
Ok(()) => Ok(()),
Err(e) => Err(annotate_error_clone(e, env.lib_tag)),
};
if let Some(t) = env.trace.as_deref() {
match &enqueue_report {
Ok(()) => t.after_enqueue(&info, Ok(())),
Err(e) => t.after_enqueue(&info, Err(e)),
}
}
if enqueue_report.is_ok() {
if let Some(t) = env.trace.as_deref() {
t.before_complete(&info);
t.after_complete(&info, Ok(()), Duration::from_micros(1));
}
}
(enqueue_result, enqueue_report)
}
fn annotate_error_clone(e: &GpuError, lib_tag: &'static str) -> GpuError {
match e {
GpuError::Driver(msg) => GpuError::LibraryError {
lib: lib_tag,
msg: msg.clone(),
},
GpuError::OutOfMemory(msg) => GpuError::OutOfMemory(msg.clone()),
GpuError::ContextPoisoned(msg) => GpuError::ContextPoisoned(msg.clone()),
GpuError::Unrecoverable(msg) => GpuError::Unrecoverable(msg.clone()),
GpuError::GpuRefStale(s) => GpuError::GpuRefStale(s),
GpuError::LibraryError { lib, msg } => GpuError::LibraryError {
lib,
msg: msg.clone(),
},
other => GpuError::LibraryError {
lib: lib_tag,
msg: other.to_string(),
},
}
}
#[test]
fn envelope_default_is_traceless_and_nvtxless() {
let env = KernelEnvelope::new("cublas");
assert!(env.trace.is_none());
assert!(env.nvtx_range_name.is_none());
assert_eq!(env.lib_tag, "cublas");
assert_eq!(env.op_name, "cublas");
assert!(env.dtype.is_none());
}
#[test]
fn envelope_builder_sets_metadata() {
let trace = Arc::new(RecordingTrace::default()) as Arc<dyn KernelTrace>;
let env = KernelEnvelope::new("cublas")
.with_op_name("sgemm")
.with_dtype("f32")
.with_trace(trace)
.with_nvtx("blas/sgemm");
assert_eq!(env.op_name, "sgemm");
assert_eq!(env.dtype, Some("f32"));
assert_eq!(env.nvtx_range_name, Some("blas/sgemm"));
assert!(env.trace.is_some());
}
#[test]
fn trace_hooks_fire_in_order_on_success() {
let trace = Arc::new(RecordingTrace::default());
let env = KernelEnvelope::new("cublas")
.with_op_name("sgemm")
.with_dtype("f32")
.with_trace(trace.clone() as Arc<dyn KernelTrace>);
let (enqueue_res, _) = drive_envelope_trace(&env, || Ok(()));
assert!(enqueue_res.is_ok());
let events = trace.events.lock().unwrap().clone();
assert_eq!(
events,
vec![
"before_enqueue",
"after_enqueue",
"before_complete",
"after_complete",
]
);
assert_eq!(trace.enqueue_ok.load(Ordering::Relaxed), 1);
assert_eq!(trace.enqueue_err.load(Ordering::Relaxed), 0);
assert_eq!(trace.last_op.lock().unwrap().as_deref(), Some("sgemm"));
assert_eq!(trace.last_lib.lock().unwrap().as_deref(), Some("cublas"));
assert_eq!(trace.last_dtype.lock().unwrap().as_deref(), Some("f32"));
}
#[test]
fn trace_hooks_skip_completion_on_enqueue_error() {
let trace = Arc::new(RecordingTrace::default());
let env = KernelEnvelope::new("cudnn")
.with_op_name("conv2d_forward")
.with_trace(trace.clone() as Arc<dyn KernelTrace>);
let (enqueue_res, report) =
drive_envelope_trace(&env, || Err(GpuError::Driver("forced".into())));
assert!(enqueue_res.is_err());
match report {
Err(GpuError::LibraryError { lib, msg }) => {
assert_eq!(lib, "cudnn");
assert_eq!(msg, "forced");
}
other => panic!("expected LibraryError, got {other:?}"),
}
let events = trace.events.lock().unwrap().clone();
assert_eq!(events, vec!["before_enqueue", "after_enqueue"]);
assert_eq!(trace.enqueue_ok.load(Ordering::Relaxed), 0);
assert_eq!(trace.enqueue_err.load(Ordering::Relaxed), 1);
}
#[test]
fn envelope_without_trace_is_silent() {
let env = KernelEnvelope::new("cufft");
let (res, _) = drive_envelope_trace(&env, || Ok(()));
assert!(res.is_ok());
}
}