#![allow(unused_imports)]
#[cfg(feature = "cuda")]
use trueno_gpu::driver::{CudaContext, CudaModule, CudaStream, GpuBuffer, LaunchConfig};
#[cfg(feature = "cuda")]
use trueno_gpu::kernels::{BatchedSoftmaxKernel, Kernel};
fn cpu_softmax(input: &[f32]) -> Vec<f32> {
let max_val = input.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_vals: Vec<f32> = input.iter().map(|&x| (x - max_val).exp()).collect();
let sum: f32 = exp_vals.iter().sum();
exp_vals.iter().map(|&x| x / sum).collect()
}
#[test]
fn test_cpu_softmax_sanity() {
let input = vec![1.0, 2.0, 3.0, 4.0];
let output = cpu_softmax(&input);
let sum: f32 = output.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-6,
"Softmax sum should be 1.0, got {}",
sum
);
for (i, &v) in output.iter().enumerate() {
assert!(
v > 0.0 && v < 1.0,
"Softmax[{}] = {} should be in (0, 1)",
i,
v
);
}
assert!(output[3] > output[2] && output[2] > output[1] && output[1] > output[0]);
}
#[test]
#[cfg(feature = "cuda")]
fn test_batched_softmax_short_row() {
let ctx = match CudaContext::new(0) {
Ok(ctx) => ctx,
Err(_) => return, };
let row_size = 32u32;
let total_rows = 1u32;
let input: Vec<f32> = (1..=row_size).map(|x| x as f32).collect();
let expected = cpu_softmax(&input);
let input_buf = GpuBuffer::from_host(&ctx, &input).expect("Upload failed");
let output_buf: GpuBuffer<f32> = GpuBuffer::new(&ctx, row_size as usize).expect("Alloc failed");
let kernel = BatchedSoftmaxKernel::new(total_rows, row_size);
let ptx = kernel.emit_ptx();
let mut module = CudaModule::from_ptx(&ctx, &ptx).expect("Compile failed");
let stream = CudaStream::new(&ctx).expect("Stream failed");
let config = LaunchConfig {
grid: (total_rows, 1, 1),
block: (32, 1, 1),
shared_mem: 72,
};
let input_ptr = input_buf.as_ptr();
let output_ptr = output_buf.as_ptr();
let mut args: Vec<*mut std::ffi::c_void> = vec![
std::ptr::addr_of!(input_ptr) as *mut _,
std::ptr::addr_of!(output_ptr) as *mut _,
std::ptr::addr_of!(total_rows) as *mut _,
std::ptr::addr_of!(row_size) as *mut _,
];
unsafe {
stream
.launch_kernel(&mut module, kernel.name(), &config, &mut args)
.expect("Launch failed");
}
stream.synchronize().expect("Sync failed");
let mut output = vec![0.0f32; row_size as usize];
output_buf
.copy_to_host(&mut output)
.expect("Download failed");
let sum: f32 = output.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"Short row: softmax sum should be 1.0, got {} (delta={})",
sum,
(sum - 1.0).abs()
);
for (i, (&gpu, &cpu)) in output.iter().zip(expected.iter()).enumerate() {
let delta: f32 = (gpu - cpu).abs();
assert!(
delta < 1e-5,
"Short row [{}]: GPU={} vs CPU={}, delta={}",
i,
gpu,
cpu,
delta
);
}
eprintln!("✓ Short row (32 elements) softmax PASSED");
}
#[test]
#[cfg(feature = "cuda")]
fn test_batched_softmax_long_row_1500() {
let ctx = match CudaContext::new(0) {
Ok(ctx) => ctx,
Err(_) => return, };
let row_size = 1500u32;
let total_rows = 1u32;
let input: Vec<f32> = (0..row_size)
.map(|i| -5.0 + 10.0 * (i as f32 / (row_size - 1) as f32))
.collect();
let expected = cpu_softmax(&input);
eprintln!(
"Input: first 5 = {:?}, last 5 = {:?}",
&input[..5],
&input[row_size as usize - 5..]
);
eprintln!(
"CPU expected: first 5 = {:?}, last 5 = {:?}",
&expected[..5],
&expected[row_size as usize - 5..]
);
eprintln!("CPU expected sum = {}", expected.iter().sum::<f32>());
let input_buf = GpuBuffer::from_host(&ctx, &input).expect("Upload failed");
let output_buf: GpuBuffer<f32> = GpuBuffer::new(&ctx, row_size as usize).expect("Alloc failed");
let kernel = BatchedSoftmaxKernel::new(total_rows, row_size);
let ptx = kernel.emit_ptx();
eprintln!("PTX has {} lines", ptx.lines().count());
eprintln!("PTX contains max_loop: {}", ptx.contains("max_loop:"));
eprintln!("PTX contains sum_loop: {}", ptx.contains("sum_loop:"));
eprintln!("PTX contains write_loop: {}", ptx.contains("write_loop:"));
let mut module = CudaModule::from_ptx(&ctx, &ptx).expect("Compile failed");
let stream = CudaStream::new(&ctx).expect("Stream failed");
let config = LaunchConfig {
grid: (total_rows, 1, 1),
block: (32, 1, 1),
shared_mem: 72,
};
let input_ptr = input_buf.as_ptr();
let output_ptr = output_buf.as_ptr();
let mut args: Vec<*mut std::ffi::c_void> = vec![
std::ptr::addr_of!(input_ptr) as *mut _,
std::ptr::addr_of!(output_ptr) as *mut _,
std::ptr::addr_of!(total_rows) as *mut _,
std::ptr::addr_of!(row_size) as *mut _,
];
unsafe {
stream
.launch_kernel(&mut module, kernel.name(), &config, &mut args)
.expect("Launch failed");
}
stream.synchronize().expect("Sync failed");
let mut output = vec![0.0f32; row_size as usize];
output_buf
.copy_to_host(&mut output)
.expect("Download failed");
eprintln!(
"GPU output: first 5 = {:?}, last 5 = {:?}",
&output[..5],
&output[row_size as usize - 5..]
);
let sum: f32 = output.iter().sum();
eprintln!("GPU sum = {}", sum);
assert!(
(sum - 1.0).abs() < 1e-4,
"LONG ROW BUG: softmax sum should be 1.0, got {} (delta={})",
sum,
(sum - 1.0).abs()
);
let gpu_max = output.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let gpu_min = output
.iter()
.cloned()
.filter(|&x| x > 0.0)
.fold(f32::INFINITY, f32::min);
let cpu_max = expected.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let cpu_min = expected
.iter()
.cloned()
.filter(|&x| x > 0.0)
.fold(f32::INFINITY, f32::min);
eprintln!("GPU max={}, min={}", gpu_max, gpu_min);
eprintln!("CPU max={}, min={}", cpu_max, cpu_min);
let gpu_argmax = output
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
let cpu_argmax = expected
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
assert_eq!(
gpu_argmax, cpu_argmax,
"Argmax mismatch: GPU={} vs CPU={}",
gpu_argmax, cpu_argmax
);
let test_indices = [0, 32, 100, 500, 1000, 1499];
for &i in &test_indices {
let delta = (output[i] - expected[i]).abs();
let rel_delta = if expected[i].abs() > 1e-10 {
delta / expected[i].abs()
} else {
delta
};
assert!(
rel_delta < 0.1 || delta < 1e-6,
"Long row [{}]: GPU={} vs CPU={}, delta={}, rel_delta={}",
i,
output[i],
expected[i],
delta,
rel_delta
);
}
eprintln!("✓ Long row (1500 elements) softmax PASSED");
}
#[test]
#[cfg(feature = "cuda")]
fn test_batched_softmax_6_rows_of_1500() {
let ctx = match CudaContext::new(0) {
Ok(ctx) => ctx,
Err(_) => return, };
let row_size = 1500u32;
let total_rows = 6u32;
let mut input: Vec<f32> = Vec::with_capacity((total_rows * row_size) as usize);
for row in 0..total_rows {
for i in 0..row_size {
let base = -5.0 + 10.0 * (i as f32 / (row_size - 1) as f32);
input.push(base + 0.1 * row as f32);
}
}
let mut expected: Vec<f32> = Vec::with_capacity((total_rows * row_size) as usize);
for row in 0..total_rows {
let start = (row * row_size) as usize;
let end = start + row_size as usize;
let row_softmax = cpu_softmax(&input[start..end]);
expected.extend(row_softmax);
}
let input_buf = GpuBuffer::from_host(&ctx, &input).expect("Upload failed");
let output_buf: GpuBuffer<f32> =
GpuBuffer::new(&ctx, (total_rows * row_size) as usize).expect("Alloc failed");
let kernel = BatchedSoftmaxKernel::new(total_rows, row_size);
let ptx = kernel.emit_ptx();
let mut module = CudaModule::from_ptx(&ctx, &ptx).expect("Compile failed");
let stream = CudaStream::new(&ctx).expect("Stream failed");
let config = LaunchConfig {
grid: (total_rows, 1, 1),
block: (32, 1, 1),
shared_mem: 72,
};
let input_ptr = input_buf.as_ptr();
let output_ptr = output_buf.as_ptr();
let mut args: Vec<*mut std::ffi::c_void> = vec![
std::ptr::addr_of!(input_ptr) as *mut _,
std::ptr::addr_of!(output_ptr) as *mut _,
std::ptr::addr_of!(total_rows) as *mut _,
std::ptr::addr_of!(row_size) as *mut _,
];
unsafe {
stream
.launch_kernel(&mut module, kernel.name(), &config, &mut args)
.expect("Launch failed");
}
stream.synchronize().expect("Sync failed");
let mut output = vec![0.0f32; (total_rows * row_size) as usize];
output_buf
.copy_to_host(&mut output)
.expect("Download failed");
for row in 0..total_rows {
let start = (row * row_size) as usize;
let end = start + row_size as usize;
let row_sum: f32 = output[start..end].iter().sum();
assert!(
(row_sum - 1.0).abs() < 1e-4,
"Row {}: softmax sum should be 1.0, got {} (delta={})",
row,
row_sum,
(row_sum - 1.0).abs()
);
}
eprintln!("✓ Batched softmax (6 rows × 1500 elements) PASSED");
}
#[cfg(not(feature = "cuda"))]
#[test]
fn test_batched_softmax_short_row() {
}
#[cfg(not(feature = "cuda"))]
#[test]
fn test_batched_softmax_long_row_1500() {
}
#[cfg(not(feature = "cuda"))]
#[test]
fn test_batched_softmax_6_rows_of_1500() {
}