1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3use cubecl_std::tensor::r#virtual::ReadWrite;
4use cubecl_std::tensor::r#virtual::VirtualTensor;
5
6use crate::BoundChecksInner;
7use crate::args::ReduceArgs;
8use crate::args::TensorArgs;
9use crate::args::init_tensors;
10use crate::instructions::*;
11use crate::precision::ReducePrecision;
12use crate::primitives::*;
13use crate::{LineMode, ReduceConfig, ReduceStrategy};
14
15pub(crate) fn launch_reduce<Run: Runtime, P: ReducePrecision, Out: Numeric, Rd: ReduceFamily>(
19 client: &ComputeClient<Run::Server, Run::Channel>,
20 input: TensorHandleRef<Run>,
21 output: TensorHandleRef<Run>,
22 axis: u32,
23 config: ReduceConfig,
24 strategy: ReduceStrategy,
25 inst: Rd::Config,
26) {
27 let settings = ReduceParams {
28 shared: strategy.shared.then(|| {
29 if strategy.use_planes {
30 config.cube_dim.y
31 } else {
32 config.cube_dim.num_elems()
33 }
34 }),
35 use_planes: strategy.use_planes,
36 line_size_input: config.line_size_input,
37 line_size_output: config.line_size_output,
38 line_mode: config.line_mode,
39 bound_checks: config.bound_checks,
40 bound_checks_inner: config.bound_checks_inner,
41 };
42 unsafe {
43 reduce_kernel::launch_unchecked::<P::EI, Out, P::EA, Rd, TensorArgs, Run>(
44 client,
45 config.cube_count,
46 config.cube_dim,
47 input.as_tensor_arg(config.line_size_input as u8),
48 output.as_tensor_arg(config.line_size_output as u8),
49 ScalarArg::new(axis),
50 settings,
51 inst,
52 );
53 }
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub struct ReduceParams {
58 pub shared: Option<u32>, pub use_planes: bool,
60 pub line_size_input: u32,
61 pub line_size_output: u32,
62 pub line_mode: LineMode,
63 pub bound_checks: bool,
64 pub bound_checks_inner: BoundChecksInner,
65}
66
67#[cube(launch_unchecked)]
68pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily, RA: ReduceArgs>(
69 input: &RA::Input<In>,
70 output: &mut RA::Output<Out>,
71 axis_reduce: u32,
72 #[comptime] params: ReduceParams,
73 #[comptime] config: R::Config,
74) {
75 let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
76 reduce_kernel_virtual::<In, Out, Acc, R>(&input, &mut output, axis_reduce, params, config);
77}
78
79#[cube]
80pub fn reduce_kernel_virtual<In: Numeric, Out: Numeric, Acc: Numeric, R: ReduceFamily>(
81 input: &VirtualTensor<In>,
82 output: &mut VirtualTensor<Out, ReadWrite>,
83 axis_reduce: u32,
84 #[comptime] params: ReduceParams,
85 #[comptime] config: R::Config,
86) {
87 let reduce_index = get_reduce_index(params);
88
89 if comptime![params.bound_checks]
90 && reduce_index >= get_reduce_count(output.len() * params.line_size_output, params)
91 {
92 terminate!();
93 }
94
95 reduce_kernel_inner::<(In, Acc), Out, R>(
96 input,
97 output,
98 axis_reduce,
99 reduce_index,
100 params,
101 config,
102 )
103}
104
105#[cube]
106fn reduce_kernel_inner<P: ReducePrecision, Out: Numeric, R: ReduceFamily>(
107 input: &VirtualTensor<P::EI>,
108 output: &mut VirtualTensor<Out, ReadWrite>,
109 axis_reduce: u32,
110 reduce_index: u32,
111 #[comptime] params: ReduceParams,
112 #[comptime] config: R::Config,
113) {
114 let range = ReduceRange::new::<P, Out>(reduce_index, input, output, axis_reduce, params);
115
116 let inst = &R::Instruction::<P>::from_config(config);
117 let accumulator = match comptime!((params.shared, params.use_planes)) {
118 (Some(accumulator_size), use_planes) => {
119 let mut accumulator = reduce_slice_shared::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
120 input,
121 inst,
122 range,
123 accumulator_size,
124 params.line_size_input,
125 params.line_mode,
126 use_planes,
127 params.bound_checks_inner,
128 );
129 sync_cube();
130 reduce_tree::<P, R::Instruction<P>>(inst, &mut accumulator, accumulator_size)
131 }
132 (None, true) => reduce_slice_plane::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
133 input,
134 inst,
135 range,
136 params.line_size_input,
137 params.line_mode,
138 params.bound_checks_inner,
139 ),
140 (None, false) => reduce_slice::<P, VirtualTensor<P::EI>, R::Instruction<P>>(
141 input,
142 range,
143 inst,
144 params.line_size_input,
145 params.line_mode,
146 ),
147 };
148
149 if elected_writer(params) {
150 write_to_output::<P, Out, R::Instruction<P>>(
151 output,
152 accumulator,
153 reduce_index,
154 input.shape(axis_reduce),
155 params,
156 inst,
157 );
158 }
159}
160
161#[cube]
162fn get_reduce_index(#[comptime] params: ReduceParams) -> u32 {
163 if params.shared.is_some() {
164 CUBE_POS
165 } else if params.use_planes {
166 CUBE_POS * CUBE_DIM_Y + UNIT_POS_Y
167 } else {
168 ABSOLUTE_POS
169 }
170}
171
172#[cube]
173fn get_reduce_count(output_size: u32, #[comptime] params: ReduceParams) -> u32 {
174 match comptime!(params.line_mode) {
175 LineMode::Parallel => output_size,
176 LineMode::Perpendicular => output_size / params.line_size_input,
177 }
178}
179
180#[cube]
181fn elected_writer(#[comptime] settings: ReduceParams) -> bool {
182 if settings.shared.is_some() {
183 UNIT_POS == 0
184 } else if settings.use_planes {
185 UNIT_POS_X == 0
186 } else {
187 true.runtime()
188 }
189}
190
191#[cube]
192fn write_to_output<P: ReducePrecision, Out: Numeric, R: ReduceInstruction<P>>(
193 output: &mut VirtualTensor<Out, ReadWrite>,
194 accumulator: R::AccumulatorItem,
195 reduce_index: u32,
196 shape_axis_reduce: u32,
197 #[comptime] settings: ReduceParams,
198 inst: &R,
199) {
200 match comptime!(settings.line_mode) {
201 LineMode::Parallel => {
202 let result = R::merge_line::<Out>(inst, accumulator, shape_axis_reduce);
203 output.write(reduce_index, Line::cast_from(result))
204 }
205 LineMode::Perpendicular => {
206 let out = R::to_output_perpendicular(inst, accumulator, shape_axis_reduce);
207
208 if comptime![settings.line_size_output == settings.line_size_input] {
209 output.write(reduce_index, out);
210 } else {
211 let num_iters = comptime![settings.line_size_input / settings.line_size_output];
212
213 #[unroll]
214 for i in 0..num_iters {
215 let mut tmp = Line::empty(settings.line_size_output);
216
217 #[unroll]
218 for j in 0..settings.line_size_output {
219 tmp[j] = out[i * settings.line_size_output + j];
220 }
221
222 let index = num_iters * reduce_index + i;
223 output.write(index, tmp);
224 }
225 }
226 }
227 }
228}