use crate::kernel::Kernel;
use crate::params::LaunchParams;
#[derive(Debug, Clone)]
pub struct LaunchRecord {
kernel_name: String,
params: LaunchParams,
}
impl LaunchRecord {
#[inline]
pub fn kernel_name(&self) -> &str {
&self.kernel_name
}
#[inline]
pub fn params(&self) -> &LaunchParams {
&self.params
}
}
impl std::fmt::Display for LaunchRecord {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"LaunchRecord(kernel={}, grid={}x{}x{}, block={}x{}x{})",
self.kernel_name,
self.params.grid.x,
self.params.grid.y,
self.params.grid.z,
self.params.block.x,
self.params.block.y,
self.params.block.z,
)
}
}
#[derive(Debug)]
pub struct GraphLaunchCapture {
stream_nodes: Vec<LaunchRecord>,
active: bool,
}
impl GraphLaunchCapture {
pub fn begin() -> Self {
Self {
stream_nodes: Vec::new(),
active: true,
}
}
pub fn record_launch(&mut self, kernel: &Kernel, params: &LaunchParams) {
if !self.active {
return;
}
self.stream_nodes.push(LaunchRecord {
kernel_name: kernel.name().to_owned(),
params: *params,
});
}
pub fn record_raw(&mut self, kernel_name: impl Into<String>, params: LaunchParams) {
if !self.active {
return;
}
self.stream_nodes.push(LaunchRecord {
kernel_name: kernel_name.into(),
params,
});
}
pub fn end(mut self) -> Vec<LaunchRecord> {
self.active = false;
std::mem::take(&mut self.stream_nodes)
}
#[inline]
pub fn len(&self) -> usize {
self.stream_nodes.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.stream_nodes.is_empty()
}
#[inline]
pub fn is_active(&self) -> bool {
self.active
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::grid::Dim3;
#[test]
fn capture_begin_is_active() {
let capture = GraphLaunchCapture::begin();
assert!(capture.is_active());
assert!(capture.is_empty());
assert_eq!(capture.len(), 0);
}
#[test]
fn capture_end_returns_empty_vec() {
let capture = GraphLaunchCapture::begin();
let records = capture.end();
assert!(records.is_empty());
}
#[test]
fn launch_record_display() {
let record = LaunchRecord {
kernel_name: "vector_add".to_owned(),
params: LaunchParams::new(Dim3::x(4), Dim3::x(256)),
};
let s = format!("{record}");
assert!(s.contains("vector_add"));
assert!(s.contains("4x1x1"));
assert!(s.contains("256x1x1"));
}
#[test]
fn launch_record_accessors() {
let record = LaunchRecord {
kernel_name: "my_kernel".to_owned(),
params: LaunchParams::new(8u32, 128u32),
};
assert_eq!(record.kernel_name(), "my_kernel");
assert_eq!(record.params().grid.x, 8);
assert_eq!(record.params().block.x, 128);
}
#[test]
fn capture_debug() {
let capture = GraphLaunchCapture::begin();
let dbg = format!("{capture:?}");
assert!(dbg.contains("GraphLaunchCapture"));
assert!(dbg.contains("active: true"));
}
#[test]
fn launch_record_clone() {
let record = LaunchRecord {
kernel_name: "clone_test".to_owned(),
params: LaunchParams::new(2u32, 64u32),
};
let cloned = record.clone();
assert_eq!(cloned.kernel_name(), record.kernel_name());
assert_eq!(cloned.params().grid.x, record.params().grid.x);
}
#[test]
fn graph_capture_records_launches() {
let mut capture = GraphLaunchCapture::begin();
assert_eq!(capture.len(), 0);
assert!(capture.is_empty());
capture.record_raw("vector_add", LaunchParams::new(Dim3::x(4), Dim3::x(256)));
assert_eq!(capture.len(), 1);
assert!(!capture.is_empty());
}
#[test]
fn graph_record_contains_params() {
let params = LaunchParams::new(Dim3::new(8, 2, 1), Dim3::new(32, 8, 1));
let record = LaunchRecord {
kernel_name: "my_kernel".to_owned(),
params,
};
assert_eq!(record.params().grid.x, 8);
assert_eq!(record.params().grid.y, 2);
assert_eq!(record.params().grid.z, 1);
assert_eq!(record.params().block.x, 32);
assert_eq!(record.params().block.y, 8);
assert_eq!(record.params().block.z, 1);
}
#[test]
fn graph_replay_count() {
let mut capture = GraphLaunchCapture::begin();
let params = LaunchParams::new(Dim3::x(4), Dim3::x(128));
capture.record_raw("kernel_a", params);
capture.record_raw("kernel_b", params);
capture.record_raw("kernel_c", params);
assert_eq!(capture.len(), 3);
let records = capture.end();
assert_eq!(records.len(), 3);
assert_eq!(records[0].kernel_name(), "kernel_a");
assert_eq!(records[1].kernel_name(), "kernel_b");
assert_eq!(records[2].kernel_name(), "kernel_c");
}
}