use crate::{
compat::*,
cuda::*,
error::{CudaKernelError, Result},
kernel::Kernels,
kernels::macros::ops,
source::Source,
};
ops!(
neg,
abs,
sign,
square,
sqrt,
recip,
relu,
sigmoid,
tanh,
gelu,
softplus,
silu,
mish,
sin,
cos,
tan,
exp,
exp2,
exp10,
ln,
log2,
log10,
logical_not,
add_scalar,
sub_scalar,
mul_scalar,
div_scalar,
pow_scalar,
maximum_scalar,
minimum_scalar,
leaky_relu,
elu,
prelu,
eq_scalar,
ne_scalar,
lt_scalar,
le_scalar,
gt_scalar,
ge_scalar
);
pub fn call_ops_unary<I, O>(
kernel: crate::kernels::macros::Kernel,
kernels: &Kernels,
context: &Arc<CudaContext>,
input: &CudaSlice<I>,
output: &mut CudaSlice<O>,
metadata: &[usize],
) -> Result<()>
where
I: cudarc::driver::DeviceRepr,
O: cudarc::driver::DeviceRepr,
{
let func = kernels.load_function(context, Source::OpsUnary, kernel.0)?;
let num_els = metadata[0];
let block_size = 256u32;
let grid_size = (num_els as u32).div_ceil(block_size).max(1);
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let stream = context.default_stream();
let metadata_dev = stream
.memcpy_stod(metadata)
.map_err(|e| CudaKernelError::MemoryError(format!("Failed to copy metadata: {:?}", e)))?;
unsafe {
func.launch(&stream, cfg, |args| {
args.arg(input).arg(output).arg(&metadata_dev);
})
.map_err(|e| CudaKernelError::LaunchError(format!("Failed to launch kernel: {:?}", e)))?;
}
Ok(())
}
pub fn call_ops_unary_scalar<I, O>(
kernel: crate::kernels::macros::Kernel,
kernels: &Kernels,
context: &Arc<CudaContext>,
input: &CudaSlice<I>,
output: &mut CudaSlice<O>,
metadata: &[usize],
scalar_val: I,
) -> Result<()>
where
I: cudarc::driver::DeviceRepr + Clone,
O: cudarc::driver::DeviceRepr,
{
let func = kernels.load_function(context, Source::OpsUnary, kernel.0)?;
let num_els = metadata[0];
let block_size = 256u32;
let grid_size = (num_els as u32).div_ceil(block_size).max(1);
let cfg = LaunchConfig {
grid_dim: (grid_size, 1, 1),
block_dim: (block_size, 1, 1),
shared_mem_bytes: 0,
};
let stream = context.default_stream();
let metadata_dev = stream
.memcpy_stod(metadata)
.map_err(|e| CudaKernelError::MemoryError(format!("Failed to copy metadata: {:?}", e)))?;
let scalar_dev = stream
.memcpy_stod(&[scalar_val])
.map_err(|e| CudaKernelError::MemoryError(format!("Failed to copy scalar: {:?}", e)))?;
unsafe {
func.launch(&stream, cfg, |args| {
args.arg(input).arg(output).arg(&metadata_dev).arg(&scalar_dev);
})
.map_err(|e| CudaKernelError::LaunchError(format!("Failed to launch kernel: {:?}", e)))?;
}
Ok(())
}