cubecl_reduce/
lib.rs

1//! This provides different implementations of the reduce algorithm which
2//! can run on multiple GPU backends using CubeCL.
3//!
4//! A reduction is a tensor operation mapping a rank `R` tensor to a rank `R - 1`
5//! by agglomerating all elements along a given axis with some binary operator.
6//! This is often also called folding.
7//!
8//! This crate provides a main entrypoint as the [`reduce`] function which allows to automatically
9//! perform a reduction for a given instruction implementing the [`ReduceInstruction`] trait and a given [`ReduceStrategy`].
10//! It also provides implementation of the [`ReduceInstruction`] trait for common operations in the [`instructions`] module.
11//! Finally, it provides many reusable primitives to perform different general reduction algorithms in the [`primitives`] module.
12
13pub mod args;
14pub mod instructions;
15pub mod primitives;
16pub mod tune_key;
17
18mod config;
19mod error;
20mod launch;
21mod precision;
22mod shared_sum;
23mod strategy;
24
25pub use config::*;
26pub use error::*;
27pub use instructions::ReduceFamily;
28pub use instructions::ReduceInstruction;
29pub use precision::ReducePrecision;
30pub use shared_sum::*;
31pub use strategy::*;
32
33use launch::*;
34
35pub use args::init_tensors;
36pub use launch::{ReduceParams, reduce_kernel, reduce_kernel_virtual};
37
38#[cfg(feature = "export_tests")]
39pub mod test;
40
41use cubecl_core::prelude::*;
42
43/// Reduce the given `axis` of the `input` tensor using the instruction `Inst` and write the result into `output`.
44///
45/// An optional [`ReduceStrategy`] can be provided to force the reduction to use a specific algorithm. If omitted, a best effort
46/// is done to try and pick the best strategy supported for the provided `client`.
47///
48/// Return an error if `strategy` is `Some(strategy)` and the specified strategy is not supported by the `client`.
49/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
50/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
51///
52///
53/// # Example
54///
55/// This examples show how to sum the rows of a small `2 x 2` matrix into a `1 x 2` vector.
56/// For more details, see the CubeCL documentation.
57///
58/// ```ignore
59/// use cubecl_reduce::instructions::Sum;
60///
61/// let client = /* ... */;
62/// let size_f32 = std::mem::size_of::<f32>();
63/// let axis = 0; // 0 for rows, 1 for columns in the case of a matrix.
64///
65/// // Create input and output handles.
66/// let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
67/// let input = unsafe {
68///     TensorHandleRef::<R>::from_raw_parts(
69///         &input_handle,
70///         &[2, 1],
71///         &[2, 2],
72///         size_f32,
73///     )
74/// };
75///
76/// let output_handle = client.empty(2 * size_f32);
77/// let output = unsafe {
78///     TensorHandleRef::<R>::from_raw_parts(
79///         &output_handle,
80///         &output_stride,
81///         &output_shape,
82///         size_f32,
83///     )
84/// };
85///
86/// // Here `R` is a `cubecl::Runtime`.
87/// let result = reduce::<R, f32, f32, Sum>(&client, input, output, axis, None);
88///
89/// if result.is_ok() {
90///        let binding = output_handle.binding();
91///        let bytes = client.read_one(binding);
92///        let output_values = f32::from_bytes(&bytes);
93///        println!("Output = {:?}", output_values); // Should print [1, 5].
94/// }
95/// ```
96pub fn reduce<R: Runtime, P: ReducePrecision, Out: Numeric, Inst: ReduceFamily>(
97    client: &ComputeClient<R::Server, R::Channel>,
98    input: TensorHandleRef<R>,
99    output: TensorHandleRef<R>,
100    axis: usize,
101    strategy: Option<ReduceStrategy>,
102    inst_config: Inst::Config,
103) -> Result<(), ReduceError> {
104    validate_axis(input.shape.len(), axis)?;
105    valid_output_shape(input.shape, output.shape, axis)?;
106    let strategy = strategy
107        .map(|s| s.validate::<R>(client))
108        .unwrap_or(Ok(ReduceStrategy::new::<R>(client, true)))?;
109    let config = ReduceConfig::generate::<R, P::EI>(client, &input, &output, axis, &strategy);
110
111    if let CubeCount::Static(x, y, z) = config.cube_count {
112        let (max_x, max_y, max_z) = R::max_cube_count();
113        if x > max_x || y > max_y || z > max_z {
114            return Err(ReduceError::CubeCountTooLarge);
115        }
116    }
117
118    launch_reduce::<R, P, Out, Inst>(
119        client,
120        input,
121        output,
122        axis as u32,
123        config,
124        strategy,
125        inst_config,
126    );
127    Ok(())
128}
129
130// Check that the given axis is less than the rank of the input.
131fn validate_axis(rank: usize, axis: usize) -> Result<(), ReduceError> {
132    if axis > rank {
133        return Err(ReduceError::InvalidAxis { axis, rank });
134    }
135    Ok(())
136}
137
138// Check that the output shape match the input shape with the given axis set to 1.
139fn valid_output_shape(
140    input_shape: &[usize],
141    output_shape: &[usize],
142    axis: usize,
143) -> Result<(), ReduceError> {
144    let mut expected_shape = input_shape.to_vec();
145    expected_shape[axis] = 1;
146    if output_shape != expected_shape {
147        return Err(ReduceError::MismatchShape {
148            expected_shape,
149            output_shape: output_shape.to_vec(),
150        });
151    }
152    Ok(())
153}