use super::stream::CudaStream;
use crate::cuda::cudaSuccess;
use crate::error::{BackendError, BackendResult};
use cuda_sys::cudart::*;
use std::collections::HashMap;
use std::ffi::c_void;
use std::ptr;
use std::sync::{Arc, Mutex};
pub struct CudaGraph {
graph: cudaGraph_t,
instance: Option<cudaGraphExec_t>,
captured_kernels: usize,
}
impl std::fmt::Debug for CudaGraph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CudaGraph")
.field("graph", &(self.graph as usize))
.field("instance", &self.instance.map(|i| i as usize))
.field("captured_kernels", &self.captured_kernels)
.finish()
}
}
impl CudaGraph {
pub fn new() -> BackendResult<Self> {
let mut graph: cudaGraph_t = ptr::null_mut();
unsafe {
let result = cudaGraphCreate(&mut graph as *mut _, 0);
if result != cudaSuccess {
return Err(BackendError::CudaError(format!(
"Failed to create CUDA graph: {:?}",
result
)));
}
}
Ok(Self {
graph,
instance: None,
captured_kernels: 0,
})
}
pub fn begin_capture(stream: &CudaStream) -> BackendResult<()> {
unsafe {
let result = cudaStreamBeginCapture(
stream.stream(),
cudaStreamCaptureMode::cudaStreamCaptureModeGlobal,
);
if result != cudaSuccess {
return Err(BackendError::CudaError(format!(
"Failed to begin CUDA graph capture: {:?}",
result
)));
}
}
Ok(())
}
pub fn end_capture(stream: &CudaStream) -> BackendResult<Self> {
let mut graph: cudaGraph_t = ptr::null_mut();
unsafe {
let result = cudaStreamEndCapture(stream.stream(), &mut graph as *mut _);
if result != cudaSuccess {
return Err(BackendError::CudaError(format!(
"Failed to end CUDA graph capture: {:?}",
result
)));
}
}
let mut num_nodes: usize = 0;
unsafe {
let result = cudaGraphGetNodes(graph, ptr::null_mut(), &mut num_nodes as *mut _);
if result != cudaSuccess {
return Err(BackendError::CudaError(format!(
"Failed to get graph node count: {:?}",
result
)));
}
}
Ok(Self {
graph,
instance: None,
captured_kernels: num_nodes,
})
}
pub fn instantiate(&mut self) -> BackendResult<()> {
if self.instance.is_some() {
return Ok(()); }
let mut instance: cudaGraphExec_t = ptr::null_mut();
unsafe {
let result = cudaGraphInstantiate(
&mut instance as *mut _,
self.graph,
ptr::null_mut(),
ptr::null_mut(),
0,
);
if result != cudaSuccess {
return Err(BackendError::CudaError(format!(
"Failed to instantiate CUDA graph: {:?}",
result
)));
}
}
self.instance = Some(instance);
Ok(())
}
pub fn launch(&self, stream: &CudaStream) -> BackendResult<()> {
let instance = self
.instance
.ok_or_else(|| BackendError::CudaError("Graph not instantiated".to_string()))?;
unsafe {
let result = cudaGraphLaunch(instance, stream.stream());
if result != cudaSuccess {
return Err(BackendError::CudaError(format!(
"Failed to launch CUDA graph: {:?}",
result
)));
}
}
Ok(())
}
pub fn update(&mut self, updates: &GraphUpdate) -> BackendResult<()> {
let _ = updates;
Ok(())
}
pub fn kernel_count(&self) -> usize {
self.captured_kernels
}
pub fn clone_graph(&self) -> BackendResult<Self> {
let mut cloned_graph: cudaGraph_t = ptr::null_mut();
unsafe {
let result = cudaGraphClone(&mut cloned_graph as *mut _, self.graph);
if result != cudaSuccess {
return Err(BackendError::CudaError(format!(
"Failed to clone CUDA graph: {:?}",
result
)));
}
}
Ok(Self {
graph: cloned_graph,
instance: None,
captured_kernels: self.captured_kernels,
})
}
}
impl Drop for CudaGraph {
fn drop(&mut self) {
unsafe {
if let Some(instance) = self.instance {
cudaGraphExecDestroy(instance);
}
cudaGraphDestroy(self.graph);
}
}
}
pub struct GraphUpdate {
pub node_id: usize,
pub new_params: Vec<f32>,
}
#[derive(Debug)]
pub struct GraphCaptureContext {
stream: Arc<CudaStream>,
capturing: bool,
}
impl GraphCaptureContext {
pub fn new(stream: Arc<CudaStream>) -> Self {
Self {
stream,
capturing: false,
}
}
pub fn start(&mut self) -> BackendResult<()> {
if self.capturing {
return Err(BackendError::CudaError(
"Already capturing a graph".to_string(),
));
}
CudaGraph::begin_capture(&self.stream)?;
self.capturing = true;
Ok(())
}
pub fn end(&mut self) -> BackendResult<CudaGraph> {
if !self.capturing {
return Err(BackendError::CudaError("Not capturing a graph".to_string()));
}
let graph = CudaGraph::end_capture(&self.stream)?;
self.capturing = false;
Ok(graph)
}
pub fn is_capturing(&self) -> bool {
self.capturing
}
}
#[derive(Debug)]
pub struct GraphCache {
cache: Mutex<HashMap<String, Arc<Mutex<CudaGraph>>>>,
}
impl GraphCache {
pub fn new() -> Self {
Self {
cache: Mutex::new(HashMap::new()),
}
}
pub fn get_or_create<F>(&self, key: &str, creator: F) -> BackendResult<Arc<Mutex<CudaGraph>>>
where
F: FnOnce() -> BackendResult<CudaGraph>,
{
let mut cache = self.cache.lock().expect("lock should not be poisoned");
if let Some(graph) = cache.get(key) {
return Ok(graph.clone());
}
let graph = creator()?;
let graph_arc = Arc::new(Mutex::new(graph));
cache.insert(key.to_string(), graph_arc.clone());
Ok(graph_arc)
}
pub fn remove(&self, key: &str) {
let mut cache = self.cache.lock().expect("lock should not be poisoned");
cache.remove(key);
}
pub fn clear(&self) {
let mut cache = self.cache.lock().expect("lock should not be poisoned");
cache.clear();
}
pub fn size(&self) -> usize {
let cache = self.cache.lock().expect("lock should not be poisoned");
cache.len()
}
}
#[macro_export]
macro_rules! cuda_graph_capture {
($stream:expr, $code:block) => {{
let mut capture_ctx = GraphCaptureContext::new($stream.clone());
capture_ctx.start()?;
$code
capture_ctx.end()
}};
}
#[derive(Debug, Clone)]
pub struct GraphOptimizationHints {
pub optimize_memory: bool,
pub enable_fusion: bool,
pub prefer_shared_memory: bool,
pub target_sm: Option<(u32, u32)>,
}
impl Default for GraphOptimizationHints {
fn default() -> Self {
Self {
optimize_memory: true,
enable_fusion: true,
prefer_shared_memory: true,
target_sm: None,
}
}
}
pub fn optimize_graph(graph: &mut CudaGraph, hints: &GraphOptimizationHints) -> BackendResult<()> {
let _ = hints;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_graph_creation() {
if cuda_available() {
let graph = CudaGraph::new();
assert!(graph.is_ok());
}
}
#[test]
fn test_graph_cache() {
let cache = GraphCache::new();
assert_eq!(cache.size(), 0);
}
fn cuda_available() -> bool {
unsafe {
let mut device_count: i32 = 0;
let result = cudaGetDeviceCount(&mut device_count as *mut _);
result == cudaSuccess && device_count > 0
}
}
}