use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
kernel_names, launch_binary_kernel, launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::cuda::{CudaDevice, CudaRuntime};
use crate::tensor::Tensor;
pub unsafe fn launch_binary_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
launch_binary_kernel(
context,
stream,
device_index,
kernel_names::BINARY_MODULE,
op,
dtype,
a_ptr,
b_ptr,
out_ptr,
numel,
)
}
}
pub unsafe fn launch_logical_and_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?;
let func_name = "logical_and_u8";
let func = get_kernel_function(&module, func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA logical_and kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
pub unsafe fn launch_logical_or_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?;
let func_name = "logical_or_u8";
let func = get_kernel_function(&module, func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA logical_or kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
pub unsafe fn launch_logical_xor_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?;
let func_name = "logical_xor_u8";
let func = get_kernel_function(&module, func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA logical_xor kernel launch failed: {:?}", e))
})?;
Ok(())
}
}
pub fn compute_broadcast_strides(input_shape: &[usize], output_shape: &[usize]) -> Vec<u32> {
let mut strides = vec![0u32; output_shape.len()];
let input_ndim = input_shape.len();
let output_ndim = output_shape.len();
let mut input_strides = vec![1usize; input_ndim];
for i in (0..input_ndim.saturating_sub(1)).rev() {
input_strides[i] = input_strides[i + 1] * input_shape[i + 1];
}
let offset = output_ndim - input_ndim;
for i in 0..output_ndim {
if i < offset {
strides[i] = 0;
} else {
let input_idx = i - offset;
if input_shape[input_idx] == 1 {
strides[i] = 0;
} else {
strides[i] = input_strides[input_idx] as u32;
}
}
}
strides
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_broadcast_binary_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
device: &CudaDevice,
op: &str,
dtype: DType,
a_ptr: u64,
b_ptr: u64,
out_ptr: u64,
a_shape: &[usize],
b_shape: &[usize],
out_shape: &[usize],
) -> Result<()> {
let numel: usize = out_shape.iter().product();
if numel == 0 {
return Ok(());
}
let ndim = out_shape.len();
let a_strides = compute_broadcast_strides(a_shape, out_shape);
let b_strides = compute_broadcast_strides(b_shape, out_shape);
let out_strides: Vec<u32> = {
let mut s = vec![1u32; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
s[i] = s[i + 1] * out_shape[i + 1] as u32;
}
s
};
let shape_u32: Vec<u32> = out_shape.iter().map(|&x| x as u32).collect();
let a_strides_tensor = Tensor::<CudaRuntime>::from_slice(&a_strides, &[ndim], device);
let b_strides_tensor = Tensor::<CudaRuntime>::from_slice(&b_strides, &[ndim], device);
let out_strides_tensor = Tensor::<CudaRuntime>::from_slice(&out_strides, &[ndim], device);
let shape_tensor = Tensor::<CudaRuntime>::from_slice(&shape_u32, &[ndim], device);
let a_strides_ptr = a_strides_tensor.ptr();
let b_strides_ptr = b_strides_tensor.ptr();
let out_strides_ptr = out_strides_tensor.ptr();
let shape_ptr = shape_tensor.ptr();
let module = get_or_load_module(context, device_index, kernel_names::BINARY_MODULE)?;
let func_name = format!(
"{}_broadcast_{}",
op,
kernel_name("", dtype).trim_start_matches('_')
);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let ndim_u32 = ndim as u32;
let cfg = launch_config(grid, block, 0);
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&b_ptr);
builder.arg(&out_ptr);
builder.arg(&a_strides_ptr);
builder.arg(&b_strides_ptr);
builder.arg(&out_strides_ptr);
builder.arg(&shape_ptr);
builder.arg(&ndim_u32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA broadcast binary kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}