use std::ffi::c_void;
use std::sync::Arc;
use oxicuda_driver::error::CudaResult;
use oxicuda_driver::loader::try_driver;
use oxicuda_driver::module::{Function, Module};
use oxicuda_driver::stream::Stream;
use crate::params::LaunchParams;
use crate::trace::KernelSpanGuard;
pub unsafe trait KernelArgs {
fn as_param_ptrs(&self) -> Vec<*mut c_void>;
}
unsafe impl KernelArgs for () {
#[inline]
fn as_param_ptrs(&self) -> Vec<*mut c_void> {
Vec::new()
}
}
macro_rules! impl_kernel_args_tuple {
($($idx:tt: $T:ident),+) => {
unsafe impl<$($T: Copy),+> KernelArgs for ($($T,)+) {
#[inline]
fn as_param_ptrs(&self) -> Vec<*mut c_void> {
vec![
$(&self.$idx as *const $T as *mut c_void,)+
]
}
}
};
}
impl_kernel_args_tuple!(0: A);
impl_kernel_args_tuple!(0: A, 1: B);
impl_kernel_args_tuple!(0: A, 1: B, 2: C);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z, 26: AA);
impl_kernel_args_tuple!(0: A, 1: B, 2: C, 3: D, 4: E, 5: F, 6: G, 7: H, 8: I, 9: J, 10: K, 11: L, 12: M, 13: N, 14: O, 15: P, 16: Q, 17: R, 18: S, 19: T, 20: U, 21: V, 22: W, 23: X, 24: Y, 25: Z, 26: AA, 27: BB);
pub struct Kernel {
function: Function,
_module: Arc<Module>,
name: String,
}
impl Kernel {
pub fn from_module(module: Arc<Module>, name: &str) -> CudaResult<Self> {
let function = module.get_function(name)?;
Ok(Self {
function,
_module: module,
name: name.to_owned(),
})
}
pub fn launch<A: KernelArgs>(
&self,
params: &LaunchParams,
stream: &Stream,
args: &A,
) -> CudaResult<()> {
let _span = KernelSpanGuard::enter(
&self.name,
(params.grid.x, params.grid.y, params.grid.z),
(params.block.x, params.block.y, params.block.z),
);
let driver = try_driver()?;
let mut param_ptrs = args.as_param_ptrs();
oxicuda_driver::error::check(unsafe {
(driver.cu_launch_kernel)(
self.function.raw(),
params.grid.x,
params.grid.y,
params.grid.z,
params.block.x,
params.block.y,
params.block.z,
params.shared_mem_bytes,
stream.raw(),
param_ptrs.as_mut_ptr(),
std::ptr::null_mut(),
)
})
}
#[inline]
pub fn name(&self) -> &str {
&self.name
}
#[inline]
pub fn function(&self) -> &Function {
&self.function
}
pub fn max_active_blocks_per_sm(
&self,
block_size: i32,
dynamic_smem: usize,
) -> CudaResult<i32> {
self.function
.max_active_blocks_per_sm(block_size, dynamic_smem)
}
pub fn optimal_block_size(&self, dynamic_smem: usize) -> CudaResult<(i32, i32)> {
self.function.optimal_block_size(dynamic_smem)
}
}
impl std::fmt::Debug for Kernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Kernel")
.field("name", &self.name)
.field("function", &self.function)
.finish_non_exhaustive()
}
}
impl std::fmt::Display for Kernel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Kernel({})", self.name)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unit_args_empty() {
let args = ();
let ptrs = args.as_param_ptrs();
assert!(ptrs.is_empty());
}
#[test]
fn single_arg_ptr_valid() {
let args = (42u32,);
let ptrs = args.as_param_ptrs();
assert_eq!(ptrs.len(), 1);
let val_ptr = ptrs[0] as *const u32;
assert_eq!(unsafe { *val_ptr }, 42u32);
}
#[test]
fn two_args_ptr_valid() {
let args = (10u32, 20u64);
let ptrs = args.as_param_ptrs();
assert_eq!(ptrs.len(), 2);
assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 10u32);
assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 20u64);
}
#[test]
fn four_args_ptr_valid() {
let args = (1u32, 2u64, 3.0f32, 4.0f64);
let ptrs = args.as_param_ptrs();
assert_eq!(ptrs.len(), 4);
assert_eq!(unsafe { *(ptrs[0] as *const u32) }, 1u32);
assert_eq!(unsafe { *(ptrs[1] as *const u64) }, 2u64);
assert!((unsafe { *(ptrs[2] as *const f32) } - 3.0f32).abs() < f32::EPSILON);
assert!((unsafe { *(ptrs[3] as *const f64) } - 4.0f64).abs() < f64::EPSILON);
}
#[test]
fn twelve_args_count() {
let args = (
1u32, 2u32, 3u32, 4u32, 5u32, 6u32, 7u32, 8u32, 9u32, 10u32, 11u32, 12u32,
);
let ptrs = args.as_param_ptrs();
assert_eq!(ptrs.len(), 12);
for (i, ptr) in ptrs.iter().enumerate() {
let val = unsafe { *(*ptr as *const u32) };
assert_eq!(val, (i as u32) + 1);
}
}
#[test]
fn launch_params_grid_calculation_e2e() {
let n: u32 = 1_048_576;
let block_size: u32 = 256;
let grid = crate::grid::grid_size_for(n, block_size);
assert_eq!(
grid, 4096,
"grid_size_for(1M, 256) must be 4096, got {grid}"
);
assert_eq!(
n % block_size,
0,
"n must be exactly divisible by block_size"
);
}
#[test]
fn launch_params_stores_grid_and_block() {
let params = LaunchParams::new(4096u32, 256u32);
assert_eq!(
params.grid.x, 4096,
"grid.x must be 4096, got {}",
params.grid.x
);
assert_eq!(
params.block.x, 256,
"block.x must be 256, got {}",
params.block.x
);
assert_eq!(params.shared_mem_bytes, 0);
assert_eq!(params.total_threads(), 1_048_576);
}
#[test]
fn named_args_builder_chain() {
use crate::named_args::ArgBuilder;
let a: u32 = 1;
let b: f32 = 2.0;
let mut builder = ArgBuilder::new();
builder.add("a", &a).add("b", &b);
assert_eq!(
builder.arg_count(),
2,
"ArgBuilder with 2 pushes must have length 2"
);
let ptrs = builder.build();
assert_eq!(ptrs.len(), 2, "build() must return 2 pointers");
}
}