cubecl_reduce/instructions/
prod.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::{instructions::ReduceRequirements, precision::ReducePrecision};
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction};
7
8#[derive(Debug, CubeType, Clone)]
9pub struct Prod {}
10
11impl ReduceFamily for Prod {
12 type Instruction<P: ReducePrecision> = Self;
13 type Config = ();
14}
15
16#[cube]
17impl<P: ReducePrecision> ReduceInstruction<P> for Prod {
18 type AccumulatorItem = Line<P::EA>;
19 type SharedAccumulator = SharedMemory<Line<P::EA>>;
20 type Config = ();
21
22 fn requirements(_this: &Self) -> ReduceRequirements {
23 ReduceRequirements { coordinates: false }
24 }
25
26 fn from_config(_config: Self::Config) -> Self {
27 Prod {}
28 }
29 fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
30 Line::empty(line_size).fill(P::EI::from_int(1))
31 }
32
33 fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
34 Line::empty(line_size).fill(P::EA::from_int(1))
35 }
36
37 fn assign_accumulator(
38 _this: &Self,
39 destination: &mut Self::AccumulatorItem,
40 source: &Self::AccumulatorItem,
41 ) {
42 *destination = *source;
43 }
44
45 fn reduce(
46 _this: &Self,
47 accumulator: &Self::AccumulatorItem,
48 item: Line<P::EI>,
49 _coordinate: ReduceCoordinate,
50 #[comptime] use_planes: bool,
51 ) -> Self::AccumulatorItem {
52 let item = Line::cast_from(item);
53 if use_planes {
54 *accumulator * plane_prod(item)
55 } else {
56 *accumulator * item
57 }
58 }
59
60 fn fuse_accumulators(
61 _this: &Self,
62 lhs: Self::AccumulatorItem,
63 rhs: Self::AccumulatorItem,
64 ) -> Self::AccumulatorItem {
65 lhs * rhs
66 }
67
68 fn merge_line<Out: Numeric>(
69 _this: &Self,
70 accumulator: Self::AccumulatorItem,
71 _shape_axis_reduce: u32,
72 ) -> Out {
73 let mut prod = P::EA::from_int(1);
74 #[unroll]
75 for k in 0..accumulator.size() {
76 prod *= accumulator[k];
77 }
78 Out::cast_from(prod)
79 }
80
81 fn to_output_perpendicular<Out: Numeric>(
82 _this: &Self,
83 accumulator: Self::AccumulatorItem,
84 _shape_axis_reduce: u32,
85 ) -> Line<Out> {
86 Line::cast_from(accumulator)
87 }
88}