1#![allow(missing_docs)]
2
3use burn_tensor::ElementConversion;
4use cubecl::{
5 AutotuneKey,
6 client::ComputeClient,
7 reduce::{ReduceFamily, tune_key::ReduceAutotuneKey},
8 tune::{LocalTuner, TunableSet, local_tuner},
9};
10use serde::{Deserialize, Serialize};
11
12use crate::{
13 CubeAutotuneKey, CubeElement, CubeRuntime, CubeTuneId, kernel::prng::random_like_uniform,
14 ops::numeric::empty_device, tensor::CubeTensor,
15};
16
17pub fn autotune_reduce<
19 Run: CubeRuntime,
20 In: CubeElement,
21 Out: CubeElement,
22 Rd: cubecl::reduce::ReduceFamily,
23>(
24 client: &ComputeClient<Run::Server, Run::Channel>,
25 input: CubeTensor<Run>,
26 output: CubeTensor<Run>,
27 dim: usize,
28 config: Rd::Config,
29) {
30 use reduce_ops::*;
31
32 static TUNER: LocalTuner<ReduceAutotuneKey, CubeTuneId> = local_tuner!("reduce-dim");
33
34 let tunables = TunableSet::new(create_key::<Run, Rd>, reduce_input_gen::<Run, In, Out, Rd>)
35 .with_tunable(reduce::<Run, In, Out, Rd>)
36 .with_tunable(reduce_shared::<Run, In, Out, Rd>)
37 .with_tunable(reduce_plane::<Run, In, Out, Rd>)
38 .with_tunable(reduce_shared_plane::<Run, In, Out, Rd>);
39
40 TUNER.execute(
41 &CubeTuneId::new::<Run>(&input.client, &input.device),
42 client,
43 &tunables,
44 (input, output, dim, config),
45 );
46}
47
48pub(crate) fn create_key<Run: CubeRuntime, Rd: ReduceFamily>(
49 input: &CubeTensor<Run>,
50 output: &CubeTensor<Run>,
51 axis: &usize,
52 _config: &Rd::Config,
53) -> ReduceAutotuneKey {
54 let elem_input = input.dtype.into();
55 let elem_output = output.dtype.into();
56
57 ReduceAutotuneKey::generate(
58 elem_input,
59 elem_output,
60 &input.shape.dims,
61 input.strides[*axis] == 1,
62 *axis,
63 )
64}
65
66mod reduce_ops {
67 #![allow(missing_docs)]
68
69 use cubecl::reduce::ReduceFamily;
70
71 use super::*;
72
73 pub(crate) fn reduce_input_gen<
74 Run: CubeRuntime,
75 In: CubeElement,
76 Out: CubeElement,
77 Rd: ReduceFamily,
78 >(
79 _key: &ReduceAutotuneKey,
80 input: &CubeTensor<Run>,
81 output: &CubeTensor<Run>,
82 dim: &usize,
83 config: &Rd::Config,
84 ) -> (CubeTensor<Run>, CubeTensor<Run>, usize, Rd::Config) {
85 let random_bounds: (In, In) = ((-10.0_f32).elem::<In>(), (10.0_f32).elem::<In>());
86 let input = random_like_uniform(input, random_bounds.0, random_bounds.1);
87
88 let output = empty_device::<Run, Out>(
89 output.client.clone(),
90 output.device.clone(),
91 output.shape.clone(),
92 );
93
94 (input, output, *dim, *config)
95 }
96
97 pub(crate) fn reduce<
98 Run: CubeRuntime,
99 In: CubeElement,
100 Out: CubeElement,
101 Rd: cubecl::reduce::ReduceFamily,
102 >(
103 input: CubeTensor<Run>,
104 output: CubeTensor<Run>,
105 axis: usize,
106 config: Rd::Config,
107 ) -> Result<(), String> {
108 cubecl::reduce::reduce::<Run, In, Out, Rd>(
109 &input.client,
110 input.as_handle_ref(),
111 output.as_handle_ref(),
112 axis,
113 Some(cubecl::reduce::ReduceStrategy {
114 shared: false,
115 use_planes: false,
116 }),
117 config,
118 )
119 .map_err(|e| format!("{e}"))
120 }
121
122 pub(crate) fn reduce_shared<
123 Run: CubeRuntime,
124 In: CubeElement,
125 Out: CubeElement,
126 Rd: cubecl::reduce::ReduceFamily,
127 >(
128 input: CubeTensor<Run>,
129 output: CubeTensor<Run>,
130 axis: usize,
131 config: Rd::Config,
132 ) -> Result<(), String> {
133 cubecl::reduce::reduce::<Run, In, Out, Rd>(
134 &input.client,
135 input.as_handle_ref(),
136 output.as_handle_ref(),
137 axis,
138 Some(cubecl::reduce::ReduceStrategy {
139 shared: true,
140 use_planes: false,
141 }),
142 config,
143 )
144 .map_err(|e| format!("{e}"))
145 }
146
147 pub(crate) fn reduce_plane<
148 Run: CubeRuntime,
149 In: CubeElement,
150 Out: CubeElement,
151 Rd: cubecl::reduce::ReduceFamily,
152 >(
153 input: CubeTensor<Run>,
154 output: CubeTensor<Run>,
155 axis: usize,
156 config: Rd::Config,
157 ) -> Result<(), String> {
158 cubecl::reduce::reduce::<Run, In, Out, Rd>(
159 &input.client,
160 input.as_handle_ref(),
161 output.as_handle_ref(),
162 axis,
163 Some(cubecl::reduce::ReduceStrategy {
164 shared: false,
165 use_planes: true,
166 }),
167 config,
168 )
169 .map_err(|e| format!("{e}"))
170 }
171
172 pub(crate) fn reduce_shared_plane<
173 Run: CubeRuntime,
174 In: CubeElement,
175 Out: CubeElement,
176 Rd: cubecl::reduce::ReduceFamily,
177 >(
178 input: CubeTensor<Run>,
179 output: CubeTensor<Run>,
180 axis: usize,
181 config: Rd::Config,
182 ) -> Result<(), String> {
183 cubecl::reduce::reduce::<Run, In, Out, Rd>(
184 &input.client,
185 input.as_handle_ref(),
186 output.as_handle_ref(),
187 axis,
188 Some(cubecl::reduce::ReduceStrategy {
189 shared: true,
190 use_planes: true,
191 }),
192 config,
193 )
194 .map_err(|e| format!("{e}"))
195 }
196}
197
198#[cfg(feature = "autotune")]
200pub fn autotune_sum<Run: CubeRuntime, E: CubeElement>(
201 client: &ComputeClient<Run::Server, Run::Channel>,
202 input: CubeTensor<Run>,
203) -> CubeTensor<Run> {
204 use sum_ops::*;
205
206 static TUNER: LocalTuner<CubeAutotuneKey, CubeTuneId> = local_tuner!("autotune-sum");
207
208 let tunables = TunableSet::new(create_key_sum::<Run>, sum_input_gen::<Run, E>)
209 .with_tunable(sum_chained::<Run, E>)
210 .with_tunable(sum_one_shot::<Run, E, 1>)
211 .with_tunable(sum_one_shot::<Run, E, 2>)
212 .with_tunable(sum_one_shot::<Run, E, 4>)
213 .with_tunable(sum_one_shot::<Run, E, 8>)
214 .with_tunable(sum_one_shot::<Run, E, 16>)
215 .with_tunable(sum_one_shot::<Run, E, 32>)
216 .with_tunable(sum_one_shot::<Run, E, 64>);
217
218 TUNER.execute(
219 &CubeTuneId::new::<Run>(&input.client, &input.device),
220 client,
221 &tunables,
222 input,
223 )
224}
225
226pub(crate) fn create_key_sum<Run: CubeRuntime>(input: &CubeTensor<Run>) -> CubeAutotuneKey {
227 CubeAutotuneKey::Sum(SumAutotuneKey::generate(input))
228}
229
230#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
231pub struct SumAutotuneKey {
233 dtype: burn_tensor::DType,
234 #[autotune(anchor)]
235 length: usize,
236}
237
238impl SumAutotuneKey {
239 pub(crate) fn generate<Run: CubeRuntime>(input: &CubeTensor<Run>) -> Self {
240 let dtype = input.dtype;
241 let length = input.shape.num_elements();
242 Self { dtype, length }
243 }
244}
245mod sum_ops {
246 #![allow(missing_docs)]
247 use super::*;
248
249 pub(crate) fn sum_input_gen<Run: CubeRuntime, E: CubeElement>(
250 _key: &CubeAutotuneKey,
251 input: &CubeTensor<Run>,
252 ) -> CubeTensor<Run> {
253 let random_bounds: (E, E) = ((-10.0_f32).elem::<E>(), (10.0_f32).elem::<E>());
254 random_like_uniform(input, random_bounds.0, random_bounds.1)
255 }
256
257 pub(crate) fn sum_one_shot<Run: CubeRuntime, E: CubeElement, const C: u32>(
258 input: CubeTensor<Run>,
259 ) -> Result<CubeTensor<Run>, String> {
260 let client = input.client.clone();
261 let device = input.device.clone();
262 let handle = client.create(E::as_bytes(&[E::from_int(0)]));
263 let output = CubeTensor::new_contiguous(client, device, [1].into(), handle, E::dtype());
264
265 cubecl::reduce::shared_sum::<Run, E>(
266 &input.client,
267 input.as_handle_ref(),
268 output.as_handle_ref(),
269 C,
270 )
271 .map_err(|e| e.to_string())
272 .map(|_| output)
273 }
274
275 #[cfg(feature = "autotune")]
276 pub(crate) fn sum_chained<Run: CubeRuntime, E: CubeElement>(
277 input: CubeTensor<Run>,
278 ) -> Result<CubeTensor<Run>, String> {
279 crate::kernel::reduce::reduce::<Run, E, E>(
280 input,
281 crate::kernel::reduce::ReduceStrategy::Autotune,
282 cubecl::reduce::instructions::ReduceFnConfig::Sum,
283 )
284 .map_err(|e| e.to_string())
285 }
286}