use cudarc::driver::PushKernelArg;
use cudarc::driver::safe::{CudaContext, CudaStream};
use std::sync::Arc;
use super::loader::{
BLOCK_SIZE, elementwise_launch_config, get_kernel_function, get_or_load_module, kernel_name,
kernel_names, launch_config,
};
use crate::dtype::DType;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Copy)]
pub enum FillValue {
F32(f32),
F64(f64),
I32(i32),
I64(i64),
U8(u8),
#[cfg(feature = "f16")]
F16(u16),
#[cfg(feature = "f16")]
BF16(u16),
FP8E4M3(u8),
FP8E5M2(u8),
}
impl FillValue {
pub fn from_f64(value: f64, dtype: DType) -> Self {
match dtype {
DType::F32 => FillValue::F32(value as f32),
DType::F64 => FillValue::F64(value),
DType::I32 => FillValue::I32(value as i32),
DType::I64 => FillValue::I64(value as i64),
DType::U8 | DType::Bool => FillValue::U8(value as u8),
#[cfg(feature = "f16")]
DType::F16 => FillValue::F16(half::f16::from_f64(value).to_bits()),
#[cfg(feature = "f16")]
DType::BF16 => FillValue::BF16(half::bf16::from_f64(value).to_bits()),
DType::FP8E4M3 => {
FillValue::FP8E4M3(crate::dtype::fp8::FP8E4M3::from_f64(value).to_bits())
}
DType::FP8E5M2 => {
FillValue::FP8E5M2(crate::dtype::fp8::FP8E5M2::from_f64(value).to_bits())
}
_ => FillValue::F64(value),
}
}
fn kernel_dtype(&self) -> DType {
match self {
FillValue::F32(_) => DType::F32,
FillValue::F64(_) => DType::F64,
FillValue::I32(_) => DType::I32,
FillValue::I64(_) => DType::I64,
FillValue::U8(_) => DType::U8,
#[cfg(feature = "f16")]
FillValue::F16(_) => DType::F16,
#[cfg(feature = "f16")]
FillValue::BF16(_) => DType::BF16,
FillValue::FP8E4M3(_) => DType::FP8E4M3,
FillValue::FP8E5M2(_) => DType::FP8E5M2,
}
}
}
pub unsafe fn launch_fill(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
_dtype: DType,
value: FillValue,
out_ptr: u64,
numel: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = kernel_name("fill", value.kernel_dtype());
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
let launch_result = match value {
FillValue::F32(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
FillValue::F64(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
FillValue::I32(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
FillValue::I64(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
FillValue::U8(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
#[cfg(feature = "f16")]
FillValue::F16(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
#[cfg(feature = "f16")]
FillValue::BF16(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
FillValue::FP8E4M3(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
FillValue::FP8E5M2(v) => {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&v);
builder.arg(&n);
unsafe { builder.launch(cfg) }
}
};
launch_result.map_err(|e| {
Error::Internal(format!(
"CUDA fill kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
Ok(())
}
pub unsafe fn launch_fill_with_f64(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
value: f64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
unsafe {
launch_fill(
context,
stream,
device_index,
dtype,
FillValue::from_f64(value, dtype),
out_ptr,
numel,
)
}
}
pub unsafe fn launch_rand(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
seed: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = kernel_name("rand", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&seed);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA rand kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}
pub unsafe fn launch_randn(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
seed: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = kernel_name("randn", dtype);
let func = get_kernel_function(&module, &func_name)?;
let thread_count = (numel + 1) / 2;
let grid = elementwise_launch_config(thread_count);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&seed);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA randn kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}
pub unsafe fn launch_randint(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
low: i64,
range: i64,
seed: u64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = kernel_name("randint", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&low);
builder.arg(&range);
builder.arg(&seed);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA randint kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}
pub unsafe fn launch_arange(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
start: f64,
step: f64,
out_ptr: u64,
numel: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = kernel_name("arange", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n = numel as u32;
let cfg = launch_config(grid, block, 0);
match dtype {
DType::F32 => unsafe {
let start_f32 = start as f32;
let step_f32 = step as f32;
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_f32);
builder.arg(&step_f32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA arange kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
DType::F64 => unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start);
builder.arg(&step);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA arange kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
#[cfg(feature = "f16")]
DType::F16 | DType::BF16 => unsafe {
let start_f32 = start as f32;
let step_f32 = step as f32;
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_f32);
builder.arg(&step_f32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA arange kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
DType::I32 => unsafe {
let start_i32 = start as i32;
let step_i32 = step as i32;
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_i32);
builder.arg(&step_i32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA arange kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
DType::I64 => unsafe {
let start_i64 = start as i64;
let step_i64 = step as i64;
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_i64);
builder.arg(&step_i64);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA arange kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
DType::U32 => unsafe {
let start_u32 = start as u32;
let step_i32 = step as i32; let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_u32);
builder.arg(&step_i32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA arange kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
DType::U64 => unsafe {
let start_u64 = start as u64;
let step_i64 = step as i64; let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_u64);
builder.arg(&step_i64);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA arange kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
#[cfg(feature = "fp8")]
DType::FP8E4M3 | DType::FP8E5M2 => unsafe {
let start_f32 = start as f32;
let step_f32 = step as f32;
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_f32);
builder.arg(&step_f32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA arange kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
_ => {
return Err(Error::UnsupportedDType {
dtype,
op: "arange",
});
}
}
Ok(())
}
pub unsafe fn launch_linspace(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
start: f64,
stop: f64,
out_ptr: u64,
steps: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = kernel_name("linspace", dtype);
let func = get_kernel_function(&module, &func_name)?;
let grid = elementwise_launch_config(steps);
let block = (BLOCK_SIZE, 1, 1);
let n = steps as u32;
let cfg = launch_config(grid, block, 0);
match dtype {
DType::F32 => unsafe {
let start_f32 = start as f32;
let stop_f32 = stop as f32;
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_f32);
builder.arg(&stop_f32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA linspace kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
DType::F64 => unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start);
builder.arg(&stop);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA linspace kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
#[cfg(feature = "f16")]
DType::F16 | DType::BF16 => unsafe {
let start_f32 = start as f32;
let stop_f32 = stop as f32;
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_f32);
builder.arg(&stop_f32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA linspace kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
DType::I32 | DType::I64 | DType::U32 | DType::U64 => unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start); builder.arg(&stop);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA linspace kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
#[cfg(feature = "fp8")]
DType::FP8E4M3 | DType::FP8E5M2 => unsafe {
let start_f32 = start as f32;
let stop_f32 = stop as f32;
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&start_f32);
builder.arg(&stop_f32);
builder.arg(&n);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA linspace kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
},
_ => {
return Err(Error::UnsupportedDType {
dtype,
op: "linspace",
});
}
}
Ok(())
}
pub unsafe fn launch_eye(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
n: usize,
m: usize,
out_ptr: u64,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = kernel_name("eye", dtype);
let func = get_kernel_function(&module, &func_name)?;
let numel = n * m;
let grid = elementwise_launch_config(numel);
let block = (BLOCK_SIZE, 1, 1);
let n_u32 = n as u32;
let m_u32 = m as u32;
let cfg = launch_config(grid, block, 0);
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&out_ptr);
builder.arg(&n_u32);
builder.arg(&m_u32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA eye kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}
pub unsafe fn launch_multinomial_with_replacement(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
probs_ptr: u64,
out_ptr: u64,
seed: u64,
num_distributions: usize,
num_categories: usize,
num_samples: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = format!("multinomial_with_replacement_{}", dtype_suffix(dtype)?);
let func = get_kernel_function(&module, &func_name)?;
let total = num_distributions * num_samples;
let grid = elementwise_launch_config(total);
let block = (BLOCK_SIZE, 1, 1);
let cfg = launch_config(grid, block, 0);
let num_distributions_u32 = num_distributions as u32;
let num_categories_u32 = num_categories as u32;
let num_samples_u32 = num_samples as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&probs_ptr);
builder.arg(&out_ptr);
builder.arg(&seed);
builder.arg(&num_distributions_u32);
builder.arg(&num_categories_u32);
builder.arg(&num_samples_u32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA multinomial_with_replacement kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}
pub unsafe fn launch_multinomial_without_replacement(
context: &Arc<CudaContext>,
stream: &CudaStream,
device_index: usize,
dtype: DType,
probs_ptr: u64,
out_ptr: u64,
seed: u64,
num_distributions: usize,
num_categories: usize,
num_samples: usize,
) -> Result<()> {
let module = get_or_load_module(context, device_index, kernel_names::UTILITY_MODULE)?;
let func_name = format!("multinomial_without_replacement_{}", dtype_suffix(dtype)?);
let func = get_kernel_function(&module, &func_name)?;
let grid = (num_distributions as u32, 1, 1);
let block = (BLOCK_SIZE, 1, 1);
let shared_mem = num_categories * std::mem::size_of::<f64>();
let cfg = launch_config(grid, block, shared_mem as u32);
let num_distributions_u32 = num_distributions as u32;
let num_categories_u32 = num_categories as u32;
let num_samples_u32 = num_samples as u32;
unsafe {
let mut builder = stream.launch_builder(&func);
builder.arg(&probs_ptr);
builder.arg(&out_ptr);
builder.arg(&seed);
builder.arg(&num_distributions_u32);
builder.arg(&num_categories_u32);
builder.arg(&num_samples_u32);
builder.launch(cfg).map_err(|e| {
Error::Internal(format!(
"CUDA multinomial_without_replacement kernel '{}' launch failed: {:?}",
func_name, e
))
})?;
}
Ok(())
}
fn dtype_suffix(dtype: DType) -> Result<&'static str> {
match dtype {
DType::F32 => Ok("f32"),
DType::F64 => Ok("f64"),
#[cfg(feature = "f16")]
DType::F16 => Ok("f16"),
#[cfg(feature = "f16")]
DType::BF16 => Ok("bf16"),
_ => Err(Error::UnsupportedDType {
dtype,
op: "multinomial",
}),
}
}