struct ArithmeticMeanParams {
byte_len: u32,
reserved0: u32,
reserved1: u32,
reserved2: u32,
}
@group(0) @binding(0) var<uniform> params: ArithmeticMeanParams;
@group(0) @binding(1) var<storage, read> input_words: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_bits: array<u32>;
@compute @workgroup_size(1, 1, 1)
fn stats_arithmetic_mean(@builtin(global_invocation_id) id: vec3<u32>) {
if (id.x != 0u) {
return;
}
if (params.byte_len == 0u) {
output_bits[0] = bitcast<u32>(0.0f);
return;
}
var index = 0u;
var sum = 0u;
loop {
if (index >= params.byte_len) {
break;
}
sum = sum + vyre_packed_byte(&input_words, index);
index = index + 1u;
}
let mean = f32(sum) / f32(params.byte_len);
output_bits[0] = bitcast<u32>(mean);
}