Function shared_sum

Source
pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
    client: &ComputeClient<R::Server, R::Channel>,
    input: TensorHandleRef<'_, R>,
    output: TensorHandleRef<'_, R>,
    cube_count: u32,
) -> Result<(), ReduceError>
Expand description

Sum all the elements of the input tensor distributed over cube_count cubes.

This is an optimized version for summing large tensors using multiple cubes. For summing a single axis, the regular [reduce] entry point is preferred.

Return an error if atomic addition is not supported for the type N.

§Important

This doesn’t set the value of output to 0 before computing the sums. It is the responsibility of the caller to ensure that output is set to the proper value. Basically, the behavior of this kernel is akin to the AddAssign operator as it update the output instead of overwriting it.

§Example

This examples show how to sum all the elements of a small 2 x 2 matrix. For more details, see the CubeCL documentation.

let client = /* ... */;
let size_f32 = std::mem::size_of::<f32>();

// Create input and output handles.
let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
let output_handle = client.empty(size_of::<F>());
let input = unsafe {
    TensorHandleRef::<R>::from_raw_parts(
        &input_handle,
        &[2, 1],
        &[2, 2],
        size_f32,
    )
};
let output = unsafe {
    TensorHandleRef::<R>::from_raw_parts(&output_handle, &[1], &[1], size_of::<F>())
};

// Here `R` is a `cubecl::Runtime`.
let result = shared_sum::<R, f32>(&client, input, output, cube_count);

if result.is_ok() {
       let binding = output_handle.binding();
       let bytes = client.read_one(binding);
       let output_values = f32::from_bytes(&bytes);
       println!("Output = {:?}", output_values); // Should print [6].
}