1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::ReadWrite;
4use cubecl_std::tensor::r#virtual::VirtualTensor;
5
6use crate::BoundChecksInner;
7use crate::args::ReduceArgs;
8use crate::args::TensorArgs;
9use crate::args::init_tensors;
10use crate::instructions::*;
11use crate::primitives::*;
12use crate::{LineMode, ReduceConfig, ReduceStrategy};
13
14pub(crate) fn launch_reduce<Run: Runtime, In: Numeric, Out: Numeric, Rd: ReduceFamily>(
18 client: &ComputeClient<Run::Server, Run::Channel>,
19 input: TensorHandleRef<Run>,
20 output: TensorHandleRef<Run>,
21 axis: u32,
22 config: ReduceConfig,
23 strategy: ReduceStrategy,
24 inst: Rd::Config,
25) {
26 let settings = ReduceParams {
27 shared: strategy.shared.then(|| {
28 if strategy.use_planes {
29 config.cube_dim.y
30 } else {
31 config.cube_dim.num_elems()
32 }
33 }),
34 use_planes: strategy.use_planes,
35 line_size_input: config.line_size_input,
36 line_size_output: config.line_size_output,
37 line_mode: config.line_mode,
38 bound_checks: config.bound_checks,
39 bound_checks_inner: config.bound_checks_inner,
40 };
41 unsafe {
42 reduce_kernel::launch_unchecked::<In, Out, Rd, TensorArgs, Run>(
43 client,
44 config.cube_count,
45 config.cube_dim,
46 input.as_tensor_arg(config.line_size_input as u8),
47 output.as_tensor_arg(config.line_size_output as u8),
48 ScalarArg::new(axis),
49 settings,
50 inst,
51 );
52 }
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
56pub struct ReduceParams {
57 pub shared: Option<u32>, pub use_planes: bool,
59 pub line_size_input: u32,
60 pub line_size_output: u32,
61 pub line_mode: LineMode,
62 pub bound_checks: bool,
63 pub bound_checks_inner: BoundChecksInner,
64}
65
66#[cube(launch_unchecked)]
67pub fn reduce_kernel<In: Numeric, Out: Numeric, R: ReduceFamily, RA: ReduceArgs>(
68 input: &RA::Input<In>,
69 output: &mut RA::Output<Out>,
70 axis_reduce: u32,
71 #[comptime] params: ReduceParams,
72 #[comptime] config: R::Config,
73) {
74 let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
75 let reduce_index = get_reduce_index(params);
76
77 if comptime![params.bound_checks]
78 && reduce_index >= get_reduce_count(output.len() * params.line_size_output, params)
79 {
80 terminate!();
81 }
82
83 let range = ReduceRange::new::<In, Out>(reduce_index, &input, &mut output, axis_reduce, params);
84
85 let inst = &R::Instruction::<In>::from_config(config);
86 let accumulator = match comptime!((params.shared, params.use_planes)) {
87 (Some(accumulator_size), use_planes) => {
88 let mut accumulator = reduce_slice_shared::<In, VirtualTensor<In>, R::Instruction<In>>(
89 &input,
90 inst,
91 range,
92 accumulator_size,
93 params.line_size_input,
94 params.line_mode,
95 use_planes,
96 params.bound_checks_inner,
97 );
98 sync_units();
99 reduce_tree::<In, R::Instruction<In>>(inst, &mut accumulator, accumulator_size)
100 }
101 (None, true) => reduce_slice_plane::<In, VirtualTensor<In>, R::Instruction<In>>(
102 &input,
103 inst,
104 range,
105 params.line_size_input,
106 params.line_mode,
107 params.bound_checks_inner,
108 ),
109 (None, false) => reduce_slice::<In, VirtualTensor<In>, R::Instruction<In>>(
110 &input,
111 range,
112 inst,
113 params.line_size_input,
114 params.line_mode,
115 ),
116 };
117
118 if elected_writer(params) {
119 write_to_output::<In, Out, R::Instruction<In>>(
120 &mut output,
121 accumulator,
122 reduce_index,
123 input.shape(axis_reduce),
124 params,
125 inst,
126 );
127 }
128}
129
130#[cube]
131fn get_reduce_index(#[comptime] params: ReduceParams) -> u32 {
132 if params.shared.is_some() {
133 CUBE_POS
134 } else if params.use_planes {
135 CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y
136 } else {
137 ABSOLUTE_POS
138 }
139}
140
141#[cube]
142fn get_reduce_count(output_size: u32, #[comptime] params: ReduceParams) -> u32 {
143 match comptime!(params.line_mode) {
144 LineMode::Parallel => output_size,
145 LineMode::Perpendicular => output_size / params.line_size_input,
146 }
147}
148
149#[cube]
150fn elected_writer(#[comptime] settings: ReduceParams) -> bool {
151 if settings.shared.is_some() {
152 UNIT_POS == 0
153 } else if settings.use_planes {
154 UNIT_POS_X == 0
155 } else {
156 true.runtime()
157 }
158}
159
160#[cube]
161fn write_to_output<In: Numeric, Out: Numeric, R: ReduceInstruction<In>>(
162 output: &mut VirtualTensor<Out, ReadWrite>,
163 accumulator: R::AccumulatorItem,
164 reduce_index: u32,
165 shape_axis_reduce: u32,
166 #[comptime] settings: ReduceParams,
167 inst: &R,
168) {
169 match comptime!(settings.line_mode) {
170 LineMode::Parallel => {
171 let result = R::merge_line::<Out>(inst, accumulator, shape_axis_reduce);
172 output.write(reduce_index, Line::cast_from(result))
173 }
174 LineMode::Perpendicular => {
175 let out = R::to_output_perpendicular(inst, accumulator, shape_axis_reduce);
176
177 if comptime![settings.line_size_output == settings.line_size_input] {
178 output.write(reduce_index, out);
179 } else {
180 let num_iters = comptime![settings.line_size_input / settings.line_size_output];
181
182 #[unroll]
183 for i in 0..num_iters {
184 let mut tmp = Line::empty(settings.line_size_output);
185
186 #[unroll]
187 for j in 0..settings.line_size_output {
188 tmp[j] = out[i * settings.line_size_output + j];
189 }
190
191 let index = num_iters * reduce_index + i;
192 output.write(index, tmp);
193 }
194 }
195 }
196 }
197}