use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Dim3, Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::ir::PtxType;
use oxicuda_ptx::templates::softmax::{
MULTI_BLOCK_DEFAULT_STRIDE, MULTI_BLOCK_THREADS, MultiBlockSoftmaxPtx, SoftmaxTemplate,
generate_multi_block_softmax_ptx,
};
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::GpuFloat;
fn build_softmax_kernel(
handle: &BlasHandle,
ptx_type: oxicuda_ptx::ir::PtxType,
row_size: u32,
) -> BlasResult<(Kernel, String)> {
let template = SoftmaxTemplate {
precision: ptx_type,
target: handle.sm_version(),
row_size,
};
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("softmax (row_size={row_size}): {e}")))?;
let module = Arc::new(
Module::from_ptx(&ptx_source)
.map_err(|e| BlasError::LaunchFailed(format!("module load for softmax: {e}")))?,
);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
Ok((kernel, kernel_name))
}
pub fn softmax<T: GpuFloat>(
handle: &BlasHandle,
rows: u32,
cols: u32,
input: &DeviceBuffer<T>,
output: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if rows == 0 || cols == 0 {
return Err(BlasError::InvalidDimension(
"softmax requires rows > 0 and cols > 0".to_string(),
));
}
let total_elements = rows as usize * cols as usize;
if input.len() < total_elements {
return Err(BlasError::BufferTooSmall {
expected: total_elements,
actual: input.len(),
});
}
if output.len() < total_elements {
return Err(BlasError::BufferTooSmall {
expected: total_elements,
actual: output.len(),
});
}
if cols > 1024 {
if !matches!(T::PTX_TYPE, PtxType::F32) {
return Err(BlasError::UnsupportedOperation(format!(
"multi-block softmax (cols > 1024) currently supports only f32, \
got {}",
T::PTX_TYPE.as_ptx_str()
)));
}
return softmax_multi_block(handle, rows, cols, input, output);
}
let (kernel, _) = build_softmax_kernel(handle, T::PTX_TYPE, cols)?;
let (grid, block) = if cols <= 32 {
let block_size: u32 = 256;
let warps_per_block = block_size / 32;
let num_blocks = grid_size_for(rows, warps_per_block);
(num_blocks, block_size)
} else {
let block_size = cols.next_power_of_two();
(rows, block_size)
};
let params = LaunchParams::new(grid, block);
let args = (input.as_device_ptr(), output.as_device_ptr(), rows);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("softmax: {e}")))?;
Ok(())
}
fn softmax_multi_block<T: GpuFloat>(
handle: &BlasHandle,
rows: u32,
cols: u32,
input: &DeviceBuffer<T>,
output: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
let plan: MultiBlockSoftmaxPtx = generate_multi_block_softmax_ptx(
cols,
MULTI_BLOCK_DEFAULT_STRIDE,
MULTI_BLOCK_THREADS,
PtxType::F32,
handle.sm_version(),
)
.map_err(|e| {
BlasError::PtxGeneration(format!(
"multi-block softmax (rows={rows}, cols={cols}): {e}"
))
})?;
let reduce_kernel = build_kernel_from_ptx(&plan.reduce_ptx, &plan.reduce_kernel_name())?;
let finalize_kernel = build_kernel_from_ptx(&plan.finalize_ptx, &plan.finalize_kernel_name())?;
let scratch_pairs = rows.checked_mul(plan.num_blocks_per_row).ok_or_else(|| {
BlasError::InvalidDimension(format!(
"softmax scratch overflow: rows={rows} * num_blocks_per_row={}",
plan.num_blocks_per_row
))
})?;
let scratch_floats = (scratch_pairs as usize).checked_mul(2).ok_or_else(|| {
BlasError::InvalidDimension(format!(
"softmax scratch overflow: pairs={scratch_pairs} * 2 floats"
))
})?;
let scratch = DeviceBuffer::<f32>::alloc(scratch_floats).map_err(BlasError::Cuda)?;
let reduce_grid = Dim3::xy(plan.num_blocks_per_row, rows);
let reduce_block = Dim3::x(plan.threads_per_block);
let reduce_params = LaunchParams::new(reduce_grid, reduce_block);
let reduce_args = (input.as_device_ptr(), scratch.as_device_ptr(), rows);
reduce_kernel
.launch(&reduce_params, handle.stream(), &reduce_args)
.map_err(|e| BlasError::LaunchFailed(format!("softmax multi-block reduce: {e}")))?;
let finalize_grid = Dim3::x(rows);
let finalize_block = Dim3::x(plan.threads_per_block);
let finalize_params = LaunchParams::new(finalize_grid, finalize_block);
let finalize_args = (
input.as_device_ptr(),
output.as_device_ptr(),
scratch.as_device_ptr(),
rows,
);
finalize_kernel
.launch(&finalize_params, handle.stream(), &finalize_args)
.map_err(|e| BlasError::LaunchFailed(format!("softmax multi-block finalize: {e}")))?;
handle.stream().synchronize().map_err(BlasError::Cuda)?;
drop(scratch);
Ok(())
}
fn build_kernel_from_ptx(ptx_source: &str, kernel_name: &str) -> BlasResult<Kernel> {
let module = Arc::new(
Module::from_ptx(ptx_source)
.map_err(|e| BlasError::LaunchFailed(format!("module load for {kernel_name}: {e}")))?,
);
Kernel::from_module(module, kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))
}
#[cfg(test)]
mod tests {
use super::*;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::ir::PtxType;
use oxicuda_ptx::templates::softmax::SoftmaxTemplate;
#[test]
fn ptx_template_generates_softmax_warp_f32() {
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: SmVersion::Sm80,
row_size: 32,
};
let ptx = template
.generate()
.expect("warp softmax PTX should generate");
assert!(ptx.contains("softmax_f32_r32"));
assert!(ptx.contains("shfl.sync"));
}
#[test]
fn ptx_template_generates_softmax_block_f32() {
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: SmVersion::Sm80,
row_size: 128,
};
let ptx = template
.generate()
.expect("block softmax PTX should generate");
assert!(ptx.contains("softmax_f32_r128"));
}
#[test]
fn ptx_template_rejects_large_row_size() {
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: SmVersion::Sm80,
row_size: 2048,
};
assert!(template.generate().is_err());
}
#[test]
fn warp_launch_config() {
let block_size: u32 = 256;
let warps_per_block = block_size / 32;
let num_blocks = grid_size_for(100, warps_per_block);
assert_eq!(num_blocks, 13);
}
#[test]
fn block_launch_config() {
let cols: u32 = 100;
let block_size = cols.next_power_of_two();
assert_eq!(block_size, 128);
}
#[test]
fn softmax_warp_small_row() {
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: SmVersion::Sm80,
row_size: 8,
};
let ptx = template
.generate()
.expect("small warp softmax should generate");
assert!(ptx.contains("softmax_f32_r8"));
}
#[test]
fn multi_block_dispatch_layout_2048() {
let plan = generate_multi_block_softmax_ptx(
2048,
MULTI_BLOCK_DEFAULT_STRIDE,
MULTI_BLOCK_THREADS,
PtxType::F32,
SmVersion::Sm80,
)
.expect("multi-block softmax PTX should generate");
assert_eq!(plan.num_blocks_per_row, 2);
assert_eq!(plan.threads_per_block, MULTI_BLOCK_THREADS);
assert!(
plan.reduce_ptx
.contains(&format!(".entry {}", plan.reduce_kernel_name()))
);
assert!(
plan.finalize_ptx
.contains(&format!(".entry {}", plan.finalize_kernel_name()))
);
}
#[test]
fn multi_block_dispatch_scratch_for_8192() {
let plan = generate_multi_block_softmax_ptx(
8192,
MULTI_BLOCK_DEFAULT_STRIDE,
MULTI_BLOCK_THREADS,
PtxType::F32,
SmVersion::Sm80,
)
.expect("multi-block softmax PTX should generate");
assert_eq!(plan.num_blocks_per_row, 8);
assert_eq!(plan.scratch_bytes_per_row, 8 * 2 * 4);
let rows: u32 = 16;
let scratch_pairs = rows * plan.num_blocks_per_row;
let scratch_floats = scratch_pairs as usize * 2;
assert_eq!(scratch_floats, 16 * 8 * 2);
}
#[test]
fn multi_block_rejects_non_f32_dtype_in_template() {
let r = generate_multi_block_softmax_ptx(
2048,
MULTI_BLOCK_DEFAULT_STRIDE,
MULTI_BLOCK_THREADS,
PtxType::F64,
SmVersion::Sm80,
);
assert!(r.is_err());
}
}