use cubecl::{
features::TypeUsage,
std::tensor::layout::{
linear::{LinearLayout, LinearLayoutArgs, LinearView, LinearViewLaunch},
plain::PlainLayoutLaunch,
},
};
use cubecl::{ir::ElemType, std::tensor::layout::linear::linear_view};
use cubecl::{prelude::*, tensor_line_size_parallel};
use crate::ReduceError;
pub fn shared_sum<R: Runtime>(
client: &ComputeClient<R>,
input: TensorHandleRef<R>,
output: TensorHandleRef<R>,
cube_count: u32,
input_elem: ElemType,
) -> Result<(), ReduceError> {
if !client
.properties()
.type_usage(StorageType::Atomic(input_elem))
.contains(TypeUsage::AtomicAdd)
{
return Err(ReduceError::MissingAtomicAdd(input_elem.into()));
}
let input_len = input.shape.iter().product::<usize>();
let contiguous_buffer = input_len * input.elem_size == input.handle.size() as usize;
let line_size = if contiguous_buffer {
client
.io_optimized_line_sizes(input.elem_size)
.filter(|line_size| input_len.is_multiple_of(*line_size))
.max()
.unwrap_or(1)
} else {
tensor_line_size_parallel(
client.io_optimized_line_sizes(input.elem_size),
input.shape,
input.strides,
input.shape.len() - 1,
)
};
let input_view = if contiguous_buffer {
let layout = LinearLayoutArgs::Plain(PlainLayoutLaunch::new(ScalarArg::new(
input_len / line_size,
)));
let buffer = unsafe {
ArrayArg::from_raw_parts_and_size(input.handle, input_len, line_size, input.elem_size)
};
LinearViewLaunch::new::<LinearLayout>(buffer, layout)
} else {
linear_view(client, &input, line_size)
};
let cube_dim = CubeDim::new_2d(32, 8); let num_units = cube_count * cube_dim.num_elems();
let num_lines_per_unit = input_len.div_ceil(num_units as usize * line_size);
let cube_count = CubeCount::new_1d(cube_count);
let address_type = input
.required_address_type()
.max(output.required_address_type());
let result = unsafe {
shared_sum_kernel::launch_unchecked(
client,
cube_count,
cube_dim,
address_type,
input_view,
output.as_tensor_arg(1),
cube_dim.num_elems() as usize,
line_size,
num_lines_per_unit,
input_elem,
)
};
match result {
Ok(_) => Ok(()),
Err(err) => Err(ReduceError::Launch(err)),
}
}
#[cube(launch_unchecked, address_type = "dynamic")]
fn shared_sum_kernel<N: Numeric>(
input: &LinearView<Line<N>>,
output: &mut Tensor<Atomic<N>>,
#[comptime] shared_memory_size: usize,
#[comptime] line_size: LineSize,
#[comptime] num_lines_per_unit: usize,
#[define(N)] _dtype: ElemType,
) {
let mut shared_memory = SharedMemory::new_lined(shared_memory_size, line_size);
shared_memory[UNIT_POS as usize] = Line::empty(line_size).fill(N::from_int(0));
let start = ABSOLUTE_POS * num_lines_per_unit;
let end = start + num_lines_per_unit;
let start = select(start < input.shape(), start, input.shape());
let end = select(end < input.shape(), end, input.shape());
for k in start..end {
shared_memory[UNIT_POS as usize] += input[k];
}
let line = sum_shared_memory(&mut shared_memory);
let sum = RuntimeCell::<N>::new(N::from_int(0));
#[unroll]
for k in 0..line_size {
let update = line[k] + sum.read();
sum.store(update);
}
if UNIT_POS == 0 {
output[0].fetch_add(sum.consume());
}
}
#[cube]
fn sum_shared_memory<N: Numeric>(accumulator: &mut SharedMemory<Line<N>>) -> Line<N> {
sync_cube();
let mut num_active_units = CUBE_DIM;
let mut jump = 1;
while num_active_units > 1 {
num_active_units /= 2;
let destination = jump * 2 * UNIT_POS;
let origin = jump * (2 * UNIT_POS + 1);
if UNIT_POS < num_active_units {
let element = accumulator[origin as usize];
accumulator[destination as usize] += element;
}
jump *= 2;
sync_cube();
}
accumulator[0]
}