struct StdDevParams {
byte_len: u32,
reserved0: u32,
reserved1: u32,
reserved2: u32,
}
@group(0) @binding(0) var<uniform> params: StdDevParams;
@group(0) @binding(1) var<storage, read> input_words: array<u32>;
@group(0) @binding(2) var<storage, read_write> output_bits: array<u32>;
@compute @workgroup_size(1, 1, 1)
fn stats_std_dev(@builtin(global_invocation_id) id: vec3<u32>) {
if (id.x != 0u) {
return;
}
if (params.byte_len == 0u) {
output_bits[0] = bitcast<u32>(0.0f);
return;
}
var index = 0u;
var sum = 0.0f;
var sum_sq = 0.0f;
loop {
if (index >= params.byte_len) {
break;
}
let value = f32(vyre_packed_byte(&input_words, index));
sum = sum + value;
sum_sq = sum_sq + (value * value);
index = index + 1u;
}
let len = f32(params.byte_len);
let mean = sum / len;
let variance = max((sum_sq / len) - (mean * mean), 0.0f);
output_bits[0] = bitcast<u32>(sqrt(variance));
}