1use cubecl_core::{
2 prelude::*, server::ComputeServer, tensor_line_size_parallel, tensor_line_size_perpendicular,
3};
4use cubecl_std::tensor::is_contiguous;
5
6use crate::ReduceStrategy;
7
8const DEFAULT_PLANE_COUNT: u32 = 8;
10
11#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
12pub enum LineMode {
13 Parallel,
14 Perpendicular,
15}
16
17#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
18pub enum BoundChecksInner {
20 None,
22 Mask,
26 Branch,
30}
31
32#[derive(Debug, Clone)]
33pub struct ReduceConfig {
34 pub cube_count: CubeCount,
35 pub cube_dim: CubeDim,
36 pub line_mode: LineMode,
37 pub line_size_input: u32,
38 pub line_size_output: u32,
39 pub bound_checks: bool,
40 pub bound_checks_inner: BoundChecksInner,
41}
42
43impl ReduceConfig {
44 pub(crate) fn generate<R: Runtime>(
45 client: &ComputeClient<R::Server>,
46 input: &TensorHandleRef<R>,
47 output: &TensorHandleRef<R>,
48 axis: usize,
49 strategy: &ReduceStrategy,
50 dtype: StorageType,
51 ) -> ReduceConfig {
52 let reduce_count = output.size() as u32;
53 ReduceConfig::new()
54 .generate_line_mode(input, axis)
55 .generate_line_size::<R>(input, output, axis, dtype)
56 .generate_cube_dim(client, strategy.use_planes)
57 .generate_cube_count::<R>(reduce_count, strategy)
58 }
59
60 fn new() -> Self {
61 Self {
63 cube_count: CubeCount::new_single(),
64 cube_dim: CubeDim::new_single(),
65 line_mode: LineMode::Parallel,
66 line_size_input: 1,
67 line_size_output: 1,
68 bound_checks: true,
69 bound_checks_inner: BoundChecksInner::Mask,
70 }
71 }
72
73 fn generate_line_mode<R: Runtime>(mut self, input: &TensorHandleRef<R>, axis: usize) -> Self {
74 let stride = input.strides[axis];
75 self.line_mode = if stride == 1 {
76 LineMode::Parallel
77 } else {
78 LineMode::Perpendicular
79 };
80 self
81 }
82
83 fn generate_line_size<R: Runtime>(
84 mut self,
85 input: &TensorHandleRef<R>,
86 output: &TensorHandleRef<R>,
87 axis: usize,
88 dtype: StorageType,
89 ) -> Self {
90 let supported_line_sizes = R::io_optimized_line_sizes_unchecked(dtype.size());
91 self.line_size_input = match self.line_mode {
92 LineMode::Parallel => {
93 tensor_line_size_parallel(supported_line_sizes, input.shape, input.strides, axis)
94 as u32
95 }
96 LineMode::Perpendicular => {
97 let mut input_axis_and_strides =
125 input.strides.iter().enumerate().collect::<Vec<_>>();
126 input_axis_and_strides.sort_by_key(|(_, stride)| *stride);
127 let input_sorted_axis = input_axis_and_strides
128 .into_iter()
129 .map(|(a, _)| a)
130 .take_while(|a| *a != axis);
131
132 let mut output_axis_and_strides =
133 output.strides.iter().enumerate().collect::<Vec<_>>();
134 output_axis_and_strides.sort_by_key(|(_, stride)| *stride);
135 let output_sorted_axis = output_axis_and_strides
136 .into_iter()
137 .filter_map(|(a, _)| (a != axis).then_some(a));
138
139 let max_line_size = input_sorted_axis
140 .zip(output_sorted_axis)
141 .filter_map(|(i, o)| (i == o).then_some(output.shape[i]))
142 .product();
143
144 tensor_line_size_perpendicular(
145 supported_line_sizes.filter(|size| {
146 *size as usize <= max_line_size && max_line_size % *size as usize == 0
147 }),
148 input.shape,
149 input.strides,
150 axis,
151 ) as u32
152 }
153 };
154
155 if self.line_size_input > 1 && self.line_mode == LineMode::Perpendicular {
156 let rank = output.strides.len();
158 let is_contiguous =
159 is_contiguous(&output.shape[axis..rank], &output.strides[axis..rank])
160 && output.strides[rank - 1] == 1;
161 let shape = output.shape.get(axis + 1).cloned().unwrap_or(1) as u32;
162
163 if is_contiguous && shape.is_multiple_of(self.line_size_input) {
164 self.line_size_output = self.line_size_input;
165 }
166 }
167 self
168 }
169
170 pub fn generate_cube_dim<S: ComputeServer>(
171 mut self,
172 client: &ComputeClient<S>,
173 use_planes: bool,
174 ) -> Self {
175 let hw_properties = &client.properties().hardware;
176
177 let plane_dim = if use_planes {
178 hw_properties.plane_size_min
179 } else {
180 hw_properties.plane_size_max
181 };
182
183 let plane_count = if plane_dim * DEFAULT_PLANE_COUNT > hw_properties.max_units_per_cube {
184 hw_properties.max_units_per_cube / plane_dim
185 } else {
186 DEFAULT_PLANE_COUNT
187 };
188
189 self.cube_dim = CubeDim::new_2d(plane_dim, plane_count);
190 self
191 }
192
193 pub fn generate_cube_count<R: Runtime>(
194 mut self,
195 reduce_count: u32,
196 strategy: &ReduceStrategy,
197 ) -> Self {
198 let agent_count_per_cube = match strategy {
200 ReduceStrategy { shared: true, .. } => 1,
201 ReduceStrategy { use_planes: true, .. } => self.cube_dim.y,
202 ReduceStrategy { use_planes: false, .. } => self.cube_dim.num_elems(),
203 };
204 let reduce_count_per_cube = match self.line_mode {
205 LineMode::Parallel => agent_count_per_cube,
206 LineMode::Perpendicular => agent_count_per_cube * self.line_size_input,
207 };
208
209 let cube_count = reduce_count.div_ceil(reduce_count_per_cube);
210
211 self.do_bound_checks_if(reduce_count_per_cube * cube_count > reduce_count);
212
213 let (max_x, max_y, _) = R::max_cube_count();
215 let mut cube_count_x = cube_count;
216 let mut cube_count_y = 1;
217 let mut cube_count_z = 1;
218 while cube_count_x > max_x {
219 cube_count_x /= 2;
220 cube_count_y *= 2;
221 }
222 while cube_count_y > max_y {
223 cube_count_y /= 2;
224 cube_count_z *= 2;
225 }
226 self.cube_count = CubeCount::new_3d(cube_count_x, cube_count_y, cube_count_z);
227 self.do_bound_checks_if(cube_count_x * cube_count_y != cube_count);
228
229 self
230 }
231
232 fn do_bound_checks_if(&mut self, condition: bool) {
233 self.bound_checks = self.bound_checks || condition;
234 }
235}