use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::templates::elementwise::{ElementwiseOp as PtxElementwiseOp, ElementwiseTemplate};
use oxicuda_ptx::templates::reduction::{ReductionOp as PtxReductionOp, ReductionTemplate};
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::GpuFloat;
const REDUCE_BLOCK_SIZE: u32 = 256;
fn compute_inv_n<T: GpuFloat>(n: u32) -> BlasResult<u64> {
let inv_f64 = 1.0_f64 / f64::from(n);
match T::PTX_TYPE {
oxicuda_ptx::ir::PtxType::F32 => {
let inv_f32 = inv_f64 as f32;
Ok(u64::from(inv_f32.to_bits()))
}
oxicuda_ptx::ir::PtxType::F64 => Ok(inv_f64.to_bits()),
other => Err(BlasError::UnsupportedOperation(format!(
"mean: unsupported precision {} for scalar division",
other.as_ptx_str()
))),
}
}
fn build_sum_kernel(
handle: &BlasHandle,
ptx_type: oxicuda_ptx::ir::PtxType,
) -> BlasResult<(Kernel, String)> {
let template = ReductionTemplate {
op: PtxReductionOp::Sum,
precision: ptx_type,
target: handle.sm_version(),
block_size: REDUCE_BLOCK_SIZE,
};
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("reduce_sum (for mean): {e}")))?;
let module = Arc::new(
Module::from_ptx(&ptx_source)
.map_err(|e| BlasError::LaunchFailed(format!("module load for mean/sum: {e}")))?,
);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
Ok((kernel, kernel_name))
}
fn build_scale_kernel(
handle: &BlasHandle,
ptx_type: oxicuda_ptx::ir::PtxType,
) -> BlasResult<(Kernel, String)> {
let template = ElementwiseTemplate::new(PtxElementwiseOp::Scale, ptx_type, handle.sm_version());
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("scale (for mean): {e}")))?;
let module = Arc::new(
Module::from_ptx(&ptx_source)
.map_err(|e| BlasError::LaunchFailed(format!("module load for mean/scale: {e}")))?,
);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
Ok((kernel, kernel_name))
}
pub fn mean<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
input: &DeviceBuffer<T>,
result: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Err(BlasError::InvalidArgument(
"mean requires n > 0".to_string(),
));
}
let n_usize = n as usize;
if input.len() < n_usize {
return Err(BlasError::BufferTooSmall {
expected: n_usize,
actual: input.len(),
});
}
if result.is_empty() {
return Err(BlasError::BufferTooSmall {
expected: 1,
actual: 0,
});
}
let num_blocks = grid_size_for(n, REDUCE_BLOCK_SIZE);
let partials_needed = num_blocks as usize;
if result.len() < partials_needed {
return Err(BlasError::BufferTooSmall {
expected: partials_needed,
actual: result.len(),
});
}
let (sum_kernel, _) = build_sum_kernel(handle, T::PTX_TYPE)?;
let params1 = LaunchParams::new(num_blocks, REDUCE_BLOCK_SIZE);
let args1 = (input.as_device_ptr(), result.as_device_ptr(), n);
sum_kernel
.launch(¶ms1, handle.stream(), &args1)
.map_err(|e| BlasError::LaunchFailed(format!("mean/sum phase 1: {e}")))?;
if num_blocks > 1 {
let phase2_blocks = grid_size_for(num_blocks, REDUCE_BLOCK_SIZE);
if phase2_blocks > 1 {
return Err(BlasError::UnsupportedOperation(format!(
"mean: input size {n} requires more than two reduction phases"
)));
}
let (sum_kernel2, _) = build_sum_kernel(handle, T::PTX_TYPE)?;
let params2 = LaunchParams::new(1u32, REDUCE_BLOCK_SIZE);
let args2 = (result.as_device_ptr(), result.as_device_ptr(), num_blocks);
sum_kernel2
.launch(¶ms2, handle.stream(), &args2)
.map_err(|e| BlasError::LaunchFailed(format!("mean/sum phase 2: {e}")))?;
}
let inv_n_bits = compute_inv_n::<T>(n)?;
let (scale_kernel, _) = build_scale_kernel(handle, T::PTX_TYPE)?;
let scale_params = LaunchParams::new(1u32, REDUCE_BLOCK_SIZE);
let args3 = (
result.as_device_ptr(),
result.as_device_ptr(),
inv_n_bits,
1u32,
);
scale_kernel
.launch(&scale_params, handle.stream(), &args3)
.map_err(|e| BlasError::LaunchFailed(format!("mean/scale: {e}")))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ptx_template_generates_sum_for_mean() {
let template = ReductionTemplate {
op: PtxReductionOp::Sum,
precision: oxicuda_ptx::ir::PtxType::F32,
target: oxicuda_ptx::arch::SmVersion::Sm80,
block_size: 256,
};
let ptx = template
.generate()
.expect("sum PTX generation should succeed");
assert!(ptx.contains("reduce_sum_f32_bs256"));
}
#[test]
fn ptx_template_generates_scale_for_mean() {
let template = ElementwiseTemplate::new(
PtxElementwiseOp::Scale,
oxicuda_ptx::ir::PtxType::F32,
oxicuda_ptx::arch::SmVersion::Sm80,
);
let ptx = template
.generate()
.expect("scale PTX generation should succeed");
assert!(ptx.contains("elementwise_scale_f32"));
}
#[test]
fn inv_n_computation_f32() {
let n: u32 = 100;
let inv: f32 = 1.0 / (n as f32);
assert!((inv - 0.01).abs() < 1e-6);
}
#[test]
fn inv_n_computation_f64() {
let n: u32 = 100;
let inv: f64 = 1.0 / (n as f64);
assert!((inv - 0.01).abs() < 1e-12);
}
}