use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::templates::reduction::{ReductionOp as PtxReductionOp, ReductionTemplate};
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::GpuFloat;
const REDUCE_BLOCK_SIZE: u32 = 256;
fn build_max_kernel(
handle: &BlasHandle,
ptx_type: oxicuda_ptx::ir::PtxType,
) -> BlasResult<(Kernel, String)> {
let template = ReductionTemplate {
op: PtxReductionOp::Max,
precision: ptx_type,
target: handle.sm_version(),
block_size: REDUCE_BLOCK_SIZE,
};
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("reduce_max: {e}")))?;
let module = Arc::new(
Module::from_ptx(&ptx_source)
.map_err(|e| BlasError::LaunchFailed(format!("module load for reduce_max: {e}")))?,
);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
Ok((kernel, kernel_name))
}
pub fn max<T: GpuFloat>(
handle: &BlasHandle,
n: u32,
input: &DeviceBuffer<T>,
result: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if n == 0 {
return Ok(());
}
let n_usize = n as usize;
if input.len() < n_usize {
return Err(BlasError::BufferTooSmall {
expected: n_usize,
actual: input.len(),
});
}
if result.is_empty() {
return Err(BlasError::BufferTooSmall {
expected: 1,
actual: 0,
});
}
let num_blocks = grid_size_for(n, REDUCE_BLOCK_SIZE);
if num_blocks == 1 {
let (kernel, _) = build_max_kernel(handle, T::PTX_TYPE)?;
let params = LaunchParams::new(1u32, REDUCE_BLOCK_SIZE);
let args = (input.as_device_ptr(), result.as_device_ptr(), n);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("max phase 1: {e}")))?;
} else {
let partials_needed = num_blocks as usize;
if result.len() < partials_needed {
return Err(BlasError::BufferTooSmall {
expected: partials_needed,
actual: result.len(),
});
}
let (kernel1, _) = build_max_kernel(handle, T::PTX_TYPE)?;
let params1 = LaunchParams::new(num_blocks, REDUCE_BLOCK_SIZE);
let args1 = (input.as_device_ptr(), result.as_device_ptr(), n);
kernel1
.launch(¶ms1, handle.stream(), &args1)
.map_err(|e| BlasError::LaunchFailed(format!("max phase 1: {e}")))?;
let phase2_blocks = grid_size_for(num_blocks, REDUCE_BLOCK_SIZE);
if phase2_blocks == 1 {
let (kernel2, _) = build_max_kernel(handle, T::PTX_TYPE)?;
let params2 = LaunchParams::new(1u32, REDUCE_BLOCK_SIZE);
let args2 = (result.as_device_ptr(), result.as_device_ptr(), num_blocks);
kernel2
.launch(¶ms2, handle.stream(), &args2)
.map_err(|e| BlasError::LaunchFailed(format!("max phase 2: {e}")))?;
} else {
return Err(BlasError::UnsupportedOperation(format!(
"max: input size {n} requires more than two reduction phases"
)));
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ptx_template_generates_max_f32() {
let template = ReductionTemplate {
op: PtxReductionOp::Max,
precision: oxicuda_ptx::ir::PtxType::F32,
target: oxicuda_ptx::arch::SmVersion::Sm80,
block_size: 256,
};
let ptx = template
.generate()
.expect("max PTX generation should succeed");
assert!(ptx.contains("reduce_max_f32_bs256"));
assert!(ptx.contains("max.f32"));
}
}