1use crate::{
2 LineMode, ReduceError, ReducePrecision,
3 components::{
4 args::{ReduceArgs, TensorArgs, init_tensors},
5 global::{
6 cube::GlobalFullCubeReduce, plane::GlobalFullPlaneReduce, unit::GlobalFullUnitReduce,
7 },
8 instructions::*,
9 },
10 launch::{ReduceStrategy, RoutineStrategy, generate_line_size},
11 routines::{
12 GlobalReduceBlueprint, ReduceBlueprint, ReduceLineSettings, ReduceProblem, Routine,
13 cube::CubeRoutine, plane::PlaneRoutine, unit::UnitRoutine,
14 },
15};
16use cubecl::{prelude::*, std::tensor::r#virtual::VirtualTensor};
17
18#[derive(Clone, Copy, Debug)]
19pub struct ReduceDtypes {
20 pub input: StorageType,
21 pub output: StorageType,
22 pub accumulation: StorageType,
23}
24
25#[allow(clippy::too_many_arguments)]
29pub(crate) fn launch_reduce<Run: Runtime>(
30 client: &ComputeClient<Run>,
31 input: TensorHandleRef<Run>,
32 output: TensorHandleRef<Run>,
33 axis: usize,
34 strategy: ReduceStrategy,
35 dtypes: ReduceDtypes,
36 inst: ReduceOperationConfig,
37) -> Result<(), ReduceError> {
38 let problem = ReduceProblem {
39 vector_size: input.shape[axis],
40 vector_count: output.shape.iter().copied().product(),
41 axis,
42 dtypes,
43 };
44 let line_mode = match input.strides[axis] {
45 1 => LineMode::Parallel,
46 _ => LineMode::Perpendicular,
47 };
48 let (line_size_input, line_size_output) = generate_line_size::<Run>(
49 client,
50 &input,
51 &output,
52 axis,
53 problem.dtypes.input,
54 line_mode,
55 &strategy.line_size,
56 );
57 let settings = ReduceLineSettings {
58 line_mode,
59 line_size_input,
60 line_size_output,
61 };
62
63 let (blueprint, settings) = match strategy.routine {
64 RoutineStrategy::Unit(strategy) => {
65 let routine = UnitRoutine;
66 routine.prepare(client, problem, settings, strategy)?
67 }
68 RoutineStrategy::Plane(strategy) => {
69 let routine = PlaneRoutine;
70 routine.prepare(client, problem, settings, strategy)?
71 }
72 RoutineStrategy::Cube(strategy) => {
73 let routine = CubeRoutine;
74 routine.prepare(client, problem, settings, strategy)?
75 }
76 };
77
78 unsafe {
79 reduce_kernel::launch_unchecked::<TensorArgs, Run>(
80 client,
81 settings.cube_count,
82 settings.cube_dim,
83 input.as_tensor_arg(settings.line.line_size_input),
84 output.as_tensor_arg(settings.line.line_size_output),
85 ScalarArg::new(axis),
86 blueprint,
87 inst,
88 dtypes.input,
89 dtypes.output,
90 dtypes.accumulation,
91 )
92 .map_err(ReduceError::Launch)
93 }
94}
95
96#[cube(launch_unchecked)]
97pub fn reduce_kernel<In: Numeric, Out: Numeric, Acc: Numeric, RA: ReduceArgs>(
98 input: &RA::Input<In>,
99 output: &mut RA::Output<Out>,
100 axis_reduce: usize,
101 #[comptime] blueprint: ReduceBlueprint,
102 #[comptime] config: ReduceOperationConfig,
103 #[define(In)] _input_dtype: StorageType,
104 #[define(Out)] _output_dtype: StorageType,
105 #[define(Acc)] _acc_dtype: StorageType,
106) {
107 let (input, mut output) = init_tensors::<RA, In, Out>(input, output);
108 reduce_kernel_virtual::<In, Out, Acc>(&input, &mut output, axis_reduce, blueprint, config);
109}
110
111#[cube]
112pub fn reduce_kernel_virtual<In: Numeric, Out: Numeric, Acc: Numeric>(
113 input: &VirtualTensor<In>,
114 output: &mut VirtualTensor<Out, ReadWrite>,
115 axis_reduce: usize,
116 #[comptime] blueprint: ReduceBlueprint,
117 #[comptime] config: ReduceOperationConfig,
118) {
119 reduce_kernel_inner::<(In, Acc), Out, ReduceOperation>(
120 input,
121 output,
122 axis_reduce,
123 blueprint,
124 config,
125 )
126}
127
128#[cube]
129fn reduce_kernel_inner<P: ReducePrecision, Out: Numeric, R: ReduceFamily>(
130 input: &VirtualTensor<P::EI>,
131 output: &mut VirtualTensor<Out, ReadWrite>,
132 axis_reduce: usize,
133 #[comptime] blueprint: ReduceBlueprint,
134 #[comptime] config: R::Config,
135) {
136 let inst = &R::Instruction::<P>::from_config(config);
137
138 match blueprint.global {
139 GlobalReduceBlueprint::Cube(cube) => {
140 GlobalFullCubeReduce::execute::<P, Out, R::Instruction<P>>(
141 input,
142 output,
143 axis_reduce,
144 inst,
145 blueprint.line_mode,
146 cube,
147 )
148 }
149 GlobalReduceBlueprint::Plane(plane) => {
150 GlobalFullPlaneReduce::execute::<P, Out, R::Instruction<P>>(
151 input,
152 output,
153 axis_reduce,
154 inst,
155 blueprint.line_mode,
156 plane,
157 )
158 }
159 GlobalReduceBlueprint::Unit(unit) => {
160 GlobalFullUnitReduce::execute::<P, Out, R::Instruction<P>>(
161 input,
162 output,
163 axis_reduce,
164 inst,
165 blueprint.line_mode,
166 unit,
167 )
168 }
169 };
170}