1use crate::{
2 ReduceError, ReducePrecision, VectorizationMode,
3 components::{
4 args::{NumericVector, ReduceArgs, TensorArgs, init_tensors},
5 global::{
6 cube::GlobalFullCubeReduce, plane::GlobalFullPlaneReduce, unit::GlobalFullUnitReduce,
7 },
8 instructions::*,
9 },
10 launch::{ReduceStrategy, RoutineStrategy, generate_vector_size},
11 output_vectorization_axis,
12 routines::{
13 GlobalReduceBlueprint, ReduceBlueprint, ReduceProblem, ReduceVectorSettings, Routine,
14 cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
15 },
16};
17use cubecl::{prelude::*, std::tensor::r#virtual::VirtualTensor};
18
19#[derive(Clone, Copy, Debug)]
20pub struct ReduceDtypes {
21 pub input: StorageType,
22 pub output: StorageType,
23 pub accumulation: StorageType,
24}
25
26#[allow(clippy::too_many_arguments)]
30pub(crate) fn launch_reduce<Run: Runtime>(
31 client: &ComputeClient<Run>,
32 input: TensorBinding<Run>,
33 output: TensorBinding<Run>,
34 reduce_axis: usize,
35 strategy: ReduceStrategy,
36 dtypes: ReduceDtypes,
37 inst: ReduceOperationConfig,
38) -> Result<(), ReduceError> {
39 let address_type = input
40 .required_address_type(dtypes.input.size())
41 .max(output.required_address_type(dtypes.output.size()));
42
43 let reduce_len = input.shape[reduce_axis];
45 let input_elems: usize = input.shape.iter().copied().product();
46 let reduce_count = input_elems / reduce_len;
47
48 let problem = ReduceProblem {
49 reduce_len,
50 reduce_count,
51 axis: reduce_axis,
52 dtypes,
53 address_type,
54 };
55 let vectorization_mode = match input.strides[reduce_axis] {
56 1 => VectorizationMode::Parallel,
57 _ => VectorizationMode::Perpendicular,
58 };
59
60 let out_vec_axis = output_vectorization_axis(&input.strides, reduce_axis, vectorization_mode);
61
62 let (vector_size_input, vector_size_output) = generate_vector_size::<Run>(
63 client,
64 &input,
65 &output,
66 reduce_axis,
67 problem.dtypes.input,
68 vectorization_mode,
69 &strategy.vectorization,
70 );
71 let settings = ReduceVectorSettings {
72 vectorization_mode,
73 vector_size_input,
74 vector_size_output,
75 };
76
77 let (blueprint, settings) = match strategy.routine {
78 RoutineStrategy::Unit(strategy) => {
79 let routine = UnitRoutine;
80 routine.prepare(client, problem, settings, strategy)?
81 }
82 RoutineStrategy::Plane(strategy) => {
83 let routine = PlaneRoutine;
84 routine.prepare(client, problem, settings, strategy)?
85 }
86 RoutineStrategy::Cube(strategy) => {
87 let routine = CubeRoutine;
88 routine.prepare(client, problem, settings, strategy)?
89 }
90 };
91
92 unsafe {
93 reduce_kernel::launch_unchecked::<TensorArgs, Run>(
94 client,
95 settings.cube_count,
96 settings.cube_dim,
97 settings.address_type,
98 settings.vector.vector_size_input,
99 settings.vector.vector_size_output,
100 input.into_tensor_arg(),
101 output.into_tensor_arg(),
102 reduce_axis,
103 out_vec_axis,
104 blueprint,
105 inst,
106 dtypes.input,
107 dtypes.output,
108 dtypes.accumulation,
109 )
110 };
111
112 Ok(())
113}
114
115#[cube(launch_unchecked, address_type = "dynamic")]
116pub fn reduce_kernel<
117 In: Numeric,
118 InSize: Size,
119 Out: Numeric,
120 OutSize: Size,
121 Acc: Numeric,
122 RA: ReduceArgs,
123>(
124 input: &RA::Input<In, InSize>,
125 output: &mut RA::Output<Out, OutSize>,
126 reduce_axis: usize,
127 out_vec_axis: usize,
128 #[comptime] blueprint: ReduceBlueprint,
129 #[comptime] config: ReduceOperationConfig,
130 #[define(In)] _input_dtype: StorageType,
131 #[define(Out)] _output_dtype: StorageType,
132 #[define(Acc)] _acc_dtype: StorageType,
133) {
134 let (input, mut output) = init_tensors::<RA, In, InSize, Out, OutSize>(input, output);
135 reduce_kernel_virtual::<In, InSize, Out, OutSize, Acc>(
136 &input,
137 &mut output,
138 reduce_axis,
139 out_vec_axis,
140 blueprint,
141 config,
142 );
143}
144
145#[cube]
146pub fn reduce_kernel_virtual<
147 In: Numeric,
148 InSize: Size,
149 Out: Numeric,
150 OutSize: Size,
151 Acc: Numeric,
152>(
153 input: &VirtualTensor<In, InSize>,
154 output: &mut VirtualTensor<Out, OutSize, ReadWrite>,
155 reduce_axis: usize,
156 out_vec_axis: usize,
157 #[comptime] blueprint: ReduceBlueprint,
158 #[comptime] config: ReduceOperationConfig,
159) {
160 reduce_kernel_inner::<(In, InSize, Acc), (Out, OutSize), ReduceOperation>(
161 input,
162 output,
163 reduce_axis,
164 out_vec_axis,
165 blueprint,
166 config,
167 )
168}
169
170#[cube]
171fn reduce_kernel_inner<P: ReducePrecision, Out: NumericVector, R: ReduceFamily>(
172 input: &VirtualTensor<P::EI, P::SI>,
173 output: &mut VirtualTensor<Out::T, Out::N, ReadWrite>,
174 reduce_axis: usize,
175 out_vec_axis: usize,
176 #[comptime] blueprint: ReduceBlueprint,
177 #[comptime] config: R::Config,
178) {
179 let inst = &R::Instruction::<P>::from_config(config);
180
181 match blueprint.global {
182 GlobalReduceBlueprint::Cube(cube) => {
183 GlobalFullCubeReduce::execute::<P, Out, R::Instruction<P>>(
184 input,
185 output,
186 reduce_axis,
187 out_vec_axis,
188 inst,
189 blueprint.vectorization_mode,
190 cube,
191 )
192 }
193 GlobalReduceBlueprint::Plane(plane) => {
194 GlobalFullPlaneReduce::execute::<P, Out, R::Instruction<P>>(
195 input,
196 output,
197 reduce_axis,
198 out_vec_axis,
199 inst,
200 blueprint.vectorization_mode,
201 plane,
202 )
203 }
204 GlobalReduceBlueprint::Unit(unit) => {
205 GlobalFullUnitReduce::execute::<P, Out, R::Instruction<P>>(
206 input,
207 output,
208 reduce_axis,
209 out_vec_axis,
210 inst,
211 blueprint.vectorization_mode,
212 unit,
213 )
214 }
215 };
216}