use cuda_rust_wasm::runtime::{Grid, Block, thread, block, grid};
use cuda_rust_wasm::memory::{DeviceBuffer, SharedMemory};
use cuda_rust_wasm::kernel::launch_kernel;
#[kernel]
pub fn reduction(
input: &[f32],
output: &mut [f32],
n: u32,
) {
#[shared]
static mut PARTIAL_SUMS: [f32; 256] = [0.0; 256];
let tid = thread::index().x;
let gid = block::index().x * block::dim().x + tid;
let block_size = block::dim().x;
let mut sum = 0.0f32;
let mut i = gid;
while i < n {
sum += input[i as usize];
i += grid::dim().x * block_size;
}
unsafe {
PARTIAL_SUMS[tid as usize] = sum;
}
cuda_rust_wasm::runtime::sync_threads();
let mut stride = block_size / 2;
while stride > 0 {
if tid < stride {
unsafe {
PARTIAL_SUMS[tid as usize] += PARTIAL_SUMS[(tid + stride) as usize];
}
}
cuda_rust_wasm::runtime::sync_threads();
stride /= 2;
}
if tid == 0 {
output[block::index().x as usize] = unsafe { PARTIAL_SUMS[0] };
}
}
#[kernel]
pub fn reductionWarp(
input: &[f32],
output: &mut [f32],
n: u32,
) {
#[shared]
static mut WARP_SUMS: [f32; 32] = [0.0; 32];
let tid = thread::index().x;
let lane_id = tid & 31; let warp_id = tid >> 5; let gid = block::index().x * block::dim().x + tid;
let block_size = block::dim().x;
let mut sum = 0.0f32;
let mut i = gid;
while i < n {
sum += input[i as usize];
i += grid::dim().x * block_size;
}
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 16);
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 8);
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 4);
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 2);
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 1);
if lane_id == 0 {
unsafe {
WARP_SUMS[warp_id as usize] = sum;
}
}
cuda_rust_wasm::runtime::sync_threads();
if warp_id == 0 {
sum = if tid < (block_size >> 5) {
unsafe { WARP_SUMS[lane_id as usize] }
} else {
0.0
};
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 16);
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 8);
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 4);
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 2);
sum += cuda_rust_wasm::runtime::warp_shuffle_down(sum, 1);
if tid == 0 {
output[block::index().x as usize] = sum;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use cuda_rust_wasm::runtime::CudaRuntime;
#[test]
fn test_reduction() {
let runtime = CudaRuntime::new().unwrap();
let n = 1_000_000;
let data: Vec<f32> = (0..n).map(|i| 1.0).collect();
let d_input = DeviceBuffer::from_slice(&data).unwrap();
let block_size = 256;
let grid_size = (n + block_size - 1) / block_size;
let mut d_output = DeviceBuffer::new(grid_size).unwrap();
launch_kernel!(
reduction<<<grid_size, block_size>>>(
d_input.as_slice(),
d_output.as_mut_slice(),
n as u32
)
);
let mut partial_sums = vec![0.0f32; grid_size];
d_output.copy_to_host(&mut partial_sums).unwrap();
let total: f32 = partial_sums.iter().sum();
assert!((total - n as f32).abs() < 1e-3,
"Expected {}, got {}", n as f32, total);
}
#[test]
fn test_reduction_warp() {
let runtime = CudaRuntime::new().unwrap();
let n = 100_000;
let data: Vec<f32> = (0..n).map(|i| (i % 10) as f32).collect();
let expected_sum: f32 = data.iter().sum();
let d_input = DeviceBuffer::from_slice(&data).unwrap();
let block_size = 256;
let grid_size = (n + block_size * 4 - 1) / (block_size * 4); let mut d_output = DeviceBuffer::new(grid_size).unwrap();
launch_kernel!(
reductionWarp<<<grid_size, block_size>>>(
d_input.as_slice(),
d_output.as_mut_slice(),
n as u32
)
);
let mut partial_sums = vec![0.0f32; grid_size];
d_output.copy_to_host(&mut partial_sums).unwrap();
let total: f32 = partial_sums.iter().sum();
assert!((total - expected_sum).abs() < 1e-3,
"Expected {}, got {}", expected_sum, total);
}
}