1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::VirtualTensor;
4
5use crate::BoundChecksInner;
6use crate::LineMode;
7use crate::ReduceParams;
8use crate::instructions::*;
9use crate::precision::ReducePrecision;
10
11#[derive(CubeType)]
13pub struct ReduceRange {
14 pub index_start: u32,
15 pub index_step: u32,
16 pub coordinate_start: u32,
17 pub coordinate_end: u32,
18 pub coordinate_step: u32,
19}
20
21#[cube]
22impl ReduceRange {
23 pub(crate) fn new<P: ReducePrecision, Out: Numeric>(
24 reduce_index: u32,
25 input: &VirtualTensor<P::EI>,
26 output: &mut VirtualTensor<Out, ReadWrite>,
27 axis_reduce: u32,
28 #[comptime] params: ReduceParams,
29 ) -> ReduceRange {
30 match comptime!(params.line_mode) {
31 LineMode::Parallel => {
32 Self::new_parallel::<P, Out>(reduce_index, input, output, axis_reduce, params)
33 }
34 LineMode::Perpendicular => {
35 Self::new_perpendicular::<P, Out>(reduce_index, input, output, axis_reduce, params)
36 }
37 }
38 }
39
40 fn new_parallel<P: ReducePrecision, Out: Numeric>(
41 reduce_index: u32,
42 input: &VirtualTensor<P::EI>,
43 output: &mut VirtualTensor<Out, ReadWrite>,
44 axis_reduce: u32,
45 #[comptime] params: ReduceParams,
46 ) -> ReduceRange {
47 let shape_axis = input.shape(axis_reduce);
48
49 let mut index_start = 0;
50 for axis in 0..input.rank() {
51 let coordinate = output.coordinate(reduce_index, axis);
52 index_start += coordinate * input.stride(axis);
53 }
54 index_start /= params.line_size_input;
55
56 let coordinate_end = shape_axis;
57
58 let coordinate_step = if params.shared.is_some() {
59 CUBE_DIM * params.line_size_input
60 } else if params.use_planes {
61 CUBE_DIM_X * params.line_size_input
62 } else {
63 params.line_size_input.runtime()
64 };
65
66 ReduceRange {
67 index_start,
68 index_step: 1,
69 coordinate_start: 0,
70 coordinate_end,
71 coordinate_step,
72 }
73 }
74
75 fn new_perpendicular<P: ReducePrecision, Out: Numeric>(
76 reduce_index: u32,
77 input: &VirtualTensor<P::EI>,
78 output: &mut VirtualTensor<Out, ReadWrite>,
79 axis_reduce: u32,
80 #[comptime] params: ReduceParams,
81 ) -> ReduceRange {
82 let shape_axis = input.shape(axis_reduce);
83
84 let mut index_start = 0;
85 for axis in 0..input.rank() {
86 let coordinate = output.coordinate(reduce_index * params.line_size_input, axis);
87 index_start += coordinate * input.stride(axis);
88 }
89 index_start /= params.line_size_input;
90
91 let index_step = input.stride(axis_reduce) / params.line_size_input;
92
93 let coordinate_end = shape_axis;
94
95 let coordinate_step = if params.shared.is_some() {
96 CUBE_DIM
97 } else if params.use_planes {
98 CUBE_DIM_X
99 } else {
100 1_u32.runtime()
101 };
102
103 ReduceRange {
104 index_start,
105 index_step,
106 coordinate_start: 0,
107 coordinate_step,
108 coordinate_end,
109 }
110 }
111}
112
113#[cube]
122pub fn reduce_slice<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
123 items: &I,
124 range: ReduceRange,
125 inst: &R,
126 #[comptime] line_size: u32,
127 #[comptime] line_mode: LineMode,
128) -> R::AccumulatorItem {
129 let mut accumulator = R::null_accumulator(inst, line_size);
130
131 let mut index = range.index_start;
132 for coordinate in range_stepped(
133 range.coordinate_start,
134 range.coordinate_end,
135 range.coordinate_step,
136 ) {
137 let requirements = R::requirements(inst);
138 let coordinates = if comptime![requirements.coordinates] {
139 ReduceCoordinate::new_Required(fill_coordinate_line(coordinate, line_size, line_mode))
140 } else {
141 ReduceCoordinate::new_NotRequired()
142 };
143 reduce_inplace::<P, R>(
144 inst,
145 &mut accumulator,
146 items.read(index),
147 coordinates,
148 false,
149 );
150 index += range.index_step;
151 }
152
153 accumulator
154}
155
156#[cube]
169pub fn reduce_slice_plane<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
170 items: &I,
171 inst: &R,
172 range: ReduceRange,
173 #[comptime] line_size: u32,
174 #[comptime] line_mode: LineMode,
175 #[comptime] bound_checks: BoundChecksInner,
176) -> R::AccumulatorItem {
177 let plane_dim = CUBE_DIM_X;
178
179 let mut accumulator = R::null_accumulator(inst, line_size);
180
181 let mut first_index = range.index_start;
182 for first_coordinate in range_stepped(
183 range.coordinate_start,
184 range.coordinate_end,
185 range.coordinate_step,
186 ) {
187 let unit_coordinate_offset = match line_mode {
188 LineMode::Parallel => UNIT_POS_X * line_size,
189 LineMode::Perpendicular => UNIT_POS_X,
190 };
191 let unit_coordinate = first_coordinate + unit_coordinate_offset;
192
193 let requirements = R::requirements(inst);
194 let coordinates = if comptime![requirements.coordinates] {
195 ReduceCoordinate::new_Required(fill_coordinate_line(
196 unit_coordinate,
197 line_size,
198 line_mode,
199 ))
200 } else {
201 ReduceCoordinate::new_NotRequired()
202 };
203
204 let index = first_index + UNIT_POS_X * range.index_step;
205 let item = match bound_checks {
206 BoundChecksInner::None => items.read(index),
207 BoundChecksInner::Mask => {
208 let mask = unit_coordinate < range.coordinate_end;
209 let index = index * u32::cast_from(mask);
210 select(mask, items.read(index), R::null_input(inst, line_size))
211 }
212 BoundChecksInner::Branch => {
213 if unit_coordinate < range.coordinate_end {
214 items.read(index)
215 } else {
216 R::null_input(inst, line_size)
217 }
218 }
219 };
220
221 reduce_inplace::<P, R>(inst, &mut accumulator, item, coordinates, true);
222
223 first_index += plane_dim * range.index_step;
224 }
225 accumulator
226}
227
228#[cube]
241pub fn reduce_slice_shared<P: ReducePrecision, I: List<Line<P::EI>>, R: ReduceInstruction<P>>(
242 items: &I,
243 inst: &R,
244 range: ReduceRange,
245 #[comptime] accumulator_size: u32,
246 #[comptime] line_size: u32,
247 #[comptime] line_mode: LineMode,
248 #[comptime] use_planes: bool,
249 #[comptime] bound_checks: BoundChecksInner,
250) -> R::SharedAccumulator {
251 let accumulator_index = if use_planes { UNIT_POS_Y } else { UNIT_POS };
253
254 let requirements = R::requirements(inst);
255 let mut accumulator =
256 R::SharedAccumulator::allocate(accumulator_size, line_size, requirements.coordinates);
257
258 R::SharedAccumulator::write(
259 &mut accumulator,
260 accumulator_index,
261 R::null_accumulator(inst, line_size),
262 );
263
264 let mut first_index = range.index_start;
265 for first_coordinate in range_stepped(
266 range.coordinate_start,
267 range.coordinate_end,
268 range.coordinate_step,
269 ) {
270 let unit_coordinate_offset = match line_mode {
271 LineMode::Parallel => UNIT_POS * line_size,
272 LineMode::Perpendicular => UNIT_POS,
273 };
274 let unit_coordinate = first_coordinate + unit_coordinate_offset;
275
276 let index = first_index + UNIT_POS * range.index_step;
277
278 let item = match bound_checks {
279 BoundChecksInner::None => items.read(index),
280 BoundChecksInner::Mask => {
281 let mask = unit_coordinate < range.coordinate_end;
282 let index = index * u32::cast_from(mask);
283 select(mask, items.read(index), R::null_input(inst, line_size))
284 }
285 BoundChecksInner::Branch => {
286 if unit_coordinate < range.coordinate_end {
287 items.read(index)
288 } else {
289 R::null_input(inst, line_size)
290 }
291 }
292 };
293
294 let coordinates = if comptime! {requirements.coordinates} {
295 let coordinate = fill_coordinate_line(unit_coordinate, line_size, line_mode);
296 let coordinate = select(
297 unit_coordinate < range.coordinate_end,
298 coordinate,
299 Line::empty(line_size).fill(u32::MAX),
300 );
301
302 ReduceCoordinate::new_Required(coordinate)
303 } else {
304 ReduceCoordinate::new_NotRequired()
305 };
306
307 reduce_shared_inplace::<P, R>(
308 inst,
309 &mut accumulator,
310 accumulator_index,
311 item,
312 coordinates,
313 use_planes,
314 );
315 first_index += range.index_step * CUBE_DIM;
316 }
317 accumulator
318}
319
320#[cube]
323fn fill_coordinate_line(
324 first: u32,
325 #[comptime] line_size: u32,
326 #[comptime] line_mode: LineMode,
327) -> Line<u32> {
328 match comptime!(line_mode) {
329 LineMode::Parallel => {
330 let mut coordinates = Line::empty(line_size);
331 #[unroll]
332 for j in 0..line_size {
333 coordinates[j] = first + j;
334 }
335 coordinates
336 }
337 LineMode::Perpendicular => Line::empty(line_size).fill(first),
338 }
339}
340
341#[cube]
366pub fn reduce_tree<P: ReducePrecision, Inst: ReduceInstruction<P>>(
367 inst: &Inst,
368 accumulator: &mut Inst::SharedAccumulator,
369 #[comptime] size: u32,
370) -> Inst::AccumulatorItem {
371 if comptime!(size.is_power_of_two()) {
372 let mut num_active_units = size.runtime();
373 let mut jump = 1;
374 while num_active_units > 1 {
375 num_active_units /= 2;
376 let destination = jump * 2 * UNIT_POS;
377 let origin = jump * (2 * UNIT_POS + 1);
378 if UNIT_POS < num_active_units {
379 fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
380 }
381 jump *= 2;
382 sync_cube();
383 }
384 } else {
385 let mut num_remaining_items = size.runtime();
386 let mut jump = 1;
387 while num_remaining_items > 1 {
388 let destination = jump * 2 * UNIT_POS;
389 let origin = jump * (2 * UNIT_POS + 1);
390 if UNIT_POS < num_remaining_items / 2 {
391 fuse_accumulator_inplace::<P, Inst>(inst, accumulator, destination, origin);
392 }
393 num_remaining_items = num_remaining_items.div_ceil(2);
394 jump *= 2;
395 sync_cube();
396 }
397 }
398 sync_cube();
399 Inst::SharedAccumulator::read(accumulator, 0)
400}
401
402#[cube]
425pub fn reduce_sum_shuffle<F: Float>(value: F) -> F {
426 let v1 = value + plane_shuffle_xor(value, 16);
428 let v2 = v1 + plane_shuffle_xor(v1, 8);
429 let v3 = v2 + plane_shuffle_xor(v2, 4);
430 let v4 = v3 + plane_shuffle_xor(v3, 2);
431 v4 + plane_shuffle_xor(v4, 1)
432}
433
434#[cube]
437pub fn reduce_max_shuffle<F: Float>(value: F) -> F {
438 let v1 = F::max(value, plane_shuffle_xor(value, 16));
439 let v2 = F::max(v1, plane_shuffle_xor(v1, 8));
440 let v3 = F::max(v2, plane_shuffle_xor(v2, 4));
441 let v4 = F::max(v3, plane_shuffle_xor(v3, 2));
442 F::max(v4, plane_shuffle_xor(v4, 1))
443}
444
445#[cube]
448pub fn reduce_min_shuffle<F: Float>(value: F) -> F {
449 let v1 = F::min(value, plane_shuffle_xor(value, 16));
450 let v2 = F::min(v1, plane_shuffle_xor(v1, 8));
451 let v3 = F::min(v2, plane_shuffle_xor(v2, 4));
452 let v4 = F::min(v3, plane_shuffle_xor(v3, 2));
453 F::min(v4, plane_shuffle_xor(v4, 1))
454}
455
456#[cube]
459pub fn reduce_prod_shuffle<F: Float>(value: F) -> F {
460 let v1 = value * plane_shuffle_xor(value, 16);
461 let v2 = v1 * plane_shuffle_xor(v1, 8);
462 let v3 = v2 * plane_shuffle_xor(v2, 4);
463 let v4 = v3 * plane_shuffle_xor(v3, 2);
464 v4 * plane_shuffle_xor(v4, 1)
465}