wonnx 0.5.1

Wonnx is an ONNX runtime based on wgpu aimed at being a universal GPU runtime, written in Rust.
Documentation

{%- include "structs.wgsl" -%}

@group(0) @binding(0)
var<storage, read> input_0: Array;

{% for output in o_lens %}
	@group({{ loop.index / 4 | int }}) @binding({{ loop.index % 4}})
	var<storage, read_write> output_{{ loop.index0 }}: Array;
{% endfor %}

@compute @workgroup_size(256, 1, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
	let gidx = global_id.x;

	if (gidx < {{ i_lens[0] }}u) {
		var rest = gidx;

		{%- for chunks in i_chunks[0] -%}
			{% if loop.last %}
				let d_{{ loop.index0 }} = rest; 
			{% else %}
				let d_{{ loop.index0 }} = rest / {{ chunks }}u; 
				rest = gidx % {{ chunks }}u; 
			{% endif %}
		{%- endfor -%}

		{% for output in o_lens %}
			{%- if loop.first %}
				if (d_{{ axis }} < {{ split | first }}u) {
					let index = 
						{%- for chunk in o_chunks | first -%}
							{%- if not loop.first %}
								+
							{%- endif -%}
							d_{{ loop.index0 }} * {{ chunk }}u
						{%- endfor -%}
					;

					output_{{ loop.index0 }}.data[index] = input_0.data[gidx];
				}
			{%- else %}
				{% set split_output = split | nth(n=loop.index0 -1) %}
				if ((d_{{ axis }} >= {{ split_output }}u) && (d_{{ axis }} < {{ split | nth(n=loop.index0)}}u)) {
					let index = 
						{%- for chunk in o_chunks | nth(n=loop.index0) -%}
							{%- if not loop.first %}
								+
							{%- endif -%}
							
							{%- if loop.index0 == axis %}
								(d_{{ loop.index0 }} - {{ split_output }}u) * {{ chunk }}u
							{% else %}
								d_{{ loop.index0 }} * {{ chunk }}u
							{%- endif -%}
						{%- endfor -%}
					;
					output_{{ loop.index0 }}.data[index] = input_0.data[gidx];
				}
			{% endif %}
		{% endfor %}
	}
}