1#![allow(missing_docs)]
2
3use burn_tensor::ElementConversion;
4use cubecl::{
5 client::ComputeClient,
6 tune,
7 tune::{local_tuner, tune_with, LocalTuner},
8 AutotuneKey,
9};
10use serde::{Deserialize, Serialize};
11
12use crate::{
13 kernel::prng::random_like_uniform, ops::numeric::empty_device, tensor::JitTensor,
14 JitAutotuneKey, JitElement, JitRuntime, JitTuneId,
15};
16
17pub fn autotune_reduce<
19 Run: JitRuntime,
20 In: JitElement,
21 Out: JitElement,
22 Rd: cubecl::reduce::Reduce,
23>(
24 client: &ComputeClient<Run::Server, Run::Channel>,
25 input: JitTensor<Run>,
26 output: JitTensor<Run>,
27 dim: usize,
28) -> Result<(), cubecl::reduce::ReduceError> {
29 static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();
30
31 TUNER.execute(
32 &JitTuneId::new::<Run>(&input.device),
33 client,
34 Box::new(ReduceOps::<Run, In, Out, Rd>::new(input, output, dim)),
35 );
36
37 Ok(())
38}
39
40#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
41pub struct ReduceAutotuneKey {
43 dtype: burn_tensor::DType,
44 #[autotune(anchor)]
45 reduce_axis_shape: usize,
46 #[autotune(anchor)]
47 reduce_axis_stride: usize,
48 #[autotune(anchor)]
49 outer_axes_product: usize, }
51
52impl ReduceAutotuneKey {
53 pub(crate) fn generate<Run: JitRuntime>(input: &JitTensor<Run>, axis: usize) -> Self {
54 let rank = input.shape.num_dims();
55
56 if axis > rank {
57 panic!("axis {axis} is out-of-bound for a rank of {rank}");
58 }
59
60 let dtype = input.dtype;
61 let reduce_axis_shape = input.shape.dims[axis];
62 let reduce_axis_stride = input.strides[axis];
63
64 let outer_axes_product = input
65 .strides
66 .iter()
67 .zip(input.shape.dims.iter())
68 .filter_map(|(stride, shape)| (*stride > reduce_axis_stride).then_some(shape))
69 .product();
70
71 Self::new(
72 dtype,
73 reduce_axis_shape,
74 reduce_axis_stride,
75 outer_axes_product,
76 )
77 }
78}
79
80pub(crate) fn create_key<Run: JitRuntime>(
81 input: &JitTensor<Run>,
82 _output: &JitTensor<Run>,
83 dim: &usize,
84) -> JitAutotuneKey {
85 JitAutotuneKey::Reduce(ReduceAutotuneKey::generate(input, *dim))
86}
87
88pub use reduce_ops::*;
89mod reduce_ops {
90 #![allow(missing_docs)]
91
92 use super::*;
93
94 #[tune(
95 operations(reduce, reduce_shared, reduce_plane, reduce_shared_plane),
96 create_key = create_key::<Run>,
97 should_run = should_run
98)]
99 fn reduce_ops<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
100 key: JitAutotuneKey,
101 input: JitTensor<Run>,
102 output: JitTensor<Run>,
103 dim: usize,
104 ) {
105 let random_bounds: (In, In) = ((-10.0_f32).elem::<In>(), (10.0_f32).elem::<In>());
106 let input = random_like_uniform(input, random_bounds.0, random_bounds.1);
107
108 let output = empty_device::<Run, Out>(
109 output.client.clone(),
110 output.device.clone(),
111 output.shape.clone(),
112 );
113
114 tune_with!(input, output, dim)
115 }
116
117 fn should_run<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
118 op: &ReduceOps<Run, In, Out, Rd>,
119 _key: &JitAutotuneKey,
120 index: usize,
121 ) -> bool {
122 match index {
123 2 | 3 => {
125 let properties = op.input.client.properties();
126 properties.feature_enabled(cubecl::Feature::Plane)
127 && properties
128 .hardware_properties()
129 .defined_plane_size()
130 .is_some()
131 }
132 _ => true,
133 }
134 }
135
136 fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
137 input: JitTensor<Run>,
138 output: JitTensor<Run>,
139 axis: usize,
140 ) -> Result<(), String> {
141 cubecl::reduce::reduce::<Run, In, Out, Rd>(
142 &input.client,
143 input.as_handle_ref(),
144 output.as_handle_ref(),
145 axis,
146 Some(cubecl::reduce::ReduceStrategy {
147 shared: false,
148 use_planes: false,
149 }),
150 )
151 .map_err(|e| format!("{e}"))
152 }
153
154 fn reduce_shared<
155 Run: JitRuntime,
156 In: JitElement,
157 Out: JitElement,
158 Rd: cubecl::reduce::Reduce,
159 >(
160 input: JitTensor<Run>,
161 output: JitTensor<Run>,
162 axis: usize,
163 ) -> Result<(), String> {
164 cubecl::reduce::reduce::<Run, In, Out, Rd>(
165 &input.client,
166 input.as_handle_ref(),
167 output.as_handle_ref(),
168 axis,
169 Some(cubecl::reduce::ReduceStrategy {
170 shared: true,
171 use_planes: false,
172 }),
173 )
174 .map_err(|e| format!("{e}"))
175 }
176
177 fn reduce_plane<
178 Run: JitRuntime,
179 In: JitElement,
180 Out: JitElement,
181 Rd: cubecl::reduce::Reduce,
182 >(
183 input: JitTensor<Run>,
184 output: JitTensor<Run>,
185 axis: usize,
186 ) -> Result<(), String> {
187 cubecl::reduce::reduce::<Run, In, Out, Rd>(
188 &input.client,
189 input.as_handle_ref(),
190 output.as_handle_ref(),
191 axis,
192 Some(cubecl::reduce::ReduceStrategy {
193 shared: false,
194 use_planes: true,
195 }),
196 )
197 .map_err(|e| format!("{e}"))
198 }
199
200 fn reduce_shared_plane<
201 Run: JitRuntime,
202 In: JitElement,
203 Out: JitElement,
204 Rd: cubecl::reduce::Reduce,
205 >(
206 input: JitTensor<Run>,
207 output: JitTensor<Run>,
208 axis: usize,
209 ) -> Result<(), String> {
210 cubecl::reduce::reduce::<Run, In, Out, Rd>(
211 &input.client,
212 input.as_handle_ref(),
213 output.as_handle_ref(),
214 axis,
215 Some(cubecl::reduce::ReduceStrategy {
216 shared: true,
217 use_planes: true,
218 }),
219 )
220 .map_err(|e| format!("{e}"))
221 }
222}