cubecl_reduce/instructions/
mean.rs1use cubecl_core as cubecl;
2use cubecl_core::prelude::*;
3
4use crate::precision::ReducePrecision;
5
6use super::{ReduceCoordinate, ReduceFamily, ReduceInstruction, ReduceRequirements, Sum};
7
8#[derive(Debug, CubeType, Clone)]
9pub struct Mean {
10 pub(crate) sum: Sum,
11}
12
13impl ReduceFamily for Mean {
14 type Instruction<P: ReducePrecision> = Self;
15 type Config = ();
16}
17
18#[cube]
19fn null_input<P: ReducePrecision, SI: ReduceInstruction<P>>(
20 sum: &SI,
21 #[comptime] line_size: u32,
22) -> Line<P::EI> {
23 SI::null_input(sum, line_size)
24}
25
26#[cube]
27impl<P: ReducePrecision> ReduceInstruction<P> for Mean {
28 type AccumulatorItem = Line<P::EA>;
29 type SharedAccumulator = SharedMemory<Line<P::EA>>;
30 type Config = ();
31
32 fn requirements(_this: &Self) -> ReduceRequirements {
33 ReduceRequirements { coordinates: false }
34 }
35 fn from_config(_config: Self::Config) -> Self {
36 Mean { sum: Sum {} }
37 }
38
39 fn null_input(this: &Self, #[comptime] line_size: u32) -> Line<P::EI> {
40 <Sum as ReduceInstruction<P>>::null_input(&this.sum, line_size)
41 }
42
43 fn null_accumulator(this: &Self, #[comptime] line_size: u32) -> Self::AccumulatorItem {
44 <Sum as ReduceInstruction<P>>::null_accumulator(&this.sum, line_size)
45 }
46
47 fn assign_accumulator(
48 this: &Self,
49 destination: &mut Self::AccumulatorItem,
50 source: &Self::AccumulatorItem,
51 ) {
52 <Sum as ReduceInstruction<P>>::assign_accumulator(&this.sum, destination, source);
53 }
54
55 fn reduce(
56 this: &Self,
57 accumulator: &Self::AccumulatorItem,
58 item: Line<P::EI>,
59 _coordinate: ReduceCoordinate,
60 #[comptime] use_planes: bool,
61 ) -> Self::AccumulatorItem {
62 <Sum as ReduceInstruction<P>>::reduce(&this.sum, accumulator, item, _coordinate, use_planes)
63 }
64
65 fn fuse_accumulators(
66 this: &Self,
67 lhs: Self::AccumulatorItem,
68 rhs: Self::AccumulatorItem,
69 ) -> Self::AccumulatorItem {
70 <Sum as ReduceInstruction<P>>::fuse_accumulators(&this.sum, lhs, rhs)
71 }
72
73 fn merge_line<Out: Numeric>(
76 this: &Self,
77 accumulator: Self::AccumulatorItem,
78 shape_axis_reduce: u32,
79 ) -> Out {
80 <Sum as ReduceInstruction<P>>::merge_line::<Out>(&this.sum, accumulator, shape_axis_reduce)
81 / Out::cast_from(shape_axis_reduce)
82 }
83
84 fn to_output_perpendicular<Out: Numeric>(
85 this: &Self,
86 accumulator: Self::AccumulatorItem,
87 shape_axis_reduce: u32,
88 ) -> Line<Out> {
89 let sum = <Sum as ReduceInstruction<P>>::to_output_perpendicular::<Out>(
90 &this.sum,
91 accumulator,
92 shape_axis_reduce,
93 );
94 sum / Line::empty(accumulator.size()).fill(Out::cast_from(shape_axis_reduce))
95 }
96}