use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::templates::elementwise::{ElementwiseOp as PtxOp, ElementwiseTemplate};
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::GpuFloat;
const BLOCK_SIZE: u32 = 256;
fn validate_binary_buffers<T: Copy>(
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &DeviceBuffer<T>,
) -> BlasResult<()> {
let n_usize = n as usize;
if a.len() < n_usize {
return Err(BlasError::BufferTooSmall {
expected: n_usize,
actual: a.len(),
});
}
if b.len() < n_usize {
return Err(BlasError::BufferTooSmall {
expected: n_usize,
actual: b.len(),
});
}
if c.len() < n_usize {
return Err(BlasError::BufferTooSmall {
expected: n_usize,
actual: c.len(),
});
}
Ok(())
}
fn build_binary_kernel(
handle: &BlasHandle,
ptx_op: PtxOp,
ptx_type: oxicuda_ptx::ir::PtxType,
) -> BlasResult<(Kernel, String)> {
let template = ElementwiseTemplate::new(ptx_op, ptx_type, handle.sm_version());
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("{}: {e}", ptx_op.as_str())))?;
let module = Arc::new(Module::from_ptx(&ptx_source).map_err(|e| {
BlasError::LaunchFailed(format!("module load for {}: {e}", ptx_op.as_str()))
})?);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
Ok((kernel, kernel_name))
}
pub fn add<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Add, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("add: {e}")))?;
Ok(())
}
pub fn mul<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Mul, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("mul: {e}")))?;
Ok(())
}
pub fn sub<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Sub, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("sub: {e}")))?;
Ok(())
}
pub fn div<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Div, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("div: {e}")))?;
Ok(())
}
pub fn pow<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Pow, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("pow: {e}")))?;
Ok(())
}
pub fn min<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Min, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("min: {e}")))?;
Ok(())
}
pub fn max<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Max, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("max: {e}")))?;
Ok(())
}
pub fn cmp_eq<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpEq, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("cmp_eq: {e}")))?;
Ok(())
}
pub fn cmp_ne<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpNe, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("cmp_ne: {e}")))?;
Ok(())
}
pub fn cmp_lt<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpLt, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("cmp_lt: {e}")))?;
Ok(())
}
pub fn cmp_gt<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpGt, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("cmp_gt: {e}")))?;
Ok(())
}
pub fn cmp_le<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpLe, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("cmp_le: {e}")))?;
Ok(())
}
pub fn cmp_ge<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::CmpGe, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("cmp_ge: {e}")))?;
Ok(())
}
pub fn or_max<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::OrMax, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("or_max: {e}")))?;
Ok(())
}
pub fn or_prob_sum<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::OrProbSum, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("or_prob_sum: {e}")))?;
Ok(())
}
pub fn nand<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Nand, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("nand: {e}")))?;
Ok(())
}
pub fn nor<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Nor, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("nor: {e}")))?;
Ok(())
}
pub fn xor<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::Xor, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("xor: {e}")))?;
Ok(())
}
pub fn fused_add_relu<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
a: &DeviceBuffer<T>,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::FusedAddRelu, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let args = (a.as_device_ptr(), b.as_device_ptr(), c.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("fused_add_relu: {e}")))?;
Ok(())
}
pub fn fused_scale_add<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
alpha: T,
a: &DeviceBuffer<T>,
beta: T,
b: &DeviceBuffer<T>,
c: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
validate_binary_buffers(n, a, b, c)?;
let (kernel, _name) = build_binary_kernel(handle, PtxOp::FusedScaleAdd, T::PTX_TYPE)?;
let grid = grid_size_for(n, BLOCK_SIZE);
let params = LaunchParams::new(grid, BLOCK_SIZE);
let alpha_bits = alpha.to_bits_u64();
let beta_bits = beta.to_bits_u64();
let args = (
a.as_device_ptr(),
b.as_device_ptr(),
c.as_device_ptr(),
alpha_bits,
beta_bits,
n,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("fused_scale_add: {e}")))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn block_size_is_power_of_two() {
assert!(BLOCK_SIZE.is_power_of_two());
const { assert!(BLOCK_SIZE >= 32) };
}
#[test]
fn ptx_template_generates_add_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Add,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("add PTX generation should succeed");
assert!(ptx.contains("elementwise_add_f32"));
}
#[test]
fn ptx_template_generates_mul_f64() {
let template = ElementwiseTemplate::new(
PtxOp::Mul,
oxicuda_ptx::ir::PtxType::F64,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("mul PTX generation should succeed");
assert!(ptx.contains("elementwise_mul_f64"));
}
#[test]
fn ptx_template_generates_fused_add_relu_f32() {
let template = ElementwiseTemplate::new(
PtxOp::FusedAddRelu,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("fused_add_relu PTX generation should succeed");
assert!(ptx.contains("elementwise_fused_add_relu_f32"));
}
#[test]
fn ptx_template_generates_fused_scale_add_f32() {
let template = ElementwiseTemplate::new(
PtxOp::FusedScaleAdd,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("fused_scale_add PTX generation should succeed");
assert!(ptx.contains("elementwise_fused_scale_add_f32"));
}
#[test]
fn ptx_template_generates_sub_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Sub,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("sub PTX generation should succeed");
assert!(ptx.contains("elementwise_sub_f32"));
assert!(ptx.contains("sub.f32"));
}
#[test]
fn ptx_template_generates_div_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Div,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("div PTX generation should succeed");
assert!(ptx.contains("elementwise_div_f32"));
assert!(ptx.contains("div.rn.f32"));
}
#[test]
fn ptx_template_generates_pow_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Pow,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("pow PTX generation should succeed");
assert!(ptx.contains("elementwise_pow_f32"));
assert!(ptx.contains("lg2.approx.f32"));
assert!(ptx.contains("ex2.approx.f32"));
}
#[test]
fn ptx_template_generates_min_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Min,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("min PTX generation should succeed");
assert!(ptx.contains("elementwise_min_f32"));
assert!(ptx.contains("min.f32"));
}
#[test]
fn ptx_template_generates_max_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Max,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("max PTX generation should succeed");
assert!(ptx.contains("elementwise_max_f32"));
assert!(ptx.contains("max.f32"));
}
#[test]
fn ptx_template_generates_cmp_eq_f32() {
let template = ElementwiseTemplate::new(
PtxOp::CmpEq,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("cmp_eq PTX generation should succeed");
assert!(ptx.contains("elementwise_cmp_eq_f32"));
assert!(ptx.contains("setp.eq.f32"));
assert!(ptx.contains("selp.f32"));
}
#[test]
fn ptx_template_generates_cmp_ne_f32() {
let template = ElementwiseTemplate::new(
PtxOp::CmpNe,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("cmp_ne PTX generation should succeed");
assert!(ptx.contains("elementwise_cmp_ne_f32"));
assert!(ptx.contains("setp.ne.f32"));
}
#[test]
fn ptx_template_generates_cmp_lt_f32() {
let template = ElementwiseTemplate::new(
PtxOp::CmpLt,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("cmp_lt PTX generation should succeed");
assert!(ptx.contains("elementwise_cmp_lt_f32"));
assert!(ptx.contains("setp.lt.f32"));
}
#[test]
fn ptx_template_generates_cmp_gt_f32() {
let template = ElementwiseTemplate::new(
PtxOp::CmpGt,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("cmp_gt PTX generation should succeed");
assert!(ptx.contains("elementwise_cmp_gt_f32"));
assert!(ptx.contains("setp.gt.f32"));
}
#[test]
fn ptx_template_generates_cmp_le_f32() {
let template = ElementwiseTemplate::new(
PtxOp::CmpLe,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("cmp_le PTX generation should succeed");
assert!(ptx.contains("elementwise_cmp_le_f32"));
assert!(ptx.contains("setp.le.f32"));
}
#[test]
fn ptx_template_generates_cmp_ge_f32() {
let template = ElementwiseTemplate::new(
PtxOp::CmpGe,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("cmp_ge PTX generation should succeed");
assert!(ptx.contains("elementwise_cmp_ge_f32"));
assert!(ptx.contains("setp.ge.f32"));
}
#[test]
fn ptx_template_generates_or_max_f32() {
let template = ElementwiseTemplate::new(
PtxOp::OrMax,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("or_max PTX generation should succeed");
assert!(ptx.contains("elementwise_or_max_f32"));
assert!(ptx.contains("max.f32"));
}
#[test]
fn ptx_template_generates_or_prob_sum_f32() {
let template = ElementwiseTemplate::new(
PtxOp::OrProbSum,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("or_prob_sum PTX generation should succeed");
assert!(ptx.contains("elementwise_or_prob_sum_f32"));
assert!(ptx.contains("mul.f32"));
assert!(ptx.contains("sub.f32"));
assert!(ptx.contains("add.f32"));
}
#[test]
fn ptx_template_generates_nand_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Nand,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("nand PTX generation should succeed");
assert!(ptx.contains("elementwise_nand_f32"));
assert!(ptx.contains("mul.f32"));
assert!(ptx.contains("sub.f32"));
}
#[test]
fn ptx_template_generates_nor_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Nor,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("nor PTX generation should succeed");
assert!(ptx.contains("elementwise_nor_f32"));
assert!(ptx.contains("mul.f32"));
assert!(ptx.contains("add.f32"));
}
#[test]
fn ptx_template_generates_xor_f32() {
let template = ElementwiseTemplate::new(
PtxOp::Xor,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("xor PTX generation should succeed");
assert!(ptx.contains("elementwise_xor_f32"));
assert!(ptx.contains("mul.f32"));
assert!(ptx.contains("add.f32"));
assert!(ptx.contains("0f40000000")); }
}