use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::templates::reduction::{PerAxisReductionTemplate, ReductionOp as PtxReductionOp};
use super::ops::ReductionOp;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::GpuFloat;
const AXIS_BLOCK_SIZE: u32 = 256;
fn build_axis_kernel(
handle: &BlasHandle,
ptx_op: PtxReductionOp,
ptx_type: oxicuda_ptx::ir::PtxType,
) -> BlasResult<(Kernel, String)> {
let template = PerAxisReductionTemplate {
op: ptx_op,
precision: ptx_type,
target: handle.sm_version(),
block_size: AXIS_BLOCK_SIZE,
};
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("reduce_axis_{}: {e}", ptx_op.as_str())))?;
let module = Arc::new(
Module::from_ptx(&ptx_source)
.map_err(|e| BlasError::LaunchFailed(format!("module load for {kernel_name}: {e}")))?,
);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup {kernel_name}: {e}")))?;
Ok((kernel, kernel_name))
}
pub fn reduce_axis<T: GpuFloat>(
handle: &BlasHandle,
op: ReductionOp,
outer: u32,
axis_len: u32,
inner: u32,
input: &DeviceBuffer<T>,
output: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if axis_len == 0 {
return Ok(());
}
let input_needed = (outer as usize)
.checked_mul(axis_len as usize)
.and_then(|x| x.checked_mul(inner as usize))
.ok_or_else(|| {
BlasError::InvalidArgument(
"reduce_axis: outer * axis_len * inner overflows usize".into(),
)
})?;
let output_needed = (outer as usize)
.checked_mul(inner as usize)
.ok_or_else(|| {
BlasError::InvalidArgument("reduce_axis: outer * inner overflows usize".into())
})?;
if input.len() < input_needed {
return Err(BlasError::BufferTooSmall {
expected: input_needed,
actual: input.len(),
});
}
if output.len() < output_needed {
return Err(BlasError::BufferTooSmall {
expected: output_needed,
actual: output.len(),
});
}
let ptx_op = op.to_ptx_op();
let (kernel, _) = build_axis_kernel(handle, ptx_op, T::PTX_TYPE)?;
let grid = outer.checked_mul(inner).ok_or_else(|| {
BlasError::InvalidArgument("reduce_axis: outer * inner grid overflows u32".into())
})?;
let params = LaunchParams::new(grid, AXIS_BLOCK_SIZE);
if op == ReductionOp::Mean {
let inv_axis_len = 1.0_f32 / axis_len as f32;
let args = (
input.as_device_ptr(),
output.as_device_ptr(),
axis_len,
inner,
inv_axis_len,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("reduce_axis mean: {e}")))?;
} else {
let args = (
input.as_device_ptr(),
output.as_device_ptr(),
axis_len,
inner,
);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("reduce_axis {}: {e}", op.as_str())))?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::ir::PtxType;
#[test]
fn axis_block_size_is_valid() {
assert!(AXIS_BLOCK_SIZE.is_power_of_two());
const { assert!(AXIS_BLOCK_SIZE >= 32) };
}
#[test]
fn ptx_generates_reduce_axis_sum_f32() {
let template = PerAxisReductionTemplate {
op: PtxReductionOp::Sum,
precision: PtxType::F32,
target: SmVersion::Sm80,
block_size: 256,
};
let ptx = template.generate().expect("generate must not fail");
assert!(ptx.contains("reduce_axis_sum_f32_bs256"));
assert!(ptx.contains("param_axis_len"));
assert!(ptx.contains("param_inner"));
}
#[test]
fn ptx_generates_reduce_axis_mean_f32() {
let template = PerAxisReductionTemplate {
op: PtxReductionOp::Mean,
precision: PtxType::F32,
target: SmVersion::Sm80,
block_size: 256,
};
let ptx = template.generate().expect("generate must not fail");
assert!(ptx.contains("param_inv_axis_len"));
}
#[test]
fn overflow_protection_outer_axis_inner() {
let big: u32 = u32::MAX;
let result = (big as usize)
.checked_mul(big as usize)
.and_then(|x| x.checked_mul(big as usize));
assert!(
result.is_none(),
"3-way u32::MAX multiply must overflow u64"
);
}
}