use crate::device_operation::{DeviceOp, ExecutionContext, GraphNode};
use crate::error::DeviceError;
use cuda_core::{stream, sys, CudaStream, IntoResult};
use std::mem::MaybeUninit;
use std::sync::Arc;
const CU_STREAM_CAPTURE_MODE_RELAXED: sys::CUstreamCaptureMode = 2;
pub struct CudaGraph<T> {
stream: Arc<CudaStream>,
cu_graph: sys::CUgraph,
cu_graph_exec: sys::CUgraphExec,
output: Option<T>,
}
impl<T: Send> CudaGraph<T> {
pub fn capture(
stream: Arc<CudaStream>,
op: impl DeviceOp<Output = T>,
) -> Result<Self, DeviceError> {
let ctx = stream.context().clone();
ctx.bind_to_thread()?;
unsafe {
stream::begin_capture(stream.cu_stream(), CU_STREAM_CAPTURE_MODE_RELAXED)?;
}
let exec_ctx = ExecutionContext::new(stream.clone());
let op_result = unsafe { op.execute(&exec_ctx) };
let end_result = unsafe { stream::end_capture(stream.cu_stream()) };
let (output, cu_graph) = match (op_result, end_result) {
(Err(op_err), Ok(cu_graph)) => {
if !cu_graph.is_null() {
unsafe {
let _ = sys::cuGraphDestroy(cu_graph).result();
}
}
return Err(op_err);
}
(Err(op_err), Err(_)) => {
return Err(op_err);
}
(Ok(_), Err(capture_err)) => {
return Err(DeviceError::Driver(capture_err));
}
(Ok(output), Ok(cu_graph)) => {
if cu_graph.is_null() {
return Err(DeviceError::Internal(
"cuStreamEndCapture returned null graph".into(),
));
}
(output, cu_graph)
}
};
let cu_graph_exec = unsafe {
let mut cu_graph_exec = MaybeUninit::<sys::CUgraphExec>::uninit();
match sys::cuGraphInstantiateWithFlags(cu_graph_exec.as_mut_ptr(), cu_graph, 0).result()
{
Ok(()) => cu_graph_exec.assume_init(),
Err(e) => {
let _ = sys::cuGraphDestroy(cu_graph).result();
return Err(DeviceError::Driver(e));
}
}
};
if let Err(e) = unsafe { sys::cuGraphUpload(cu_graph_exec, stream.cu_stream()).result() } {
unsafe {
let _ = sys::cuGraphExecDestroy(cu_graph_exec).result();
let _ = sys::cuGraphDestroy(cu_graph).result();
}
return Err(DeviceError::Driver(e));
}
stream.synchronize()?;
Ok(Self {
stream,
cu_graph,
cu_graph_exec,
output: Some(output),
})
}
pub fn take_output(&mut self) -> Option<T> {
self.output.take()
}
pub fn update<O: Send>(&self, op: impl DeviceOp<Output = O>) -> Result<O, DeviceError> {
let ctx = ExecutionContext::new(self.stream.clone());
unsafe { op.execute(&ctx) }
}
pub fn launch(&self) -> Result<(), DeviceError> {
unsafe {
sys::cuGraphLaunch(self.cu_graph_exec, self.stream.cu_stream()).result()?;
}
self.stream.synchronize()?;
Ok(())
}
pub fn stream(&self) -> &Arc<CudaStream> {
&self.stream
}
}
impl<T> Drop for CudaGraph<T> {
fn drop(&mut self) {
let ctx = self.stream.context();
ctx.record_err(ctx.bind_to_thread());
let cu_graph_exec = std::mem::replace(&mut self.cu_graph_exec, std::ptr::null_mut());
if !cu_graph_exec.is_null() {
ctx.record_err(unsafe { sys::cuGraphExecDestroy(cu_graph_exec).result() });
}
let cu_graph = std::mem::replace(&mut self.cu_graph, std::ptr::null_mut());
if !cu_graph.is_null() {
ctx.record_err(unsafe { sys::cuGraphDestroy(cu_graph).result() });
}
}
}
pub struct Scope {
ctx: ExecutionContext,
_not_send: std::marker::PhantomData<*const ()>,
}
impl Scope {
pub fn record<T: Send>(
&self,
op: impl GraphNode + DeviceOp<Output = T>,
) -> Result<T, DeviceError> {
unsafe { op.execute(&self.ctx) }
}
}
impl CudaGraph<()> {
pub fn scope<F>(stream: &Arc<CudaStream>, f: F) -> Result<Self, DeviceError>
where
F: FnOnce(&Scope) -> Result<(), DeviceError>,
{
crate::device_operation::acquire_execution_lock()?;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
Self::scope_inner(stream, f)
}));
crate::device_operation::release_execution_lock();
match result {
Ok(inner) => inner,
Err(payload) => std::panic::resume_unwind(payload),
}
}
fn scope_inner<F>(stream: &Arc<CudaStream>, f: F) -> Result<Self, DeviceError>
where
F: FnOnce(&Scope) -> Result<(), DeviceError>,
{
let ctx = stream.context().clone();
ctx.bind_to_thread()?;
unsafe {
stream::begin_capture(stream.cu_stream(), CU_STREAM_CAPTURE_MODE_RELAXED)?;
}
let scope = Scope {
ctx: ExecutionContext::new(stream.clone()),
_not_send: std::marker::PhantomData,
};
let scope_result =
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| f(&scope))) {
Ok(result) => result,
Err(panic_payload) => {
let _ = unsafe { stream::end_capture(stream.cu_stream()) };
std::panic::resume_unwind(panic_payload);
}
};
let end_result = unsafe { stream::end_capture(stream.cu_stream()) };
let cu_graph = match (scope_result, end_result) {
(Err(scope_err), Ok(cu_graph)) => {
if !cu_graph.is_null() {
unsafe {
let _ = sys::cuGraphDestroy(cu_graph).result();
}
}
return Err(scope_err);
}
(Err(scope_err), Err(_)) => {
return Err(scope_err);
}
(Ok(_), Err(capture_err)) => {
return Err(DeviceError::Driver(capture_err));
}
(Ok(()), Ok(cu_graph)) => {
if cu_graph.is_null() {
return Err(DeviceError::Internal(
"cuStreamEndCapture returned null graph".into(),
));
}
cu_graph
}
};
let cu_graph_exec = unsafe {
let mut cu_graph_exec = MaybeUninit::<sys::CUgraphExec>::uninit();
match sys::cuGraphInstantiateWithFlags(cu_graph_exec.as_mut_ptr(), cu_graph, 0).result()
{
Ok(()) => cu_graph_exec.assume_init(),
Err(e) => {
let _ = sys::cuGraphDestroy(cu_graph).result();
return Err(DeviceError::Driver(e));
}
}
};
if let Err(e) = unsafe { sys::cuGraphUpload(cu_graph_exec, stream.cu_stream()).result() } {
unsafe {
let _ = sys::cuGraphExecDestroy(cu_graph_exec).result();
let _ = sys::cuGraphDestroy(cu_graph).result();
}
return Err(DeviceError::Driver(e));
}
stream.synchronize()?;
Ok(CudaGraph {
stream: stream.clone(),
cu_graph,
cu_graph_exec,
output: Some(()),
})
}
}
pub trait Module {
type Input: Send;
type Output: Send;
fn forward(&mut self, input: Self::Input) -> Result<Self::Output, DeviceError>;
}