use std::collections::HashMap;
use std::fmt;
use std::time::Instant;
use crate::grid::Dim3;
use crate::kernel::KernelArgs;
use crate::params::LaunchParams;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ArgType {
U8,
U16,
U32,
U64,
I8,
I16,
I32,
I64,
F32,
F64,
Ptr,
Custom(String),
}
impl fmt::Display for ArgType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::U8 => write!(f, "u8"),
Self::U16 => write!(f, "u16"),
Self::U32 => write!(f, "u32"),
Self::U64 => write!(f, "u64"),
Self::I8 => write!(f, "i8"),
Self::I16 => write!(f, "i16"),
Self::I32 => write!(f, "i32"),
Self::I64 => write!(f, "i64"),
Self::F32 => write!(f, "f32"),
Self::F64 => write!(f, "f64"),
Self::Ptr => write!(f, "ptr"),
Self::Custom(name) => write!(f, "{name}"),
}
}
}
#[derive(Debug, Clone)]
pub struct SerializedArg {
name: Option<String>,
arg_type: ArgType,
value_repr: String,
size_bytes: usize,
}
impl SerializedArg {
#[inline]
pub fn new(
name: Option<String>,
arg_type: ArgType,
value_repr: String,
size_bytes: usize,
) -> Self {
Self {
name,
arg_type,
value_repr,
size_bytes,
}
}
#[inline]
pub fn name(&self) -> Option<&str> {
self.name.as_deref()
}
#[inline]
pub fn arg_type(&self) -> &ArgType {
&self.arg_type
}
#[inline]
pub fn value_repr(&self) -> &str {
&self.value_repr
}
#[inline]
pub fn size_bytes(&self) -> usize {
self.size_bytes
}
pub fn total_size(args: &[Self]) -> usize {
args.iter().map(|a| a.size_bytes).sum()
}
}
impl fmt::Display for SerializedArg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match &self.name {
Some(name) => write!(f, "{name}: {} = {}", self.arg_type, self.value_repr),
None => write!(f, "{}: {}", self.arg_type, self.value_repr),
}
}
}
pub struct LaunchLog {
kernel_name: String,
grid: Dim3,
block: Dim3,
shared_mem: u32,
args: Vec<SerializedArg>,
timestamp: Instant,
}
impl LaunchLog {
pub fn new(
kernel_name: String,
grid: Dim3,
block: Dim3,
shared_mem: u32,
args: Vec<SerializedArg>,
) -> Self {
Self {
kernel_name,
grid,
block,
shared_mem,
args,
timestamp: Instant::now(),
}
}
pub fn from_params(
kernel_name: String,
params: &LaunchParams,
args: Vec<SerializedArg>,
) -> Self {
Self::new(
kernel_name,
params.grid,
params.block,
params.shared_mem_bytes,
args,
)
}
#[inline]
pub fn kernel_name(&self) -> &str {
&self.kernel_name
}
#[inline]
pub fn grid(&self) -> Dim3 {
self.grid
}
#[inline]
pub fn block(&self) -> Dim3 {
self.block
}
#[inline]
pub fn shared_mem(&self) -> u32 {
self.shared_mem
}
#[inline]
pub fn args(&self) -> &[SerializedArg] {
&self.args
}
#[inline]
pub fn timestamp(&self) -> Instant {
self.timestamp
}
#[inline]
pub fn total_threads(&self) -> u64 {
self.grid.total() as u64 * self.block.total() as u64
}
}
impl fmt::Display for LaunchLog {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let grid_str = format!("({},{},{})", self.grid.x, self.grid.y, self.grid.z);
let block_str = format!("({},{},{})", self.block.x, self.block.y, self.block.z);
let args_str = format_args_inner(&self.args);
write!(
f,
"{}<<<{}, {}, {}>>>( {} )",
self.kernel_name, grid_str, block_str, self.shared_mem, args_str
)
}
}
impl fmt::Debug for LaunchLog {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("LaunchLog")
.field("kernel_name", &self.kernel_name)
.field("grid", &self.grid)
.field("block", &self.block)
.field("shared_mem", &self.shared_mem)
.field("args_count", &self.args.len())
.finish()
}
}
#[derive(Debug)]
pub struct LaunchLogger {
entries: Vec<LaunchLog>,
}
impl LaunchLogger {
#[inline]
pub fn new() -> Self {
Self {
entries: Vec::new(),
}
}
#[inline]
pub fn log(&mut self, record: LaunchLog) {
self.entries.push(record);
}
#[inline]
pub fn entries(&self) -> &[LaunchLog] {
&self.entries
}
#[inline]
pub fn clear(&mut self) {
self.entries.clear();
}
#[inline]
pub fn len(&self) -> usize {
self.entries.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn summary(&self) -> LaunchSummary {
let mut per_kernel: HashMap<String, KernelLaunchStats> = HashMap::new();
for entry in &self.entries {
let stats = per_kernel
.entry(entry.kernel_name.clone())
.or_insert_with(|| KernelLaunchStats {
kernel_name: entry.kernel_name.clone(),
launch_count: 0,
total_threads: 0,
total_shared_mem: 0,
});
stats.launch_count += 1;
stats.total_threads += entry.total_threads();
stats.total_shared_mem += u64::from(entry.shared_mem);
}
LaunchSummary {
total_launches: self.entries.len(),
per_kernel,
}
}
}
impl Default for LaunchLogger {
#[inline]
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct KernelLaunchStats {
kernel_name: String,
launch_count: usize,
total_threads: u64,
total_shared_mem: u64,
}
impl KernelLaunchStats {
#[inline]
pub fn kernel_name(&self) -> &str {
&self.kernel_name
}
#[inline]
pub fn launch_count(&self) -> usize {
self.launch_count
}
#[inline]
pub fn total_threads(&self) -> u64 {
self.total_threads
}
#[inline]
pub fn total_shared_mem(&self) -> u64 {
self.total_shared_mem
}
}
impl fmt::Display for KernelLaunchStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}: {} launches, {} total threads, {} bytes shared mem",
self.kernel_name, self.launch_count, self.total_threads, self.total_shared_mem
)
}
}
#[derive(Debug)]
pub struct LaunchSummary {
total_launches: usize,
per_kernel: HashMap<String, KernelLaunchStats>,
}
impl LaunchSummary {
#[inline]
pub fn total_launches(&self) -> usize {
self.total_launches
}
#[inline]
pub fn per_kernel(&self) -> &HashMap<String, KernelLaunchStats> {
&self.per_kernel
}
#[inline]
pub fn unique_kernels(&self) -> usize {
self.per_kernel.len()
}
#[inline]
pub fn kernel_stats(&self, name: &str) -> Option<&KernelLaunchStats> {
self.per_kernel.get(name)
}
}
impl fmt::Display for LaunchSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "LaunchSummary: {} total launches", self.total_launches)?;
let mut names: Vec<&String> = self.per_kernel.keys().collect();
names.sort();
for name in names {
if let Some(stats) = self.per_kernel.get(name) {
writeln!(f, " {stats}")?;
}
}
Ok(())
}
}
pub unsafe trait SerializableKernelArgs: KernelArgs {
fn serialize_args(&self) -> Vec<SerializedArg>;
}
unsafe impl SerializableKernelArgs for () {
fn serialize_args(&self) -> Vec<SerializedArg> {
Vec::new()
}
}
pub trait SerializeArg: Copy {
fn arg_type() -> ArgType;
fn value_repr(&self) -> String;
fn size_bytes() -> usize;
fn to_serialized(&self, name: Option<String>) -> SerializedArg {
SerializedArg::new(
name,
Self::arg_type(),
self.value_repr(),
Self::size_bytes(),
)
}
}
macro_rules! impl_serialize_arg_int {
($ty:ty, $variant:ident) => {
impl SerializeArg for $ty {
#[inline]
fn arg_type() -> ArgType {
ArgType::$variant
}
#[inline]
fn value_repr(&self) -> String {
self.to_string()
}
#[inline]
fn size_bytes() -> usize {
std::mem::size_of::<$ty>()
}
}
};
}
impl_serialize_arg_int!(u8, U8);
impl_serialize_arg_int!(u16, U16);
impl_serialize_arg_int!(u32, U32);
impl_serialize_arg_int!(u64, U64);
impl_serialize_arg_int!(i8, I8);
impl_serialize_arg_int!(i16, I16);
impl_serialize_arg_int!(i32, I32);
impl_serialize_arg_int!(i64, I64);
impl SerializeArg for f32 {
#[inline]
fn arg_type() -> ArgType {
ArgType::F32
}
#[inline]
fn value_repr(&self) -> String {
if self.fract() == 0.0 && self.is_finite() {
format!("{self:.1}")
} else {
format!("{self}")
}
}
#[inline]
fn size_bytes() -> usize {
4
}
}
impl SerializeArg for f64 {
#[inline]
fn arg_type() -> ArgType {
ArgType::F64
}
#[inline]
fn value_repr(&self) -> String {
if self.fract() == 0.0 && self.is_finite() {
format!("{self:.1}")
} else {
format!("{self}")
}
}
#[inline]
fn size_bytes() -> usize {
8
}
}
impl SerializeArg for usize {
#[inline]
fn arg_type() -> ArgType {
ArgType::Ptr
}
#[inline]
fn value_repr(&self) -> String {
format!("0x{self:x}")
}
#[inline]
fn size_bytes() -> usize {
std::mem::size_of::<usize>()
}
}
impl SerializeArg for isize {
#[inline]
fn arg_type() -> ArgType {
ArgType::Ptr
}
#[inline]
fn value_repr(&self) -> String {
format!("0x{self:x}")
}
#[inline]
fn size_bytes() -> usize {
std::mem::size_of::<isize>()
}
}
macro_rules! impl_serializable_kernel_args_tuple {
($($idx:tt: $T:ident),+) => {
unsafe impl<$($T: Copy + SerializeArg),+> SerializableKernelArgs for ($($T,)+) {
fn serialize_args(&self) -> Vec<SerializedArg> {
vec![
$(self.$idx.to_serialized(Some(format!("arg{}", $idx))),)+
]
}
}
};
}
impl_serializable_kernel_args_tuple!(0: A);
impl_serializable_kernel_args_tuple!(0: A, 1: B);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K);
impl_serializable_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L);
pub fn format_launch_params(params: &LaunchParams) -> String {
format!(
"grid=({},{},{}) block=({},{},{}) smem={}",
params.grid.x,
params.grid.y,
params.grid.z,
params.block.x,
params.block.y,
params.block.z,
params.shared_mem_bytes,
)
}
pub fn format_args(args: &[SerializedArg]) -> String {
format_args_inner(args)
}
fn format_args_inner(args: &[SerializedArg]) -> String {
if args.is_empty() {
return String::new();
}
let parts: Vec<String> = args.iter().map(|a| a.to_string()).collect();
parts.join(", ")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::params::LaunchParams;
#[test]
fn arg_type_display() {
assert_eq!(format!("{}", ArgType::U32), "u32");
assert_eq!(format!("{}", ArgType::F64), "f64");
assert_eq!(format!("{}", ArgType::Ptr), "ptr");
assert_eq!(format!("{}", ArgType::Custom("MyType".into())), "MyType");
}
#[test]
fn arg_type_equality() {
assert_eq!(ArgType::U32, ArgType::U32);
assert_ne!(ArgType::U32, ArgType::U64);
assert_eq!(ArgType::Custom("Foo".into()), ArgType::Custom("Foo".into()));
}
#[test]
fn serialized_arg_new_and_accessors() {
let arg = SerializedArg::new(Some("count".into()), ArgType::U32, "42".into(), 4);
assert_eq!(arg.name(), Some("count"));
assert_eq!(*arg.arg_type(), ArgType::U32);
assert_eq!(arg.value_repr(), "42");
assert_eq!(arg.size_bytes(), 4);
}
#[test]
fn serialized_arg_no_name() {
let arg = SerializedArg::new(None, ArgType::F32, "3.14".into(), 4);
assert_eq!(arg.name(), None);
assert_eq!(format!("{arg}"), "f32: 3.14");
}
#[test]
fn serialized_arg_with_name_display() {
let arg = SerializedArg::new(Some("x".into()), ArgType::I64, "-100".into(), 8);
assert_eq!(format!("{arg}"), "x: i64 = -100");
}
#[test]
fn serialized_arg_total_size() {
let args = vec![
SerializedArg::new(None, ArgType::U32, "1".into(), 4),
SerializedArg::new(None, ArgType::U64, "2".into(), 8),
SerializedArg::new(None, ArgType::F32, "3.0".into(), 4),
];
assert_eq!(SerializedArg::total_size(&args), 16);
}
#[test]
fn launch_log_creation_and_accessors() {
let log = LaunchLog::new(
"vector_add".into(),
Dim3::x(4),
Dim3::x(256),
1024,
vec![SerializedArg::new(None, ArgType::U32, "42".into(), 4)],
);
assert_eq!(log.kernel_name(), "vector_add");
assert_eq!(log.grid(), Dim3::x(4));
assert_eq!(log.block(), Dim3::x(256));
assert_eq!(log.shared_mem(), 1024);
assert_eq!(log.args().len(), 1);
assert_eq!(log.total_threads(), 1024);
}
#[test]
fn launch_log_from_params() {
let params = LaunchParams::new(Dim3::xy(2, 2), Dim3::x(128)).with_shared_mem(512);
let log = LaunchLog::from_params("matmul".into(), ¶ms, vec![]);
assert_eq!(log.kernel_name(), "matmul");
assert_eq!(log.grid(), Dim3::xy(2, 2));
assert_eq!(log.shared_mem(), 512);
}
#[test]
fn launch_log_display() {
let log = LaunchLog::new(
"my_kernel".into(),
Dim3::x(4),
Dim3::x(256),
0,
vec![
SerializedArg::new(Some("a".into()), ArgType::U64, "0x1000".into(), 8),
SerializedArg::new(Some("n".into()), ArgType::U32, "1024".into(), 4),
],
);
let s = format!("{log}");
assert!(s.contains("my_kernel<<<"));
assert!(s.contains("(4,1,1)"));
assert!(s.contains("(256,1,1)"));
assert!(s.contains("a: u64 = 0x1000"));
assert!(s.contains("n: u32 = 1024"));
}
#[test]
fn launch_log_debug() {
let log = LaunchLog::new("kern".into(), Dim3::x(1), Dim3::x(1), 0, vec![]);
let dbg = format!("{log:?}");
assert!(dbg.contains("LaunchLog"));
assert!(dbg.contains("kern"));
}
#[test]
fn launch_logger_basic_workflow() {
let mut logger = LaunchLogger::new();
assert!(logger.is_empty());
assert_eq!(logger.len(), 0);
logger.log(LaunchLog::new(
"kern_a".into(),
Dim3::x(4),
Dim3::x(256),
0,
vec![],
));
logger.log(LaunchLog::new(
"kern_b".into(),
Dim3::x(8),
Dim3::x(128),
512,
vec![],
));
assert_eq!(logger.len(), 2);
assert!(!logger.is_empty());
assert_eq!(logger.entries()[0].kernel_name(), "kern_a");
assert_eq!(logger.entries()[1].kernel_name(), "kern_b");
logger.clear();
assert!(logger.is_empty());
}
#[test]
fn launch_logger_default() {
let logger = LaunchLogger::default();
assert!(logger.is_empty());
}
#[test]
fn launch_summary_aggregation() {
let mut logger = LaunchLogger::new();
logger.log(LaunchLog::new(
"kern_a".into(),
Dim3::x(4),
Dim3::x(256),
0,
vec![],
));
logger.log(LaunchLog::new(
"kern_a".into(),
Dim3::x(8),
Dim3::x(256),
1024,
vec![],
));
logger.log(LaunchLog::new(
"kern_b".into(),
Dim3::x(1),
Dim3::x(128),
0,
vec![],
));
let summary = logger.summary();
assert_eq!(summary.total_launches(), 3);
assert_eq!(summary.unique_kernels(), 2);
let a_stats = summary.kernel_stats("kern_a");
assert!(a_stats.is_some());
let a_stats = a_stats.expect("kern_a stats should exist in test");
assert_eq!(a_stats.launch_count(), 2);
assert_eq!(a_stats.total_threads(), 4 * 256 + 8 * 256);
assert_eq!(a_stats.total_shared_mem(), 1024);
let b_stats = summary.kernel_stats("kern_b");
assert!(b_stats.is_some());
let b_stats = b_stats.expect("kern_b stats should exist in test");
assert_eq!(b_stats.launch_count(), 1);
}
#[test]
fn launch_summary_display() {
let mut logger = LaunchLogger::new();
logger.log(LaunchLog::new(
"kern".into(),
Dim3::x(1),
Dim3::x(1),
0,
vec![],
));
let summary = logger.summary();
let s = format!("{summary}");
assert!(s.contains("LaunchSummary"));
assert!(s.contains("1 total launches"));
assert!(s.contains("kern"));
}
#[test]
fn serialize_arg_trait_scalars() {
let v: u32 = 42;
let sa = v.to_serialized(Some("n".into()));
assert_eq!(*sa.arg_type(), ArgType::U32);
assert_eq!(sa.value_repr(), "42");
assert_eq!(sa.size_bytes(), 4);
let v: f64 = 3.15;
let sa = v.to_serialized(None);
assert_eq!(*sa.arg_type(), ArgType::F64);
assert_eq!(sa.value_repr(), "3.15");
assert_eq!(sa.size_bytes(), 8);
let v: f32 = 1.0;
let sa = v.to_serialized(None);
assert_eq!(sa.value_repr(), "1.0");
}
#[test]
fn serializable_kernel_args_unit() {
let args = ();
let serialized = args.serialize_args();
assert!(serialized.is_empty());
}
#[test]
fn serializable_kernel_args_tuple() {
let args = (42u32, 3.15f64);
let serialized = args.serialize_args();
assert_eq!(serialized.len(), 2);
assert_eq!(serialized[0].name(), Some("arg0"));
assert_eq!(*serialized[0].arg_type(), ArgType::U32);
assert_eq!(serialized[0].value_repr(), "42");
assert_eq!(serialized[1].name(), Some("arg1"));
assert_eq!(*serialized[1].arg_type(), ArgType::F64);
assert_eq!(serialized[1].value_repr(), "3.15");
}
#[test]
fn format_launch_params_output() {
let params = LaunchParams::new(Dim3::xy(4, 2), Dim3::x(256)).with_shared_mem(4096);
let s = format_launch_params(¶ms);
assert!(s.contains("grid=(4,2,1)"));
assert!(s.contains("block=(256,1,1)"));
assert!(s.contains("smem=4096"));
}
#[test]
fn format_args_output() {
let args = vec![
SerializedArg::new(Some("a".into()), ArgType::U64, "0x1000".into(), 8),
SerializedArg::new(Some("n".into()), ArgType::U32, "1024".into(), 4),
];
let s = format_args(&args);
assert!(s.contains("a: u64 = 0x1000"));
assert!(s.contains("n: u32 = 1024"));
}
#[test]
fn format_args_empty() {
let s = format_args(&[]);
assert!(s.is_empty());
}
#[test]
fn kernel_launch_stats_display() {
let stats = KernelLaunchStats {
kernel_name: "matmul".into(),
launch_count: 5,
total_threads: 1_000_000,
total_shared_mem: 4096,
};
let s = format!("{stats}");
assert!(s.contains("matmul"));
assert!(s.contains("5 launches"));
assert!(s.contains("1000000 total threads"));
}
}