cubecl_reduce/instructions/
mean.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::precision::ReducePrecision;
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction, ReduceRequirements, Sum};
7
8#[derive(Debug, CubeType, Clone)]
9pub struct Mean {
10    pub(crate) sum: Sum,
11}
12
13impl ReduceFamily for Mean {
14    type Instruction<P: ReducePrecision> = Self;
15    type Config = ();
16}
17
18#[cube]
19fn null_input<P: ReducePrecision, SI: ReduceInstruction<P>>(
20    sum: &SI,
21    #[comptime] line_size: u32,
22) -> Line<P::EI> {
23    SI::null_input(sum, line_size)
24}
25
26#[cube]
27impl<P: ReducePrecision> ReduceInstruction<P> for Mean {
28    type AccumulatorItem = Line<P::EA>;
29    type SharedAccumulator = SharedMemory<Line<P::EA>>;
30    type Config = ();
31
32    fn requirements(_this: &Self) -> ReduceRequirements {
33        ReduceRequirements { coordinates: false }
34    }
35    fn from_config(_config: Self::Config) -> Self {
36        Mean { sum: Sum {} }
37    }
38
39    fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
40        <Sum as ReduceInstruction<P>>::null_input(&this.sum, line_size)
41    }
42
43    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
44        <Sum as ReduceInstruction<P>>::null_accumulator(&this.sum, line_size)
45    }
46
47    fn assign_accumulator(
48        this: &Self,
49        destination: &mut Self::AccumulatorItem,
50        source: &Self::AccumulatorItem,
51    ) {
52        <Sum as ReduceInstruction<P>>::assign_accumulator(&this.sum, destination, source);
53    }
54
55    fn reduce(
56        this: &Self,
57        accumulator: &Self::AccumulatorItem,
58        item: Line<P::EI>,
59        _coordinate: ReduceCoordinate,
60        #[comptime] use_planes: bool,
61    ) -> Self::AccumulatorItem {
62        <Sum as ReduceInstruction<P>>::reduce(&this.sum, accumulator, item, _coordinate, use_planes)
63    }
64
65    fn fuse_accumulators(
66        this: &Self,
67        lhs: Self::AccumulatorItem,
68        rhs: Self::AccumulatorItem,
69    ) -> Self::AccumulatorItem {
70        <Sum as ReduceInstruction<P>>::fuse_accumulators(&this.sum, lhs, rhs)
71    }
72
73    // TODO Remove shape_axis_reduce when fusion-on-write is well supported for reduce instructions.
74    //      Then, an instruction like Mean can be implemented by fusing a <Sum as ReduceInstruction<P>> reduction and a element-wise division.
75    fn merge_line<Out: Numeric>(
76        this: &Self,
77        accumulator: Self::AccumulatorItem,
78        shape_axis_reduce: u32,
79    ) -> Out {
80        <Sum as ReduceInstruction<P>>::merge_line::<Out>(&this.sum, accumulator, shape_axis_reduce)
81            / Out::cast_from(shape_axis_reduce)
82    }
83
84    fn to_output_perpendicular<Out: Numeric>(
85        this: &Self,
86        accumulator: Self::AccumulatorItem,
87        shape_axis_reduce: u32,
88    ) -> Line<Out> {
89        let sum = <Sum as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
90            &this.sum,
91            accumulator,
92            shape_axis_reduce,
93        );
94        sum / Line::empty(accumulator.size()).fill(Out::cast_from(shape_axis_reduce))
95    }
96}