use std::sync::Arc;
use crate::driver::{result, sys};
use super::{CudaStream, DriverError};
pub struct CudaGraph {
cu_graph: sys::CUgraph,
cu_graph_exec: sys::CUgraphExec,
stream: Arc<CudaStream>,
}
impl Drop for CudaGraph {
fn drop(&mut self) {
let ctx = &self.stream.ctx;
let cu_graph_exec = std::mem::replace(&mut self.cu_graph_exec, std::ptr::null_mut());
if !cu_graph_exec.is_null() {
ctx.record_err(unsafe { result::graph::exec_destroy(cu_graph_exec) });
}
let cu_graph = std::mem::replace(&mut self.cu_graph, std::ptr::null_mut());
if !cu_graph.is_null() {
ctx.record_err(unsafe { result::graph::destroy(cu_graph) });
}
}
}
impl CudaStream {
pub fn begin_capture(&self, mode: sys::CUstreamCaptureMode) -> Result<(), DriverError> {
self.ctx.bind_to_thread()?;
unsafe { result::stream::begin_capture(self.cu_stream, mode) }
}
pub fn end_capture(
self: &Arc<Self>,
flags: sys::CUgraphInstantiate_flags,
) -> Result<Option<CudaGraph>, DriverError> {
self.ctx.bind_to_thread()?;
let cu_graph = unsafe { result::stream::end_capture(self.cu_stream) }?;
if cu_graph.is_null() {
return Ok(None);
}
let cu_graph_exec = unsafe { result::graph::instantiate(cu_graph, flags) }?;
Ok(Some(CudaGraph {
cu_graph,
cu_graph_exec,
stream: self.clone(),
}))
}
pub fn capture_status(&self) -> Result<sys::CUstreamCaptureStatus, DriverError> {
self.ctx.bind_to_thread()?;
unsafe { result::stream::is_capturing(self.cu_stream) }
}
}
impl CudaGraph {
pub fn launch(&self) -> Result<(), DriverError> {
self.stream.ctx.bind_to_thread()?;
unsafe { result::graph::launch(self.cu_graph_exec, self.stream.cu_stream) }
}
pub fn upload(&self) -> Result<(), DriverError> {
self.stream.ctx.bind_to_thread()?;
unsafe { result::graph::upload(self.cu_graph_exec, self.stream.cu_stream) }
}
pub fn cu_graph(&self) -> sys::CUgraph {
self.cu_graph
}
pub fn cu_graph_exec(&self) -> sys::CUgraphExec {
self.cu_graph_exec
}
}