pub(super) const COSINE_SHADER: &str = r"
struct Params {
dimension: u32,
num_vectors: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> vectors: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn batch_cosine(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
if (idx >= params.num_vectors) {
return;
}
let dim = params.dimension;
let offset = idx * dim;
var dot: f32 = 0.0;
var norm_q: f32 = 0.0;
var norm_v: f32 = 0.0;
for (var i: u32 = 0u; i < dim; i = i + 1u) {
let q = query[i];
let v = vectors[offset + i];
dot = dot + q * v;
norm_q = norm_q + q * q;
norm_v = norm_v + v * v;
}
let denom = sqrt(norm_q) * sqrt(norm_v);
if (denom > 0.0) {
results[idx] = dot / denom;
} else {
results[idx] = 0.0;
}
}
";
pub(super) const EUCLIDEAN_SHADER: &str = r"
struct Params {
dimension: u32,
num_vectors: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> vectors: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn batch_euclidean(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
if (idx >= params.num_vectors) {
return;
}
let dim = params.dimension;
let offset = idx * dim;
var sum_sq: f32 = 0.0;
for (var i: u32 = 0u; i < dim; i = i + 1u) {
let diff = query[i] - vectors[offset + i];
sum_sq = sum_sq + diff * diff;
}
results[idx] = sqrt(sum_sq);
}
";
pub(super) const PQ_KMEANS_ASSIGN_SHADER: &str = r"
struct Params {
num_vectors: u32,
num_centroids: u32,
subspace_dim: u32,
_padding: u32,
}
@group(0) @binding(0) var<storage, read> vectors: array<f32>;
@group(0) @binding(1) var<storage, read> centroids: array<f32>;
@group(0) @binding(2) var<storage, read_write> assignments: array<u32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn kmeans_assign(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
if (idx >= params.num_vectors) { return; }
let sd = params.subspace_dim;
let k = params.num_centroids;
let vec_offset = idx * sd;
var best_dist: f32 = 3.4028235e+38;
var best_idx: u32 = 0u;
for (var c: u32 = 0u; c < k; c = c + 1u) {
let cent_offset = c * sd;
var dist: f32 = 0.0;
for (var d: u32 = 0u; d < sd; d = d + 1u) {
let diff = vectors[vec_offset + d] - centroids[cent_offset + d];
dist = dist + diff * diff;
}
if (dist < best_dist) {
best_dist = dist;
best_idx = c;
}
}
assignments[idx] = best_idx;
}
";
pub(super) const DOT_PRODUCT_SHADER: &str = r"
struct Params {
dimension: u32,
num_vectors: u32,
}
@group(0) @binding(0) var<storage, read> query: array<f32>;
@group(0) @binding(1) var<storage, read> vectors: array<f32>;
@group(0) @binding(2) var<storage, read_write> results: array<f32>;
@group(0) @binding(3) var<uniform> params: Params;
@compute @workgroup_size(256)
fn batch_dot(@builtin(global_invocation_id) id: vec3<u32>) {
let idx = id.x;
if (idx >= params.num_vectors) {
return;
}
let dim = params.dimension;
let offset = idx * dim;
var dot: f32 = 0.0;
for (var i: u32 = 0u; i < dim; i = i + 1u) {
dot = dot + query[i] * vectors[offset + i];
}
results[idx] = dot;
}
";