cubecl_reduce/instructions/
mean.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{Reduce, ReduceInstruction, Sum};
5
6#[derive(Debug)]
7pub struct Mean;
8
9impl Reduce for Mean {
10 type Instruction<In: Numeric> = Self;
11}
12
13#[cube]
14impl<In: Numeric> ReduceInstruction<In> for Mean {
15 type AccumulatorItem = Line<In>;
16 type SharedAccumulator = SharedMemory<Line<In>>;
17
18 fn null_input(#[comptime] line_size: u32) -> Line<In> {
19 Sum::null_input(line_size)
20 }
21
22 fn null_accumulator(#[comptime] line_size: u32) -> Self::AccumulatorItem {
23 Sum::null_accumulator(line_size)
24 }
25
26 fn assign_accumulator(destination: &mut Self::AccumulatorItem, source: &Self::AccumulatorItem) {
27 Sum::assign_accumulator(destination, source);
28 }
29
30 fn reduce(
31 accumulator: &Self::AccumulatorItem,
32 item: Line<In>,
33 _coordinate: Line<u32>,
34 #[comptime] use_planes: bool,
35 ) -> Self::AccumulatorItem {
36 Sum::reduce(accumulator, item, _coordinate, use_planes)
37 }
38
39 fn fuse_accumulators(
40 lhs: Self::AccumulatorItem,
41 rhs: Self::AccumulatorItem,
42 ) -> Self::AccumulatorItem {
43 Sum::fuse_accumulators(lhs, rhs)
44 }
45
46 fn merge_line<Out: Numeric>(accumulator: Self::AccumulatorItem, shape_axis_reduce: u32) -> Out {
47 Sum::merge_line::<Out>(accumulator, shape_axis_reduce) / Out::cast_from(shape_axis_reduce)
48 }
49
50 fn to_output_perpendicular<Out: Numeric>(
51 accumulator: Self::AccumulatorItem,
52 shape_axis_reduce: u32,
53 ) -> Line<Out> {
54 let sum = Sum::to_output_perpendicular::<Out>(accumulator, shape_axis_reduce);
55 sum / Line::empty(accumulator.size()).fill(Out::cast_from(shape_axis_reduce))
56 }
57}