use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
kernel_names, launch_config, launch_unary_kernel,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
pub unsafe fn launch_unary_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
op: &str,
dtype: DType,
a_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
launch_unary_kernel(
context,
stream,
device_index,
kernel_names::UNARY_MODULE,
op,
dtype,
a_ptr,
out_ptr,
numel,
)
}
}
pub unsafe fn launch_isnan_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
input_dtype: DType,
a_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::UNARY_MODULE)?;
let func_name = kernel_name("isnan", input_dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA isnan kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
Ok(())
}
}
pub unsafe fn launch_isinf_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
input_dtype: DType,
a_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::UNARY_MODULE)?;
let func_name = kernel_name("isinf", input_dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA isinf kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
Ok(())
}
}
pub unsafe fn launch_logical_not_op(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
a_ptr: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
let module = get_or_load_module(context, device_index, kernel_names::UNARY_MODULE)?;
let func_name = "logical_not_u8";
let func = get_kernel_function(&module, func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let mut builder = stream.launch_builder(&func);
builder.arg(&a_ptr);
builder.arg(&out_ptr);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!("CUDA logical_not kernel launch failed: {:?}", e))
})?;
Ok(())
}
}