wonnx 0.5.1

Wonnx is an ONNX runtime based on wgpu aimed at being a universal GPU runtime, written in Rust.
Documentation

{%- include "structs.wgsl" -%}

struct Block {
	data: array<{{ elem_type }}>
};

// X (input)
@group(0) @binding(0)
var<storage, read> input_0: Block;

// Scale
@group(0) @binding(1)
var<storage, read> input_1: Array;

// B (bias)
@group(0) @binding(2)
var<storage, read> input_2: Array;

// Input mean
@group(0) @binding(3)
var<storage, read> input_3: Array;

// Input variance
@group(1) @binding(0)
var<storage, read> input_4: Array;

// Y (Output)
@group(1) @binding(1)
var<storage, read_write> output_0: Block;

@compute @workgroup_size(1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
	let channel = global_id.y;
	let batch = global_id.z;
	let index = global_id.x + batch * {{ batch_size }}u + channel * {{ channel_size }}u;

	// Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + B
	let x = input_0.data[index];
	let channel_scale = input_1.data[channel];
	let channel_bias = input_2.data[channel];
	let channel_mean = input_3.data[channel];
	let channel_var = input_4.data[channel];

	output_0.data[index] = (x - channel_mean) / sqrt(channel_var + {{ epsilon }}) * channel_scale + channel_bias;
}