use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::binary::compute_broadcast_strides;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
kernel_names, launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::runtime::cuda::{CudaDevice, CudaRuntime};
use crate::tensor::Tensor;
pub unsafe fn launch_where_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
cond_ptr: u64,
x_ptr: u64,
y_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::TERNARY_MODULE)?;
let func_name = kernel_name("where", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&cond_ptr);
builder.arg(&x_ptr);
builder.arg(&y_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA where kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_where_broadcast_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
device: &CudaDevice,
dtype: DType,
cond_ptr: u64,
x_ptr: u64,
y_ptr: u64,
out_ptr: u64,
cond_shape: &[usize],
x_shape: &[usize],
y_shape: &[usize],
out_shape: &[usize],
) -> Result<()> {
let numel: usize = out_shape.iter().product();
if numel == 0 {
return Ok(());
}
let ndim = out_shape.len();
let cond_strides = compute_broadcast_strides(cond_shape, out_shape);
let x_strides = compute_broadcast_strides(x_shape, out_shape);
let y_strides = compute_broadcast_strides(y_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&x| x as u32).collect();
let cond_strides_tensor = Tensor::<CudaRuntime>::from_slice(&cond_strides, &[ndim], device);
let x_strides_tensor = Tensor::<CudaRuntime>::from_slice(&x_strides, &[ndim], device);
let y_strides_tensor = Tensor::<CudaRuntime>::from_slice(&y_strides, &[ndim], device);
let shape_tensor = Tensor::<CudaRuntime>::from_slice(&shape_u32, &[ndim], device);
let cond_strides_ptr = cond_strides_tensor.ptr();
let x_strides_ptr = x_strides_tensor.ptr();
let y_strides_ptr = y_strides_tensor.ptr();
let shape_ptr = shape_tensor.ptr();
let module = get_or_load_module(context, device_index, kernel_names::TERNARY_MODULE)?;
let func_name = format!("where_broadcast_{}", super::loader::dtype_suffix(dtype));
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let ndim_u32 = ndim as u32;
let cfg = launch_config(grid, block, 0);
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&cond_ptr);
builder.arg(&x_ptr);
builder.arg(&y_ptr);
builder.arg(&out_ptr);
builder.arg(&cond_strides_ptr);
builder.arg(&x_strides_ptr);
builder.arg(&y_strides_ptr);
builder.arg(&shape_ptr);
builder.arg(&ndim_u32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA where_broadcast kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_where_generic_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
cond_dtype: DType,
dtype: DType,
cond_ptr: u64,
x_ptr: u64,
y_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
let cond_suffix = super::loader::dtype_suffix(cond_dtype);
let out_suffix = super::loader::dtype_suffix(dtype);
let func_name = format!("where_cond_{}_{}", cond_suffix, out_suffix);
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::TERNARY_MODULE)?;
let func =
get_kernel_function(&module, &func_name).map_err(|_| Error::UnsupportedDType {
dtype: cond_dtype,
op: "where_cond (condition dtype)",
})?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&cond_ptr);
builder.arg(&x_ptr);
builder.arg(&y_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA where_cond kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
Ok(())
}
}
#[allow(clippy::too_many_arguments)]
pub unsafe fn launch_where_broadcast_generic_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
device: &CudaDevice,
cond_dtype: DType,
dtype: DType,
cond_ptr: u64,
x_ptr: u64,
y_ptr: u64,
out_ptr: u64,
cond_shape: &[usize],
x_shape: &[usize],
y_shape: &[usize],
out_shape: &[usize],
) -> Result<()> {
let numel: usize = out_shape.iter().product();
if numel == 0 {
return Ok(());
}
let ndim = out_shape.len();
let cond_strides = compute_broadcast_strides(cond_shape, out_shape);
let x_strides = compute_broadcast_strides(x_shape, out_shape);
let y_strides = compute_broadcast_strides(y_shape, out_shape);
let shape_u32: Vec<u32> = out_shape.iter().map(|&x| x as u32).collect();
let cond_strides_tensor = Tensor::<CudaRuntime>::from_slice(&cond_strides, &[ndim], device);
let x_strides_tensor = Tensor::<CudaRuntime>::from_slice(&x_strides, &[ndim], device);
let y_strides_tensor = Tensor::<CudaRuntime>::from_slice(&y_strides, &[ndim], device);
let shape_tensor = Tensor::<CudaRuntime>::from_slice(&shape_u32, &[ndim], device);
let cond_strides_ptr = cond_strides_tensor.ptr();
let x_strides_ptr = x_strides_tensor.ptr();
let y_strides_ptr = y_strides_tensor.ptr();
let shape_ptr = shape_tensor.ptr();
let cond_suffix = super::loader::dtype_suffix(cond_dtype);
let out_suffix = super::loader::dtype_suffix(dtype);
let func_name = format!("where_broadcast_cond_{}_{}", cond_suffix, out_suffix);
let module = get_or_load_module(context, device_index, kernel_names::TERNARY_MODULE)?;
let func = get_kernel_function(&module, &func_name).map_err(|_| Error::UnsupportedDType {
dtype: cond_dtype,
op: "where_cond broadcast (condition dtype)",
})?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let ndim_u32 = ndim as u32;
let cfg = launch_config(grid, block, 0);
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&cond_ptr);
builder.arg(&x_ptr);
builder.arg(&y_ptr);
builder.arg(&out_ptr);
builder.arg(&cond_strides_ptr);
builder.arg(&x_strides_ptr);
builder.arg(&y_strides_ptr);
builder.arg(&shape_ptr);
builder.arg(&ndim_u32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA where_broadcast_cond kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}