#version 450
// Matrix transpose compute shader
// Computes: result[j * rows + i] = input[i * cols + j]
// Uses shared memory tiling for cache efficiency
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
layout(set = 0, binding = 0) buffer InputBuffer {
float input_data[];
};
layout(set = 0, binding = 1) buffer OutputBuffer {
float result[];
};
layout(set = 0, binding = 2) uniform UniformBuffer {
uint rows; // Number of rows in input matrix
uint cols; // Number of columns in input matrix
};
// Shared memory tile for coalesced memory access
// Add padding to avoid bank conflicts
shared float tile[16][17];
void main() {
uint x = gl_GlobalInvocationID.x;
uint y = gl_GlobalInvocationID.y;
uint local_x = gl_LocalInvocationID.x;
uint local_y = gl_LocalInvocationID.y;
// Load data into shared memory tile
if (x < cols && y < rows) {
// Read from input in row-major order
tile[local_y][local_x] = input_data[y * cols + x];
} else {
// Pad with zeros for out-of-bounds access
tile[local_y][local_x] = 0.0;
}
// Synchronize to ensure all data is loaded
barrier();
// Calculate transposed coordinates
uint trans_x = gl_WorkGroupID.y * 16 + local_x;
uint trans_y = gl_WorkGroupID.x * 16 + local_y;
// Write to output in transposed layout
if (trans_x < rows && trans_y < cols) {
// Write to output: transpose dimensions
// Original input[y][x] becomes output[x][y]
result[trans_y * rows + trans_x] = tile[local_x][local_y];
}
}