Skip to main content

MATMUL_2D_SHADERS_PATH

Constant MATMUL_2D_SHADERS_PATH 

Source
pub const MATMUL_2D_SHADERS_PATH: &'static str = "// cache\n// will be adjusted to WORKGROUP_SIZE before compilation\nvar<workgroup> tile_a: array<array<f32, 16>, 16>;\nvar<workgroup> tile_b: array<array<f32, 16>, 16>;\n\n// override\noverride LEN_HEAP_MINUS_ONE: u32;\noverride WORKGROUP_SIZE: u32 = 16;\n// INIT\nstruct ArrayMetadata{\n\tpointer: vec2<u32>, // 8\n\tlen: u32, // 4\n\toffset: u32, // 4\n\tdim: u32, // 4\n\tpadding0: u32, // 4\n\tpadding1: u32, // 4\n\tpadding2: u32, // 4\n\tshape: array<vec4<u32>, 2>, // 32\n\tstride: array<vec4<u32>, 2>, // 32\n\to_stride: array<vec4<u32>, 2>, // 32\n\tm_n_shape: array<vec4<u32>, 2>, // 32\n\tm_n_origin_stride: array<vec4<u32>, 2>, // 32\n\tpadding3: array<vec4<u32>, 4>,\n}\n// MODULE\n// // heap\n@group(0) @binding(0)\nvar<storage, read_write> heap: array<f32>;\n// // execute_args\n@group(0) @binding(1)\nvar<uniform> execute_args: array<ArrayMetadata, 3>;\n\n@compute @workgroup_size(WORKGROUP_SIZE, WORKGROUP_SIZE, 1)\nfn main(\n@builtin (local_invocation_id) local_id: vec3<u32>,\n@builtin (global_invocation_id) global_id: vec3<u32>,\n@builtin (workgroup_id) group_id: vec3<u32>,\n){\n    let m = execute_args[0].shape[0][0];\n    let k = execute_args[0].shape[0][1];\n    let n = execute_args[1].shape[0][1];\n\n    let iter = (k + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;\n    var acc = 0.;\n    for (var i = 0u; i < iter; i++){\n        // A\n        let row_a = global_id.x;\n        let col_a = local_id.y + WORKGROUP_SIZE * i;\n\n        var a_value = select(0., heap[indexing_a(row_a, col_a)], row_a < m && col_a < k);\n        tile_a[local_id.x][local_id.y] = a_value;\n\n        // B\n        let row_b = local_id.x + WORKGROUP_SIZE * i;\n        let col_b = global_id.y;\n\n        var b_value = select(0., heap[indexing_b(row_b, col_b)], row_b < k && col_b < n);\n        tile_b[local_id.x][local_id.y] = b_value;\n\n        workgroupBarrier();\n\n        for (var ii = 0u; ii < WORKGROUP_SIZE; ii++){\n            acc += tile_a[local_id.x][ii] * tile_b[ii][local_id.y];\n        }\n\n        workgroupBarrier();\n    }\n\n    if global_id.x >= m || global_id.y >= n{\n        return;\n    }\n\n    heap[indexing_out(global_id.x, global_id.y)] = acc;\n}\n\nfn indexing_a(row:u32, col:u32)-> u32{\n    let arr = execute_args[0];\n    let stride = arr.stride[0];\n    let index = arr.pointer.x + arr.offset + row * stride[0] + col * stride[1];\n    return select(0, index, index < LEN_HEAP_MINUS_ONE);\n}\n\nfn indexing_b(row:u32, col:u32)-> u32{\n    let arr = execute_args[1];\n    let stride = arr.stride[0];\n    let index = arr.pointer.x + arr.offset + row * stride[0] + col * stride[1];\n    return select(0, index, index < LEN_HEAP_MINUS_ONE);\n}\n\nfn indexing_out(row: u32, col:u32)-> u32{\n    let arr = execute_args[2];\n    let stride = arr.stride[0];\n    let index = arr.pointer.x + arr.offset + row * stride[0] + col * stride[1];\n    return index;\n}\n";