burn_cubecl/kernel/reduce/
base.rs

1#[cfg(feature = "autotune")]
2use super::{autotune_reduce, autotune_sum};
3use crate::{
4    CubeRuntime,
5    ops::numeric::{empty_device_dtype, zeros_client},
6    tensor::CubeTensor,
7};
8use burn_backend::{DType, Shape};
9use cubecl::{AutotuneKey, client::ComputeClient, features::TypeUsage, ir::StorageType};
10use cubek::reduce::{
11    ReduceDtypes, ReduceError, ReduceStrategy,
12    components::instructions::ReduceOperationConfig,
13    launch::{LineSizeStrategy, RoutineStrategy},
14    routines::{BlueprintStrategy, unit::UnitStrategy},
15    shared_sum,
16};
17use serde::{Deserialize, Serialize};
18
19#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
20/// Autotune key representative of sum versions
21pub struct SumAutotuneKey {
22    /// The type of the tensor
23    pub dtype: burn_backend::DType,
24    /// The anchored length of the tensor
25    #[autotune(anchor)]
26    pub length: usize,
27}
28
29/// Check if the client supports atomic add for the given element type.
30fn supports_atomic_add<R: CubeRuntime>(client: &ComputeClient<R>, dtype: DType) -> bool {
31    client
32        .properties()
33        .type_usage(StorageType::Atomic(dtype.into()))
34        .contains(TypeUsage::AtomicAdd)
35}
36
37/// [Sum](sum) with fallback when `client` doesn't support atomic add for the type `E`.
38pub fn sum_fallback<R: CubeRuntime>(
39    tensor: CubeTensor<R>,
40    mut strategy: SumStrategy,
41) -> Result<CubeTensor<R>, ReduceError> {
42    // Early check before creating output and fallback
43    if matches!(strategy, SumStrategy::OneShot(_))
44        && !supports_atomic_add(&tensor.client, tensor.dtype)
45    {
46        strategy = SumStrategy::Chained(Default::default());
47    }
48    sum(tensor, strategy)
49}
50
51/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return
52/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`.
53///
54/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction.
55///
56/// Return an error if the `client` doesn't support atomic add for the type `E`.
57pub fn sum<Run: CubeRuntime>(
58    tensor: CubeTensor<Run>,
59    strategy: SumStrategy,
60) -> Result<CubeTensor<Run>, ReduceError> {
61    let client = tensor.client.clone();
62    let device = tensor.device.clone();
63
64    match strategy {
65        SumStrategy::OneShot(cube_count) => {
66            let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype);
67            shared_sum::<Run>(
68                &client,
69                tensor.as_handle_ref(),
70                output.as_handle_ref(),
71                cube_count,
72                tensor.dtype.into(),
73            )?;
74
75            Ok(output)
76        }
77        SumStrategy::Chained(strategy) => {
78            reduce::<Run>(tensor, None, strategy, ReduceOperationConfig::Sum)
79        }
80        #[cfg(feature = "autotune")]
81        SumStrategy::Autotune => Ok(autotune_sum::<Run>(&client, tensor)),
82    }
83}
84
85/// Select a strategy to perform a sum.
86pub enum SumStrategy {
87    /// Run a single kernel with many cubes working in parallel to sum all elements.
88    /// The provided value is the number of elements summed per unit (up-to-rounding )
89    OneShot(u32),
90    /// Use multiple kernels
91    Chained(KernelReduceStrategy),
92    /// Use autotune to find the best cube count given the hardware and the input.
93    #[cfg(feature = "autotune")]
94    Autotune,
95}
96
97impl Default for SumStrategy {
98    fn default() -> Self {
99        #[cfg(feature = "autotune")]
100        return Self::Autotune;
101
102        #[cfg(not(feature = "autotune"))]
103        return Self::OneShot(4);
104    }
105}
106
107/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
108///
109/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
110///
111/// If there is no error, the output is a tensor with decreasing strides
112/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
113pub fn reduce<Run: CubeRuntime>(
114    mut tensor: CubeTensor<Run>,
115    output_dtype: Option<DType>,
116    strategy: KernelReduceStrategy,
117    config: ReduceOperationConfig,
118) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
119    // In practice, it looks like starting by the axis with the smallest shape
120    // and going in increasing order lead to the fastest calculation.
121    let sorted_axis = argsort(&tensor.shape);
122    for axis in sorted_axis {
123        tensor = reduce_dim::<Run>(tensor, output_dtype, axis, strategy.clone(), config)?;
124    }
125    // reshape to scalar tensor
126    tensor.shape = Shape::new([1]);
127    tensor.strides = vec![1];
128    Ok(tensor)
129}
130
131fn argsort(shape: &[usize]) -> Vec<usize> {
132    let mut indices = (0..shape.len()).collect::<Vec<_>>();
133    indices.sort_by_key(|&i| &shape[i]);
134    indices
135}
136
137/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
138///
139/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
140/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
141///
142/// If there is no error, the output is a tensor with decreasing strides
143/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
144pub fn reduce_dim<Run: CubeRuntime>(
145    input: CubeTensor<Run>,
146    output_dtype: Option<DType>,
147    dim: usize,
148    strategy: KernelReduceStrategy,
149    config: ReduceOperationConfig,
150) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
151    debug_assert!(
152        !matches!(
153            config,
154            ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin
155        ) || output_dtype.is_some(),
156        "The `output_dtype` has to be `Some` only when the `config` is `ArgMax` or `ArgMin`.
157        "
158    );
159
160    let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into));
161    let client = input.client.clone();
162    let output = init_reduce_output::<Run>(&input, dim, &dtypes).ok_or(
163        cubek::reduce::ReduceError::InvalidAxis {
164            axis: dim,
165            rank: input.shape.num_dims(),
166        },
167    )?;
168
169    let result = match strategy {
170        KernelReduceStrategy::Unspecified => cubek::reduce::reduce::<Run>(
171            &client,
172            input.as_handle_ref(),
173            output.as_handle_ref(),
174            dim,
175            ReduceStrategy {
176                routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
177                line_size: LineSizeStrategy {
178                    parallel_output_vectorization: false,
179                },
180            },
181            config,
182            dtypes,
183        ),
184        KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::<Run>(
185            &client,
186            input.as_handle_ref(),
187            output.as_handle_ref(),
188            dim,
189            strategy,
190            config,
191            dtypes,
192        ),
193        #[cfg(feature = "autotune")]
194        KernelReduceStrategy::Autotune => {
195            autotune_reduce::<Run>(&client, input, output.clone(), dim, config, dtypes);
196            Ok(())
197        }
198    };
199    result.map(|_| output)
200}
201
202/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input`
203/// or return `None` if `axis` is out-of-bound.
204pub fn init_reduce_output<Run: CubeRuntime>(
205    input: &CubeTensor<Run>,
206    dim: usize,
207    dtypes: &ReduceDtypes,
208) -> Option<CubeTensor<Run>> {
209    (dim < input.shape.num_dims()).then(|| {
210        let mut shape_out = input.shape.clone();
211        shape_out.dims[dim] = 1;
212        empty_device_dtype::<Run>(
213            input.client.clone(),
214            input.device.clone(),
215            shape_out,
216            dtypes.output.elem_type().into(),
217        )
218    })
219}
220
221/// Select a strategy to perform a reduction.
222#[derive(Clone, Debug)]
223pub enum KernelReduceStrategy {
224    /// Use a best-effort strategy based on the hardware capacity.
225    /// This differs from Autotune as it doesn't try and compare many strategies to select the best.
226    Unspecified,
227    /// Fix the exact strategy for the reduction.
228    Specific(cubek::reduce::launch::ReduceStrategy),
229    /// Use autotune to find the best strategy given the hardware and the inputs.
230    #[cfg(feature = "autotune")]
231    Autotune,
232}
233
234impl Default for KernelReduceStrategy {
235    fn default() -> Self {
236        #[cfg(feature = "autotune")]
237        return Self::Autotune;
238
239        #[cfg(not(feature = "autotune"))]
240        return Self::Unspecified;
241    }
242}