burn_cubecl/kernel/reduce/
tune.rs

1#![allow(missing_docs)]
2
3use burn_tensor::ElementConversion;
4use cubecl::{
5    AutotuneKey,
6    client::ComputeClient,
7    reduce::{ReduceFamily, tune_key::ReduceAutotuneKey},
8    tune::{LocalTuner, TunableSet, local_tuner},
9};
10use serde::{Deserialize, Serialize};
11
12use crate::{
13    CubeAutotuneKey, CubeElement, CubeRuntime, CubeTuneId, kernel::prng::random_like_uniform,
14    ops::numeric::empty_device, tensor::CubeTensor,
15};
16
17/// Executes autotune on reduce operations.
18pub fn autotune_reduce<
19    Run: CubeRuntime,
20    In: CubeElement,
21    Out: CubeElement,
22    Rd: cubecl::reduce::ReduceFamily,
23>(
24    client: &ComputeClient<Run::Server, Run::Channel>,
25    input: CubeTensor<Run>,
26    output: CubeTensor<Run>,
27    dim: usize,
28    config: Rd::Config,
29) {
30    use reduce_ops::*;
31
32    static TUNER: LocalTuner<ReduceAutotuneKey, CubeTuneId> = local_tuner!("reduce-dim");
33
34    let tunables = TunableSet::new(create_key::<Run, Rd>, reduce_input_gen::<Run, In, Out, Rd>)
35        .with_tunable(reduce::<Run, In, Out, Rd>)
36        .with_tunable(reduce_shared::<Run, In, Out, Rd>)
37        .with_tunable(reduce_plane::<Run, In, Out, Rd>)
38        .with_tunable(reduce_shared_plane::<Run, In, Out, Rd>);
39
40    TUNER.execute(
41        &CubeTuneId::new::<Run>(&input.client, &input.device),
42        client,
43        &tunables,
44        (input, output, dim, config),
45    );
46}
47
48pub(crate) fn create_key<Run: CubeRuntime, Rd: ReduceFamily>(
49    input: &CubeTensor<Run>,
50    output: &CubeTensor<Run>,
51    axis: &usize,
52    _config: &Rd::Config,
53) -> ReduceAutotuneKey {
54    let elem_input = input.dtype.into();
55    let elem_output = output.dtype.into();
56
57    ReduceAutotuneKey::generate(
58        elem_input,
59        elem_output,
60        &input.shape.dims,
61        input.strides[*axis] == 1,
62        *axis,
63    )
64}
65
66mod reduce_ops {
67    #![allow(missing_docs)]
68
69    use cubecl::reduce::ReduceFamily;
70
71    use super::*;
72
73    pub(crate) fn reduce_input_gen<
74        Run: CubeRuntime,
75        In: CubeElement,
76        Out: CubeElement,
77        Rd: ReduceFamily,
78    >(
79        _key: &ReduceAutotuneKey,
80        input: &CubeTensor<Run>,
81        output: &CubeTensor<Run>,
82        dim: &usize,
83        config: &Rd::Config,
84    ) -> (CubeTensor<Run>, CubeTensor<Run>, usize, Rd::Config) {
85        let random_bounds: (In, In) = ((-10.0_f32).elem::<In>(), (10.0_f32).elem::<In>());
86        let input = random_like_uniform(input, random_bounds.0, random_bounds.1);
87
88        let output = empty_device::<Run, Out>(
89            output.client.clone(),
90            output.device.clone(),
91            output.shape.clone(),
92        );
93
94        (input, output, *dim, *config)
95    }
96
97    pub(crate) fn reduce<
98        Run: CubeRuntime,
99        In: CubeElement,
100        Out: CubeElement,
101        Rd: cubecl::reduce::ReduceFamily,
102    >(
103        input: CubeTensor<Run>,
104        output: CubeTensor<Run>,
105        axis: usize,
106        config: Rd::Config,
107    ) -> Result<(), String> {
108        cubecl::reduce::reduce::<Run, In, Out, Rd>(
109            &input.client,
110            input.as_handle_ref(),
111            output.as_handle_ref(),
112            axis,
113            Some(cubecl::reduce::ReduceStrategy {
114                shared: false,
115                use_planes: false,
116            }),
117            config,
118        )
119        .map_err(|e| format!("{e}"))
120    }
121
122    pub(crate) fn reduce_shared<
123        Run: CubeRuntime,
124        In: CubeElement,
125        Out: CubeElement,
126        Rd: cubecl::reduce::ReduceFamily,
127    >(
128        input: CubeTensor<Run>,
129        output: CubeTensor<Run>,
130        axis: usize,
131        config: Rd::Config,
132    ) -> Result<(), String> {
133        cubecl::reduce::reduce::<Run, In, Out, Rd>(
134            &input.client,
135            input.as_handle_ref(),
136            output.as_handle_ref(),
137            axis,
138            Some(cubecl::reduce::ReduceStrategy {
139                shared: true,
140                use_planes: false,
141            }),
142            config,
143        )
144        .map_err(|e| format!("{e}"))
145    }
146
147    pub(crate) fn reduce_plane<
148        Run: CubeRuntime,
149        In: CubeElement,
150        Out: CubeElement,
151        Rd: cubecl::reduce::ReduceFamily,
152    >(
153        input: CubeTensor<Run>,
154        output: CubeTensor<Run>,
155        axis: usize,
156        config: Rd::Config,
157    ) -> Result<(), String> {
158        cubecl::reduce::reduce::<Run, In, Out, Rd>(
159            &input.client,
160            input.as_handle_ref(),
161            output.as_handle_ref(),
162            axis,
163            Some(cubecl::reduce::ReduceStrategy {
164                shared: false,
165                use_planes: true,
166            }),
167            config,
168        )
169        .map_err(|e| format!("{e}"))
170    }
171
172    pub(crate) fn reduce_shared_plane<
173        Run: CubeRuntime,
174        In: CubeElement,
175        Out: CubeElement,
176        Rd: cubecl::reduce::ReduceFamily,
177    >(
178        input: CubeTensor<Run>,
179        output: CubeTensor<Run>,
180        axis: usize,
181        config: Rd::Config,
182    ) -> Result<(), String> {
183        cubecl::reduce::reduce::<Run, In, Out, Rd>(
184            &input.client,
185            input.as_handle_ref(),
186            output.as_handle_ref(),
187            axis,
188            Some(cubecl::reduce::ReduceStrategy {
189                shared: true,
190                use_planes: true,
191            }),
192            config,
193        )
194        .map_err(|e| format!("{e}"))
195    }
196}
197
198/// Executes autotune on reduce operations.
199#[cfg(feature = "autotune")]
200pub fn autotune_sum<Run: CubeRuntime, E: CubeElement>(
201    client: &ComputeClient<Run::Server, Run::Channel>,
202    input: CubeTensor<Run>,
203) -> CubeTensor<Run> {
204    use sum_ops::*;
205
206    static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!("autotune-sum");
207
208    let tunables = TunableSet::new(create_key_sum::<Run>, sum_input_gen::<Run, E>)
209        .with_tunable(sum_chained::<Run, E>)
210        .with_tunable(sum_one_shot::<Run, E, 1>)
211        .with_tunable(sum_one_shot::<Run, E, 2>)
212        .with_tunable(sum_one_shot::<Run, E, 4>)
213        .with_tunable(sum_one_shot::<Run, E, 8>)
214        .with_tunable(sum_one_shot::<Run, E, 16>)
215        .with_tunable(sum_one_shot::<Run, E, 32>)
216        .with_tunable(sum_one_shot::<Run, E, 64>);
217
218    TUNER.execute(
219        &CubeTuneId::new::<Run>(&input.client, &input.device),
220        client,
221        &tunables,
222        input,
223    )
224}
225
226pub(crate) fn create_key_sum<Run: CubeRuntime>(input: &CubeTensor<Run>) -> CubeAutotuneKey {
227    CubeAutotuneKey::Sum(SumAutotuneKey::generate(input))
228}
229
230#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
231/// Autotune key representative of sum versions
232pub struct SumAutotuneKey {
233    dtype: burn_tensor::DType,
234    #[autotune(anchor)]
235    length: usize,
236}
237
238impl SumAutotuneKey {
239    pub(crate) fn generate<Run: CubeRuntime>(input: &CubeTensor<Run>) -> Self {
240        let dtype = input.dtype;
241        let length = input.shape.num_elements();
242        Self { dtype, length }
243    }
244}
245mod sum_ops {
246    #![allow(missing_docs)]
247    use super::*;
248
249    pub(crate) fn sum_input_gen<Run: CubeRuntime, E: CubeElement>(
250        _key: &CubeAutotuneKey,
251        input: &CubeTensor<Run>,
252    ) -> CubeTensor<Run> {
253        let random_bounds: (E, E) = ((-10.0_f32).elem::<E>(), (10.0_f32).elem::<E>());
254        random_like_uniform(input, random_bounds.0, random_bounds.1)
255    }
256
257    pub(crate) fn sum_one_shot<Run: CubeRuntime, E: CubeElement, const C: u32>(
258        input: CubeTensor<Run>,
259    ) -> Result<CubeTensor<Run>, String> {
260        let client = input.client.clone();
261        let device = input.device.clone();
262        let handle = client.create(E::as_bytes(&[E::from_int(0)]));
263        let output = CubeTensor::new_contiguous(client, device, [1].into(), handle, E::dtype());
264
265        cubecl::reduce::shared_sum::<Run, E>(
266            &input.client,
267            input.as_handle_ref(),
268            output.as_handle_ref(),
269            C,
270        )
271        .map_err(|e| e.to_string())
272        .map(|_| output)
273    }
274
275    #[cfg(feature = "autotune")]
276    pub(crate) fn sum_chained<Run: CubeRuntime, E: CubeElement>(
277        input: CubeTensor<Run>,
278    ) -> Result<CubeTensor<Run>, String> {
279        crate::kernel::reduce::reduce::<Run, E, E>(
280            input,
281            crate::kernel::reduce::ReduceStrategy::Autotune,
282            cubecl::reduce::instructions::ReduceFnConfig::Sum,
283        )
284        .map_err(|e| e.to_string())
285    }
286}