cubecl_reduce/instructions/
mean.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction, ReduceRequirements, Sum};
5
6#[derive(Debug, CubeType, Clone)]
7pub struct Mean {
8    pub(crate) sum: Sum,
9}
10
11impl ReduceFamily for Mean {
12    type Instruction<In: Numeric> = Self;
13    type Config = ();
14}
15
16#[cube]
17impl<In: Numeric> ReduceInstruction<In> for Mean {
18    type AccumulatorItem = Line<In>;
19    type SharedAccumulator = SharedMemory<Line<In>>;
20    type Config = ();
21
22    fn requirements(_this: &Self) -> ReduceRequirements {
23        ReduceRequirements { coordinates: false }
24    }
25    fn from_config(_config: Self::Config) -> Self {
26        Mean { sum: Sum {} }
27    }
28
29    fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<In> {
30        Sum::null_input(&this.sum, line_size)
31    }
32
33    fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
34        Sum::null_accumulator(&this.sum, line_size)
35    }
36
37    fn assign_accumulator(
38        this: &Self,
39        destination: &mut Self::AccumulatorItem,
40        source: &Self::AccumulatorItem,
41    ) {
42        Sum::assign_accumulator(&this.sum, destination, source);
43    }
44
45    fn reduce(
46        this: &Self,
47        accumulator: &Self::AccumulatorItem,
48        item: Line<In>,
49        _coordinate: ReduceCoordinate,
50        #[comptime] use_planes: bool,
51    ) -> Self::AccumulatorItem {
52        Sum::reduce(&this.sum, accumulator, item, _coordinate, use_planes)
53    }
54
55    fn fuse_accumulators(
56        this: &Self,
57        lhs: Self::AccumulatorItem,
58        rhs: Self::AccumulatorItem,
59    ) -> Self::AccumulatorItem {
60        Sum::fuse_accumulators(&this.sum, lhs, rhs)
61    }
62
63    // TODO Remove shape_axis_reduce when fusion-on-write is well supported for reduce instructions.
64    //      Then, an instruction like Mean can be implemented by fusing a Sum reduction and a element-wise division.
65    fn merge_line<Out: Numeric>(
66        this: &Self,
67        accumulator: Self::AccumulatorItem,
68        shape_axis_reduce: u32,
69    ) -> Out {
70        Sum::merge_line::<Out>(&this.sum, accumulator, shape_axis_reduce)
71            / Out::cast_from(shape_axis_reduce)
72    }
73
74    fn to_output_perpendicular<Out: Numeric>(
75        this: &Self,
76        accumulator: Self::AccumulatorItem,
77        shape_axis_reduce: u32,
78    ) -> Line<Out> {
79        let sum = Sum::to_output_perpendicular::<Out>(&this.sum, accumulator, shape_axis_reduce);
80        sum / Line::empty(accumulator.size()).fill(Out::cast_from(shape_axis_reduce))
81    }
82}