1use cubecl_core::{
2 channel::ComputeChannel, prelude::*, server::ComputeServer, tensor_line_size_parallel,
3 tensor_line_size_perpendicular,
4};
5use cubecl_std::tensor::is_contiguous;
6
7use crate::ReduceStrategy;
8
9const DEFAULT_CUBE_DIM: CubeDim = CubeDim::new_2d(32, 8);
11const DEFAULT_PLANE_COUNT: u32 = 8;
12
13#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
14pub enum LineMode {
15 Parallel,
16 Perpendicular,
17}
18
19#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
20pub enum BoundChecksInner {
22 None,
24 Mask,
28 Branch,
32}
33
34#[derive(Debug, Clone)]
35pub struct ReduceConfig {
36 pub cube_count: CubeCount,
37 pub cube_dim: CubeDim,
38 pub line_mode: LineMode,
39 pub line_size_input: u32,
40 pub line_size_output: u32,
41 pub bound_checks: bool,
42 pub bound_checks_inner: BoundChecksInner,
43}
44
45impl ReduceConfig {
46 pub(crate) fn generate<R: Runtime, In: CubePrimitive>(
47 client: &ComputeClient<R::Server, R::Channel>,
48 input: &TensorHandleRef<R>,
49 output: &TensorHandleRef<R>,
50 axis: usize,
51 strategy: &ReduceStrategy,
52 ) -> ReduceConfig {
53 let reduce_count = output.size() as u32;
54 ReduceConfig::new()
55 .generate_line_mode(input, axis)
56 .generate_line_size::<R, In>(input, output, axis)
57 .generate_cube_dim(client, strategy.use_planes)
58 .generate_cube_count::<R>(reduce_count, strategy)
59 }
60
61 fn new() -> Self {
62 Self {
64 cube_count: CubeCount::new_single(),
65 cube_dim: CubeDim::new_single(),
66 line_mode: LineMode::Parallel,
67 line_size_input: 1,
68 line_size_output: 1,
69 bound_checks: true,
70 bound_checks_inner: BoundChecksInner::Mask,
71 }
72 }
73
74 fn generate_line_mode<R: Runtime>(mut self, input: &TensorHandleRef<R>, axis: usize) -> Self {
75 let stride = input.strides[axis];
76 self.line_mode = if stride == 1 {
77 LineMode::Parallel
78 } else {
79 LineMode::Perpendicular
80 };
81 self
82 }
83
84 fn generate_line_size<R: Runtime, In: CubePrimitive>(
85 mut self,
86 input: &TensorHandleRef<R>,
87 output: &TensorHandleRef<R>,
88 axis: usize,
89 ) -> Self {
90 let elem = In::as_elem_native_unchecked();
91 let supported_line_sizes = R::line_size_elem(&elem);
92 self.line_size_input = match self.line_mode {
93 LineMode::Parallel => {
94 tensor_line_size_parallel(supported_line_sizes, input.shape, input.strides, axis)
95 as u32
96 }
97 LineMode::Perpendicular => {
98 let mut input_axis_and_strides =
126 input.strides.iter().enumerate().collect::<Vec<_>>();
127 input_axis_and_strides.sort_by_key(|(_, stride)| *stride);
128 let input_sorted_axis = input_axis_and_strides
129 .into_iter()
130 .map(|(a, _)| a)
131 .take_while(|a| *a != axis);
132
133 let mut output_axis_and_strides =
134 output.strides.iter().enumerate().collect::<Vec<_>>();
135 output_axis_and_strides.sort_by_key(|(_, stride)| *stride);
136 let output_sorted_axis = output_axis_and_strides
137 .into_iter()
138 .filter_map(|(a, _)| (a != axis).then_some(a));
139
140 let max_line_size = input_sorted_axis
141 .zip(output_sorted_axis)
142 .filter_map(|(i, o)| (i == o).then_some(output.shape[i]))
143 .product();
144
145 tensor_line_size_perpendicular(
146 supported_line_sizes.filter(|size| {
147 *size as usize <= max_line_size && max_line_size % *size as usize == 0
148 }),
149 input.shape,
150 input.strides,
151 axis,
152 ) as u32
153 }
154 };
155
156 if self.line_size_input > 1 && self.line_mode == LineMode::Perpendicular {
157 let rank = output.strides.len();
159 let is_contiguous =
160 is_contiguous(&output.shape[axis..rank], &output.strides[axis..rank])
161 && output.strides[rank - 1] == 1;
162 let shape = output.shape.get(axis + 1).cloned().unwrap_or(1) as u32;
163
164 if is_contiguous && shape % self.line_size_input == 0 {
165 self.line_size_output = self.line_size_input;
166 }
167 }
168 self
169 }
170
171 pub fn generate_cube_dim<S: ComputeServer, C: ComputeChannel<S>>(
172 mut self,
173 client: &ComputeClient<S, C>,
174 use_planes: bool,
175 ) -> Self {
176 self.cube_dim = if use_planes {
177 let plane_dim = client.properties().hardware.plane_size_min;
178 CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT)
179 } else {
180 DEFAULT_CUBE_DIM
181 };
182 self
183 }
184
185 pub fn generate_cube_count<R: Runtime>(
186 mut self,
187 reduce_count: u32,
188 strategy: &ReduceStrategy,
189 ) -> Self {
190 let agent_count_per_cube = match strategy {
192 ReduceStrategy { shared: true, .. } => 1,
193 ReduceStrategy { use_planes: true, .. } => self.cube_dim.y,
194 ReduceStrategy { use_planes: false, .. } => self.cube_dim.num_elems(),
195 };
196 let reduce_count_per_cube = match self.line_mode {
197 LineMode::Parallel => agent_count_per_cube,
198 LineMode::Perpendicular => agent_count_per_cube * self.line_size_input,
199 };
200
201 let cube_count = reduce_count.div_ceil(reduce_count_per_cube);
202
203 self.do_bound_checks_if(reduce_count_per_cube * cube_count > reduce_count);
204
205 let (max_x, max_y, _) = R::max_cube_count();
207 let mut cube_count_x = cube_count;
208 let mut cube_count_y = 1;
209 let mut cube_count_z = 1;
210 while cube_count_x > max_x {
211 cube_count_x /= 2;
212 cube_count_y *= 2;
213 }
214 while cube_count_y > max_y {
215 cube_count_y /= 2;
216 cube_count_z *= 2;
217 }
218 self.cube_count = CubeCount::new_3d(cube_count_x, cube_count_y, cube_count_z);
219 self.do_bound_checks_if(cube_count_x * cube_count_y != cube_count);
220
221 self
222 }
223
224 fn do_bound_checks_if(&mut self, condition: bool) {
225 self.bound_checks = self.bound_checks || condition;
226 }
227}