flashlight_tensor 0.4.5

gpu/cpu tensor library focused around matrix and neural network operations
Documentation
@group(0) @binding(0)
var<storage, read> input: array<f32>; //input_cache, grad_output

@group(0) @binding(1)
var<storage, read> shapes: array<u32>;

@group(0) @binding(3)
var<storage, read_write> output: array<f32>;

fn idx_to_global(idx: u32, shape: array<u32, 6>, rank: u32) -> array<u32, 6> {
    var used = idx;
    var out: array<u32, 6>;
    var shape_prod = 1u;

    for (var i = 0u; i < rank; i++) {
        shape_prod *= shape[i];
    }

    for (var i = 0u; i < rank; i++) {
        shape_prod = shape_prod / shape[i];
        out[i] = used / shape_prod;
        used = used % shape_prod;
    }

    return out;
}

fn global_to_idx(pos: array<u32, 6>, shape: array<u32, 6>, rank: u32) -> u32 {
    var idx = 0u;
    var stride = 1u;
    for (var i = rank; i > 0u; i--) {
        idx += pos[i - 1u] * stride;
        stride *= shape[i - 1u];
    }
    return idx;
}

fn relu_der(x: f32) -> f32 {
    if (x > 0.0) {
        return 1.0;
    }
    return 0.0;
}

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>){
	
	let idx = global_id.y * 65535u + global_id.x;
	if (idx >= arrayLength(&output)) {
		return;
	}

	let grad_offset = shapes[0] * shapes[1];

	let sample_size = grad_offset + shapes[2] * shapes[3];
	let sample_count: u32 = arrayLength(&input) / sample_size;

	let sample_idx = idx / sample_size;
    let inner_idx = idx % sample_size;

	var input_shape: array<u32, 6>;
	var grad_shape: array<u32, 6>;
	var output_shape: array<u32, 6>;

	for (var i=0u; i<2; i++){
		input_shape[i] = shapes[i];
		grad_shape[i] = shapes[i+2];
		output_shape[i] = shapes[i+4];
	}

	let output_pos = idx_to_global(idx, output_shape, 2);

	var input_pos: array<u32, 6>;
    var grad_pos: array<u32, 6>;

    for (var i = 0u; i < 2; i++) {
		if (input_shape[i] == 1u){
			input_pos[i] = 0u;
		}
		else{
			input_pos[i] = output_pos[i];
		}

		if (grad_shape[i] == 1u){
			grad_pos[i] = 0u;
		}
		else{
			grad_pos[i] = output_pos[i];
		}
    }

	let input_idx = sample_idx * sample_size + global_to_idx(input_pos, input_shape, 2);
	let grad_idx = sample_idx * sample_size + grad_offset + global_to_idx(grad_pos, grad_shape, 2);

	output[idx] = relu_der(input[input_idx]) * input[grad_idx];
}