#version 450
// SGEMM-lite compute shader: out[i, j] = sum_d samples[i, d] * centroids[j, d].
// samples is row-major (n_samples, dimension); centroids is row-major
// (n_list, dimension); output is row-major (n_samples, n_list).
// 2D grid — one thread per output cell. Used by the Intel IVF index
// (k-means assignment) mirroring src/shaders/metal_hnsw.metal::sgemm_dot.
layout(local_size_x = 16, local_size_y = 16, local_size_z = 1) in;
layout(std430, binding = 0) readonly buffer Samples { float samples[]; };
layout(std430, binding = 1) readonly buffer Centroids { float centroids[]; };
layout(std430, binding = 2) writeonly buffer Out { float out_[]; };
layout(push_constant) uniform PushConstants {
uint dimension;
uint n_list;
uint n_samples;
} pc;
void main() {
uint i = gl_GlobalInvocationID.x; // sample index
uint j = gl_GlobalInvocationID.y; // centroid index
if (i >= pc.n_samples || j >= pc.n_list) {
return;
}
uint sbase = i * pc.dimension;
uint cbase = j * pc.dimension;
float sum = 0.0;
for (uint d = 0u; d < pc.dimension; ++d) {
sum += samples[sbase + d] * centroids[cbase + d];
}
out_[i * pc.n_list + j] = sum;
}