use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Kernel, LaunchParams, grid_size_for};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::templates::softmax::SoftmaxTemplate;
use crate::error::{BlasError, BlasResult};
use crate::handle::BlasHandle;
use crate::types::GpuFloat;
fn build_softmax_kernel(
handle: &BlasHandle,
ptx_type: oxicuda_ptx::ir::PtxType,
row_size: u32,
) -> BlasResult<(Kernel, String)> {
let template = SoftmaxTemplate {
precision: ptx_type,
target: handle.sm_version(),
row_size,
};
let kernel_name = template.kernel_name();
let ptx_source = template
.generate()
.map_err(|e| BlasError::PtxGeneration(format!("softmax (row_size={row_size}): {e}")))?;
let module = Arc::new(
Module::from_ptx(&ptx_source)
.map_err(|e| BlasError::LaunchFailed(format!("module load for softmax: {e}")))?,
);
let kernel = Kernel::from_module(module, &kernel_name)
.map_err(|e| BlasError::LaunchFailed(format!("kernel lookup for {kernel_name}: {e}")))?;
Ok((kernel, kernel_name))
}
pub fn softmax<T: GpuFloat>(
handle: &BlasHandle,
rows: u32,
cols: u32,
input: &DeviceBuffer<T>,
output: &mut DeviceBuffer<T>,
) -> BlasResult<()> {
if rows == 0 || cols == 0 {
return Err(BlasError::InvalidDimension(
"softmax requires rows > 0 and cols > 0".to_string(),
));
}
if cols > 1024 {
return Err(BlasError::UnsupportedOperation(format!(
"softmax: cols={cols} exceeds the current limit of 1024; \
multi-block softmax not yet implemented"
)));
}
let total_elements = rows as usize * cols as usize;
if input.len() < total_elements {
return Err(BlasError::BufferTooSmall {
expected: total_elements,
actual: input.len(),
});
}
if output.len() < total_elements {
return Err(BlasError::BufferTooSmall {
expected: total_elements,
actual: output.len(),
});
}
let (kernel, _) = build_softmax_kernel(handle, T::PTX_TYPE, cols)?;
let (grid, block) = if cols <= 32 {
let block_size: u32 = 256;
let warps_per_block = block_size / 32;
let num_blocks = grid_size_for(rows, warps_per_block);
(num_blocks, block_size)
} else {
let block_size = cols.next_power_of_two();
(rows, block_size)
};
let params = LaunchParams::new(grid, block);
let args = (input.as_device_ptr(), output.as_device_ptr(), rows);
kernel
.launch(¶ms, handle.stream(), &args)
.map_err(|e| BlasError::LaunchFailed(format!("softmax: {e}")))?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use oxicuda_ptx::arch::SmVersion;
use oxicuda_ptx::ir::PtxType;
use oxicuda_ptx::templates::softmax::SoftmaxTemplate;
#[test]
fn ptx_template_generates_softmax_warp_f32() {
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: SmVersion::Sm80,
row_size: 32,
};
let ptx = template
.generate()
.expect("warp softmax PTX should generate");
assert!(ptx.contains("softmax_f32_r32"));
assert!(ptx.contains("shfl.sync"));
}
#[test]
fn ptx_template_generates_softmax_block_f32() {
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: SmVersion::Sm80,
row_size: 128,
};
let ptx = template
.generate()
.expect("block softmax PTX should generate");
assert!(ptx.contains("softmax_f32_r128"));
}
#[test]
fn ptx_template_rejects_large_row_size() {
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: SmVersion::Sm80,
row_size: 2048,
};
assert!(template.generate().is_err());
}
#[test]
fn warp_launch_config() {
let block_size: u32 = 256;
let warps_per_block = block_size / 32;
let num_blocks = grid_size_for(100, warps_per_block);
assert_eq!(num_blocks, 13);
}
#[test]
fn block_launch_config() {
let cols: u32 = 100;
let block_size = cols.next_power_of_two();
assert_eq!(block_size, 128);
}
#[test]
fn softmax_warp_small_row() {
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: SmVersion::Sm80,
row_size: 8,
};
let ptx = template
.generate()
.expect("small warp softmax should generate");
assert!(ptx.contains("softmax_f32_r8"));
}
}