burn_jit/kernel/reduce/
base.rs

1use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
2
3use super::autotune_reduce;
4
5pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum};
6
7/// Reduce all elements of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
8///
9/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
10/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
11/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
12///
13/// If there is no error, the output is a tensor with decreasing strides
14/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
15pub fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
16    mut input: JitTensor<Run>,
17    strategy: ReduceStrategy,
18) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
19    input.shape = input.shape.flatten();
20    input.strides = vec![1];
21    reduce_dim::<Run, In, Out, Rd>(input, 0, strategy)
22}
23
24/// Reduce the given `axis` of the `input` tensor using the instruction `Rd` and the given [Strategy](ReduceStrategy).
25///
26/// Return an error if `strategy` is `Specific(strategy)` and the specified strategy is not supported by the `client`.
27/// Also returns an error if the `axis` is larger than the `input` rank or if the shape of `output` is invalid.
28/// The shape of `output` must be the same as input except with a value of 1 for the given `axis`.
29///
30/// If there is no error, the output is a tensor with decreasing strides
31/// where the shape of reduced dim is set to 1 but all shape are similar to the input.
32pub fn reduce_dim<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
33    input: JitTensor<Run>,
34    dim: usize,
35    strategy: ReduceStrategy,
36) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
37    let client = input.client.clone();
38    let output = init_reduce_output::<Run, In, Out>(&input, dim).ok_or(
39        cubecl::reduce::ReduceError::InvalidAxis {
40            axis: dim,
41            rank: input.shape.num_dims(),
42        },
43    )?;
44    let result = match strategy {
45        ReduceStrategy::Unspecified => cubecl::reduce::reduce::<Run, In, Out, Rd>(
46            &client,
47            input.as_handle_ref(),
48            output.as_handle_ref(),
49            dim,
50            None,
51        ),
52        ReduceStrategy::Specific(strategy) => cubecl::reduce::reduce::<Run, In, Out, Rd>(
53            &client,
54            input.as_handle_ref(),
55            output.as_handle_ref(),
56            dim,
57            Some(strategy),
58        ),
59        #[cfg(feature = "autotune")]
60        ReduceStrategy::Autotune => {
61            autotune_reduce::<Run, In, Out, Rd>(&client, input, output.clone(), dim)
62        }
63    };
64    result.map(|_| output)
65}
66
67/// Creates an empty output tensor with the proper shape and decreasing strides to reduce the given `axis` of `input`
68/// or return `None` if `axis` is out-of-bound.
69pub fn init_reduce_output<Run: JitRuntime, In: JitElement, Out: JitElement>(
70    input: &JitTensor<Run>,
71    dim: usize,
72) -> Option<JitTensor<Run>> {
73    (dim < input.shape.num_dims()).then(|| {
74        let mut shape_out = input.shape.clone();
75        shape_out.dims[dim] = 1;
76        empty_device::<Run, Out>(input.client.clone(), input.device.clone(), shape_out)
77    })
78}
79
80/// Select a strategy to perform a reduction.
81#[derive(Copy, Clone, Debug)]
82pub enum ReduceStrategy {
83    /// Use a best-effort strategy based on the hardware capacity.
84    /// This differs from Autotune as it doesn't try and compare many strategies to select the best.
85    Unspecified,
86    /// Fix the exact strategy for the reduction.
87    Specific(cubecl::reduce::ReduceStrategy),
88    /// Use autotune to find the best strategy given the hardware and the inputs.
89    #[cfg(feature = "autotune")]
90    Autotune,
91}
92
93impl Default for ReduceStrategy {
94    fn default() -> Self {
95        #[cfg(feature = "autotune")]
96        return Self::Autotune;
97
98        #[cfg(not(feature = "autotune"))]
99        return Self::Unspecified;
100    }
101}