use std::sync::Arc;
use oxicuda_backend::{BackendError, BackendResult, BinaryOp, ReduceOp, UnaryOp};
use oxicuda_driver::ffi::CUdeviceptr;
use oxicuda_driver::{Device, Module};
use oxicuda_launch::{Dim3, Kernel, LaunchParams, grid_size_for};
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::builder::KernelBuilder;
use oxicuda_ptx::cache::{PtxCache, PtxCacheKey};
use oxicuda_ptx::ir::PtxType;
use oxicuda_ptx::templates::elementwise::{ElementwiseOp, ElementwiseTemplate};
use super::CudaBackend;
const ELEMENTWISE_BLOCK: u32 = 256;
const REDUCE_BLOCK: u32 = 256;
fn device_sm(device: Device) -> BackendResult<SmVersion> {
let (major, minor) = device
.compute_capability()
.map_err(|e| BackendError::DeviceError(format!("compute capability query failed: {e}")))?;
SmVersion::from_compute_capability(major, minor).ok_or_else(|| {
BackendError::Unsupported(format!(
"unsupported compute capability {major}.{minor} for PTX generation"
))
})
}
fn ptx_cache() -> BackendResult<PtxCache> {
PtxCache::new().map_err(|e| BackendError::DeviceError(format!("PTX cache init failed: {e}")))
}
fn build_kernel(ptx: &str, kernel_name: &str) -> BackendResult<Kernel> {
let module = Module::from_ptx(ptx)
.map_err(|e| BackendError::DeviceError(format!("PTX module load failed: {e}")))?;
Kernel::from_module(Arc::new(module), kernel_name)
.map_err(|e| BackendError::DeviceError(format!("kernel lookup failed: {e}")))
}
fn map_unary(op: UnaryOp) -> ElementwiseOp {
match op {
UnaryOp::Relu => ElementwiseOp::Relu,
UnaryOp::Sigmoid => ElementwiseOp::Sigmoid,
UnaryOp::Tanh => ElementwiseOp::Tanh,
UnaryOp::Exp => ElementwiseOp::Exp,
UnaryOp::Log => ElementwiseOp::Log,
UnaryOp::Sqrt => ElementwiseOp::Sqrt,
UnaryOp::Abs => ElementwiseOp::Abs,
UnaryOp::Neg => ElementwiseOp::Neg,
}
}
fn map_binary(op: BinaryOp) -> ElementwiseOp {
match op {
BinaryOp::Add => ElementwiseOp::Add,
BinaryOp::Sub => ElementwiseOp::Sub,
BinaryOp::Mul => ElementwiseOp::Mul,
BinaryOp::Div => ElementwiseOp::Div,
BinaryOp::Max => ElementwiseOp::Max,
BinaryOp::Min => ElementwiseOp::Min,
}
}
fn reduce_op_name(op: ReduceOp) -> &'static str {
match op {
ReduceOp::Sum => "sum",
ReduceOp::Max => "max",
ReduceOp::Min => "min",
ReduceOp::Mean => "mean",
}
}
fn generate_reduce_ptx(op: ReduceOp, sm: SmVersion) -> Result<String, oxicuda_ptx::PtxGenError> {
let kernel_name = reduce_kernel_name(op);
let identity = match op {
ReduceOp::Sum | ReduceOp::Mean => "0f00000000", ReduceOp::Max => "0fFF800000", ReduceOp::Min => "0f7F800000", };
let combine = match op {
ReduceOp::Sum | ReduceOp::Mean => "add.f32",
ReduceOp::Max => "max.f32",
ReduceOp::Min => "min.f32",
};
let is_mean = matches!(op, ReduceOp::Mean);
KernelBuilder::new(&kernel_name)
.target(sm)
.param("in_ptr", PtxType::U64)
.param("out_ptr", PtxType::U64)
.param("n_out", PtxType::U32)
.param("axis_len", PtxType::U32)
.param("inner", PtxType::U32)
.param("inv", PtxType::F32)
.max_threads_per_block(REDUCE_BLOCK)
.body(move |b| {
let tid = b.global_thread_id_x();
let tid_name = tid.to_string();
let n_out = b.load_param_u32("n_out");
b.if_lt_u32(tid, n_out, move |b| {
let in_ptr = b.load_param_u64("in_ptr");
let out_ptr = b.load_param_u64("out_ptr");
b.raw_ptx(&format!(
"ld.param.u32 %r_axis, [%param_axis_len];\n \
ld.param.u32 %r_inner, [%param_inner];\n \
div.u32 %r_outer_idx, {tid_name}, %r_inner;\n \
rem.u32 %r_inner_idx, {tid_name}, %r_inner;\n \
mul.lo.u32 %r_base, %r_outer_idx, %r_axis;\n \
mul.lo.u32 %r_base, %r_base, %r_inner;\n \
add.u32 %r_base, %r_base, %r_inner_idx;\n \
mov.f32 %f_acc, {identity};\n \
mov.u32 %r_k, 0;"
));
b.raw_ptx(&format!(
"$RED_K_LOOP:\n \
setp.ge.u32 %p_done, %r_k, %r_axis;\n \
@%p_done bra $RED_K_END;\n \
mul.lo.u32 %r_idx, %r_k, %r_inner;\n \
add.u32 %r_idx, %r_idx, %r_base;\n \
cvt.u64.u32 %rd_off, %r_idx;\n \
mul.lo.u64 %rd_off, %rd_off, 4;\n \
add.u64 %rd_addr, {in_ptr}, %rd_off;\n \
ld.global.f32 %f_v, [%rd_addr];\n \
{combine} %f_acc, %f_acc, %f_v;\n \
add.u32 %r_k, %r_k, 1;\n \
bra $RED_K_LOOP;\n \
$RED_K_END:"
));
if is_mean {
b.raw_ptx(
"ld.param.f32 %f_inv, [%param_inv];\n \
mul.f32 %f_acc, %f_acc, %f_inv;",
);
}
b.raw_ptx(&format!(
"cvt.u64.u32 %rd_ooff, {tid_name};\n \
mul.lo.u64 %rd_ooff, %rd_ooff, 4;\n \
add.u64 %rd_oaddr, {out_ptr}, %rd_ooff;\n \
st.global.f32 [%rd_oaddr], %f_acc;"
));
});
b.ret();
})
.build()
}
fn reduce_kernel_name(op: ReduceOp) -> String {
format!("reduce_axis_{}_f32", reduce_op_name(op))
}
fn params_hash(tag: &str, block: u32) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
tag.hash(&mut hasher);
block.hash(&mut hasher);
PtxType::F32.as_ptx_str().hash(&mut hasher);
hasher.finish()
}
pub(super) fn unary_elementwise(
backend: &CudaBackend,
op: UnaryOp,
input_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()> {
let device = backend.activate_gpu()?;
let sm = device_sm(device)?;
let n_u = u32::try_from(n)
.map_err(|_| BackendError::InvalidArgument("unary element count exceeds u32".into()))?;
let ew_op = map_unary(op);
let template = ElementwiseTemplate::new(ew_op, PtxType::F32, sm);
let kernel_name = template.kernel_name();
let cache = ptx_cache()?;
let key = PtxCacheKey {
kernel_name: kernel_name.clone(),
params_hash: params_hash(ew_op.as_str(), ELEMENTWISE_BLOCK),
sm_version: sm,
};
let ptx = cache
.get_or_generate(&key, || template.generate())
.map_err(|e| BackendError::DeviceError(format!("unary PTX generation failed: {e}")))?;
let kernel = build_kernel(&ptx, &kernel_name)?;
let grid = grid_size_for(n_u, ELEMENTWISE_BLOCK);
let params = LaunchParams::builder()
.grid(Dim3::new(grid, 1, 1))
.block(Dim3::new(ELEMENTWISE_BLOCK, 1, 1))
.build();
let args = (input_ptr as CUdeviceptr, output_ptr as CUdeviceptr, n_u);
launch_with(backend, &kernel, ¶ms, &args)
}
pub(super) fn binary_elementwise(
backend: &CudaBackend,
op: BinaryOp,
a_ptr: u64,
b_ptr: u64,
output_ptr: u64,
n: usize,
) -> BackendResult<()> {
let device = backend.activate_gpu()?;
let sm = device_sm(device)?;
let n_u = u32::try_from(n)
.map_err(|_| BackendError::InvalidArgument("binary element count exceeds u32".into()))?;
let ew_op = map_binary(op);
let template = ElementwiseTemplate::new(ew_op, PtxType::F32, sm);
let kernel_name = template.kernel_name();
let cache = ptx_cache()?;
let key = PtxCacheKey {
kernel_name: kernel_name.clone(),
params_hash: params_hash(ew_op.as_str(), ELEMENTWISE_BLOCK),
sm_version: sm,
};
let ptx = cache
.get_or_generate(&key, || template.generate())
.map_err(|e| BackendError::DeviceError(format!("binary PTX generation failed: {e}")))?;
let kernel = build_kernel(&ptx, &kernel_name)?;
let grid = grid_size_for(n_u, ELEMENTWISE_BLOCK);
let params = LaunchParams::builder()
.grid(Dim3::new(grid, 1, 1))
.block(Dim3::new(ELEMENTWISE_BLOCK, 1, 1))
.build();
let args = (
a_ptr as CUdeviceptr,
b_ptr as CUdeviceptr,
output_ptr as CUdeviceptr,
n_u,
);
launch_with(backend, &kernel, ¶ms, &args)
}
pub(super) fn reduce_axis(
backend: &CudaBackend,
op: ReduceOp,
input_ptr: u64,
output_ptr: u64,
shape: &[usize],
axis: usize,
) -> BackendResult<()> {
if shape.is_empty() {
return Err(BackendError::InvalidArgument(
"reduce shape must not be empty".into(),
));
}
if axis >= shape.len() {
return Err(BackendError::InvalidArgument(format!(
"reduce axis {axis} out of bounds for {}-D shape",
shape.len()
)));
}
let outer: usize = shape[..axis].iter().product();
let axis_len: usize = shape[axis];
let inner: usize = shape[axis + 1..].iter().product();
if axis_len == 0 {
return Err(BackendError::InvalidArgument(
"reduce axis length must be > 0".into(),
));
}
let n_out = outer
.checked_mul(inner)
.ok_or_else(|| BackendError::InvalidArgument("reduce output size overflow".into()))?;
let n_out_u = u32::try_from(n_out)
.map_err(|_| BackendError::InvalidArgument("reduce output size exceeds u32".into()))?;
let axis_len_u = u32::try_from(axis_len)
.map_err(|_| BackendError::InvalidArgument("reduce axis length exceeds u32".into()))?;
let inner_u = u32::try_from(inner)
.map_err(|_| BackendError::InvalidArgument("reduce inner size exceeds u32".into()))?;
let device = backend.activate_gpu()?;
let sm = device_sm(device)?;
let kernel_name = reduce_kernel_name(op);
let cache = ptx_cache()?;
let key = PtxCacheKey {
kernel_name: kernel_name.clone(),
params_hash: params_hash(reduce_op_name(op), REDUCE_BLOCK),
sm_version: sm,
};
let ptx = cache
.get_or_generate(&key, || generate_reduce_ptx(op, sm))
.map_err(|e| BackendError::DeviceError(format!("reduce PTX generation failed: {e}")))?;
let kernel = build_kernel(&ptx, &kernel_name)?;
let grid = grid_size_for(n_out_u, REDUCE_BLOCK);
let params = LaunchParams::builder()
.grid(Dim3::new(grid, 1, 1))
.block(Dim3::new(REDUCE_BLOCK, 1, 1))
.build();
let inv = 1.0f32 / axis_len as f32;
let args = (
input_ptr as CUdeviceptr,
output_ptr as CUdeviceptr,
n_out_u,
axis_len_u,
inner_u,
inv,
);
launch_with(backend, &kernel, ¶ms, &args)
}
fn launch_with<A: oxicuda_launch::KernelArgs>(
backend: &CudaBackend,
kernel: &Kernel,
params: &LaunchParams,
args: &A,
) -> BackendResult<()> {
let device = backend.activate_gpu()?;
let token = std::sync::Arc::new(oxicuda_driver::Context::new(&device).map_err(|e| {
BackendError::DeviceError(format!("stream context token creation failed: {e}"))
})?);
backend.activate_gpu()?;
let stream = oxicuda_driver::Stream::new(&token)
.map_err(|e| BackendError::DeviceError(format!("stream creation failed: {e}")))?;
kernel
.launch(params, &stream, args)
.map_err(|e| BackendError::DeviceError(format!("kernel launch failed: {e}")))?;
stream
.synchronize()
.map_err(|e| BackendError::DeviceError(format!("stream synchronize failed: {e}")))
}