burn_jit/kernel/reduce/
tune.rs

1#![allow(missing_docs)]
2
3use burn_tensor::ElementConversion;
4use cubecl::{
5    client::ComputeClient,
6    tune,
7    tune::{local_tuner, tune_with, LocalTuner},
8    AutotuneKey,
9};
10use serde::{Deserialize, Serialize};
11
12use crate::{
13    kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor,
14    JitAutotuneKey, JitElement, JitRuntime, JitTuneId,
15};
16
17/// Executes autotune on reduce operations.
18pub fn autotune_reduce<
19    Run: JitRuntime,
20    In: JitElement,
21    Out: JitElement,
22    Rd: cubecl::reduce::Reduce,
23>(
24    client: &ComputeClient<Run::Server, Run::Channel>,
25    input: JitTensor<Run>,
26    output: JitTensor<Run>,
27    dim: usize,
28) -> Result<(), cubecl::reduce::ReduceError> {
29    static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();
30
31    TUNER.execute(
32        &JitTuneId::new::<Run>(&input.device),
33        client,
34        Box::new(ReduceOps::<Run, In, Out, Rd>::new(input, output, dim)),
35    );
36
37    Ok(())
38}
39
40#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
41/// Autotune key representative of redue versions
42pub struct ReduceAutotuneKey {
43    dtype: burn_tensor::DType,
44    #[autotune(anchor)]
45    reduce_axis_shape: usize,
46    #[autotune(anchor)]
47    reduce_axis_stride: usize,
48    #[autotune(anchor)]
49    outer_axes_product: usize, // The product of the shapes of all axes with greater strides.
50}
51
52impl ReduceAutotuneKey {
53    pub(crate) fn generate<Run: JitRuntime>(input: &JitTensor<Run>, axis: usize) -> Self {
54        let rank = input.shape.num_dims();
55
56        if axis > rank {
57            panic!("axis {axis} is out-of-bound for a rank of {rank}");
58        }
59
60        let dtype = input.dtype;
61        let reduce_axis_shape = input.shape.dims[axis];
62        let reduce_axis_stride = input.strides[axis];
63
64        let outer_axes_product = input
65            .strides
66            .iter()
67            .zip(input.shape.dims.iter())
68            .filter_map(|(stride, shape)| (*stride > reduce_axis_stride).then_some(shape))
69            .product();
70
71        Self::new(
72            dtype,
73            reduce_axis_shape,
74            reduce_axis_stride,
75            outer_axes_product,
76        )
77    }
78}
79
80pub(crate) fn create_key<Run: JitRuntime>(
81    input: &JitTensor<Run>,
82    _output: &JitTensor<Run>,
83    dim: &usize,
84) -> JitAutotuneKey {
85    JitAutotuneKey::Reduce(ReduceAutotuneKey::generate(input, *dim))
86}
87
88pub use reduce_ops::*;
89mod reduce_ops {
90    #![allow(missing_docs)]
91
92    use super::*;
93
94    #[tune(
95    operations(reduce, reduce_shared, reduce_plane, reduce_shared_plane),
96    create_key = create_key::<Run>,
97    should_run = should_run
98)]
99    fn reduce_ops<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
100        key: JitAutotuneKey,
101        input: JitTensor<Run>,
102        output: JitTensor<Run>,
103        dim: usize,
104    ) {
105        let random_bounds: (In, In) = ((-10.0_f32).elem::<In>(), (10.0_f32).elem::<In>());
106        let input = random_like_uniform(input, random_bounds.0, random_bounds.1);
107
108        let output = empty_device::<Run, Out>(
109            output.client.clone(),
110            output.device.clone(),
111            output.shape.clone(),
112        );
113
114        tune_with!(input, output, dim)
115    }
116
117    fn should_run<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
118        op: &ReduceOps<Run, In, Out, Rd>,
119        _key: &JitAutotuneKey,
120        index: usize,
121    ) -> bool {
122        match index {
123            // if strategy uses planes
124            2 | 3 => {
125                let properties = op.input.client.properties();
126                properties.feature_enabled(cubecl::Feature::Plane)
127                    && properties
128                        .hardware_properties()
129                        .defined_plane_size()
130                        .is_some()
131            }
132            _ => true,
133        }
134    }
135
136    fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
137        input: JitTensor<Run>,
138        output: JitTensor<Run>,
139        axis: usize,
140    ) -> Result<(), String> {
141        cubecl::reduce::reduce::<Run, In, Out, Rd>(
142            &input.client,
143            input.as_handle_ref(),
144            output.as_handle_ref(),
145            axis,
146            Some(cubecl::reduce::ReduceStrategy {
147                shared: false,
148                use_planes: false,
149            }),
150        )
151        .map_err(|e| format!("{e}"))
152    }
153
154    fn reduce_shared<
155        Run: JitRuntime,
156        In: JitElement,
157        Out: JitElement,
158        Rd: cubecl::reduce::Reduce,
159    >(
160        input: JitTensor<Run>,
161        output: JitTensor<Run>,
162        axis: usize,
163    ) -> Result<(), String> {
164        cubecl::reduce::reduce::<Run, In, Out, Rd>(
165            &input.client,
166            input.as_handle_ref(),
167            output.as_handle_ref(),
168            axis,
169            Some(cubecl::reduce::ReduceStrategy {
170                shared: true,
171                use_planes: false,
172            }),
173        )
174        .map_err(|e| format!("{e}"))
175    }
176
177    fn reduce_plane<
178        Run: JitRuntime,
179        In: JitElement,
180        Out: JitElement,
181        Rd: cubecl::reduce::Reduce,
182    >(
183        input: JitTensor<Run>,
184        output: JitTensor<Run>,
185        axis: usize,
186    ) -> Result<(), String> {
187        cubecl::reduce::reduce::<Run, In, Out, Rd>(
188            &input.client,
189            input.as_handle_ref(),
190            output.as_handle_ref(),
191            axis,
192            Some(cubecl::reduce::ReduceStrategy {
193                shared: false,
194                use_planes: true,
195            }),
196        )
197        .map_err(|e| format!("{e}"))
198    }
199
200    fn reduce_shared_plane<
201        Run: JitRuntime,
202        In: JitElement,
203        Out: JitElement,
204        Rd: cubecl::reduce::Reduce,
205    >(
206        input: JitTensor<Run>,
207        output: JitTensor<Run>,
208        axis: usize,
209    ) -> Result<(), String> {
210        cubecl::reduce::reduce::<Run, In, Out, Rd>(
211            &input.client,
212            input.as_handle_ref(),
213            output.as_handle_ref(),
214            axis,
215            Some(cubecl::reduce::ReduceStrategy {
216                shared: true,
217                use_planes: true,
218            }),
219        )
220        .map_err(|e| format!("{e}"))
221    }
222}