use std::sync::Arc;
use oxicuda_driver::Module;
use oxicuda_launch::{Dim3, Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::{ir::PtxType, templates::softmax::SoftmaxTemplate};
use crate::context::CudaContext;
use crate::error::CudaDispatchError;
pub fn cuda_softmax(
ctx: &CudaContext,
data: &[f32],
shape: &[usize],
) -> Result<Option<Vec<f32>>, CudaDispatchError> {
if shape.is_empty() {
return Ok(None);
}
let row_size = match shape.last() {
Some(&s) => s as u32,
None => return Ok(None),
};
if row_size > 1024 {
return Ok(None);
}
let batch_size: u32 = shape[..shape.len() - 1].iter().product::<usize>().max(1) as u32;
let sm = ctx.dnn.sm_version();
let template = SoftmaxTemplate {
precision: PtxType::F32,
target: sm,
row_size,
};
let kernel_name = template.kernel_name();
let ptx = template
.generate()
.map_err(|e| CudaDispatchError::Ptx(e.to_string()))?;
let module = Arc::new(Module::from_ptx(&ptx).map_err(CudaDispatchError::Driver)?);
let kernel = Kernel::from_module(module, &kernel_name).map_err(CudaDispatchError::Driver)?;
let n = data.len();
let mut d_input: DeviceBuffer<f32> = DeviceBuffer::alloc(n)?;
d_input.copy_from_host(data)?;
let d_output: DeviceBuffer<f32> = DeviceBuffer::alloc(n)?;
let block_threads = if row_size <= 32 {
32u32
} else {
row_size.next_power_of_two().min(256)
};
let params = LaunchParams::new(Dim3::from(batch_size), Dim3::from(block_threads));
let stream = ctx.dnn.stream();
let args = (
d_input.as_device_ptr(),
d_output.as_device_ptr(),
batch_size,
);
kernel
.launch(¶ms, stream, &args)
.map_err(CudaDispatchError::Driver)?;
stream.synchronize().map_err(CudaDispatchError::Driver)?;
let mut out = vec![0.0_f32; n];
d_output.copy_to_host(&mut out)?;
Ok(Some(out))
}