use cubecl_core::prelude::*;
use cubecl_core::{self as cubecl};
use cubecl_runtime::TypeUsage;
use crate::ReduceError;
pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
client: &ComputeClient<R::Server>,
input: TensorHandleRef<R>,
output: TensorHandleRef<R>,
cube_count: u32,
) -> Result<(), ReduceError> {
let atomic_elem = Atomic::<N>::as_type_native_unchecked();
if !client
.properties()
.type_usage(atomic_elem)
.contains(TypeUsage::AtomicAdd)
{
return Err(ReduceError::MissingAtomicAdd(N::as_type_native_unchecked()));
}
let input_len = input.shape.iter().map(|s| *s as u32).product::<u32>();
let line_size = R::io_optimized_line_sizes_unchecked(size_of::<N>())
.filter(|line_size| input_len % *line_size as u32 == 0)
.max()
.unwrap_or(1) as u32;
let cube_dim = CubeDim::new_2d(32, 8); let num_units = cube_count * cube_dim.num_elems();
let num_lines_per_unit = input_len.div_ceil(num_units * line_size);
let cube_count = CubeCount::new_1d(cube_count);
unsafe {
shared_sum_kernel::launch_unchecked::<N, R>(
client,
cube_count,
cube_dim,
input.as_tensor_arg(line_size as u8),
output.as_tensor_arg(1),
cube_dim.num_elems(),
line_size,
num_lines_per_unit,
);
}
Ok(())
}
#[cube(launch_unchecked)]
fn shared_sum_kernel<N: Numeric>(
input: &Tensor<Line<N>>,
output: &mut Tensor<Atomic<N>>,
#[comptime] shared_memory_size: u32,
#[comptime] line_size: u32,
#[comptime] num_lines_per_unit: u32,
) {
let mut shared_memory = SharedMemory::new_lined(shared_memory_size, line_size);
shared_memory[UNIT_POS] = Line::empty(line_size).fill(N::from_int(0));
let start = ABSOLUTE_POS * num_lines_per_unit;
let end = start + num_lines_per_unit;
let start = select(start < input.len(), start, input.len());
let end = select(end < input.len(), end, input.len());
for k in start..end {
shared_memory[UNIT_POS] += input[k];
}
let line = sum_shared_memory(&mut shared_memory);
let sum = RuntimeCell::<N>::new(N::from_int(0));
#[unroll]
for k in 0..line_size {
let update = line[k] + sum.read();
sum.store(update);
}
if UNIT_POS == 0 {
Atomic::add(&output[0], sum.consume());
}
}
#[cube]
fn sum_shared_memory<N: Numeric>(accumulator: &mut SharedMemory<Line<N>>) -> Line<N> {
sync_cube();
let mut num_active_units = CUBE_DIM;
let mut jump = 1;
while num_active_units > 1 {
num_active_units /= 2;
let destination = jump * 2 * UNIT_POS;
let origin = jump * (2 * UNIT_POS + 1);
if UNIT_POS < num_active_units {
let element = accumulator[origin];
accumulator[destination] += element;
}
jump *= 2;
sync_cube();
}
accumulator[0]
}