use crate::error::{CudaError, CudaResult};
use crate::stream::Stream;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum MemcpyDirection {
HostToDevice,
DeviceToHost,
DeviceToDevice,
}
impl std::fmt::Display for MemcpyDirection {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::HostToDevice => write!(f, "HtoD"),
Self::DeviceToHost => write!(f, "DtoH"),
Self::DeviceToDevice => write!(f, "DtoD"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum GraphNode {
KernelLaunch {
function_name: String,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
},
Memcpy {
direction: MemcpyDirection,
size: usize,
},
Memset {
size: usize,
value: u8,
},
Empty,
}
impl std::fmt::Display for GraphNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::KernelLaunch {
function_name,
grid,
block,
shared_mem,
} => write!(
f,
"Kernel({}, grid=({},{},{}), block=({},{},{}), smem={})",
function_name, grid.0, grid.1, grid.2, block.0, block.1, block.2, shared_mem,
),
Self::Memcpy { direction, size } => {
write!(f, "Memcpy({direction}, {size} bytes)")
}
Self::Memset { size, value } => {
write!(f, "Memset({size} bytes, value=0x{value:02x})")
}
Self::Empty => write!(f, "Empty"),
}
}
}
#[derive(Debug, Clone)]
pub struct Graph {
nodes: Vec<GraphNode>,
dependencies: Vec<(usize, usize)>,
}
impl Default for Graph {
fn default() -> Self {
Self::new()
}
}
impl Graph {
pub fn new() -> Self {
Self {
nodes: Vec::new(),
dependencies: Vec::new(),
}
}
pub fn add_kernel_node(
&mut self,
function_name: &str,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
) -> usize {
let idx = self.nodes.len();
self.nodes.push(GraphNode::KernelLaunch {
function_name: function_name.to_owned(),
grid,
block,
shared_mem,
});
idx
}
pub fn add_memcpy_node(&mut self, direction: MemcpyDirection, size: usize) -> usize {
let idx = self.nodes.len();
self.nodes.push(GraphNode::Memcpy { direction, size });
idx
}
pub fn add_memset_node(&mut self, size: usize, value: u8) -> usize {
let idx = self.nodes.len();
self.nodes.push(GraphNode::Memset { size, value });
idx
}
pub fn add_empty_node(&mut self) -> usize {
let idx = self.nodes.len();
self.nodes.push(GraphNode::Empty);
idx
}
pub fn add_dependency(&mut self, from: usize, to: usize) -> CudaResult<()> {
if from >= self.nodes.len() || to >= self.nodes.len() {
return Err(CudaError::InvalidValue);
}
if from == to {
return Err(CudaError::InvalidValue);
}
self.dependencies.push((from, to));
Ok(())
}
#[inline]
pub fn node_count(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn dependency_count(&self) -> usize {
self.dependencies.len()
}
#[inline]
pub fn nodes(&self) -> &[GraphNode] {
&self.nodes
}
#[inline]
pub fn dependencies(&self) -> &[(usize, usize)] {
&self.dependencies
}
pub fn get_node(&self, index: usize) -> Option<&GraphNode> {
self.nodes.get(index)
}
pub fn topological_sort(&self) -> CudaResult<Vec<usize>> {
let n = self.nodes.len();
let mut in_degree = vec![0u32; n];
let mut adj: Vec<Vec<usize>> = vec![Vec::new(); n];
for &(from, to) in &self.dependencies {
adj[from].push(to);
in_degree[to] = in_degree[to].saturating_add(1);
}
let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
let mut result = Vec::with_capacity(n);
while let Some(node) = queue.pop() {
result.push(node);
for &next in &adj[node] {
in_degree[next] = in_degree[next].saturating_sub(1);
if in_degree[next] == 0 {
queue.push(next);
}
}
}
if result.len() != n {
return Err(CudaError::InvalidValue);
}
Ok(result)
}
pub fn instantiate(&self) -> CudaResult<GraphExec> {
let execution_order = self.topological_sort()?;
Ok(GraphExec {
graph: self.clone(),
execution_order,
})
}
}
impl std::fmt::Display for Graph {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Graph({} nodes, {} deps)",
self.nodes.len(),
self.dependencies.len()
)
}
}
pub struct GraphExec {
graph: Graph,
execution_order: Vec<usize>,
}
impl GraphExec {
pub fn launch(&self, _stream: &Stream) -> CudaResult<()> {
let _api = crate::loader::try_driver()?;
Ok(())
}
#[inline]
pub fn graph(&self) -> &Graph {
&self.graph
}
#[inline]
pub fn execution_order(&self) -> &[usize] {
&self.execution_order
}
#[inline]
pub fn node_count(&self) -> usize {
self.graph.node_count()
}
}
impl std::fmt::Debug for GraphExec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GraphExec")
.field("graph", &self.graph)
.field("execution_order", &self.execution_order)
.finish()
}
}
pub struct StreamCapture {
nodes: Vec<GraphNode>,
active: bool,
}
impl StreamCapture {
pub fn begin(_stream: &Stream) -> CudaResult<Self> {
let _api = crate::loader::try_driver()?;
Ok(Self {
nodes: Vec::new(),
active: true,
})
}
pub fn record_kernel(
&mut self,
function_name: &str,
grid: (u32, u32, u32),
block: (u32, u32, u32),
shared_mem: u32,
) {
if self.active {
self.nodes.push(GraphNode::KernelLaunch {
function_name: function_name.to_owned(),
grid,
block,
shared_mem,
});
}
}
pub fn record_memcpy(&mut self, direction: MemcpyDirection, size: usize) {
if self.active {
self.nodes.push(GraphNode::Memcpy { direction, size });
}
}
pub fn record_memset(&mut self, size: usize, value: u8) {
if self.active {
self.nodes.push(GraphNode::Memset { size, value });
}
}
#[inline]
pub fn recorded_count(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn is_active(&self) -> bool {
self.active
}
pub fn end(mut self) -> CudaResult<Graph> {
if !self.active {
return Err(CudaError::StreamCaptureUnmatched);
}
self.active = false;
let mut graph = Graph::new();
let mut prev_idx: Option<usize> = None;
for node in self.nodes.drain(..) {
let idx = graph.nodes.len();
graph.nodes.push(node);
if let Some(prev) = prev_idx {
graph.dependencies.push((prev, idx));
}
prev_idx = Some(idx);
}
Ok(graph)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn graph_new_is_empty() {
let g = Graph::new();
assert_eq!(g.node_count(), 0);
assert_eq!(g.dependency_count(), 0);
assert!(g.nodes().is_empty());
assert!(g.dependencies().is_empty());
}
#[test]
fn graph_default_is_empty() {
let g = Graph::default();
assert_eq!(g.node_count(), 0);
}
#[test]
fn add_kernel_node_returns_sequential_indices() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (2, 1, 1), (64, 1, 1), 128);
assert_eq!(n0, 0);
assert_eq!(n1, 1);
assert_eq!(g.node_count(), 2);
}
#[test]
fn add_memcpy_node_records_direction_and_size() {
let mut g = Graph::new();
let idx = g.add_memcpy_node(MemcpyDirection::HostToDevice, 4096);
assert_eq!(idx, 0);
let node = g.get_node(0);
assert!(node.is_some());
if let Some(GraphNode::Memcpy { direction, size }) = node {
assert_eq!(*direction, MemcpyDirection::HostToDevice);
assert_eq!(*size, 4096);
} else {
panic!("expected Memcpy node");
}
}
#[test]
fn add_memset_node_records_size_and_value() {
let mut g = Graph::new();
let idx = g.add_memset_node(8192, 0xAB);
assert_eq!(idx, 0);
if let Some(GraphNode::Memset { size, value }) = g.get_node(idx) {
assert_eq!(*size, 8192);
assert_eq!(*value, 0xAB);
} else {
panic!("expected Memset node");
}
}
#[test]
fn add_empty_node_works() {
let mut g = Graph::new();
let idx = g.add_empty_node();
assert_eq!(idx, 0);
assert_eq!(g.get_node(idx), Some(&GraphNode::Empty));
}
#[test]
fn add_dependency_valid() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
assert!(g.add_dependency(n0, n1).is_ok());
assert_eq!(g.dependency_count(), 1);
assert_eq!(g.dependencies()[0], (0, 1));
}
#[test]
fn add_dependency_out_of_bounds() {
let mut g = Graph::new();
let _n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let result = g.add_dependency(0, 5);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn add_dependency_self_loop() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let result = g.add_dependency(n0, n0);
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn topological_sort_linear_chain() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n2).ok();
let order = g.topological_sort();
assert!(order.is_ok());
let order = order.ok();
assert!(order.is_some());
let order = order.unwrap_or_default();
let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
assert!(pos(n0) < pos(n1));
assert!(pos(n1) < pos(n2));
}
#[test]
fn topological_sort_detects_cycle() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n0).ok();
let result = g.topological_sort();
assert_eq!(result, Err(CudaError::InvalidValue));
}
#[test]
fn topological_sort_no_deps() {
let mut g = Graph::new();
g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
let order = g.topological_sort();
assert!(order.is_ok());
let order = order.unwrap_or_default();
assert_eq!(order.len(), 3);
}
#[test]
fn instantiate_valid_graph() {
let mut g = Graph::new();
let n0 = g.add_memcpy_node(MemcpyDirection::HostToDevice, 1024);
let n1 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_memcpy_node(MemcpyDirection::DeviceToHost, 1024);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n2).ok();
let exec = g.instantiate();
assert!(exec.is_ok());
let exec = exec.ok();
assert!(exec.is_some());
if let Some(exec) = exec {
assert_eq!(exec.node_count(), 3);
assert_eq!(exec.execution_order().len(), 3);
}
}
#[test]
fn instantiate_cyclic_graph_fails() {
let mut g = Graph::new();
let n0 = g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
g.add_dependency(n0, n1).ok();
g.add_dependency(n1, n0).ok();
let result = g.instantiate();
assert!(result.is_err());
}
#[test]
fn graph_display() {
let mut g = Graph::new();
g.add_kernel_node("k0", (1, 1, 1), (32, 1, 1), 0);
g.add_memcpy_node(MemcpyDirection::HostToDevice, 512);
let disp = format!("{g}");
assert!(disp.contains("2 nodes"));
assert!(disp.contains("0 deps"));
}
#[test]
fn node_display() {
let node = GraphNode::KernelLaunch {
function_name: "foo".to_owned(),
grid: (4, 1, 1),
block: (256, 1, 1),
shared_mem: 0,
};
let disp = format!("{node}");
assert!(disp.contains("foo"));
let node = GraphNode::Memcpy {
direction: MemcpyDirection::DeviceToHost,
size: 1024,
};
let disp = format!("{node}");
assert!(disp.contains("DtoH"));
let node = GraphNode::Memset {
size: 256,
value: 0xFF,
};
let disp = format!("{node}");
assert!(disp.contains("0xff"));
let node = GraphNode::Empty;
let disp = format!("{node}");
assert!(disp.contains("Empty"));
}
#[test]
fn memcpy_direction_display() {
assert_eq!(format!("{}", MemcpyDirection::HostToDevice), "HtoD");
assert_eq!(format!("{}", MemcpyDirection::DeviceToHost), "DtoH");
assert_eq!(format!("{}", MemcpyDirection::DeviceToDevice), "DtoD");
}
#[test]
fn graph_get_node_out_of_bounds() {
let g = Graph::new();
assert!(g.get_node(0).is_none());
assert!(g.get_node(100).is_none());
}
#[test]
fn graph_diamond_dag() {
let mut g = Graph::new();
let n0 = g.add_empty_node();
let n1 = g.add_kernel_node("k1", (1, 1, 1), (32, 1, 1), 0);
let n2 = g.add_kernel_node("k2", (1, 1, 1), (32, 1, 1), 0);
let n3 = g.add_empty_node();
g.add_dependency(n0, n1).ok();
g.add_dependency(n0, n2).ok();
g.add_dependency(n1, n3).ok();
g.add_dependency(n2, n3).ok();
let order = g.topological_sort().unwrap_or_default();
assert_eq!(order.len(), 4);
let pos = |n: usize| -> usize { order.iter().position(|&x| x == n).unwrap_or(usize::MAX) };
assert!(pos(n0) < pos(n1));
assert!(pos(n0) < pos(n2));
assert!(pos(n1) < pos(n3));
assert!(pos(n2) < pos(n3));
let exec = g.instantiate();
assert!(exec.is_ok());
}
#[test]
fn graph_exec_debug() {
let mut g = Graph::new();
g.add_empty_node();
let exec = g.instantiate().ok();
assert!(exec.is_some());
if let Some(exec) = exec {
let dbg = format!("{exec:?}");
assert!(dbg.contains("GraphExec"));
}
}
}