Constant MATMUL_2D_SHADERS_PATH
Source pub const MATMUL_2D_SHADERS_PATH: &'static str = "// cache\n// will be adjusted to WORKGROUP_SIZE before compilation\nvar<workgroup> tile_a: array<array<f32, 16>, 16>;\nvar<workgroup> tile_b: array<array<f32, 16>, 16>;\n\n// override\noverride LEN_HEAP_MINUS_ONE: u32;\noverride WORKGROUP_SIZE: u32 = 16;\n// INIT\nstruct ArrayMetadata{\n\tpointer: vec2<u32>, // 8\n\tlen: u32, // 4\n\toffset: u32, // 4\n\tdim: u32, // 4\n\tpadding0: u32, // 4\n\tpadding1: u32, // 4\n\tpadding2: u32, // 4\n\tshape: array<vec4<u32>, 2>, // 32\n\tstride: array<vec4<u32>, 2>, // 32\n\to_stride: array<vec4<u32>, 2>, // 32\n\tm_n_shape: array<vec4<u32>, 2>, // 32\n\tm_n_origin_stride: array<vec4<u32>, 2>, // 32\n\tpadding3: array<vec4<u32>, 4>,\n}\n// MODULE\n// // heap\n@group(0) @binding(0)\nvar<storage, read_write> heap: array<f32>;\n// // execute_args\n@group(0) @binding(1)\nvar<uniform> execute_args: array<ArrayMetadata, 3>;\n\n@compute @workgroup_size(WORKGROUP_SIZE, WORKGROUP_SIZE, 1)\nfn main(\n@builtin (local_invocation_id) local_id: vec3<u32>,\n@builtin (global_invocation_id) global_id: vec3<u32>,\n@builtin (workgroup_id) group_id: vec3<u32>,\n){\n let m = execute_args[0].shape[0][0];\n let k = execute_args[0].shape[0][1];\n let n = execute_args[1].shape[0][1];\n\n let iter = (k + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE;\n var acc = 0.;\n for (var i = 0u; i < iter; i++){\n // A\n let row_a = global_id.x;\n let col_a = local_id.y + WORKGROUP_SIZE * i;\n\n var a_value = select(0., heap[indexing_a(row_a, col_a)], row_a < m && col_a < k);\n tile_a[local_id.x][local_id.y] = a_value;\n\n // B\n let row_b = local_id.x + WORKGROUP_SIZE * i;\n let col_b = global_id.y;\n\n var b_value = select(0., heap[indexing_b(row_b, col_b)], row_b < k && col_b < n);\n tile_b[local_id.x][local_id.y] = b_value;\n\n workgroupBarrier();\n\n for (var ii = 0u; ii < WORKGROUP_SIZE; ii++){\n acc += tile_a[local_id.x][ii] * tile_b[ii][local_id.y];\n }\n\n workgroupBarrier();\n }\n\n if global_id.x >= m || global_id.y >= n{\n return;\n }\n\n heap[indexing_out(global_id.x, global_id.y)] = acc;\n}\n\nfn indexing_a(row:u32, col:u32)-> u32{\n let arr = execute_args[0];\n let stride = arr.stride[0];\n let index = arr.pointer.x + arr.offset + row * stride[0] + col * stride[1];\n return select(0, index, index < LEN_HEAP_MINUS_ONE);\n}\n\nfn indexing_b(row:u32, col:u32)-> u32{\n let arr = execute_args[1];\n let stride = arr.stride[0];\n let index = arr.pointer.x + arr.offset + row * stride[0] + col * stride[1];\n return select(0, index, index < LEN_HEAP_MINUS_ONE);\n}\n\nfn indexing_out(row: u32, col:u32)-> u32{\n let arr = execute_args[2];\n let stride = arr.stride[0];\n let index = arr.pointer.x + arr.offset + row * stride[0] + col * stride[1];\n return index;\n}\n";