Skip to main content

burn_cubecl/kernel/reduce/
base.rs

1#[cfg(feature = "autotune")]
2use super::{autotune_reduce, autotune_sum};
3use crate::{
4    CubeRuntime,
5    ops::numeric::{empty_device_contiguous_dtype, zeros_client},
6    tensor::CubeTensor,
7};
8use burn_backend::{DType, TensorMetadata};
9use burn_std::Metadata;
10use cubecl::{
11    AutotuneKey,
12    client::ComputeClient,
13    features::AtomicUsage,
14    ir::{StorageType, Type},
15};
16use cubek::reduce::{
17    ReduceDtypes, ReduceError, ReduceStrategy,
18    components::instructions::ReduceOperationConfig,
19    launch::{RoutineStrategy, VectorizationStrategy},
20    routines::{BlueprintStrategy, unit::UnitStrategy},
21    shared_sum,
22};
23use serde::{Deserialize, Serialize};
24
25#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
26/// Autotune key representative of sum versions
27pub struct SumAutotuneKey {
28    /// The type of the tensor
29    dtype: burn_backend::DType,
30    /// The anchored length of the tensor
31    #[autotune(anchor)]
32    length: usize,
33}
34
35/// Check if the client supports atomic add for the given element type.
36fn supports_atomic_add<R: CubeRuntime>(client: &ComputeClient<R>, dtype: DType) -> bool {
37    client
38        .properties()
39        .atomic_type_usage(Type::new(StorageType::Atomic(dtype.into())))
40        .contains(AtomicUsage::Add)
41}
42
43/// [Sum](sum) with fallback when `client` doesn't support atomic add for the type `E`.
44pub fn sum_fallback<R: CubeRuntime>(
45    tensor: CubeTensor<R>,
46    mut strategy: SumStrategy,
47) -> Result<CubeTensor<R>, ReduceError> {
48    // Early check before creating output and fallback
49    if matches!(strategy, SumStrategy::OneShot(_))
50        && !supports_atomic_add(&tensor.client, tensor.dtype)
51    {
52        strategy = SumStrategy::Chained(Default::default());
53    }
54    sum(tensor, strategy)
55}
56
57/// Specialize reduce function to compute the sum of all elements of the `input` tensor and return
58/// the value into a single-element tensor of shape `1 x 1 x 1 x ...` with the same rank as `input`.
59///
60/// This is expected to be faster for larger tensors than calling [reduce] with the `Sum` instruction.
61///
62/// Return an error if the `client` doesn't support atomic add for the type `E`.
63pub fn sum<Run: CubeRuntime>(
64    tensor: CubeTensor<Run>,
65    strategy: SumStrategy,
66) -> Result<CubeTensor<Run>, ReduceError> {
67    let client = tensor.client.clone();
68    let device = tensor.device.clone();
69
70    match strategy {
71        SumStrategy::OneShot(cube_count) => {
72            let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype);
73            let dtype = tensor.dtype;
74
75            shared_sum::<Run>(
76                &client,
77                tensor.binding(),
78                output.clone().binding(),
79                cube_count,
80                dtype.into(),
81            )?;
82
83            Ok(output)
84        }
85        SumStrategy::Chained(strategy) => {
86            reduce::<Run>(tensor, None, strategy, ReduceOperationConfig::Sum)
87        }
88        #[cfg(feature = "autotune")]
89        SumStrategy::Autotune => Ok(autotune_sum::<Run>(&client, tensor)),
90    }
91}
92
93/// Select a strategy to perform a sum.
94pub enum SumStrategy {
95    /// Run a single kernel with many cubes working in parallel to sum all elements.
96    /// The provided value is the number of elements summed per unit (up-to-rounding )
97    OneShot(u32),
98    /// Use multiple kernels
99    Chained(KernelReduceStrategy),
100    /// Use autotune to find the best cube count given the hardware and the input.
101    #[cfg(feature = "autotune")]
102    Autotune,
103}
104
105impl Default for SumStrategy {
106    fn default() -> Self {
107        #[cfg(feature = "autotune")]
108        return Self::Autotune;
109
110        #[cfg(not(feature = "autotune"))]
111        return Self::OneShot(4);
112    }
113}
114
115/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
116///
117/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
118///
119/// If there is no error, the output is a tensor with decreasing strides
120/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
121pub fn reduce<Run: CubeRuntime>(
122    mut tensor: CubeTensor<Run>,
123    output_dtype: Option<DType>,
124    strategy: KernelReduceStrategy,
125    config: ReduceOperationConfig,
126) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
127    // In practice, it looks like starting by the axis with the smallest shape
128    // and going in increasing order lead to the fastest calculation.
129    let sorted_axis = argsort(tensor.meta.shape());
130    for axis in sorted_axis {
131        tensor = reduce_dim::<Run>(tensor, output_dtype, axis, strategy.clone(), config)?;
132    }
133    // reshape to scalar tensor
134    *tensor.meta = Metadata::new([1], [1]);
135    Ok(tensor)
136}
137
138fn argsort(shape: &[usize]) -> Vec<usize> {
139    let mut indices = (0..shape.len()).collect::<Vec<_>>();
140    indices.sort_by_key(|&i| &shape[i]);
141    indices
142}
143
144/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
145///
146/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
147/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
148///
149/// If there is no error, the output is a tensor with decreasing strides
150/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
151pub fn reduce_dim<Run: CubeRuntime>(
152    input: CubeTensor<Run>,
153    output_dtype: Option<DType>,
154    dim: usize,
155    strategy: KernelReduceStrategy,
156    config: ReduceOperationConfig,
157) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
158    debug_assert!(
159        !matches!(
160            config,
161            ReduceOperationConfig::ArgMax
162                | ReduceOperationConfig::ArgMin
163                | ReduceOperationConfig::ArgTopK(_)
164        ) || output_dtype.is_some(),
165        "The `output_dtype` has to be `Some` only when the `config` is `ArgMax`, `ArgMin` or `ArgTopK`.
166        "
167    );
168
169    let accumulator_len = match config {
170        ReduceOperationConfig::ArgTopK(k) => k,
171        ReduceOperationConfig::TopK(k) => k,
172        _ => 1,
173    };
174    let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into));
175    let client = input.client.clone();
176    let output = init_reduce_output::<Run>(&input, dim, &dtypes, accumulator_len).ok_or(
177        cubek::reduce::ReduceError::InvalidAxis {
178            axis: dim,
179            rank: input.meta.num_dims(),
180        },
181    )?;
182
183    let result = match strategy {
184        KernelReduceStrategy::Unspecified => cubek::reduce::reduce::<Run>(
185            &client,
186            input.binding(),
187            output.clone().binding(),
188            dim,
189            ReduceStrategy {
190                routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
191                vectorization: VectorizationStrategy {
192                    parallel_output_vectorization: false,
193                },
194            },
195            config,
196            dtypes,
197        ),
198        KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::<Run>(
199            &client,
200            input.binding(),
201            output.clone().binding(),
202            dim,
203            strategy,
204            config,
205            dtypes,
206        ),
207        #[cfg(feature = "autotune")]
208        KernelReduceStrategy::Autotune => {
209            autotune_reduce::<Run>(&client, input, output.clone(), dim, config, dtypes);
210            Ok(())
211        }
212    };
213    result.map(|_| output)
214}
215
216/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input`
217/// or return `None` if `axis` is out-of-bound.
218pub fn init_reduce_output<Run: CubeRuntime>(
219    input: &CubeTensor<Run>,
220    dim: usize,
221    dtypes: &ReduceDtypes,
222    accumulator_len: usize,
223) -> Option<CubeTensor<Run>> {
224    (dim < input.meta.num_dims()).then(|| {
225        let mut shape_out = input.shape();
226        shape_out[dim] = accumulator_len;
227        empty_device_contiguous_dtype(
228            input.client.clone(),
229            input.device.clone(),
230            shape_out,
231            dtypes.output.elem_type().into(),
232        )
233    })
234}
235
236/// Select a strategy to perform a reduction.
237#[derive(Clone, Debug)]
238pub enum KernelReduceStrategy {
239    /// Use a best-effort strategy based on the hardware capacity.
240    /// This differs from Autotune as it doesn't try and compare many strategies to select the best.
241    Unspecified,
242    /// Fix the exact strategy for the reduction.
243    Specific(cubek::reduce::launch::ReduceStrategy),
244    /// Use autotune to find the best strategy given the hardware and the inputs.
245    #[cfg(feature = "autotune")]
246    Autotune,
247}
248
249impl Default for KernelReduceStrategy {
250    fn default() -> Self {
251        #[cfg(feature = "autotune")]
252        return Self::Autotune;
253
254        #[cfg(not(feature = "autotune"))]
255        return Self::Unspecified;
256    }
257}