burn_cubecl/kernel/reduce/
base.rs1#[cfg(feature = "autotune")]
2use super::{autotune_reduce, autotune_sum};
3use crate::{
4 CubeRuntime,
5 ops::numeric::{empty_device_contiguous_dtype, zeros_client},
6 tensor::CubeTensor,
7};
8use burn_backend::{DType, TensorMetadata};
9use burn_std::Metadata;
10use cubecl::{
11 AutotuneKey,
12 client::ComputeClient,
13 features::AtomicUsage,
14 ir::{StorageType, Type},
15};
16use cubek::reduce::{
17 ReduceDtypes, ReduceError, ReduceStrategy,
18 components::instructions::ReduceOperationConfig,
19 launch::{RoutineStrategy, VectorizationStrategy},
20 routines::{BlueprintStrategy, unit::UnitStrategy},
21 shared_sum,
22};
23use serde::{Deserialize, Serialize};
24
25#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
26pub struct SumAutotuneKey {
28 dtype: burn_backend::DType,
30 #[autotune(anchor)]
32 length: usize,
33}
34
35fn supports_atomic_add<R: CubeRuntime>(client: &ComputeClient<R>, dtype: DType) -> bool {
37 client
38 .properties()
39 .atomic_type_usage(Type::new(StorageType::Atomic(dtype.into())))
40 .contains(AtomicUsage::Add)
41}
42
43pub fn sum_fallback<R: CubeRuntime>(
45 tensor: CubeTensor<R>,
46 mut strategy: SumStrategy,
47) -> Result<CubeTensor<R>, ReduceError> {
48 if matches!(strategy, SumStrategy::OneShot(_))
50 && !supports_atomic_add(&tensor.client, tensor.dtype)
51 {
52 strategy = SumStrategy::Chained(Default::default());
53 }
54 sum(tensor, strategy)
55}
56
57pub fn sum<Run: CubeRuntime>(
64 tensor: CubeTensor<Run>,
65 strategy: SumStrategy,
66) -> Result<CubeTensor<Run>, ReduceError> {
67 let client = tensor.client.clone();
68 let device = tensor.device.clone();
69
70 match strategy {
71 SumStrategy::OneShot(cube_count) => {
72 let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype);
73 let dtype = tensor.dtype;
74
75 shared_sum::<Run>(
76 &client,
77 tensor.binding(),
78 output.clone().binding(),
79 cube_count,
80 dtype.into(),
81 )?;
82
83 Ok(output)
84 }
85 SumStrategy::Chained(strategy) => {
86 reduce::<Run>(tensor, None, strategy, ReduceOperationConfig::Sum)
87 }
88 #[cfg(feature = "autotune")]
89 SumStrategy::Autotune => Ok(autotune_sum::<Run>(&client, tensor)),
90 }
91}
92
93pub enum SumStrategy {
95 OneShot(u32),
98 Chained(KernelReduceStrategy),
100 #[cfg(feature = "autotune")]
102 Autotune,
103}
104
105impl Default for SumStrategy {
106 fn default() -> Self {
107 #[cfg(feature = "autotune")]
108 return Self::Autotune;
109
110 #[cfg(not(feature = "autotune"))]
111 return Self::OneShot(4);
112 }
113}
114
115pub fn reduce<Run: CubeRuntime>(
122 mut tensor: CubeTensor<Run>,
123 output_dtype: Option<DType>,
124 strategy: KernelReduceStrategy,
125 config: ReduceOperationConfig,
126) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
127 let sorted_axis = argsort(tensor.meta.shape());
130 for axis in sorted_axis {
131 tensor = reduce_dim::<Run>(tensor, output_dtype, axis, strategy.clone(), config)?;
132 }
133 *tensor.meta = Metadata::new([1], [1]);
135 Ok(tensor)
136}
137
138fn argsort(shape: &[usize]) -> Vec<usize> {
139 let mut indices = (0..shape.len()).collect::<Vec<_>>();
140 indices.sort_by_key(|&i| &shape[i]);
141 indices
142}
143
144pub fn reduce_dim<Run: CubeRuntime>(
152 input: CubeTensor<Run>,
153 output_dtype: Option<DType>,
154 dim: usize,
155 strategy: KernelReduceStrategy,
156 config: ReduceOperationConfig,
157) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
158 debug_assert!(
159 !matches!(
160 config,
161 ReduceOperationConfig::ArgMax
162 | ReduceOperationConfig::ArgMin
163 | ReduceOperationConfig::ArgTopK(_)
164 ) || output_dtype.is_some(),
165 "The `output_dtype` has to be `Some` only when the `config` is `ArgMax`, `ArgMin` or `ArgTopK`.
166 "
167 );
168
169 let accumulator_len = match config {
170 ReduceOperationConfig::ArgTopK(k) => k,
171 ReduceOperationConfig::TopK(k) => k,
172 _ => 1,
173 };
174 let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into));
175 let client = input.client.clone();
176 let output = init_reduce_output::<Run>(&input, dim, &dtypes, accumulator_len).ok_or(
177 cubek::reduce::ReduceError::InvalidAxis {
178 axis: dim,
179 rank: input.meta.num_dims(),
180 },
181 )?;
182
183 let result = match strategy {
184 KernelReduceStrategy::Unspecified => cubek::reduce::reduce::<Run>(
185 &client,
186 input.binding(),
187 output.clone().binding(),
188 dim,
189 ReduceStrategy {
190 routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
191 vectorization: VectorizationStrategy {
192 parallel_output_vectorization: false,
193 },
194 },
195 config,
196 dtypes,
197 ),
198 KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::<Run>(
199 &client,
200 input.binding(),
201 output.clone().binding(),
202 dim,
203 strategy,
204 config,
205 dtypes,
206 ),
207 #[cfg(feature = "autotune")]
208 KernelReduceStrategy::Autotune => {
209 autotune_reduce::<Run>(&client, input, output.clone(), dim, config, dtypes);
210 Ok(())
211 }
212 };
213 result.map(|_| output)
214}
215
216pub fn init_reduce_output<Run: CubeRuntime>(
219 input: &CubeTensor<Run>,
220 dim: usize,
221 dtypes: &ReduceDtypes,
222 accumulator_len: usize,
223) -> Option<CubeTensor<Run>> {
224 (dim < input.meta.num_dims()).then(|| {
225 let mut shape_out = input.shape();
226 shape_out[dim] = accumulator_len;
227 empty_device_contiguous_dtype(
228 input.client.clone(),
229 input.device.clone(),
230 shape_out,
231 dtypes.output.elem_type().into(),
232 )
233 })
234}
235
236#[derive(Clone, Debug)]
238pub enum KernelReduceStrategy {
239 Unspecified,
242 Specific(cubek::reduce::launch::ReduceStrategy),
244 #[cfg(feature = "autotune")]
246 Autotune,
247}
248
249impl Default for KernelReduceStrategy {
250 fn default() -> Self {
251 #[cfg(feature = "autotune")]
252 return Self::Autotune;
253
254 #[cfg(not(feature = "autotune"))]
255 return Self::Unspecified;
256 }
257}