cubecl_reduce/lib.rs
1//! This provides different implementations of the reduce algorithm which
2//! can run on multiple GPU backends using CubeCL.
3//!
4//! A reduction is a tensor operation mapping a rank `R` tensor to a rank `R - 1`
5//! by agglomerating all elements along a given axis with some binary operator.
6//! This is often also called folding.
7//!
8//! This crate provides a main entrypoint as the [`reduce`] function which allows to automatically
9//! perform a reduction for a given instruction implementing the [`ReduceInstruction`] trait and a given [`StrategyStrategy`].
10//! It also provides implementation of the [`ReduceInstruction`] trait for common operations in the [`instructions`] module.
11//! Finally, it provides many reusable primitives to perform different general reduction algorithms in the [`primitives`] module.
12
13pub mod instructions;
14pub mod primitives;
15
16mod config;
17mod error;
18mod launch;
19mod strategy;
20
21pub use config::*;
22pub use error::*;
23pub use instructions::Reduce;
24pub use instructions::ReduceInstruction;
25pub use strategy::*;
26
27use launch::*;
28
29#[cfg(feature = "export_tests")]
30pub mod test;
31
32use cubecl_core::prelude::*;
33
34/// Reduce the given `axis` of the `input` tensor using the instruction `Inst` and write the result into `output`.
35///
36/// An optional [`Strategy`] can be provided to force the reduction to use a specific algorithm. If omitted, a best effort
37/// is done to try and pick the best strategy supported for the provided `client`.
38///
39/// Return an error if `strategy` is `Some(strategy)` and the specified strategy is not supported by the `client`.
40/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
41/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
42///
43///
44/// # Example
45///
46/// This examples show how to sum the rows of a small `2 x 2` matrix into a `1 x 2` vector.
47/// For more details, see the CubeCL documentation.
48///
49/// ```ignore
50/// use cubecl_reduce::instructions::Sum;
51///
52/// let client = /* ... */;
53/// let size_f32 = std::mem::size_of::<f32>();
54/// let axis = 0; // 0 for rows, 1 for columns in the case of a matrix.
55///
56/// // Create input and output handles.
57/// let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
58/// let input = unsafe {
59/// TensorHandleRef::<R>::from_raw_parts(
60/// &input_handle,
61/// &[2, 1],
62/// &[2, 2],
63/// size_f32,
64/// )
65/// };
66///
67/// let output_handle = client.empty(2 * size_f32);
68/// let output = unsafe {
69/// TensorHandleRef::<R>::from_raw_parts(
70/// &output_handle,
71/// &output_stride,
72/// &output_shape,
73/// size_f32,
74/// )
75/// };
76///
77/// // Here `R` is a `cubecl::Runtime`.
78/// let result = reduce::<R, f32, f32, Sum>(&client, input, output, axis, None);
79///
80/// if result.is_ok() {
81/// let binding = output_handle.binding();
82/// let bytes = client.read_one(binding);
83/// let output_values = f32::from_bytes(&bytes);
84/// println!("Output = {:?}", output_values); // Should print [1, 5].
85/// }
86/// ```
87pub fn reduce<R: Runtime, In: Numeric, Out: Numeric, Inst: Reduce>(
88 client: &ComputeClient<R::Server, R::Channel>,
89 input: TensorHandleRef<R>,
90 output: TensorHandleRef<R>,
91 axis: usize,
92 strategy: Option<ReduceStrategy>,
93) -> Result<(), ReduceError> {
94 validate_axis(input.shape.len(), axis)?;
95 valide_output_shape(input.shape, output.shape, axis)?;
96 let strategy = strategy
97 .map(|s| s.validate::<R>(client))
98 .unwrap_or(Ok(ReduceStrategy::fallback_strategy::<R>(client)))?;
99 let config = ReduceConfig::generate::<R, In>(client, &input, &output, axis, &strategy);
100
101 if let CubeCount::Static(x, y, z) = config.cube_count {
102 let (max_x, max_y, max_z) = R::max_cube_count();
103 if x > max_x || y > max_y || z > max_z {
104 return Err(ReduceError::CubeCountTooLarge);
105 }
106 }
107
108 launch_reduce::<R, In, Out, Inst>(client, input, output, axis as u32, config, strategy);
109 Ok(())
110}
111
112// Check that the given axis is less than the rank of the input.
113fn validate_axis(rank: usize, axis: usize) -> Result<(), ReduceError> {
114 if axis > rank {
115 return Err(ReduceError::InvalidAxis { axis, rank });
116 }
117 Ok(())
118}
119
120// Check that the output shape match the input shape with the given axis set to 1.
121fn valide_output_shape(
122 input_shape: &[usize],
123 output_shape: &[usize],
124 axis: usize,
125) -> Result<(), ReduceError> {
126 let mut expected_shape = input_shape.to_vec();
127 expected_shape[axis] = 1;
128 if output_shape != expected_shape {
129 return Err(ReduceError::MismatchShape {
130 expected_shape,
131 output_shape: output_shape.to_vec(),
132 });
133 }
134 Ok(())
135}