Function reduce

Source
pub fn reduce<R: Runtime, P: ReducePrecision, Out: Numeric, Inst: ReduceFamily>(
    client: &ComputeClient<R::Server, R::Channel>,
    input: TensorHandleRef<'_, R>,
    output: TensorHandleRef<'_, R>,
    axis: usize,
    strategy: Option<ReduceStrategy>,
    inst_config: Inst::Config,
) -> Result<(), ReduceError>
Expand description

Reduce the given axis of the input tensor using the instruction Inst and write the result into output.

An optional ReduceStrategy can be provided to force the reduction to use a specific algorithm. If omitted, a best effort is done to try and pick the best strategy supported for the provided client.

Return an error if strategy is Some(strategy) and the specified strategy is not supported by the client. Also returns an error if the axis is larger than the input rank or if the shape of output is invalid. The shape of output must be the same as input except with a value of 1 for the given axis.

§Example

This examples show how to sum the rows of a small 2 x 2 matrix into a 1 x 2 vector. For more details, see the CubeCL documentation.

use cubecl_reduce::instructions::Sum;

let client = /* ... */;
let size_f32 = std::mem::size_of::<f32>();
let axis = 0; // 0 for rows, 1 for columns in the case of a matrix.

// Create input and output handles.
let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
let input = unsafe {
    TensorHandleRef::<R>::from_raw_parts(
        &input_handle,
        &[2, 1],
        &[2, 2],
        size_f32,
    )
};

let output_handle = client.empty(2 * size_f32);
let output = unsafe {
    TensorHandleRef::<R>::from_raw_parts(
        &output_handle,
        &output_stride,
        &output_shape,
        size_f32,
    )
};

// Here `R` is a `cubecl::Runtime`.
let result = reduce::<R, f32, f32, Sum>(&client, input, output, axis, None);

if result.is_ok() {
       let binding = output_handle.binding();
       let bytes = client.read_one(binding);
       let output_values = f32::from_bytes(&bytes);
       println!("Output = {:?}", output_values); // Should print [1, 5].
}