cubecl_reduce/lib.rs
1//! This provides different implementations of the reduce algorithm which
2//! can run on multiple GPU backends using CubeCL.
3//!
4//! A reduction is a tensor operation mapping a rank `R` tensor to a rank `R - 1`
5//! by agglomerating all elements along a given axis with some binary operator.
6//! This is often also called folding.
7//!
8//! This crate provides a main entrypoint as the [`reduce`] function which allows to automatically
9//! perform a reduction for a given instruction implementing the [`ReduceInstruction`] trait and a given [`ReduceStrategy`].
10//! It also provides implementation of the [`ReduceInstruction`] trait for common operations in the [`instructions`] module.
11//! Finally, it provides many reusable primitives to perform different general reduction algorithms in the [`primitives`] module.
12
13pub mod args;
14pub mod instructions;
15pub mod primitives;
16pub mod tune_key;
17
18mod config;
19mod error;
20mod launch;
21mod precision;
22mod shared_sum;
23mod strategy;
24
25pub use config::*;
26pub use error::*;
27pub use instructions::ReduceFamily;
28pub use instructions::ReduceInstruction;
29pub use precision::ReducePrecision;
30pub use shared_sum::*;
31pub use strategy::*;
32
33use launch::*;
34
35pub use args::init_tensors;
36pub use launch::{ReduceDtypes, ReduceParams, reduce_kernel, reduce_kernel_virtual};
37
38#[cfg(feature = "export_tests")]
39pub mod test;
40
41#[cfg(feature = "export_tests")]
42pub mod test_shuffle;
43
44use cubecl_core::prelude::*;
45
46/// Reduce the given `axis` of the `input` tensor using the instruction `Inst` and write the result into `output`.
47///
48/// An optional [`ReduceStrategy`] can be provided to force the reduction to use a specific algorithm. If omitted, a best effort
49/// is done to try and pick the best strategy supported for the provided `client`.
50///
51/// Return an error if `strategy` is `Some(strategy)` and the specified strategy is not supported by the `client`.
52/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
53/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
54///
55///
56/// # Example
57///
58/// This examples show how to sum the rows of a small `2 x 2` matrix into a `1 x 2` vector.
59/// For more details, see the CubeCL documentation.
60///
61/// ```ignore
62/// use cubecl_reduce::instructions::Sum;
63///
64/// let client = /* ... */;
65/// let size_f32 = std::mem::size_of::<f32>();
66/// let axis = 0; // 0 for rows, 1 for columns in the case of a matrix.
67///
68/// // Create input and output handles.
69/// let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
70/// let input = unsafe {
71/// TensorHandleRef::from_raw_parts(
72/// &input_handle,
73/// &[2, 1],
74/// &[2, 2],
75/// size_f32,
76/// )
77/// };
78///
79/// let output_handle = client.empty(2 * size_f32);
80/// let output = unsafe {
81/// TensorHandleRef::from_raw_parts(
82/// &output_handle,
83/// &output_stride,
84/// &output_shape,
85/// size_f32,
86/// )
87/// };
88///
89/// // Here `R` is a `cubecl::Runtime`.
90/// let result = reduce::<R, f32, f32, Sum>(&client, input, output, axis, None);
91///
92/// if result.is_ok() {
93/// let binding = output_handle.binding();
94/// let bytes = client.read_one(binding);
95/// let output_values = f32::from_bytes(&bytes);
96/// println!("Output = {:?}", output_values); // Should print [1, 5].
97/// }
98/// ```
99pub fn reduce<R: Runtime, Inst: ReduceFamily>(
100 client: &ComputeClient<R>,
101 input: TensorHandleRef<R>,
102 output: TensorHandleRef<R>,
103 axis: usize,
104 strategy: Option<ReduceStrategy>,
105 inst_config: Inst::Config,
106 dtypes: ReduceDtypes,
107) -> Result<(), ReduceError> {
108 validate_axis(input.shape.len(), axis)?;
109 valid_output_shape(input.shape, output.shape, axis)?;
110 let strategy = strategy
111 .map(|s| s.validate(client))
112 .unwrap_or(Ok(ReduceStrategy::new(client, true)))?;
113 let config = ReduceConfig::generate(client, &input, &output, axis, &strategy, dtypes.input);
114
115 if let CubeCount::Static(x, y, z) = config.cube_count {
116 let (max_x, max_y, max_z) = R::max_cube_count();
117 if x > max_x || y > max_y || z > max_z {
118 return Err(ReduceError::CubeCountTooLarge);
119 }
120 }
121
122 let result = launch_reduce::<R, Inst>(
123 client,
124 input,
125 output,
126 axis as u32,
127 config,
128 strategy,
129 dtypes,
130 inst_config,
131 );
132
133 match result {
134 Ok(_) => Ok(()),
135 Err(err) => Err(ReduceError::Launch(err)),
136 }
137}
138
139// Check that the given axis is less than the rank of the input.
140fn validate_axis(rank: usize, axis: usize) -> Result<(), ReduceError> {
141 if axis > rank {
142 return Err(ReduceError::InvalidAxis { axis, rank });
143 }
144 Ok(())
145}
146
147// Check that the output shape match the input shape with the given axis set to 1.
148fn valid_output_shape(
149 input_shape: &[usize],
150 output_shape: &[usize],
151 axis: usize,
152) -> Result<(), ReduceError> {
153 let mut expected_shape = input_shape.to_vec();
154 expected_shape[axis] = 1;
155 if output_shape != expected_shape {
156 return Err(ReduceError::MismatchShape {
157 expected_shape,
158 output_shape: output_shape.to_vec(),
159 });
160 }
161 Ok(())
162}