use cudarc::driver::sys as driver_sys;
use crate::error::GpuError;
use crate::graph::GraphHandle;
const LIB: &str = "graph";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GraphExecUpdateOutcome {
Success,
TopologyChanged,
Other,
}
impl From<driver_sys::CUgraphExecUpdateResult> for GraphExecUpdateOutcome {
fn from(r: driver_sys::CUgraphExecUpdateResult) -> Self {
match r as u32 {
0 => GraphExecUpdateOutcome::Success,
2..=8 => GraphExecUpdateOutcome::TopologyChanged,
_ => GraphExecUpdateOutcome::Other,
}
}
}
pub fn exec_update(
exec: &GraphHandle,
new_graph_cu: driver_sys::CUgraph,
) -> Result<GraphExecUpdateOutcome, GpuError> {
let exec_handle = exec.cu_graph_exec();
let probe = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut info = driver_sys::CUgraphExecUpdateResultInfo_st {
result: driver_sys::CUgraphExecUpdateResult_enum::CU_GRAPH_EXEC_UPDATE_SUCCESS,
errorNode: std::ptr::null_mut(),
errorFromNode: std::ptr::null_mut(),
};
let s = unsafe {
driver_sys::cuGraphExecUpdate_v2(exec_handle, new_graph_cu, &mut info as *mut _)
};
(s, info.result)
}));
match probe {
Ok((s, result)) => {
if s == driver_sys::cudaError_enum::CUDA_SUCCESS {
Ok(GraphExecUpdateOutcome::from(result))
} else if s == driver_sys::cudaError_enum::CUDA_ERROR_GRAPH_EXEC_UPDATE_FAILURE {
Ok(GraphExecUpdateOutcome::TopologyChanged)
} else {
Err(GpuError::LibraryError {
lib: LIB,
msg: format!("cuGraphExecUpdate_v2: {s:?}"),
})
}
}
Err(_) => Err(GpuError::Unrecoverable(
"exec_update: CUDA driver not loadable".into(),
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn outcome_classification_round_trip() {
use driver_sys::CUgraphExecUpdateResult_enum::*;
assert_eq!(
GraphExecUpdateOutcome::from(CU_GRAPH_EXEC_UPDATE_SUCCESS),
GraphExecUpdateOutcome::Success
);
let topology_value: driver_sys::CUgraphExecUpdateResult =
unsafe { std::mem::transmute::<u32, _>(2) };
assert_eq!(
GraphExecUpdateOutcome::from(topology_value),
GraphExecUpdateOutcome::TopologyChanged
);
}
#[test]
fn param_rebind_round_trip() {
let exec = GraphHandle::synthetic_for_tests();
let r = exec_update(&exec, std::ptr::null_mut());
match r {
Ok(_) => {}
Err(GpuError::Unrecoverable(_)) => {}
Err(GpuError::LibraryError { .. }) => {}
other => panic!("unexpected: {other:?}"),
}
}
}