#[derive(Debug, Clone)]
pub struct KernelConfig {
pub name: String,
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
pub shared_memory_bytes: usize,
pub stream_id: Option<usize>,
}
impl KernelConfig {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
grid_dim: (1, 1, 1),
block_dim: (32, 1, 1),
shared_memory_bytes: 0,
stream_id: None,
}
}
pub fn with_grid(mut self, x: u32, y: u32, z: u32) -> Self {
self.grid_dim = (x, y, z);
self
}
pub fn with_block(mut self, x: u32, y: u32, z: u32) -> Self {
self.block_dim = (x, y, z);
self
}
pub fn with_shared_memory(mut self, bytes: usize) -> Self {
self.shared_memory_bytes = bytes;
self
}
pub fn total_threads(&self) -> u32 {
let grid_total = self.grid_dim.0 * self.grid_dim.1 * self.grid_dim.2;
let block_total = self.block_dim.0 * self.block_dim.1 * self.block_dim.2;
grid_total * block_total
}
pub fn blocks_needed(n_elements: u32, threads_per_block: u32) -> u32 {
if threads_per_block == 0 {
return 0;
}
n_elements.div_ceil(threads_per_block)
}
}
#[derive(Debug, Clone)]
pub struct KernelLaunchResult {
pub kernel_name: String,
pub elapsed_us: Option<u64>,
pub grid_dim: (u32, u32, u32),
pub block_dim: (u32, u32, u32),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kernel_config_new() {
let config = KernelConfig::new("elementwise_add");
assert_eq!(config.name, "elementwise_add");
assert_eq!(config.grid_dim, (1, 1, 1));
assert_eq!(config.block_dim, (32, 1, 1));
assert_eq!(config.shared_memory_bytes, 0);
assert!(config.stream_id.is_none());
}
#[test]
fn test_kernel_config_builder() {
let config = KernelConfig::new("matmul")
.with_grid(64, 64, 1)
.with_block(16, 16, 1)
.with_shared_memory(2048);
assert_eq!(config.grid_dim, (64, 64, 1));
assert_eq!(config.block_dim, (16, 16, 1));
assert_eq!(config.shared_memory_bytes, 2048);
}
#[test]
fn test_total_threads() {
let config = KernelConfig::new("test")
.with_grid(4, 2, 1)
.with_block(32, 1, 1);
assert_eq!(config.total_threads(), 256);
}
#[test]
fn test_blocks_needed() {
assert_eq!(KernelConfig::blocks_needed(1024, 32), 32);
assert_eq!(KernelConfig::blocks_needed(1025, 32), 33);
assert_eq!(KernelConfig::blocks_needed(0, 32), 0);
assert_eq!(KernelConfig::blocks_needed(32, 32), 1);
assert_eq!(KernelConfig::blocks_needed(100, 0), 0);
}
#[test]
fn test_kernel_launch_result() {
let result = KernelLaunchResult {
kernel_name: "test_kernel".to_string(),
elapsed_us: Some(42),
grid_dim: (8, 1, 1),
block_dim: (128, 1, 1),
};
assert_eq!(result.kernel_name, "test_kernel");
assert_eq!(result.elapsed_us, Some(42));
assert_eq!(result.grid_dim, (8, 1, 1));
assert_eq!(result.block_dim, (128, 1, 1));
}
}