cubek_reduce/lib.rs
1//! This provides different implementations of the reduce algorithm which
2//! can run on multiple GPU backends using CubeCL.
3//!
4//! A reduction is a tensor operation mapping a rank `R` tensor to a rank `R - 1`
5//! by agglomerating all elements along a given axis with some binary operator.
6//! This is often also called folding.
7//!
8//! This crate provides a main entrypoint as the [`reduce`] function which allows to automatically
9//! perform a reduction for a given instruction implementing the [`ReduceInstruction`] trait and a given [`ReduceStrategy`].
10//! It also provides implementation of the [`ReduceInstruction`] trait for common operations in the [`instructions`] module.
11//! Finally, it provides many reusable primitives to perform different general reduction algorithms in the [`primitives`] module.
12
13pub mod components;
14pub mod launch;
15pub mod routines;
16
17mod error;
18
19pub use crate::launch::ReduceStrategy;
20use crate::{components::instructions::ReduceOperationConfig, launch::launch_reduce};
21pub use components::{
22 args::init_tensors,
23 config::*,
24 instructions::{ReduceFamily, ReduceInstruction},
25 precision::ReducePrecision,
26};
27use cubecl::prelude::*;
28pub use error::*;
29pub use launch::{ReduceDtypes, reduce_kernel};
30pub use routines::shared_sum::shared_sum;
31
32/// Reduce the given `axis` of the `input` tensor using the instruction `Inst` and write the result into `output`.
33///
34/// An optional [`ReduceStrategy`] can be provided to force the reduction to use a specific algorithm. If omitted, a best effort
35/// is done to try and pick the best strategy supported for the provided `client`.
36///
37/// Return an error if `strategy` is `Some(strategy)` and the specified strategy is not supported by the `client`.
38/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
39/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
40///
41///
42/// # Example
43///
44/// This examples show how to sum the rows of a small `2 x 2` matrix into a `1 x 2` vector.
45/// For more details, see the CubeCL documentation.
46///
47/// ```ignore
48/// use cubecl_reduce::instructions::Sum;
49///
50/// let client = /* ... */;
51/// let size_f32 = std::mem::size_of::<f32>();
52/// let axis = 0; // 0 for rows, 1 for columns in the case of a matrix.
53///
54/// // Create input and output handles.
55/// let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
56/// let input = unsafe {
57/// TensorHandleRef::from_raw_parts(
58/// &input_handle,
59/// &[2, 1],
60/// &[2, 2],
61/// size_f32,
62/// )
63/// };
64///
65/// let output_handle = client.empty(2 * size_f32);
66/// let output = unsafe {
67/// TensorHandleRef::from_raw_parts(
68/// &output_handle,
69/// &output_stride,
70/// &output_shape,
71/// size_f32,
72/// )
73/// };
74///
75/// // Here `R` is a `cubecl::Runtime`.
76/// let result = reduce::<R, f32, f32, Sum>(&client, input, output, axis, None);
77///
78/// if result.is_ok() {
79/// let binding = output_handle.binding();
80/// let bytes = client.read_one(binding);
81/// let output_values = f32::from_bytes(&bytes);
82/// println!("Output = {:?}", output_values); // Should print [1, 5].
83/// }
84/// ```
85pub fn reduce<R: Runtime>(
86 client: &ComputeClient<R>,
87 input: TensorHandleRef<R>,
88 output: TensorHandleRef<R>,
89 axis: usize,
90 strategy: ReduceStrategy,
91 operation: ReduceOperationConfig,
92 dtypes: ReduceDtypes,
93) -> Result<(), ReduceError> {
94 validate_axis(input.shape.len(), axis)?;
95 valid_output_shape(input.shape, output.shape, axis)?;
96
97 launch_reduce::<R>(client, input, output, axis, strategy, dtypes, operation)
98}
99
100// Check that the given axis is less than the rank of the input.
101fn validate_axis(rank: usize, axis: usize) -> Result<(), ReduceError> {
102 if axis > rank {
103 return Err(ReduceError::InvalidAxis { axis, rank });
104 }
105 Ok(())
106}
107
108// Check that the output shape match the input shape with the given axis set to 1.
109fn valid_output_shape(
110 input_shape: &[usize],
111 output_shape: &[usize],
112 axis: usize,
113) -> Result<(), ReduceError> {
114 let mut expected_shape = input_shape.to_vec();
115 expected_shape[axis] = 1;
116 if output_shape != expected_shape {
117 return Err(ReduceError::MismatchShape {
118 expected_shape,
119 output_shape: output_shape.to_vec(),
120 });
121 }
122 Ok(())
123}