use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::templates::elementwise::{ElementwiseOp as PtxOp, ElementwiseTemplate};
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::GpuFloat;
const BLOCK_SIZE: u32 = 256;
fn validate_binary_buffers<T: Copy>(
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &DeviceBuffer<T>,
) -> BlasResult<()> {
let n_usize = n as usize;
if a.len() < n_usize {
return Err(BlasError::BufferTooSmall {
expected: n_usize,
actual: a.len(),
});
}
if b.len() < n_usize {
return Err(BlasError::BufferTooSmall {
expected: n_usize,
actual: b.len(),
});
}
if c.len() < n_usize {
return Err(BlasError::BufferTooSmall {
expected: n_usize,
actual: c.len(),
});
}
Ok(())
}
fn build_binary_kernel(
handle: &BlasHandle,
ptx_op: PtxOp,
ptx_type: oxicuda_ptx::ir::PtxType,
) -> BlasResult<(Kernel, String)> {
let template = ElementwiseTemplate::new(ptx_op, ptx_type, handle.sm_version());
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("{}: {e}", ptx_op.as_str())))?;
let module = Arc::new(Module::from_ptx(&ptx_source).map_err(|e| {
BlasError::LaunchFailed(format!("module load for {}: {e}", ptx_op.as_str()))
})?);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
Ok((kernel, kernel_name))
}
pub fn add<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Add, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("add: {e}")))?;
Ok(())
}
pub fn mul<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Mul, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("mul: {e}")))?;
Ok(())
}
pub fn fused_add_relu<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::FusedAddRelu, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("fused_add_relu: {e}")))?;
Ok(())
}
pub fn fused_scale_add<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
alpha: T,
a: &DeviceBuffer<T>,
beta: T,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::FusedScaleAdd, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let alpha_bits = alpha.to_bits_u64();
let beta_bits = beta.to_bits_u64();
let args = (
a.as_device_ptr(),
b.as_device_ptr(),
c.as_device_ptr(),
alpha_bits,
beta_bits,
n,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("fused_scale_add: {e}")))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn block_size_is_power_of_two() {
assert!(BLOCK_SIZE.is_power_of_two());
const { assert!(BLOCK_SIZE >= 32) };
}
#[test]
fn ptx_template_generates_add_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Add,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("add PTX generation should succeed");
assert!(ptx.contains("elementwise_add_f32"));
}
#[test]
fn ptx_template_generates_mul_f64() {
let template = ElementwiseTemplate::new(
PtxOp::Mul,
oxicuda_ptx::ir::PtxType::F64,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("mul PTX generation should succeed");
assert!(ptx.contains("elementwise_mul_f64"));
}
#[test]
fn ptx_template_generates_fused_add_relu_f32() {
let template = ElementwiseTemplate::new(
PtxOp::FusedAddRelu,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("fused_add_relu PTX generation should succeed");
assert!(ptx.contains("elementwise_fused_add_relu_f32"));
}
#[test]
fn ptx_template_generates_fused_scale_add_f32() {
let template = ElementwiseTemplate::new(
PtxOp::FusedScaleAdd,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("fused_scale_add PTX generation should succeed");
assert!(ptx.contains("elementwise_fused_scale_add_f32"));
}
}