cubecl_reduce/instructions/
prod.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::{instructions::ReduceRequirements, precision::ReducePrecision};
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction};
7
8#[derive(Debug, CubeType, Clone)]
9pub struct Prod {}
10
11impl ReduceFamily for Prod {
12    type Instruction<P: ReducePrecision> = Self;
13    type Config = ();
14}
15
16#[cube]
17impl<P: ReducePrecision> ReduceInstruction<P> for Prod {
18    type AccumulatorItem = Line<P::EA>;
19    type SharedAccumulator = SharedMemory<Line<P::EA>>;
20    type Config = ();
21
22    fn requirements(_this: &Self) -> ReduceRequirements {
23        ReduceRequirements { coordinates: false }
24    }
25
26    fn from_config(_config: Self::Config) -> Self {
27        Prod {}
28    }
29    fn null_input(_this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
30        Line::empty(line_size).fill(P::EI::from_int(1))
31    }
32
33    fn null_accumulator(_this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
34        Line::empty(line_size).fill(P::EA::from_int(1))
35    }
36
37    fn assign_accumulator(
38        _this: &Self,
39        destination: &mut Self::AccumulatorItem,
40        source: &Self::AccumulatorItem,
41    ) {
42        *destination = *source;
43    }
44
45    fn reduce(
46        _this: &Self,
47        accumulator: &Self::AccumulatorItem,
48        item: Line<P::EI>,
49        _coordinate: ReduceCoordinate,
50        #[comptime] use_planes: bool,
51    ) -> Self::AccumulatorItem {
52        let item = Line::cast_from(item);
53        if use_planes {
54            *accumulator * plane_prod(item)
55        } else {
56            *accumulator * item
57        }
58    }
59
60    fn fuse_accumulators(
61        _this: &Self,
62        lhs: Self::AccumulatorItem,
63        rhs: Self::AccumulatorItem,
64    ) -> Self::AccumulatorItem {
65        lhs * rhs
66    }
67
68    fn merge_line<Out: Numeric>(
69        _this: &Self,
70        accumulator: Self::AccumulatorItem,
71        _shape_axis_reduce: u32,
72    ) -> Out {
73        let mut prod = P::EA::from_int(1);
74        #[unroll]
75        for k in 0..accumulator.size() {
76            prod *= accumulator[k];
77        }
78        Out::cast_from(prod)
79    }
80
81    fn to_output_perpendicular<Out: Numeric>(
82        _this: &Self,
83        accumulator: Self::AccumulatorItem,
84        _shape_axis_reduce: u32,
85    ) -> Line<Out> {
86        Line::cast_from(accumulator)
87    }
88}