burn_cubecl/kernel/reduce/
base.rs1#[cfg(feature = "autotune")]
2use super::{autotune_reduce, autotune_sum};
3use crate::{
4 CubeRuntime,
5 ops::numeric::{empty_device_dtype, zeros_client},
6 tensor::CubeTensor,
7};
8use burn_backend::{DType, Shape};
9use cubecl::{AutotuneKey, client::ComputeClient, features::TypeUsage, ir::StorageType};
10use cubek::reduce::{
11 ReduceDtypes, ReduceError, ReduceStrategy,
12 components::instructions::ReduceOperationConfig,
13 launch::{LineSizeStrategy, RoutineStrategy},
14 routines::{BlueprintStrategy, unit::UnitStrategy},
15 shared_sum,
16};
17use serde::{Deserialize, Serialize};
18
19#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
20pub struct SumAutotuneKey {
22 pub dtype: burn_backend::DType,
24 #[autotune(anchor)]
26 pub length: usize,
27}
28
29fn supports_atomic_add<R: CubeRuntime>(client: &ComputeClient<R>, dtype: DType) -> bool {
31 client
32 .properties()
33 .type_usage(StorageType::Atomic(dtype.into()))
34 .contains(TypeUsage::AtomicAdd)
35}
36
37pub fn sum_fallback<R: CubeRuntime>(
39 tensor: CubeTensor<R>,
40 mut strategy: SumStrategy,
41) -> Result<CubeTensor<R>, ReduceError> {
42 if matches!(strategy, SumStrategy::OneShot(_))
44 && !supports_atomic_add(&tensor.client, tensor.dtype)
45 {
46 strategy = SumStrategy::Chained(Default::default());
47 }
48 sum(tensor, strategy)
49}
50
51pub fn sum<Run: CubeRuntime>(
58 tensor: CubeTensor<Run>,
59 strategy: SumStrategy,
60) -> Result<CubeTensor<Run>, ReduceError> {
61 let client = tensor.client.clone();
62 let device = tensor.device.clone();
63
64 match strategy {
65 SumStrategy::OneShot(cube_count) => {
66 let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype);
67 shared_sum::<Run>(
68 &client,
69 tensor.as_handle_ref(),
70 output.as_handle_ref(),
71 cube_count,
72 tensor.dtype.into(),
73 )?;
74
75 Ok(output)
76 }
77 SumStrategy::Chained(strategy) => {
78 reduce::<Run>(tensor, None, strategy, ReduceOperationConfig::Sum)
79 }
80 #[cfg(feature = "autotune")]
81 SumStrategy::Autotune => Ok(autotune_sum::<Run>(&client, tensor)),
82 }
83}
84
85pub enum SumStrategy {
87 OneShot(u32),
90 Chained(KernelReduceStrategy),
92 #[cfg(feature = "autotune")]
94 Autotune,
95}
96
97impl Default for SumStrategy {
98 fn default() -> Self {
99 #[cfg(feature = "autotune")]
100 return Self::Autotune;
101
102 #[cfg(not(feature = "autotune"))]
103 return Self::OneShot(4);
104 }
105}
106
107pub fn reduce<Run: CubeRuntime>(
114 mut tensor: CubeTensor<Run>,
115 output_dtype: Option<DType>,
116 strategy: KernelReduceStrategy,
117 config: ReduceOperationConfig,
118) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
119 let sorted_axis = argsort(&tensor.shape);
122 for axis in sorted_axis {
123 tensor = reduce_dim::<Run>(tensor, output_dtype, axis, strategy.clone(), config)?;
124 }
125 tensor.shape = Shape::new([1]);
127 tensor.strides = vec![1];
128 Ok(tensor)
129}
130
131fn argsort(shape: &[usize]) -> Vec<usize> {
132 let mut indices = (0..shape.len()).collect::<Vec<_>>();
133 indices.sort_by_key(|&i| &shape[i]);
134 indices
135}
136
137pub fn reduce_dim<Run: CubeRuntime>(
145 input: CubeTensor<Run>,
146 output_dtype: Option<DType>,
147 dim: usize,
148 strategy: KernelReduceStrategy,
149 config: ReduceOperationConfig,
150) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
151 debug_assert!(
152 !matches!(
153 config,
154 ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin
155 ) || output_dtype.is_some(),
156 "The `output_dtype` has to be `Some` only when the `config` is `ArgMax` or `ArgMin`.
157 "
158 );
159
160 let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into));
161 let client = input.client.clone();
162 let output = init_reduce_output::<Run>(&input, dim, &dtypes).ok_or(
163 cubek::reduce::ReduceError::InvalidAxis {
164 axis: dim,
165 rank: input.shape.num_dims(),
166 },
167 )?;
168
169 let result = match strategy {
170 KernelReduceStrategy::Unspecified => cubek::reduce::reduce::<Run>(
171 &client,
172 input.as_handle_ref(),
173 output.as_handle_ref(),
174 dim,
175 ReduceStrategy {
176 routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
177 line_size: LineSizeStrategy {
178 parallel_output_vectorization: false,
179 },
180 },
181 config,
182 dtypes,
183 ),
184 KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::<Run>(
185 &client,
186 input.as_handle_ref(),
187 output.as_handle_ref(),
188 dim,
189 strategy,
190 config,
191 dtypes,
192 ),
193 #[cfg(feature = "autotune")]
194 KernelReduceStrategy::Autotune => {
195 autotune_reduce::<Run>(&client, input, output.clone(), dim, config, dtypes);
196 Ok(())
197 }
198 };
199 result.map(|_| output)
200}
201
202pub fn init_reduce_output<Run: CubeRuntime>(
205 input: &CubeTensor<Run>,
206 dim: usize,
207 dtypes: &ReduceDtypes,
208) -> Option<CubeTensor<Run>> {
209 (dim < input.shape.num_dims()).then(|| {
210 let mut shape_out = input.shape.clone();
211 shape_out.dims[dim] = 1;
212 empty_device_dtype::<Run>(
213 input.client.clone(),
214 input.device.clone(),
215 shape_out,
216 dtypes.output.elem_type().into(),
217 )
218 })
219}
220
221#[derive(Clone, Debug)]
223pub enum KernelReduceStrategy {
224 Unspecified,
227 Specific(cubek::reduce::launch::ReduceStrategy),
229 #[cfg(feature = "autotune")]
231 Autotune,
232}
233
234impl Default for KernelReduceStrategy {
235 fn default() -> Self {
236 #[cfg(feature = "autotune")]
237 return Self::Autotune;
238
239 #[cfg(not(feature = "autotune"))]
240 return Self::Unspecified;
241 }
242}