cubecl_reduce/instructions/
mean.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction, ReduceRequirements, Sum};
5
6#[derive(Debug, CubeType, Clone)]
7pub struct Mean {
8 pub(crate) sum: Sum,
9}
10
11impl ReduceFamily for Mean {
12 type Instruction<In: Numeric> = Self;
13 type Config = ();
14}
15
16#[cube]
17impl<In: Numeric> ReduceInstruction<In> for Mean {
18 type AccumulatorItem = Line<In>;
19 type SharedAccumulator = SharedMemory<Line<In>>;
20 type Config = ();
21
22 fn requirements(_this: &Self) -> ReduceRequirements {
23 ReduceRequirements { coordinates: false }
24 }
25 fn from_config(_config: Self::Config) -> Self {
26 Mean { sum: Sum {} }
27 }
28
29 fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<In> {
30 Sum::null_input(&this.sum, line_size)
31 }
32
33 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
34 Sum::null_accumulator(&this.sum, line_size)
35 }
36
37 fn assign_accumulator(
38 this: &Self,
39 destination: &mut Self::AccumulatorItem,
40 source: &Self::AccumulatorItem,
41 ) {
42 Sum::assign_accumulator(&this.sum, destination, source);
43 }
44
45 fn reduce(
46 this: &Self,
47 accumulator: &Self::AccumulatorItem,
48 item: Line<In>,
49 _coordinate: ReduceCoordinate,
50 #[comptime] use_planes: bool,
51 ) -> Self::AccumulatorItem {
52 Sum::reduce(&this.sum, accumulator, item, _coordinate, use_planes)
53 }
54
55 fn fuse_accumulators(
56 this: &Self,
57 lhs: Self::AccumulatorItem,
58 rhs: Self::AccumulatorItem,
59 ) -> Self::AccumulatorItem {
60 Sum::fuse_accumulators(&this.sum, lhs, rhs)
61 }
62
63 fn merge_line<Out: Numeric>(
66 this: &Self,
67 accumulator: Self::AccumulatorItem,
68 shape_axis_reduce: u32,
69 ) -> Out {
70 Sum::merge_line::<Out>(&this.sum, accumulator, shape_axis_reduce)
71 / Out::cast_from(shape_axis_reduce)
72 }
73
74 fn to_output_perpendicular<Out: Numeric>(
75 this: &Self,
76 accumulator: Self::AccumulatorItem,
77 shape_axis_reduce: u32,
78 ) -> Line<Out> {
79 let sum = Sum::to_output_perpendicular::<Out>(&this.sum, accumulator, shape_axis_reduce);
80 sum / Line::empty(accumulator.size()).fill(Out::cast_from(shape_axis_reduce))
81 }
82}