1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::ReadWrite;
4use cubecl_std::tensor::r#virtual::VirtualTensor;
5
6use crate::BoundChecksInner;
7use crate::LineMode;
8use crate::ReduceParams;
9use crate::instructions::*;
10use crate::precision::ReducePrecision;
11
12#[derive(CubeType)]
14pub struct ReduceRange {
15 pub index_start: u32,
16 pub index_step: u32,
17 pub coordinate_start: u32,
18 pub coordinate_end: u32,
19 pub coordinate_step: u32,
20}
21
22#[cube]
23impl ReduceRange {
24 pub(crate) fn new<P: ReducePrecision, Out: Numeric>(
25 reduce_index: u32,
26 input: &VirtualTensor<P::EI>,
27 output: &mut VirtualTensor<Out, ReadWrite>,
28 axis_reduce: u32,
29 #[comptime] params: ReduceParams,
30 ) -> ReduceRange {
31 match comptime!(params.line_mode) {
32 LineMode::Parallel => {
33 Self::new_parallel::<P, Out>(reduce_index, input, output, axis_reduce, params)
34 }
35 LineMode::Perpendicular => {
36 Self::new_perpendicular::<P, Out>(reduce_index, input, output, axis_reduce, params)
37 }
38 }
39 }
40
41 fn new_parallel<P: ReducePrecision, Out: Numeric>(
42 reduce_index: u32,
43 input: &VirtualTensor<P::EI>,
44 output: &mut VirtualTensor<Out, ReadWrite>,
45 axis_reduce: u32,
46 #[comptime] params: ReduceParams,
47 ) -> ReduceRange {
48 let shape_axis = input.shape(axis_reduce);
49
50 let mut index_start = 0;
51 for axis in 0..input.rank() {
52 let coordinate = output.coordinate(reduce_index, axis);
53 index_start += coordinate * input.stride(axis);
54 }
55 index_start /= params.line_size_input;
56
57 let coordinate_end = shape_axis;
58
59 let coordinate_step = if params.shared.is_some() {
60 CUBE_DIM * params.line_size_input
61 } else if params.use_planes {
62 CUBE_DIM_X * params.line_size_input
63 } else {
64 params.line_size_input.runtime()
65 };
66
67 ReduceRange {
68 index_start,
69 index_step: 1,
70 coordinate_start: 0,
71 coordinate_end,
72 coordinate_step,
73 }
74 }
75
76 fn new_perpendicular<P: ReducePrecision, Out: Numeric>(
77 reduce_index: u32,
78 input: &VirtualTensor<P::EI>,
79 output: &mut VirtualTensor<Out, ReadWrite>,
80 axis_reduce: u32,
81 #[comptime] params: ReduceParams,
82 ) -> ReduceRange {
83 let shape_axis = input.shape(axis_reduce);
84
85 let mut index_start = 0;
86 for axis in 0..input.rank() {
87 let coordinate = output.coordinate(reduce_index * params.line_size_input, axis);
88 index_start += coordinate * input.stride(axis);
89 }
90 index_start /= params.line_size_input;
91
92 let index_step = input.stride(axis_reduce) / params.line_size_input;
93
94 let coordinate_end = shape_axis;
95
96 let coordinate_step = if params.shared.is_some() {
97 CUBE_DIM
98 } else if params.use_planes {
99 CUBE_DIM_X
100 } else {
101 1_u32.runtime()
102 };
103
104 ReduceRange {
105 index_start,
106 index_step,
107 coordinate_start: 0,
108 coordinate_step,
109 coordinate_end,
110 }
111 }
112}
113
114#[cube]
123pub fn reduce_slice<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
124 items: &I,
125 range: ReduceRange,
126 inst: &R,
127 #[comptime] line_size: u32,
128 #[comptime] line_mode: LineMode,
129) -> R::AccumulatorItem {
130 let mut accumulator = R::null_accumulator(inst, line_size);
131
132 let mut index = range.index_start;
133 for coordinate in range_stepped(
134 range.coordinate_start,
135 range.coordinate_end,
136 range.coordinate_step,
137 ) {
138 let requirements = R::requirements(inst);
139 let coordinates = if comptime![requirements.coordinates] {
140 ReduceCoordinate::new_Required(fill_coordinate_line(coordinate, line_size, line_mode))
141 } else {
142 ReduceCoordinate::new_NotRequired()
143 };
144 reduce_inplace::<P, R>(
145 inst,
146 &mut accumulator,
147 items.read(index),
148 coordinates,
149 false,
150 );
151 index += range.index_step;
152 }
153
154 accumulator
155}
156
157#[cube]
170pub fn reduce_slice_plane<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
171 items: &I,
172 inst: &R,
173 range: ReduceRange,
174 #[comptime] line_size: u32,
175 #[comptime] line_mode: LineMode,
176 #[comptime] bound_checks: BoundChecksInner,
177) -> R::AccumulatorItem {
178 let plane_dim = CUBE_DIM_X;
179
180 let mut accumulator = R::null_accumulator(inst, line_size);
181
182 let mut first_index = range.index_start;
183 for first_coordinate in range_stepped(
184 range.coordinate_start,
185 range.coordinate_end,
186 range.coordinate_step,
187 ) {
188 let unit_coordinate_offset = match line_mode {
189 LineMode::Parallel => UNIT_POS_X * line_size,
190 LineMode::Perpendicular => UNIT_POS_X,
191 };
192 let unit_coordinate = first_coordinate + unit_coordinate_offset;
193
194 let requirements = R::requirements(inst);
195 let coordinates = if comptime![requirements.coordinates] {
196 ReduceCoordinate::new_Required(fill_coordinate_line(
197 unit_coordinate,
198 line_size,
199 line_mode,
200 ))
201 } else {
202 ReduceCoordinate::new_NotRequired()
203 };
204
205 let index = first_index + UNIT_POS_X * range.index_step;
206 let item = match bound_checks {
207 BoundChecksInner::None => items.read(index),
208 BoundChecksInner::Mask => {
209 let mask = unit_coordinate < range.coordinate_end;
210 let index = index * u32::cast_from(mask);
211 select(mask, items.read(index), R::null_input(inst, line_size))
212 }
213 BoundChecksInner::Branch => {
214 if unit_coordinate < range.coordinate_end {
215 items.read(index)
216 } else {
217 R::null_input(inst, line_size)
218 }
219 }
220 };
221
222 reduce_inplace::<P, R>(inst, &mut accumulator, item, coordinates, true);
223
224 first_index += plane_dim * range.index_step;
225 }
226 accumulator
227}
228
229#[cube]
242pub fn reduce_slice_shared<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
243 items: &I,
244 inst: &R,
245 range: ReduceRange,
246 #[comptime] accumulator_size: u32,
247 #[comptime] line_size: u32,
248 #[comptime] line_mode: LineMode,
249 #[comptime] use_planes: bool,
250 #[comptime] bound_checks: BoundChecksInner,
251) -> R::SharedAccumulator {
252 let accumulator_index = if use_planes { UNIT_POS_Y } else { UNIT_POS };
254
255 let requirements = R::requirements(inst);
256 let mut accumulator =
257 R::SharedAccumulator::allocate(accumulator_size, line_size, requirements.coordinates);
258
259 R::SharedAccumulator::write(
260 &mut accumulator,
261 accumulator_index,
262 R::null_accumulator(inst, line_size),
263 );
264
265 let mut first_index = range.index_start;
266 for first_coordinate in range_stepped(
267 range.coordinate_start,
268 range.coordinate_end,
269 range.coordinate_step,
270 ) {
271 let unit_coordinate_offset = match line_mode {
272 LineMode::Parallel => UNIT_POS * line_size,
273 LineMode::Perpendicular => UNIT_POS,
274 };
275 let unit_coordinate = first_coordinate + unit_coordinate_offset;
276
277 let index = first_index + UNIT_POS * range.index_step;
278
279 let item = match bound_checks {
280 BoundChecksInner::None => items.read(index),
281 BoundChecksInner::Mask => {
282 let mask = unit_coordinate < range.coordinate_end;
283 let index = index * u32::cast_from(mask);
284 select(mask, items.read(index), R::null_input(inst, line_size))
285 }
286 BoundChecksInner::Branch => {
287 if unit_coordinate < range.coordinate_end {
288 items.read(index)
289 } else {
290 R::null_input(inst, line_size)
291 }
292 }
293 };
294
295 let coordinates = if comptime! {requirements.coordinates} {
296 let coordinate = fill_coordinate_line(unit_coordinate, line_size, line_mode);
297 let coordinate = select(
298 unit_coordinate < range.coordinate_end,
299 coordinate,
300 Line::empty(line_size).fill(u32::MAX),
301 );
302
303 ReduceCoordinate::new_Required(coordinate)
304 } else {
305 ReduceCoordinate::new_NotRequired()
306 };
307
308 reduce_shared_inplace::<P, R>(
309 inst,
310 &mut accumulator,
311 accumulator_index,
312 item,
313 coordinates,
314 use_planes,
315 );
316 first_index += range.index_step * CUBE_DIM;
317 }
318 accumulator
319}
320
321#[cube]
324fn fill_coordinate_line(
325 first: u32,
326 #[comptime] line_size: u32,
327 #[comptime] line_mode: LineMode,
328) -> Line<u32> {
329 match comptime!(line_mode) {
330 LineMode::Parallel => {
331 let mut coordinates = Line::empty(line_size);
332 #[unroll]
333 for j in 0..line_size {
334 coordinates[j] = first + j;
335 }
336 coordinates
337 }
338 LineMode::Perpendicular => Line::empty(line_size).fill(first),
339 }
340}
341
342#[cube]
367pub fn reduce_tree<P: ReducePrecision, Inst: ReduceInstruction<P>>(
368 inst: &Inst,
369 accumulator: &mut Inst::SharedAccumulator,
370 #[comptime] size: u32,
371) -> Inst::AccumulatorItem {
372 if comptime!(size.is_power_of_two()) {
373 let mut num_active_units = size.runtime();
374 let mut jump = 1;
375 while num_active_units > 1 {
376 num_active_units /= 2;
377 let destination = jump * 2 * UNIT_POS;
378 let origin = jump * (2 * UNIT_POS + 1);
379 if UNIT_POS < num_active_units {
380 fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
381 }
382 jump *= 2;
383 sync_cube();
384 }
385 } else {
386 let mut num_remaining_items = size.runtime();
387 let mut jump = 1;
388 while num_remaining_items > 1 {
389 let destination = jump * 2 * UNIT_POS;
390 let origin = jump * (2 * UNIT_POS + 1);
391 if UNIT_POS < num_remaining_items / 2 {
392 fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
393 }
394 num_remaining_items = div_ceil(num_remaining_items, 2);
395 jump *= 2;
396 sync_cube();
397 }
398 }
399 sync_cube();
400 Inst::SharedAccumulator::read(accumulator, 0)
401}
402
403#[cube]
404#[allow(unknown_lints)] #[allow(clippy::manual_div_ceil)]
406fn div_ceil(a: u32, b: u32) -> u32 {
407 (a + b - 1) / b
408}