use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
kernel_names, launch_binary_kernel, launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::cuda::CudaDevice;
pub unsafe fn launch_binary_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
launch_binary_kernel(
context,
stream,
device_index,
kernel_names::BINARY_MODULE,
op,
dtype,
a_ptr,
b_ptr,
out_ptr,
numel,
)
}
}
pub unsafe fn launch_logical_and_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?;
let func_name = "logical_and_u8";
let func = get_kernel_function(&module, func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA logical_and kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
pub unsafe fn launch_logical_or_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?;
let func_name = "logical_or_u8";
let func = get_kernel_function(&module, func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA logical_or kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
pub unsafe fn launch_logical_xor_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?;
let func_name = "logical_xor_u8";
let func = get_kernel_function(&module, func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA logical_xor kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
pub fn compute_broadcast_strides(input_shape: &[usize], output_shape: &[usize]) -> Vec<u32> {
let mut strides = vec![0u32; output_shape.len()];
let input_ndim = input_shape.len();
let output_ndim = output_shape.len();
let mut input_strides = vec![1usize; input_ndim];
for i in (0..input_ndim.saturating_sub(1)).rev() {
input_strides[i] = input_strides[i + 1] * input_shape[i + 1];
}
let offset = output_ndim - input_ndim;
for i in 0..output_ndim {
if i < offset {
strides[i] = 0;
} else {
let input_idx = i - offset;
if input_shape[input_idx] == 1 {
strides[i] = 0;
} else {
strides[i] = input_strides[input_idx] as u32;
}
}
}
strides
}
pub const MAX_BROADCAST_DIMS: usize = 8;
pub fn compute_magic_divisor(d: u32) -> (u32, u32) {
if d <= 1 {
return (0u32, 0u32);
}
if d.is_power_of_two() {
let shift = d.trailing_zeros();
return (0u32, shift);
}
let p = 31u32 - d.leading_zeros();
let numerator: u64 = 1u64 << (32 + p);
let magic_full = (numerator + (d as u64) - 1) / (d as u64);
debug_assert!(magic_full <= 0xFFFF_FFFFu64, "magic overflow for d={d}");
(magic_full as u32, p)
}
pub fn detect_fast_trailing_broadcast(
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> Option<usize> {
if a_shape != out_shape {
return None;
}
let ndim = out_shape.len();
let b_ndim = b_shape.len();
let offset = ndim.saturating_sub(b_ndim);
let mut b_start = b_ndim; for i in 0..b_ndim {
if b_shape[i] != 1 {
b_start = i;
break;
}
}
for i in b_start..b_ndim {
let out_i = offset + i;
if b_shape[i] != out_shape[out_i] {
return None;
}
}
let b_numel: usize = b_shape[b_start..].iter().product();
if b_numel == 0 {
return None;
}
Some(b_numel)
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_broadcast_binary_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
_device: &CudaDevice,
op: &str,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> Result<()> {
let numel: usize = out_shape.iter().product();
if numel == 0 {
return Ok(());
}
let ndim = out_shape.len();
if ndim > MAX_BROADCAST_DIMS {
return Err(Error::Internal(format!(
"launch_broadcast_binary_op: ndim={ndim} exceeds MAX_BROADCAST_DIMS={MAX_BROADCAST_DIMS}"
)));
}
let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?;
let dtype_str = kernel_name("", dtype).trim_start_matches('_').to_owned();
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
if let Some(b_numel) = detect_fast_trailing_broadcast(a_shape, b_shape, out_shape) {
let func_name = format!("{}_broadcast_fast_trailing_{}", op, dtype_str);
if let Ok(func) = get_kernel_function(&module, &func_name) {
let (b_magic, b_shift) = compute_magic_divisor(b_numel as u32);
let b_numel_u32 = b_numel as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&b_magic);
builder.arg(&b_shift);
builder.arg(&b_numel_u32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA broadcast fast-trailing kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
return Ok(());
}
}
let a_strides_vec = compute_broadcast_strides(a_shape, out_shape);
let b_strides_vec = compute_broadcast_strides(b_shape, out_shape);
let shape_vec: Vec<u32> = out_shape.iter().map(|&x| x as u32).collect();
let mut a_strides = [0u32; MAX_BROADCAST_DIMS];
let mut b_strides = [0u32; MAX_BROADCAST_DIMS];
let mut shape = [0u32; MAX_BROADCAST_DIMS];
let mut magic = [0u32; MAX_BROADCAST_DIMS];
let mut pshift = [0u32; MAX_BROADCAST_DIMS];
for i in 0..ndim {
a_strides[i] = a_strides_vec[i];
b_strides[i] = b_strides_vec[i];
shape[i] = shape_vec[i];
let (m, s) = compute_magic_divisor(shape_vec[i]);
magic[i] = m;
pshift[i] = s;
}
let func_name = format!("{}_broadcast_{}_inline", op, dtype_str);
let func = get_kernel_function(&module, &func_name)?;
let ndim_u32 = ndim as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&a_strides[0]);
builder.arg(&a_strides[1]);
builder.arg(&a_strides[2]);
builder.arg(&a_strides[3]);
builder.arg(&a_strides[4]);
builder.arg(&a_strides[5]);
builder.arg(&a_strides[6]);
builder.arg(&a_strides[7]);
builder.arg(&b_strides[0]);
builder.arg(&b_strides[1]);
builder.arg(&b_strides[2]);
builder.arg(&b_strides[3]);
builder.arg(&b_strides[4]);
builder.arg(&b_strides[5]);
builder.arg(&b_strides[6]);
builder.arg(&b_strides[7]);
builder.arg(&shape[0]);
builder.arg(&shape[1]);
builder.arg(&shape[2]);
builder.arg(&shape[3]);
builder.arg(&shape[4]);
builder.arg(&shape[5]);
builder.arg(&shape[6]);
builder.arg(&shape[7]);
builder.arg(&magic[0]);
builder.arg(&magic[1]);
builder.arg(&magic[2]);
builder.arg(&magic[3]);
builder.arg(&magic[4]);
builder.arg(&magic[5]);
builder.arg(&magic[6]);
builder.arg(&magic[7]);
builder.arg(&pshift[0]);
builder.arg(&pshift[1]);
builder.arg(&pshift[2]);
builder.arg(&pshift[3]);
builder.arg(&pshift[4]);
builder.arg(&pshift[5]);
builder.arg(&pshift[6]);
builder.arg(&pshift[7]);
builder.arg(&ndim_u32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA broadcast binary kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}