cubecl_reduce/instructions/
mean.rs

1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{Reduce, ReduceInstruction, Sum};
5
6#[derive(Debug)]
7pub struct Mean;
8
9impl Reduce for Mean {
10    type Instruction<In: Numeric> = Self;
11}
12
13#[cube]
14impl<In: Numeric> ReduceInstruction<In> for Mean {
15    type AccumulatorItem = Line<In>;
16    type SharedAccumulator = SharedMemory<Line<In>>;
17
18    fn null_input(#[comptime] line_size: u32) -> Line<In> {
19        Sum::null_input(line_size)
20    }
21
22    fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem {
23        Sum::null_accumulator(line_size)
24    }
25
26    fn assign_accumulator(destination: &mut Self::AccumulatorItem, source: &Self::AccumulatorItem) {
27        Sum::assign_accumulator(destination, source);
28    }
29
30    fn reduce(
31        accumulator: &Self::AccumulatorItem,
32        item: Line<In>,
33        _coordinate: Line<u32>,
34        #[comptime] use_planes: bool,
35    ) -> Self::AccumulatorItem {
36        Sum::reduce(accumulator, item, _coordinate, use_planes)
37    }
38
39    fn fuse_accumulators(
40        lhs: Self::AccumulatorItem,
41        rhs: Self::AccumulatorItem,
42    ) -> Self::AccumulatorItem {
43        Sum::fuse_accumulators(lhs, rhs)
44    }
45
46    fn merge_line<Out: Numeric>(accumulator: Self::AccumulatorItem, shape_axis_reduce: u32) -> Out {
47        Sum::merge_line::<Out>(accumulator, shape_axis_reduce) / Out::cast_from(shape_axis_reduce)
48    }
49
50    fn to_output_perpendicular<Out: Numeric>(
51        accumulator: Self::AccumulatorItem,
52        shape_axis_reduce: u32,
53    ) -> Line<Out> {
54        let sum = Sum::to_output_perpendicular::<Out>(accumulator, shape_axis_reduce);
55        sum / Line::empty(accumulator.size()).fill(Out::cast_from(shape_axis_reduce))
56    }
57}