1use cubecl_core::{
2 prelude::*, server::ComputeServer, tensor_line_size_parallel, tensor_line_size_perpendicular,
3};
4use cubecl_std::tensor::is_contiguous;
5
6use crate::ReduceStrategy;
7
8const DEFAULT_PLANE_COUNT: u32 = 8;
10
11#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
12pub enum LineMode {
13 Parallel,
14 Perpendicular,
15}
16
17#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
18pub enum BoundChecksInner {
20 None,
22 Mask,
26 Branch,
30}
31
32#[derive(Debug, Clone)]
33pub struct ReduceConfig {
34 pub cube_count: CubeCount,
35 pub cube_dim: CubeDim,
36 pub line_mode: LineMode,
37 pub line_size_input: u32,
38 pub line_size_output: u32,
39 pub bound_checks: bool,
40 pub bound_checks_inner: BoundChecksInner,
41}
42
43impl ReduceConfig {
44 pub(crate) fn generate<R: Runtime, In: CubePrimitive>(
45 client: &ComputeClient<R::Server>,
46 input: &TensorHandleRef<R>,
47 output: &TensorHandleRef<R>,
48 axis: usize,
49 strategy: &ReduceStrategy,
50 ) -> ReduceConfig {
51 let reduce_count = output.size() as u32;
52 ReduceConfig::new()
53 .generate_line_mode(input, axis)
54 .generate_line_size::<R, In>(input, output, axis)
55 .generate_cube_dim(client, strategy.use_planes)
56 .generate_cube_count::<R>(reduce_count, strategy)
57 }
58
59 fn new() -> Self {
60 Self {
62 cube_count: CubeCount::new_single(),
63 cube_dim: CubeDim::new_single(),
64 line_mode: LineMode::Parallel,
65 line_size_input: 1,
66 line_size_output: 1,
67 bound_checks: true,
68 bound_checks_inner: BoundChecksInner::Mask,
69 }
70 }
71
72 fn generate_line_mode<R: Runtime>(mut self, input: &TensorHandleRef<R>, axis: usize) -> Self {
73 let stride = input.strides[axis];
74 self.line_mode = if stride == 1 {
75 LineMode::Parallel
76 } else {
77 LineMode::Perpendicular
78 };
79 self
80 }
81
82 fn generate_line_size<R: Runtime, In: CubePrimitive>(
83 mut self,
84 input: &TensorHandleRef<R>,
85 output: &TensorHandleRef<R>,
86 axis: usize,
87 ) -> Self {
88 let supported_line_sizes = R::io_optimized_line_sizes_unchecked(size_of::<In>());
89 self.line_size_input = match self.line_mode {
90 LineMode::Parallel => {
91 tensor_line_size_parallel(supported_line_sizes, input.shape, input.strides, axis)
92 as u32
93 }
94 LineMode::Perpendicular => {
95 let mut input_axis_and_strides =
123 input.strides.iter().enumerate().collect::<Vec<_>>();
124 input_axis_and_strides.sort_by_key(|(_, stride)| *stride);
125 let input_sorted_axis = input_axis_and_strides
126 .into_iter()
127 .map(|(a, _)| a)
128 .take_while(|a| *a != axis);
129
130 let mut output_axis_and_strides =
131 output.strides.iter().enumerate().collect::<Vec<_>>();
132 output_axis_and_strides.sort_by_key(|(_, stride)| *stride);
133 let output_sorted_axis = output_axis_and_strides
134 .into_iter()
135 .filter_map(|(a, _)| (a != axis).then_some(a));
136
137 let max_line_size = input_sorted_axis
138 .zip(output_sorted_axis)
139 .filter_map(|(i, o)| (i == o).then_some(output.shape[i]))
140 .product();
141
142 tensor_line_size_perpendicular(
143 supported_line_sizes.filter(|size| {
144 *size as usize <= max_line_size && max_line_size % *size as usize == 0
145 }),
146 input.shape,
147 input.strides,
148 axis,
149 ) as u32
150 }
151 };
152
153 if self.line_size_input > 1 && self.line_mode == LineMode::Perpendicular {
154 let rank = output.strides.len();
156 let is_contiguous =
157 is_contiguous(&output.shape[axis..rank], &output.strides[axis..rank])
158 && output.strides[rank - 1] == 1;
159 let shape = output.shape.get(axis + 1).cloned().unwrap_or(1) as u32;
160
161 if is_contiguous && shape.is_multiple_of(self.line_size_input) {
162 self.line_size_output = self.line_size_input;
163 }
164 }
165 self
166 }
167
168 pub fn generate_cube_dim<S: ComputeServer>(
169 mut self,
170 client: &ComputeClient<S>,
171 use_planes: bool,
172 ) -> Self {
173 self.cube_dim = if use_planes {
174 let plane_dim = client.properties().hardware.plane_size_min;
175 CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT)
176 } else {
177 let plane_dim = client.properties().hardware.plane_size_max;
178 CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT)
179 };
180 self
181 }
182
183 pub fn generate_cube_count<R: Runtime>(
184 mut self,
185 reduce_count: u32,
186 strategy: &ReduceStrategy,
187 ) -> Self {
188 let agent_count_per_cube = match strategy {
190 ReduceStrategy { shared: true, .. } => 1,
191 ReduceStrategy { use_planes: true, .. } => self.cube_dim.y,
192 ReduceStrategy { use_planes: false, .. } => self.cube_dim.num_elems(),
193 };
194 let reduce_count_per_cube = match self.line_mode {
195 LineMode::Parallel => agent_count_per_cube,
196 LineMode::Perpendicular => agent_count_per_cube * self.line_size_input,
197 };
198
199 let cube_count = reduce_count.div_ceil(reduce_count_per_cube);
200
201 self.do_bound_checks_if(reduce_count_per_cube * cube_count > reduce_count);
202
203 let (max_x, max_y, _) = R::max_cube_count();
205 let mut cube_count_x = cube_count;
206 let mut cube_count_y = 1;
207 let mut cube_count_z = 1;
208 while cube_count_x > max_x {
209 cube_count_x /= 2;
210 cube_count_y *= 2;
211 }
212 while cube_count_y > max_y {
213 cube_count_y /= 2;
214 cube_count_z *= 2;
215 }
216 self.cube_count = CubeCount::new_3d(cube_count_x, cube_count_y, cube_count_z);
217 self.do_bound_checks_if(cube_count_x * cube_count_y != cube_count);
218
219 self
220 }
221
222 fn do_bound_checks_if(&mut self, condition: bool) {
223 self.bound_checks = self.bound_checks || condition;
224 }
225}