pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
client: &ComputeClient<R::Server, R::Channel>,
input: TensorHandleRef<'_, R>,
output: TensorHandleRef<'_, R>,
cube_count: u32,
) -> Result<(), ReduceError>
Expand description
Sum all the elements of the input tensor distributed over cube_count
cubes.
This is an optimized version for summing large tensors using multiple cubes. For summing a single axis, the regular [reduce] entry point is preferred.
Return an error if atomic addition is not supported for the type N
.
§Important
This doesn’t set the value of output to 0 before computing the sums. It is the responsibility of the caller to ensure that output is set to the proper value. Basically, the behavior of this kernel is akin to the AddAssign operator as it update the output instead of overwriting it.
§Example
This examples show how to sum all the elements of a small 2 x 2
matrix.
For more details, see the CubeCL documentation.
ⓘ
let client = /* ... */;
let size_f32 = std::mem::size_of::<f32>();
// Create input and output handles.
let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
let output_handle = client.empty(size_of::<F>());
let input = unsafe {
TensorHandleRef::<R>::from_raw_parts(
&input_handle,
&[2, 1],
&[2, 2],
size_f32,
)
};
let output = unsafe {
TensorHandleRef::<R>::from_raw_parts(&output_handle, &[1], &[1], size_of::<F>())
};
// Here `R` is a `cubecl::Runtime`.
let result = shared_sum::<R, f32>(&client, input, output, cube_count);
if result.is_ok() {
let binding = output_handle.binding();
let bytes = client.read_one(binding);
let output_values = f32::from_bytes(&bytes);
println!("Output = {:?}", output_values); // Should print [6].
}