use crate::error::{CudaError, CudaResult};
use crate::stream::Stream;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemcpyDirection {
HostToDevice,
DeviceToHost,
DeviceToDevice,
}
impl std::fmt::Display for MemcpyDirection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HostToDevice => write!(f, "HtoD"),
Self::DeviceToHost => write!(f, "DtoH"),
Self::DeviceToDevice => write!(f, "DtoD"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GraphNode {
KernelLaunch {
function_name: String,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
},
Memcpy {
direction: MemcpyDirection,
size: usize,
},
Memset {
size: usize,
value: u8,
},
Empty,
}
impl std::fmt::Display for GraphNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::KernelLaunch {
function_name,
grid,
block,
shared_mem,
} => write!(
f,
"Kernel({}, grid=({},{},{}), block=({},{},{}), smem={})",
function_name, grid.0, grid.1, grid.2, block.0, block.1, block.2, shared_mem,
),
Self::Memcpy { direction, size } => {
write!(f, "Memcpy({direction}, {size} bytes)")
}
Self::Memset { size, value } => {
write!(f, "Memset({size} bytes, value=0x{value:02x})")
}
Self::Empty => write!(f, "Empty"),
}
}
}
#[derive(Debug, Clone)]
pub struct Graph {
nodes: Vec<GraphNode>,
dependencies: Vec<(usize, usize)>,
}
impl Default for Graph {
fn default() -> Self {
Self::new()
}
}
impl Graph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
dependencies: Vec::new(),
}
}
pub fn add_kernel_node(
&mut self,
function_name: &str,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
) -> usize {
let idx = self.nodes.len();
self.nodes.push(GraphNode::KernelLaunch {
function_name: function_name.to_owned(),
grid,
block,
shared_mem,
});
idx
}
pub fn add_memcpy_node(&mut self, direction: MemcpyDirection, size: usize) -> usize {
let idx = self.nodes.len();
self.nodes.push(GraphNode::Memcpy { direction, size });
idx
}
pub fn add_memset_node(&mut self, size: usize, value: u8) -> usize {
let idx = self.nodes.len();
self.nodes.push(GraphNode::Memset { size, value });
idx
}
pub fn add_empty_node(&mut self) -> usize {
let idx = self.nodes.len();
self.nodes.push(GraphNode::Empty);
idx
}
pub fn add_dependency(&mut self, from: usize, to: usize) -> CudaResult<()> {
if from >= self.nodes.len() || to >= self.nodes.len() {
return Err(CudaError::InvalidValue);
}
if from == to {
return Err(CudaError::InvalidValue);
}
self.dependencies.push((from, to));
Ok(())
}
#[inline]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn dependency_count(&self) -> usize {
self.dependencies.len()
}
#[inline]
pub fn nodes(&self) -> &[GraphNode] {
&self.nodes
}
#[inline]
pub fn dependencies(&self) -> &[(usize, usize)] {
&self.dependencies
}
pub fn get_node(&self, index: usize) -> Option<&GraphNode> {
self.nodes.get(index)
}
pub fn topological_sort(&self) -> CudaResult<Vec<usize>> {
let n = self.nodes.len();
let mut in_degree = vec![0u32; n];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for &(from, to) in &self.dependencies {
adj[from].push(to);
in_degree[to] = in_degree[to].saturating_add(1);
}
let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
let mut result = Vec::with_capacity(n);
while let Some(node) = queue.pop() {
result.push(node);
for &next in &adj[node] {
in_degree[next] = in_degree[next].saturating_sub(1);
if in_degree[next] == 0 {
queue.push(next);
}
}
}
if result.len() != n {
return Err(CudaError::InvalidValue);
}
Ok(result)
}
pub fn instantiate(&self) -> CudaResult<GraphExec> {
let execution_order = self.topological_sort()?;
let (raw_graph, raw_exec) = match self.build_driver_graph() {
Ok(handles) => handles,
Err(
CudaError::NotInitialized
| CudaError::NotSupported
| CudaError::InvalidContext
| CudaError::NoDevice
| CudaError::InvalidDevice
| CudaError::Deinitialized,
) => (None, None),
Err(other) => return Err(other),
};
Ok(GraphExec {
graph: self.clone(),
execution_order,
raw_graph,
raw_exec,
})
}
fn build_driver_graph(
&self,
) -> CudaResult<(Option<crate::ffi::CUgraph>, Option<crate::ffi::CUgraphExec>)> {
use crate::ffi::{CUgraph, CUgraphExec, CUgraphNode};
let api = crate::loader::try_driver()?;
let create = api.cu_graph_create.ok_or(CudaError::NotSupported)?;
let add_empty = api.cu_graph_add_empty_node.ok_or(CudaError::NotSupported)?;
let destroy = api.cu_graph_destroy.ok_or(CudaError::NotSupported)?;
let order = self.topological_sort()?;
let mut raw_graph = CUgraph::default();
crate::error::check(unsafe { create(&mut raw_graph, 0) })?;
let build = || -> CudaResult<CUgraphExec> {
let mut driver_nodes: Vec<Option<CUgraphNode>> = vec![None; self.nodes.len()];
for &node_idx in &order {
let mut deps: Vec<CUgraphNode> = Vec::new();
for &(from, to) in &self.dependencies {
if to == node_idx {
let handle = driver_nodes
.get(from)
.copied()
.flatten()
.ok_or(CudaError::InvalidValue)?;
deps.push(handle);
}
}
let dep_ptr = if deps.is_empty() {
std::ptr::null()
} else {
deps.as_ptr()
};
let mut driver_node = CUgraphNode::default();
crate::error::check(unsafe {
add_empty(&mut driver_node, raw_graph, dep_ptr, deps.len())
})?;
driver_nodes[node_idx] = Some(driver_node);
}
self.instantiate_driver_graph(api, raw_graph)
};
match build() {
Ok(raw_exec) => Ok((Some(raw_graph), Some(raw_exec))),
Err(e) => {
let rc = unsafe { destroy(raw_graph) };
if rc != 0 {
tracing::warn!(
cuda_error = rc,
"cuGraphDestroy failed while unwinding a failed instantiation"
);
}
Err(e)
}
}
}
fn instantiate_driver_graph(
&self,
api: &crate::loader::DriverApi,
raw_graph: crate::ffi::CUgraph,
) -> CudaResult<crate::ffi::CUgraphExec> {
use crate::ffi::CUgraphExec;
let mut raw_exec = CUgraphExec::default();
if let Some(instantiate_flags) = api.cu_graph_instantiate_with_flags {
crate::error::check(unsafe { instantiate_flags(&mut raw_exec, raw_graph, 0) })?;
return Ok(raw_exec);
}
let instantiate = api.cu_graph_instantiate.ok_or(CudaError::NotSupported)?;
crate::error::check(unsafe {
instantiate(
&mut raw_exec,
raw_graph,
std::ptr::null_mut(),
std::ptr::null_mut(),
0,
)
})?;
Ok(raw_exec)
}
}
impl std::fmt::Display for Graph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Graph({} nodes, {} deps)",
self.nodes.len(),
self.dependencies.len()
)
}
}
pub struct GraphExec {
graph: Graph,
execution_order: Vec<usize>,
raw_graph: Option<crate::ffi::CUgraph>,
raw_exec: Option<crate::ffi::CUgraphExec>,
}
impl GraphExec {
pub fn launch(&self, stream: &Stream) -> CudaResult<()> {
let api = crate::loader::try_driver()?;
let raw_exec = self.raw_exec.ok_or(CudaError::NotSupported)?;
let launch = api.cu_graph_launch.ok_or(CudaError::NotSupported)?;
crate::error::check(unsafe { launch(raw_exec, stream.raw()) })
}
#[inline]
pub fn graph(&self) -> &Graph {
&self.graph
}
#[inline]
pub fn execution_order(&self) -> &[usize] {
&self.execution_order
}
#[inline]
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
#[inline]
pub fn is_driver_backed(&self) -> bool {
self.raw_exec.is_some()
}
}
impl std::fmt::Debug for GraphExec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GraphExec")
.field("graph", &self.graph)
.field("execution_order", &self.execution_order)
.field("driver_backed", &self.is_driver_backed())
.finish()
}
}
impl Drop for GraphExec {
fn drop(&mut self) {
if let Ok(api) = crate::loader::try_driver() {
if let (Some(exec), Some(destroy)) = (self.raw_exec, api.cu_graph_exec_destroy) {
let rc = unsafe { destroy(exec) };
if rc != 0 {
tracing::warn!(cuda_error = rc, "cuGraphExecDestroy failed during drop");
}
}
if let (Some(graph), Some(destroy)) = (self.raw_graph, api.cu_graph_destroy) {
let rc = unsafe { destroy(graph) };
if rc != 0 {
tracing::warn!(cuda_error = rc, "cuGraphDestroy failed during drop");
}
}
}
}
}
pub struct StreamCapture {
nodes: Vec<GraphNode>,
active: bool,
}
impl StreamCapture {
pub fn begin(_stream: &Stream) -> CudaResult<Self> {
let _api = crate::loader::try_driver()?;
Ok(Self {
nodes: Vec::new(),
active: true,
})
}
pub fn record_kernel(
&mut self,
function_name: &str,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
) {
if self.active {
self.nodes.push(GraphNode::KernelLaunch {
function_name: function_name.to_owned(),
grid,
block,
shared_mem,
});
}
}
pub fn record_memcpy(&mut self, direction: MemcpyDirection, size: usize) {
if self.active {
self.nodes.push(GraphNode::Memcpy { direction, size });
}
}
pub fn record_memset(&mut self, size: usize, value: u8) {
if self.active {
self.nodes.push(GraphNode::Memset { size, value });
}
}
#[inline]
pub fn recorded_count(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn is_active(&self) -> bool {
self.active
}
pub fn end(mut self) -> CudaResult<Graph> {
if !self.active {
return Err(CudaError::StreamCaptureUnmatched);
}
self.active = false;
let mut graph = Graph::new();
let mut prev_idx: Option<usize> = None;
for node in self.nodes.drain(..) {
let idx = graph.nodes.len();
graph.nodes.push(node);
if let Some(prev) = prev_idx {
graph.dependencies.push((prev, idx));
}
prev_idx = Some(idx);
}
Ok(graph)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn graph_new_is_empty() {
let g = Graph::new();
assert_eq!(g.node_count(), 0);
assert_eq!(g.dependency_count(), 0);
assert!(g.nodes().is_empty());
assert!(g.dependencies().is_empty());
}
#[test]
fn graph_default_is_empty() {
let g = Graph::default();
assert_eq!(g.node_count(), 0);
}
#[test]
fn add_kernel_node_returns_sequential_indices() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (2, 1, 1), (64, 1, 1), 128);
assert_eq!(n0, 0);
assert_eq!(n1, 1);
assert_eq!(g.node_count(), 2);
}
#[test]
fn add_memcpy_node_records_direction_and_size() {
let mut g = Graph::new();
let idx = g.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
assert_eq!(idx, 0);
let node = g.get_node(0);
assert!(node.is_some());
if let Some(GraphNode::Memcpy { direction, size }) = node {
assert_eq!(*direction, MemcpyDirection::HostToDevice);
assert_eq!(*size, 4096);
} else {
panic!("expected Memcpy node");
}
}
#[test]
fn add_memset_node_records_size_and_value() {
let mut g = Graph::new();
let idx = g.add_memset_node(8192, 0xAB);
assert_eq!(idx, 0);
if let Some(GraphNode::Memset { size, value }) = g.get_node(idx) {
assert_eq!(*size, 8192);
assert_eq!(*value, 0xAB);
} else {
panic!("expected Memset node");
}
}
#[test]
fn add_empty_node_works() {
let mut g = Graph::new();
let idx = g.add_empty_node();
assert_eq!(idx, 0);
assert_eq!(g.get_node(idx), Some(&GraphNode::Empty));
}
#[test]
fn add_dependency_valid() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
assert!(g.add_dependency(n0, n1).is_ok());
assert_eq!(g.dependency_count(), 1);
assert_eq!(g.dependencies()[0], (0, 1));
}
#[test]
fn add_dependency_out_of_bounds() {
let mut g = Graph::new();
let _n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let result = g.add_dependency(0, 5);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn add_dependency_self_loop() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let result = g.add_dependency(n0, n0);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn topological_sort_linear_chain() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n2).ok();
let order = g.topological_sort();
assert!(order.is_ok());
let order = order.ok();
assert!(order.is_some());
let order = order.unwrap_or_default();
let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
assert!(pos(n0) < pos(n1));
assert!(pos(n1) < pos(n2));
}
#[test]
fn topological_sort_detects_cycle() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n0).ok();
let result = g.topological_sort();
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn topological_sort_no_deps() {
let mut g = Graph::new();
g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
let order = g.topological_sort();
assert!(order.is_ok());
let order = order.unwrap_or_default();
assert_eq!(order.len(), 3);
}
#[test]
fn instantiate_valid_graph() {
let mut g = Graph::new();
let n0 = g.add_memcpy_node(MemcpyDirection::HostToDevice, 1024);
let n1 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 1024);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n2).ok();
let exec = g.instantiate();
assert!(exec.is_ok());
let exec = exec.ok();
assert!(exec.is_some());
if let Some(exec) = exec {
assert_eq!(exec.node_count(), 3);
assert_eq!(exec.execution_order().len(), 3);
}
}
#[test]
fn instantiate_cyclic_graph_fails() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n0).ok();
let result = g.instantiate();
assert!(result.is_err());
}
#[test]
fn graph_display() {
let mut g = Graph::new();
g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
g.add_memcpy_node(MemcpyDirection::HostToDevice, 512);
let disp = format!("{g}");
assert!(disp.contains("2 nodes"));
assert!(disp.contains("0 deps"));
}
#[test]
fn node_display() {
let node = GraphNode::KernelLaunch {
function_name: "foo".to_owned(),
grid: (4, 1, 1),
block: (256, 1, 1),
shared_mem: 0,
};
let disp = format!("{node}");
assert!(disp.contains("foo"));
let node = GraphNode::Memcpy {
direction: MemcpyDirection::DeviceToHost,
size: 1024,
};
let disp = format!("{node}");
assert!(disp.contains("DtoH"));
let node = GraphNode::Memset {
size: 256,
value: 0xFF,
};
let disp = format!("{node}");
assert!(disp.contains("0xff"));
let node = GraphNode::Empty;
let disp = format!("{node}");
assert!(disp.contains("Empty"));
}
#[test]
fn memcpy_direction_display() {
assert_eq!(format!("{}", MemcpyDirection::HostToDevice), "HtoD");
assert_eq!(format!("{}", MemcpyDirection::DeviceToHost), "DtoH");
assert_eq!(format!("{}", MemcpyDirection::DeviceToDevice), "DtoD");
}
#[test]
fn graph_get_node_out_of_bounds() {
let g = Graph::new();
assert!(g.get_node(0).is_none());
assert!(g.get_node(100).is_none());
}
#[test]
fn graph_diamond_dag() {
let mut g = Graph::new();
let n0 = g.add_empty_node();
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
let n3 = g.add_empty_node();
g.add_dependency(n0, n1).ok();
g.add_dependency(n0, n2).ok();
g.add_dependency(n1, n3).ok();
g.add_dependency(n2, n3).ok();
let order = g.topological_sort().unwrap_or_default();
assert_eq!(order.len(), 4);
let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
assert!(pos(n0) < pos(n1));
assert!(pos(n0) < pos(n2));
assert!(pos(n1) < pos(n3));
assert!(pos(n2) < pos(n3));
let exec = g.instantiate();
assert!(exec.is_ok());
}
#[test]
fn graph_exec_debug() {
let mut g = Graph::new();
g.add_empty_node();
let exec = g.instantiate().ok();
assert!(exec.is_some());
if let Some(exec) = exec {
let dbg = format!("{exec:?}");
assert!(dbg.contains("GraphExec"));
assert!(dbg.contains("driver_backed"));
}
}
fn driver_present() -> bool {
crate::loader::try_driver().is_ok()
}
#[test]
fn instantiate_empty_graph_driver_state() {
let g = Graph::new();
let exec = g.instantiate().expect("empty graph instantiates");
assert_eq!(exec.node_count(), 0);
if driver_present() {
let _ = exec.is_driver_backed();
} else {
assert!(!exec.is_driver_backed());
}
}
#[test]
fn instantiate_chain_preserves_topology() {
let mut g = Graph::new();
let n0 = g.add_memset_node(256, 0);
let n1 = g.add_kernel_node("k", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 256);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n2).ok();
let exec = g.instantiate().expect("chain instantiates");
assert_eq!(exec.node_count(), 3);
assert_eq!(exec.execution_order().len(), 3);
if !driver_present() {
assert!(!exec.is_driver_backed());
}
}
#[test]
fn instantiate_diamond_without_driver_is_clean() {
let mut g = Graph::new();
let n0 = g.add_empty_node();
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
let n3 = g.add_empty_node();
g.add_dependency(n0, n1).ok();
g.add_dependency(n0, n2).ok();
g.add_dependency(n1, n3).ok();
g.add_dependency(n2, n3).ok();
let exec = g.instantiate();
assert!(exec.is_ok(), "diamond DAG must instantiate cleanly");
if !driver_present() {
if let Ok(exec) = exec {
assert!(!exec.is_driver_backed());
}
}
}
#[test]
fn build_driver_graph_absent_driver_is_clean() {
let mut g = Graph::new();
g.add_empty_node();
let result = g.build_driver_graph();
if driver_present() {
match result {
Ok((raw_graph, raw_exec)) => {
assert_eq!(raw_graph.is_some(), raw_exec.is_some());
}
Err(_) => { }
}
} else {
assert_eq!(result.err(), Some(CudaError::NotInitialized));
}
}
#[test]
fn graph_exec_drop_without_driver_is_safe() {
let mut g = Graph::new();
g.add_empty_node();
g.add_empty_node();
let exec = g.instantiate().expect("instantiates");
drop(exec);
}
#[test]
fn instantiate_cycle_fails_before_driver() {
let mut g = Graph::new();
let n0 = g.add_empty_node();
let n1 = g.add_empty_node();
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n0).ok();
assert_eq!(g.instantiate().err(), Some(CudaError::InvalidValue));
}
#[test]
fn real_graph_instantiate_and_launch() {
use crate::context::Context;
use crate::device::Device;
let device = match Device::get(0) {
Ok(d) => d,
Err(_) => return,
};
let ctx = match Context::new(&device) {
Ok(c) => std::sync::Arc::new(c),
Err(_) => return,
};
let stream = match Stream::new(&ctx) {
Ok(s) => s,
Err(_) => return,
};
let mut g = Graph::new();
let n0 = g.add_empty_node();
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
let n3 = g.add_empty_node();
g.add_dependency(n0, n1).ok();
g.add_dependency(n0, n2).ok();
g.add_dependency(n1, n3).ok();
g.add_dependency(n2, n3).ok();
let exec = g.instantiate().expect("diamond DAG instantiates");
assert_eq!(exec.node_count(), 4);
if exec.is_driver_backed() {
exec.launch(&stream)
.expect("cuGraphLaunch on a real graph succeeds");
stream
.synchronize()
.expect("stream synchronises after graph launch");
}
}
#[test]
fn real_graph_repeated_launch() {
use crate::context::Context;
use crate::device::Device;
let device = match Device::get(0) {
Ok(d) => d,
Err(_) => return,
};
let ctx = match Context::new(&device) {
Ok(c) => std::sync::Arc::new(c),
Err(_) => return,
};
let stream = match Stream::new(&ctx) {
Ok(s) => s,
Err(_) => return,
};
let mut g = Graph::new();
let a = g.add_empty_node();
let b = g.add_empty_node();
g.add_dependency(a, b).ok();
let exec = g.instantiate().expect("chain instantiates");
if exec.is_driver_backed() {
for _ in 0..8 {
exec.launch(&stream)
.expect("repeated cuGraphLaunch succeeds");
}
stream.synchronize().expect("stream synchronises");
}
}
}