#![allow(missing_docs)]
use cubecl::prelude::*;
use std::iter::Sum;
use crate::gpu::tensor::*;
use crate::gpu::*;
use crate::utils::dist::Dist;
pub struct BatchData<'a, T> {
pub data: &'a [T],
pub norm: &'a [T],
pub n: usize,
}
impl<'a, T> BatchData<'a, T> {
pub fn new(data: &'a [T], norm: &'a [T], n: usize) -> Self {
Self { data, norm, n }
}
}
#[cube(launch_unchecked)]
pub fn euclidean_tiled<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_vectors: &Tensor<Line<F>>,
distances: &mut Tensor<F>,
db_start: u32,
n_db_chunk: u32,
n_queries: u32,
dist_stride: u32,
#[comptime] dim_lines: usize,
) {
let db_idx = ABSOLUTE_POS_X as usize;
let query_idx =
((CUBE_POS_Z * CUBE_COUNT_Y + CUBE_POS_Y) * WORKGROUP_SIZE_Y + UNIT_POS_Y) as usize;
let local_y = UNIT_POS_Y as usize;
let local_x = UNIT_POS_X as usize;
let dim_scalars = dim_lines * 4usize;
let wg_y = WORKGROUP_SIZE_Y as usize;
let mut s_query = SharedMemory::<F>::new(wg_y * dim_scalars);
let thread_id = local_y * WORKGROUP_SIZE_X as usize + local_x;
let total_threads = WORKGROUP_SIZE_X as usize * wg_y;
let total_elems = wg_y * dim_scalars;
let q_base = query_idx - local_y;
let mut load_idx = thread_id;
while load_idx < total_elems {
let q_local = load_idx / dim_scalars;
let elem = load_idx % dim_scalars;
let q_global = q_base + q_local;
if q_global < n_queries as usize {
let line_idx = elem / 4usize;
let lane = elem % 4usize;
let line_val = query_vectors[q_global * dim_lines + line_idx];
s_query[load_idx] = line_val[lane];
} else {
s_query[load_idx] = F::new(0.0);
}
load_idx += total_threads;
}
sync_cube();
if query_idx >= n_queries as usize || db_idx >= n_db_chunk as usize {
terminate!();
}
let global_db_idx = db_start as usize + db_idx;
let q_shared_base = local_y * dim_scalars;
let mut sum = F::new(0.0);
for i in 0..dim_lines {
let d_line = db_vectors[global_db_idx * dim_lines + i];
let s_off = q_shared_base + i * 4usize;
let diff0 = s_query[s_off] - d_line[0];
let diff1 = s_query[s_off + 1usize] - d_line[1];
let diff2 = s_query[s_off + 2usize] - d_line[2];
let diff3 = s_query[s_off + 3usize] - d_line[3];
sum += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
}
distances[query_idx * dist_stride as usize + db_idx] = sum;
}
#[cube(launch_unchecked)]
pub fn cosine_tiled<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_vectors: &Tensor<Line<F>>,
query_norms: &Tensor<F>,
db_norms: &Tensor<F>,
distances: &mut Tensor<F>,
db_start: u32,
n_db_chunk: u32,
n_queries: u32,
dist_stride: u32,
#[comptime] dim_lines: usize,
) {
let db_idx = ABSOLUTE_POS_X as usize;
let query_idx =
((CUBE_POS_Z * CUBE_COUNT_Y + CUBE_POS_Y) * WORKGROUP_SIZE_Y + UNIT_POS_Y) as usize;
let local_y = UNIT_POS_Y as usize;
let local_x = UNIT_POS_X as usize;
let dim_scalars = dim_lines * 4usize;
let wg_y = WORKGROUP_SIZE_Y as usize;
let mut s_query = SharedMemory::<F>::new(wg_y * dim_scalars);
let thread_id = local_y * WORKGROUP_SIZE_X as usize + local_x;
let total_threads = WORKGROUP_SIZE_X as usize * wg_y;
let total_elems = wg_y * dim_scalars;
let q_base = query_idx - local_y;
let mut load_idx = thread_id;
while load_idx < total_elems {
let q_local = load_idx / dim_scalars;
let elem = load_idx % dim_scalars;
let q_global = q_base + q_local;
if q_global < n_queries as usize {
let line_idx = elem / 4usize;
let lane = elem % 4usize;
let line_val = query_vectors[q_global * dim_lines + line_idx];
s_query[load_idx] = line_val[lane];
} else {
s_query[load_idx] = F::new(0.0);
}
load_idx += total_threads;
}
sync_cube();
if query_idx >= n_queries as usize || db_idx >= n_db_chunk as usize {
terminate!();
}
let global_db_idx = db_start as usize + db_idx;
let q_shared_base = local_y * dim_scalars;
let mut dot = F::new(0.0);
for i in 0..dim_lines {
let d_line = db_vectors[global_db_idx * dim_lines + i];
let s_off = q_shared_base + i * 4usize;
dot += s_query[s_off] * d_line[0]
+ s_query[s_off + 1usize] * d_line[1]
+ s_query[s_off + 2usize] * d_line[2]
+ s_query[s_off + 3usize] * d_line[3];
}
let q_norm = query_norms[query_idx];
let d_norm = db_norms[global_db_idx];
distances[query_idx * dist_stride as usize + db_idx] = F::new(1.0) - (dot / (q_norm * d_norm));
}
#[allow(dead_code)]
fn prefer_coalesced_topk<R: Runtime>(client: &ComputeClient<R>) -> bool {
let name = R::name(client).to_lowercase();
!name.contains("metal")
}
#[cube(launch_unchecked)]
pub fn init_topk<F: Float>(dists: &mut Tensor<F>, indices: &mut Tensor<u32>) {
let query_idx =
((CUBE_POS_Z * CUBE_COUNT_Y + CUBE_POS_Y) * WORKGROUP_SIZE_Y + UNIT_POS_Y) as usize;
let k_idx = ABSOLUTE_POS_X as usize;
let k = dists.shape(1);
if query_idx >= dists.shape(0) || k_idx >= k {
terminate!();
}
let offset = query_idx * dists.stride(0) + k_idx;
dists[offset] = F::new(f32::MAX);
indices[offset] = 0u32;
}
#[cube(launch_unchecked)]
pub fn extract_topk<F: Float>(
distances: &Tensor<F>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
chunk_offset: u32,
actual_chunk_size: u32,
) {
let query_idx =
((CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) * WORKGROUP_SIZE_X + UNIT_POS_X) as usize;
if query_idx >= distances.shape(0) {
terminate!();
}
let k = out_dists.shape(1);
let dist_offset = query_idx * distances.stride(0);
let out_offset = query_idx * out_dists.stride(0);
for i in 0..actual_chunk_size {
let dist = distances[dist_offset + i as usize];
if dist < out_dists[out_offset + k - 1] {
let mut insert_pos: usize = k - 1;
for j in 0..k {
if dist < out_dists[out_offset + j] && insert_pos == k - 1 {
insert_pos = j;
}
}
for j in 0..k - 1 {
let src = k - 2 - j;
let dst = k - 1 - j;
if src >= insert_pos {
out_dists[out_offset + dst] = out_dists[out_offset + src];
out_indices[out_offset + dst] = out_indices[out_offset + src];
}
}
out_dists[out_offset + insert_pos] = dist;
out_indices[out_offset + insert_pos] = chunk_offset + i;
}
}
}
#[cube(launch_unchecked)]
pub fn extract_topk_coalesced<F: Float>(
distances: &Tensor<F>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
chunk_offset: u32,
actual_chunk_size: u32,
dist_stride: u32,
k_param: u32,
#[comptime] k: usize,
) {
let query_idx = CUBE_POS_X as usize;
let tx = UNIT_POS_X as usize;
let wg = WORKGROUP_SIZE_X as usize;
let kr = k_param as usize;
if query_idx >= out_dists.shape(0) {
terminate!();
}
let mut local_dists = Array::<F>::new(k);
let mut local_indices = Array::<u32>::new(k);
for i in 0..k {
local_dists[i] = F::new(f32::MAX);
local_indices[i] = 0u32;
}
let dist_base = query_idx * dist_stride as usize;
let mut col = tx as u32;
while col < actual_chunk_size {
let dist = distances[dist_base + col as usize];
if dist < local_dists[kr - 1] {
let mut pos = kr - 1;
for j in 0..k {
if dist < local_dists[j] && pos == kr - 1 {
pos = j;
}
}
let mut s = kr - 1;
while s > pos {
local_dists[s] = local_dists[s - 1];
local_indices[s] = local_indices[s - 1];
s -= 1usize;
}
local_dists[pos] = dist;
local_indices[pos] = chunk_offset + col;
}
col += wg as u32;
}
let mut s_dist = SharedMemory::<F>::new(32 * k);
let mut s_idx = SharedMemory::<u32>::new(32 * k);
let s_base = tx * kr;
for i in 0..k {
s_dist[s_base + i] = local_dists[i];
s_idx[s_base + i] = local_indices[i];
}
sync_cube();
if tx == 0usize {
let mut t = 1usize;
while t < wg {
let t_base = t * kr;
let mut done: u32 = 0u32;
for i in 0..k {
if done == 0u32 {
let cd = s_dist[t_base + i];
let ci = s_idx[t_base + i];
if cd >= s_dist[kr - 1] {
done = 1u32;
} else {
let mut pos = kr - 1;
for j in 0..k {
if cd < s_dist[j] && pos == kr - 1 {
pos = j;
}
}
let mut s_i = kr - 1;
while s_i > pos {
s_dist[s_i] = s_dist[s_i - 1];
s_idx[s_i] = s_idx[s_i - 1];
s_i -= 1usize;
}
s_dist[pos] = cd;
s_idx[pos] = ci;
}
}
}
t += 1usize;
}
let out_base = query_idx * kr;
for i in 0..k {
let running_dist = out_dists[out_base + i];
let running_idx = out_indices[out_base + i];
if running_dist < s_dist[kr - 1] {
let mut pos = kr - 1;
for j in 0..k {
if running_dist < s_dist[j] && pos == kr - 1 {
pos = j;
}
}
let mut s_i = kr - 1;
while s_i > pos {
s_dist[s_i] = s_dist[s_i - 1];
s_idx[s_i] = s_idx[s_i - 1];
s_i -= 1usize;
}
s_dist[pos] = running_dist;
s_idx[pos] = running_idx;
}
}
for i in 0..k {
out_dists[out_base + i] = s_dist[i];
out_indices[out_base + i] = s_idx[i];
}
}
}
pub fn query_batch_gpu<T, R>(
k: usize,
query_data: &BatchData<T>,
db_data: &BatchData<T>,
dim: usize,
metric: &Dist,
device: R::Device,
verbose: bool,
) -> (Vec<Vec<usize>>, Vec<Vec<T>>)
where
R: Runtime,
T: Float + Sum + cubecl::CubeElement + num_traits::Float + num_traits::FromPrimitive,
{
let client = R::client(&device);
let vec_size = LINE_SIZE as usize;
let dim_lines = dim / vec_size;
let n_query_chunks = query_data.n.div_ceil(QUERY_CHUNK_SIZE);
let n_db_chunks = db_data.n.div_ceil(DB_CHUNK_SIZE);
let db_gpu = GpuTensor::<R, T>::from_slice(db_data.data, vec![db_data.n, dim], &client);
let db_norms_gpu = if *metric == Dist::Cosine {
Some(GpuTensor::<R, T>::from_slice(
db_data.norm,
vec![db_data.n],
&client,
))
} else {
None
};
let mut all_indices = Vec::with_capacity(query_data.n);
let mut all_distances = Vec::with_capacity(query_data.n);
let max_db_chunk = DB_CHUNK_SIZE.min(db_data.n);
for query_chunk_idx in 0..n_query_chunks {
if verbose && query_chunk_idx % 10 == 0 {
println!(
"Processed {} query chunks out of {}",
query_chunk_idx, n_query_chunks
);
}
let query_start = query_chunk_idx * QUERY_CHUNK_SIZE;
let query_end = (query_start + QUERY_CHUNK_SIZE).min(query_data.n);
let n_q = query_end - query_start;
let query_gpu = GpuTensor::<R, T>::from_slice(
&query_data.data[query_start * dim..query_end * dim],
vec![n_q, dim],
&client,
);
let query_norms_gpu = if *metric == Dist::Cosine {
Some(GpuTensor::<R, T>::from_slice(
&query_data.norm[query_start..query_end],
vec![n_q],
&client,
))
} else {
None
};
let topk_dists = GpuTensor::<R, T>::empty(vec![n_q, k], &client);
let topk_indices = GpuTensor::<R, u32>::empty(vec![n_q, k], &client);
let init_gx = (k as u32).div_ceil(WORKGROUP_SIZE_X);
let (init_gy, init_gz) = grid_2d((n_q as u32).div_ceil(WORKGROUP_SIZE_Y));
unsafe {
let _ = init_topk::launch_unchecked::<T, R>(
&client,
CubeCount::Static(init_gx, init_gy, init_gz),
CubeDim::new_2d(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y),
topk_dists.clone().into_tensor_arg(1),
topk_indices.clone().into_tensor_arg(1),
);
}
let distances_gpu = GpuTensor::<R, T>::empty(vec![n_q, max_db_chunk], &client);
for db_chunk_idx in 0..n_db_chunks {
let db_start = db_chunk_idx * DB_CHUNK_SIZE;
let db_end = (db_start + DB_CHUNK_SIZE).min(db_data.n);
let n_db = db_end - db_start;
let grid_x = (n_db as u32).div_ceil(WORKGROUP_SIZE_X);
let (grid_y, grid_z) = grid_2d((n_q as u32).div_ceil(WORKGROUP_SIZE_Y));
match *metric {
Dist::Euclidean => unsafe {
let _ = euclidean_tiled::launch_unchecked::<T, R>(
&client,
CubeCount::Static(grid_x, grid_y, grid_z),
CubeDim::new_2d(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y),
query_gpu.clone().into_tensor_arg(vec_size),
db_gpu.clone().into_tensor_arg(vec_size),
distances_gpu.clone().into_tensor_arg(1),
ScalarArg {
elem: db_start as u32,
},
ScalarArg { elem: n_db as u32 },
ScalarArg { elem: n_q as u32 },
ScalarArg {
elem: max_db_chunk as u32,
},
dim_lines,
);
},
Dist::Cosine => unsafe {
let _ = cosine_tiled::launch_unchecked::<T, R>(
&client,
CubeCount::Static(grid_x, grid_y, grid_z),
CubeDim::new_2d(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y),
query_gpu.clone().into_tensor_arg(vec_size),
db_gpu.clone().into_tensor_arg(vec_size),
query_norms_gpu.as_ref().unwrap().clone().into_tensor_arg(1),
db_norms_gpu.as_ref().unwrap().clone().into_tensor_arg(1),
distances_gpu.clone().into_tensor_arg(1),
ScalarArg {
elem: db_start as u32,
},
ScalarArg { elem: n_db as u32 },
ScalarArg { elem: n_q as u32 },
ScalarArg {
elem: max_db_chunk as u32,
},
dim_lines,
);
},
}
let (extract_grid_x, extract_grid_y) = grid_2d((n_q as u32).div_ceil(WORKGROUP_SIZE_X));
unsafe {
let _ = extract_topk::launch_unchecked::<T, R>(
&client,
CubeCount::Static(extract_grid_x, extract_grid_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
distances_gpu.clone().into_tensor_arg(1),
topk_dists.clone().into_tensor_arg(1),
topk_indices.clone().into_tensor_arg(1),
ScalarArg {
elem: db_start as u32,
},
ScalarArg { elem: n_db as u32 },
);
}
}
let final_dists = topk_dists.read(&client);
let final_indices = topk_indices.read(&client);
for q in 0..n_q {
let start = q * k;
let end = start + k;
all_distances.push(final_dists[start..end].to_vec());
all_indices.push(
final_indices[start..end]
.iter()
.map(|&i| i as usize)
.collect(),
);
}
}
(all_indices, all_distances)
}
#[cube(launch_unchecked)]
pub fn reduce_ivf_topk_coalesced<F: Float>(
candidate_dists: &Tensor<F>,
candidate_indices: &Tensor<u32>,
candidates_per_query: &Tensor<u32>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
k_param: u32,
#[comptime] k: usize,
) {
let q_idx = (CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) as usize;
let tx = UNIT_POS_X;
let wg = WORKGROUP_SIZE_X;
if q_idx >= candidate_dists.shape(0) {
terminate!();
}
let count = candidates_per_query[q_idx];
let in_base = q_idx * candidate_dists.stride(0);
let mut local_dists = Array::<F>::new(k);
let mut local_indices = Array::<u32>::new(k);
let mut init_i = 0u32;
while init_i < k_param {
local_dists[init_i as usize] = F::new(f32::MAX);
local_indices[init_i as usize] = 0u32;
init_i += 1u32;
}
let mut col: u32 = tx;
while col < count {
let dist = candidate_dists[in_base + col as usize];
if dist < local_dists[(k_param - 1u32) as usize] {
let mut pos = 0u32;
let mut found: bool = false;
while pos < k_param && !found {
if dist < local_dists[pos as usize] {
found = true;
} else {
pos += 1u32;
}
}
if found {
let mut sh = k_param - 1u32;
while sh > pos {
local_dists[sh as usize] = local_dists[(sh - 1u32) as usize];
local_indices[sh as usize] = local_indices[(sh - 1u32) as usize];
sh -= 1u32;
}
local_dists[pos as usize] = dist;
local_indices[pos as usize] = candidate_indices[in_base + col as usize];
}
}
col += wg;
}
let mut s_dist = SharedMemory::<F>::new(32 * k);
let mut s_idx = SharedMemory::<u32>::new(32 * k);
let s_base = tx as usize * k_param as usize;
let mut cp = 0u32;
while cp < k_param {
s_dist[s_base + cp as usize] = local_dists[cp as usize];
s_idx[s_base + cp as usize] = local_indices[cp as usize];
cp += 1u32;
}
sync_cube();
if tx == 0u32 {
let mut t: u32 = 1u32;
while t < wg {
let t_base = t as usize * k_param as usize;
let mut i = 0u32;
let mut early_stop = false;
while i < k_param && !early_stop {
let cd = s_dist[t_base + i as usize];
let ci = s_idx[t_base + i as usize];
if cd >= s_dist[(k_param - 1u32) as usize] {
early_stop = true;
} else {
let mut pos = 0u32;
let mut found = false;
while pos < k_param && !found {
if cd < s_dist[pos as usize] {
found = true;
} else {
pos += 1u32;
}
}
if found {
let mut sh = k_param - 1u32;
while sh > pos {
s_dist[sh as usize] = s_dist[(sh - 1u32) as usize];
s_idx[sh as usize] = s_idx[(sh - 1u32) as usize];
sh -= 1u32;
}
s_dist[pos as usize] = cd;
s_idx[pos as usize] = ci;
}
i += 1u32;
}
}
t += 1u32;
}
let out_base = q_idx * k_param as usize;
let mut w = 0u32;
while w < k_param {
out_dists[out_base + w as usize] = s_dist[w as usize];
out_indices[out_base + w as usize] = s_idx[w as usize];
w += 1u32;
}
}
}
#[cube(launch_unchecked)]
pub fn euclidean_distances_gpu_chunk<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_chunk: &Tensor<Line<F>>,
distances: &mut Tensor<F>,
) {
let query_idx = ABSOLUTE_POS_Y as usize;
let db_idx = ABSOLUTE_POS_X as usize;
if query_idx < query_vectors.shape(0) && db_idx < db_chunk.shape(0) {
let dim_lines = query_vectors.shape(1);
let mut sum = F::new(0.0);
for i in 0..dim_lines {
let q_line = query_vectors[query_idx * query_vectors.stride(0) + i];
let d_line = db_chunk[db_idx * db_chunk.stride(0) + i];
let diff = q_line - d_line;
let sq = diff * diff;
sum += sq[0];
sum += sq[1];
sum += sq[2];
sum += sq[3];
}
distances[query_idx * distances.stride(0) + db_idx] = sum;
}
}
#[cube(launch_unchecked)]
pub fn cosine_distances_gpu_chunk<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_chunk: &Tensor<Line<F>>,
query_norms: &Tensor<F>,
db_norms: &Tensor<F>,
distances: &mut Tensor<F>,
) {
let query_idx = ABSOLUTE_POS_Y as usize;
let db_idx = ABSOLUTE_POS_X as usize;
if query_idx < query_vectors.shape(0) && db_idx < db_chunk.shape(0) {
let dim_lines = query_vectors.shape(1);
let mut dot = F::new(0.0);
for i in 0..dim_lines {
let q_line = query_vectors[query_idx * query_vectors.stride(0) + i];
let d_line = db_chunk[db_idx * db_chunk.stride(0) + i];
let prod = q_line * d_line;
dot += prod[0];
dot += prod[1];
dot += prod[2];
dot += prod[3];
}
let q_norm = query_norms[query_idx];
let d_norm = db_norms[db_idx];
distances[query_idx * distances.stride(0) + db_idx] =
F::new(1.0) - (dot / (q_norm * d_norm));
}
}
#[cube(launch_unchecked)]
pub fn compute_candidates_euclidean<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_vectors: &Tensor<Line<F>>,
active_indices: &Tensor<u32>,
write_offsets: &Tensor<u32>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
db_start: u32,
db_count: u32,
) {
let local_db_idx = ABSOLUTE_POS_X;
let active_q_idx = ABSOLUTE_POS_Y;
if active_q_idx >= active_indices.len() as u32 || local_db_idx >= db_count {
terminate!();
}
let real_q_idx = active_indices[active_q_idx as usize];
let write_pos = write_offsets[active_q_idx as usize] + local_db_idx;
let db_idx = db_start + local_db_idx;
let mut sum = F::new(0.0);
let dim_lines = query_vectors.shape(1) / LINE_SIZE as usize;
let q_offset = real_q_idx as usize * dim_lines;
let d_offset = db_idx as usize * dim_lines;
for i in 0..dim_lines {
let q_line = query_vectors[q_offset + i];
let d_line = db_vectors[d_offset + i];
let diff = q_line - d_line;
let sq = diff * diff;
sum += sq[0];
sum += sq[1];
sum += sq[2];
sum += sq[3];
}
let out_offset = real_q_idx as usize * out_dists.stride(0) + write_pos as usize;
out_dists[out_offset] = sum;
out_indices[out_offset] = db_idx;
}
#[cube(launch_unchecked)]
pub fn compute_candidates_cosine<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_vectors: &Tensor<Line<F>>,
query_norms: &Tensor<F>,
db_norms: &Tensor<F>,
active_indices: &Tensor<u32>,
write_offsets: &Tensor<u32>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
db_start: u32,
db_count: u32,
) {
let local_db_idx = ABSOLUTE_POS_X;
let active_q_idx = ABSOLUTE_POS_Y;
if active_q_idx >= active_indices.len() as u32 || local_db_idx >= db_count {
terminate!();
}
let real_q_idx = active_indices[active_q_idx as usize];
let write_pos = write_offsets[active_q_idx as usize] + local_db_idx;
let db_idx = db_start + local_db_idx;
let dim_lines = query_vectors.shape(1) / LINE_SIZE as usize;
let mut dot = F::new(0.0);
let q_offset = real_q_idx as usize * dim_lines;
let d_offset = db_idx as usize * dim_lines;
for i in 0..dim_lines {
let q_line = query_vectors[q_offset + i];
let d_line = db_vectors[d_offset + i];
let prod = q_line * d_line;
dot += prod[0];
dot += prod[1];
dot += prod[2];
dot += prod[3];
}
let q_norm = query_norms[real_q_idx as usize];
let d_norm = db_norms[db_idx as usize];
let out_offset = real_q_idx as usize * out_dists.stride(0) + write_pos as usize;
out_dists[out_offset] = F::new(1.0) - (dot / (q_norm * d_norm));
out_indices[out_offset] = db_idx;
}
#[cube(launch_unchecked)]
pub fn compute_ivf_mega_euclidean<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_vectors: &Tensor<Line<F>>,
task_q_idx: &Tensor<u32>,
task_db_start: &Tensor<u32>,
task_write_offset: &Tensor<u32>,
task_db_count: &Tensor<u32>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
) {
let local_db_idx = ABSOLUTE_POS_X;
let task_idx = (CUBE_POS_Z * CUBE_COUNT_Y + CUBE_POS_Y) * WORKGROUP_SIZE_Y + UNIT_POS_Y;
if task_idx >= task_q_idx.len() as u32 {
terminate!();
}
let db_count = task_db_count[task_idx as usize];
if local_db_idx >= db_count {
terminate!();
}
let q_idx = task_q_idx[task_idx as usize];
let db_start = task_db_start[task_idx as usize];
let write_offset = task_write_offset[task_idx as usize];
let real_db_idx = db_start + local_db_idx;
let write_pos = write_offset + local_db_idx;
let mut sum = F::new(0.0);
let dim_lines = query_vectors.shape(1) / LINE_SIZE as usize;
let q_offset = q_idx as usize * dim_lines;
let d_offset = real_db_idx as usize * dim_lines;
for i in 0..dim_lines {
let q_line = query_vectors[q_offset + i];
let d_line = db_vectors[d_offset + i];
let diff = q_line - d_line;
let sq = diff * diff;
sum += sq[0];
sum += sq[1];
sum += sq[2];
sum += sq[3];
}
let out_offset = q_idx as usize * out_dists.stride(0) + write_pos as usize;
out_dists[out_offset] = sum;
out_indices[out_offset] = real_db_idx;
}
#[cube(launch_unchecked)]
pub fn compute_ivf_mega_cosine<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_vectors: &Tensor<Line<F>>,
query_norms: &Tensor<F>,
db_norms: &Tensor<F>,
task_q_idx: &Tensor<u32>,
task_db_start: &Tensor<u32>,
task_write_offset: &Tensor<u32>,
task_db_count: &Tensor<u32>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
) {
let local_db_idx = ABSOLUTE_POS_X;
let task_idx = (CUBE_POS_Z * CUBE_COUNT_Y + CUBE_POS_Y) * WORKGROUP_SIZE_Y + UNIT_POS_Y;
if task_idx >= task_q_idx.len() as u32 {
terminate!();
}
let db_count = task_db_count[task_idx as usize];
if local_db_idx >= db_count {
terminate!();
}
let q_idx = task_q_idx[task_idx as usize];
let db_start = task_db_start[task_idx as usize];
let write_offset = task_write_offset[task_idx as usize];
let real_db_idx = db_start + local_db_idx;
let write_pos = write_offset + local_db_idx;
let mut dot = F::new(0.0);
let dim_lines = query_vectors.shape(1) / LINE_SIZE as usize;
let q_offset = q_idx as usize * dim_lines;
let d_offset = real_db_idx as usize * dim_lines;
for i in 0..dim_lines {
let q_line = query_vectors[q_offset + i];
let d_line = db_vectors[d_offset + i];
let prod = q_line * d_line;
dot += prod[0];
dot += prod[1];
dot += prod[2];
dot += prod[3];
}
let q_norm = query_norms[q_idx as usize];
let d_norm = db_norms[real_db_idx as usize];
let out_offset = q_idx as usize * out_dists.stride(0) + write_pos as usize;
out_dists[out_offset] = F::new(1.0) - (dot / (q_norm * d_norm));
out_indices[out_offset] = real_db_idx;
}
#[cube(launch_unchecked)]
pub fn reduce_ivf_topk<F: Float>(
candidate_dists: &Tensor<F>,
candidate_indices: &Tensor<u32>,
candidates_per_query: &Tensor<u32>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
) {
let q_idx = ((CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) * WORKGROUP_SIZE_X + UNIT_POS_X) as usize;
if q_idx >= candidate_dists.shape(0) {
terminate!();
}
let k = out_dists.shape(1);
let count = candidates_per_query[q_idx];
let in_offset = q_idx * candidate_dists.stride(0);
let out_offset = q_idx * out_dists.stride(0);
for i in 0..count {
let dist = candidate_dists[in_offset + i as usize];
let idx = candidate_indices[in_offset + i as usize];
if dist < out_dists[out_offset + k - 1] {
let mut insert_pos: usize = k - 1;
for j in 0..k {
if dist < out_dists[out_offset + j] && insert_pos == k - 1 {
insert_pos = j;
}
}
for j in 0..k - 1 {
let src = k - 2 - j;
let dst = k - 1 - j;
if src >= insert_pos {
out_dists[out_offset + dst] = out_dists[out_offset + src];
out_indices[out_offset + dst] = out_indices[out_offset + src];
}
}
out_dists[out_offset + insert_pos] = dist;
out_indices[out_offset + insert_pos] = idx;
}
}
}
#[cube(launch_unchecked)]
pub fn compute_ivf_mega_euclidean_cached<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_vectors: &Tensor<Line<F>>,
task_q_idx: &Tensor<u32>,
task_db_start: &Tensor<u32>,
task_write_offset: &Tensor<u32>,
task_db_count: &Tensor<u32>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
n_tasks: u32,
#[comptime] dim_lines: usize,
) {
let local_db_idx = ABSOLUTE_POS_X;
let task_idx = (CUBE_POS_Z * CUBE_COUNT_Y + CUBE_POS_Y) * WORKGROUP_SIZE_Y + UNIT_POS_Y;
let local_y = UNIT_POS_Y as usize;
let local_x = UNIT_POS_X as usize;
let dim_scalars = dim_lines * 4usize;
let wg_y = WORKGROUP_SIZE_Y as usize;
let mut s_q_idx = SharedMemory::<u32>::new(32usize);
let mut s_db_start = SharedMemory::<u32>::new(32usize);
let mut s_write_offset = SharedMemory::<u32>::new(32usize);
let mut s_db_count = SharedMemory::<u32>::new(32usize);
if local_x == 0usize {
if task_idx < n_tasks {
s_q_idx[local_y] = task_q_idx[task_idx as usize];
s_db_start[local_y] = task_db_start[task_idx as usize];
s_write_offset[local_y] = task_write_offset[task_idx as usize];
s_db_count[local_y] = task_db_count[task_idx as usize];
} else {
s_q_idx[local_y] = 0u32;
s_db_start[local_y] = 0u32;
s_write_offset[local_y] = 0u32;
s_db_count[local_y] = 0u32;
}
}
sync_cube();
let mut s_query = SharedMemory::<F>::new(32 * dim_scalars);
let thread_id = local_y * WORKGROUP_SIZE_X as usize + local_x;
let total_threads = WORKGROUP_SIZE_X as usize * wg_y;
let total_elems = wg_y * dim_scalars;
let mut load_idx = thread_id;
while load_idx < total_elems {
let q_local = load_idx / dim_scalars;
let elem = load_idx % dim_scalars;
let q_global = s_q_idx[q_local];
let line_idx = elem / 4usize;
let lane = elem % 4usize;
let line_val = query_vectors[q_global as usize * dim_lines + line_idx];
s_query[load_idx] = line_val[lane];
load_idx += total_threads;
}
sync_cube();
if task_idx >= n_tasks {
terminate!();
}
let db_count = s_db_count[local_y];
if local_db_idx >= db_count {
terminate!();
}
let real_db_idx = s_db_start[local_y] + local_db_idx;
let write_pos = s_write_offset[local_y] + local_db_idx;
let q_shared_base = local_y * dim_scalars;
let d_offset = real_db_idx as usize * dim_lines;
let mut sum = F::new(0.0);
for i in 0..dim_lines {
let d_line = db_vectors[d_offset + i];
let s_off = q_shared_base + i * 4usize;
let diff0 = s_query[s_off] - d_line[0];
let diff1 = s_query[s_off + 1usize] - d_line[1];
let diff2 = s_query[s_off + 2usize] - d_line[2];
let diff3 = s_query[s_off + 3usize] - d_line[3];
sum += diff0 * diff0 + diff1 * diff1 + diff2 * diff2 + diff3 * diff3;
}
let q_idx = s_q_idx[local_y];
let out_offset = q_idx as usize * out_dists.stride(0) + write_pos as usize;
out_dists[out_offset] = sum;
out_indices[out_offset] = real_db_idx;
}
#[cube(launch_unchecked)]
pub fn compute_ivf_mega_cosine_cached<F: Float>(
query_vectors: &Tensor<Line<F>>,
db_vectors: &Tensor<Line<F>>,
query_norms: &Tensor<F>,
db_norms: &Tensor<F>,
task_q_idx: &Tensor<u32>,
task_db_start: &Tensor<u32>,
task_write_offset: &Tensor<u32>,
task_db_count: &Tensor<u32>,
out_dists: &mut Tensor<F>,
out_indices: &mut Tensor<u32>,
n_tasks: u32,
#[comptime] dim_lines: usize,
) {
let local_db_idx = ABSOLUTE_POS_X;
let task_idx = (CUBE_POS_Z * CUBE_COUNT_Y + CUBE_POS_Y) * WORKGROUP_SIZE_Y + UNIT_POS_Y;
let local_y = UNIT_POS_Y as usize;
let local_x = UNIT_POS_X as usize;
let dim_scalars = dim_lines * 4usize;
let wg_y = WORKGROUP_SIZE_Y as usize;
let mut s_q_idx = SharedMemory::<u32>::new(32usize);
let mut s_db_start = SharedMemory::<u32>::new(32usize);
let mut s_write_offset = SharedMemory::<u32>::new(32usize);
let mut s_db_count = SharedMemory::<u32>::new(32usize);
let mut s_query_norms = SharedMemory::<F>::new(32usize);
if local_x == 0usize {
if task_idx < n_tasks {
let q = task_q_idx[task_idx as usize];
s_q_idx[local_y] = q;
s_db_start[local_y] = task_db_start[task_idx as usize];
s_write_offset[local_y] = task_write_offset[task_idx as usize];
s_db_count[local_y] = task_db_count[task_idx as usize];
s_query_norms[local_y] = query_norms[q as usize];
} else {
s_q_idx[local_y] = 0u32;
s_db_start[local_y] = 0u32;
s_write_offset[local_y] = 0u32;
s_db_count[local_y] = 0u32;
s_query_norms[local_y] = F::new(1.0);
}
}
sync_cube();
let mut s_query = SharedMemory::<F>::new(32 * dim_scalars);
let thread_id = local_y * WORKGROUP_SIZE_X as usize + local_x;
let total_threads = WORKGROUP_SIZE_X as usize * wg_y;
let total_elems = wg_y * dim_scalars;
let mut load_idx = thread_id;
while load_idx < total_elems {
let q_local = load_idx / dim_scalars;
let elem = load_idx % dim_scalars;
let q_global = s_q_idx[q_local];
let line_idx = elem / 4usize;
let lane = elem % 4usize;
let line_val = query_vectors[q_global as usize * dim_lines + line_idx];
s_query[load_idx] = line_val[lane];
load_idx += total_threads;
}
sync_cube();
if task_idx >= n_tasks {
terminate!();
}
let db_count = s_db_count[local_y];
if local_db_idx >= db_count {
terminate!();
}
let real_db_idx = s_db_start[local_y] + local_db_idx;
let write_pos = s_write_offset[local_y] + local_db_idx;
let q_shared_base = local_y * dim_scalars;
let d_offset = real_db_idx as usize * dim_lines;
let mut dot = F::new(0.0);
for i in 0..dim_lines {
let d_line = db_vectors[d_offset + i];
let s_off = q_shared_base + i * 4usize;
dot += s_query[s_off] * d_line[0]
+ s_query[s_off + 1usize] * d_line[1]
+ s_query[s_off + 2usize] * d_line[2]
+ s_query[s_off + 3usize] * d_line[3];
}
let q_norm = s_query_norms[local_y];
let d_norm = db_norms[real_db_idx as usize];
let q_idx = s_q_idx[local_y];
let out_offset = q_idx as usize * out_dists.stride(0) + write_pos as usize;
out_dists[out_offset] = F::new(1.0) - (dot / (q_norm * d_norm));
out_indices[out_offset] = real_db_idx;
}
#[cfg(test)]
#[cfg(feature = "gpu-tests")]
mod tests {
use super::*;
use cubecl::wgpu::{WgpuDevice, WgpuRuntime};
fn try_device() -> Option<WgpuDevice> {
Some(WgpuDevice::default())
}
fn cpu_euclidean_dists(
queries: &[f32],
db: &[f32],
nq: usize,
ndb: usize,
dim: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; nq * ndb];
for q in 0..nq {
for d in 0..ndb {
let mut sum = 0.0f32;
for j in 0..dim {
let diff = queries[q * dim + j] - db[d * dim + j];
sum += diff * diff;
}
out[q * ndb + d] = sum;
}
}
out
}
fn cpu_cosine_dists(
queries: &[f32],
db: &[f32],
q_norms: &[f32],
d_norms: &[f32],
nq: usize,
ndb: usize,
dim: usize,
) -> Vec<f32> {
let mut out = vec![0.0f32; nq * ndb];
for q in 0..nq {
for d in 0..ndb {
let mut dot = 0.0f32;
for j in 0..dim {
dot += queries[q * dim + j] * db[d * dim + j];
}
out[q * ndb + d] = 1.0 - dot / (q_norms[q] * d_norms[d]);
}
}
out
}
fn cpu_topk(
distances: &[f32],
nq: usize,
ndb: usize,
k: usize,
) -> (Vec<Vec<usize>>, Vec<Vec<f32>>) {
let mut indices = Vec::with_capacity(nq);
let mut dists = Vec::with_capacity(nq);
for q in 0..nq {
let row = &distances[q * ndb..(q + 1) * ndb];
let mut pairs: Vec<(f32, usize)> = row
.iter()
.copied()
.enumerate()
.map(|(i, d)| (d, i))
.collect();
pairs.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
indices.push(pairs.iter().take(k).map(|p| p.1).collect());
dists.push(pairs.iter().take(k).map(|p| p.0).collect());
}
(indices, dists)
}
fn l2_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[test]
fn test_pipeline_euclidean_dim8() {
let Some(device) = try_device() else { return };
let nq = 10usize;
let ndb = 50usize;
let dim = 8usize;
let k = 5usize;
let queries: Vec<f32> = (0..nq * dim)
.map(|i| ((i * 13 + 7) % 29) as f32 * 0.1)
.collect();
let db: Vec<f32> = (0..ndb * dim)
.map(|i| ((i * 17 + 3) % 31) as f32 * 0.1)
.collect();
let qb = BatchData::new(&queries, &[], nq);
let dbb = BatchData::new(&db, &[], ndb);
let (_, gpu_dist) =
query_batch_gpu::<f32, WgpuRuntime>(k, &qb, &dbb, dim, &Dist::Euclidean, device, false);
let cpu_d = cpu_euclidean_dists(&queries, &db, nq, ndb, dim);
let (_, cpu_dist) = cpu_topk(&cpu_d, nq, ndb, k);
for q in 0..nq {
for i in 0..k {
assert!(
(gpu_dist[q][i] - cpu_dist[q][i]).abs() < 1e-3,
"Query {} rank {}: gpu dist {} != cpu dist {}",
q,
i,
gpu_dist[q][i],
cpu_dist[q][i]
);
}
}
}
#[test]
fn test_pipeline_euclidean_dim32() {
let Some(device) = try_device() else { return };
let nq = 8usize;
let ndb = 40usize;
let dim = 32usize;
let k = 5usize;
let queries: Vec<f32> = (0..nq * dim)
.map(|i| ((i * 13 + 7) % 29) as f32 * 0.1)
.collect();
let db: Vec<f32> = (0..ndb * dim)
.map(|i| ((i * 17 + 3) % 31) as f32 * 0.1)
.collect();
let qb = BatchData::new(&queries, &[], nq);
let dbb = BatchData::new(&db, &[], ndb);
let (_, gpu_dist) =
query_batch_gpu::<f32, WgpuRuntime>(k, &qb, &dbb, dim, &Dist::Euclidean, device, false);
let cpu_d = cpu_euclidean_dists(&queries, &db, nq, ndb, dim);
let (_, cpu_dist) = cpu_topk(&cpu_d, nq, ndb, k);
for q in 0..nq {
for i in 0..k {
assert!(
(gpu_dist[q][i] - cpu_dist[q][i]).abs() < 1e-2,
"dim=32 query {} rank {}: gpu dist {} != cpu dist {}",
q,
i,
gpu_dist[q][i],
cpu_dist[q][i]
);
}
}
}
#[test]
fn test_pipeline_cosine_dim32() {
let Some(device) = try_device() else { return };
let nq = 4usize;
let ndb = 20usize;
let dim = 32usize;
let k = 3usize;
let queries: Vec<f32> = (0..nq * dim)
.map(|i| ((i * 7 + 1) % 11) as f32 + 0.5)
.collect();
let db: Vec<f32> = (0..ndb * dim)
.map(|i| ((i * 13 + 3) % 17) as f32 + 0.5)
.collect();
let q_norms: Vec<f32> = (0..nq)
.map(|q| l2_norm(&queries[q * dim..(q + 1) * dim]))
.collect();
let d_norms: Vec<f32> = (0..ndb)
.map(|d| l2_norm(&db[d * dim..(d + 1) * dim]))
.collect();
let qb = BatchData::new(&queries, &q_norms, nq);
let dbb = BatchData::new(&db, &d_norms, ndb);
let (_, gpu_dist) =
query_batch_gpu::<f32, WgpuRuntime>(k, &qb, &dbb, dim, &Dist::Cosine, device, false);
let cpu_d = cpu_cosine_dists(&queries, &db, &q_norms, &d_norms, nq, ndb, dim);
let (_, cpu_dist) = cpu_topk(&cpu_d, nq, ndb, k);
for q in 0..nq {
for i in 0..k {
assert!(
(gpu_dist[q][i] - cpu_dist[q][i]).abs() < 1e-3,
"Cosine query {} rank {}: gpu dist {} != cpu dist {}",
q,
i,
gpu_dist[q][i],
cpu_dist[q][i]
);
}
}
}
#[test]
fn test_self_query_finds_self() {
let Some(device) = try_device() else { return };
let n = 64usize;
let dim = 32usize;
let data: Vec<f32> = (0..n * dim).map(|i| (i as f32) * 0.3 + 0.1).collect();
let batch = BatchData::new(&data, &[], n);
let (indices, distances) = query_batch_gpu::<f32, WgpuRuntime>(
3,
&batch,
&batch,
dim,
&Dist::Euclidean,
device,
false,
);
for q in 0..n {
assert_eq!(
indices[q][0], q,
"Query {} nearest should be itself, got {}",
q, indices[q][0]
);
assert!(
distances[q][0] < 1e-4,
"Self-distance for query {}: got {}",
q,
distances[q][0]
);
}
}
#[test]
fn test_output_is_sorted() {
let Some(device) = try_device() else { return };
let nq = 16usize;
let ndb = 64usize;
let dim = 32usize;
let k = 5usize;
let queries: Vec<f32> = (0..nq * dim).map(|i| ((i * 7 + 3) % 13) as f32).collect();
let db: Vec<f32> = (0..ndb * dim).map(|i| ((i * 11 + 5) % 17) as f32).collect();
let qb = BatchData::new(&queries, &[], nq);
let dbb = BatchData::new(&db, &[], ndb);
let (indices, distances) =
query_batch_gpu::<f32, WgpuRuntime>(k, &qb, &dbb, dim, &Dist::Euclidean, device, false);
for q in 0..nq {
for i in 1..k {
assert!(
distances[q][i] >= distances[q][i - 1],
"Query {}: not sorted at {}: {} < {}",
q,
i,
distances[q][i],
distances[q][i - 1]
);
}
let unique: std::collections::HashSet<usize> = indices[q].iter().copied().collect();
assert_eq!(unique.len(), k, "Query {}: duplicate indices", q);
}
}
#[test]
fn test_k_equals_one() {
let Some(device) = try_device() else { return };
let data: Vec<f32> = vec![
0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 10.0, 10.0, 10.0, 10.0,
];
let query: Vec<f32> = vec![0.9, 0.0, 0.0, 0.0];
let qb = BatchData::new(&query, &[], 1);
let dbb = BatchData::new(&data, &[], 4);
let (idx, dist) =
query_batch_gpu::<f32, WgpuRuntime>(1, &qb, &dbb, 4, &Dist::Euclidean, device, false);
assert_eq!(idx[0][0], 1);
assert!((dist[0][0] - 0.01).abs() < 1e-3);
}
#[test]
fn test_planted_nearest() {
let Some(device) = try_device() else { return };
let dim = 32usize;
let k = 3usize;
let nq = 2usize;
let ndb = 200usize;
let mut db: Vec<f32> = (0..ndb * dim).map(|i| ((i * 17 + 3) % 31) as f32).collect();
let target = vec![100.0f32; dim];
db[73 * dim..74 * dim].copy_from_slice(&target);
let mut queries = target.clone();
queries[0] += 0.001;
queries.extend_from_slice(&vec![0.0f32; dim]);
let qb = BatchData::new(&queries, &[], nq);
let dbb = BatchData::new(&db, &[], ndb);
let (idx, dist) =
query_batch_gpu::<f32, WgpuRuntime>(k, &qb, &dbb, dim, &Dist::Euclidean, device, false);
assert_eq!(idx[0][0], 73, "Should find planted nearest at index 73");
assert!(dist[0][0] < 0.01);
let cpu_d = cpu_euclidean_dists(&queries, &db, nq, ndb, dim);
let (cpu_idx, _) = cpu_topk(&cpu_d, nq, ndb, k);
for q in 0..nq {
assert_eq!(idx[q], cpu_idx[q], "Query {} mismatch vs CPU", q);
}
}
#[test]
fn test_single_query_single_db() {
let Some(device) = try_device() else { return };
let dim = 4usize;
let query = vec![1.0f32, 2.0, 3.0, 4.0];
let db = vec![5.0f32, 6.0, 7.0, 8.0];
let qb = BatchData::new(&query, &[], 1);
let dbb = BatchData::new(&db, &[], 1);
let (idx, dist) =
query_batch_gpu::<f32, WgpuRuntime>(1, &qb, &dbb, dim, &Dist::Euclidean, device, false);
assert_eq!(idx[0][0], 0);
assert!((dist[0][0] - 64.0).abs() < 1e-3);
}
fn run_serial_reduce(
candidate_dists: &[f32],
candidate_indices: &[u32],
candidates_per_query: &[u32],
n_queries: usize,
max_candidates: usize,
k: usize,
device: &WgpuDevice,
) -> (Vec<f32>, Vec<u32>) {
let client = WgpuRuntime::client(device);
let cd_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(
candidate_dists,
vec![n_queries, max_candidates],
&client,
);
let ci_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(
candidate_indices,
vec![n_queries, max_candidates],
&client,
);
let cpq_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(
candidates_per_query,
vec![n_queries],
&client,
);
let topk_d = GpuTensor::<WgpuRuntime, f32>::empty(vec![n_queries, k], &client);
let topk_i = GpuTensor::<WgpuRuntime, u32>::empty(vec![n_queries, k], &client);
let init_gx = (k as u32).div_ceil(WORKGROUP_SIZE_X);
let (init_gy, init_gz) = grid_2d((n_queries as u32).div_ceil(WORKGROUP_SIZE_Y));
unsafe {
let _ = init_topk::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(init_gx, init_gy, init_gz),
CubeDim::new_2d(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y),
topk_d.clone().into_tensor_arg(1),
topk_i.clone().into_tensor_arg(1),
);
}
let (rgx, rgy) = grid_2d((n_queries as u32).div_ceil(WORKGROUP_SIZE_X));
unsafe {
let _ = reduce_ivf_topk::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(rgx, rgy, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
cd_gpu.into_tensor_arg(1),
ci_gpu.into_tensor_arg(1),
cpq_gpu.into_tensor_arg(1),
topk_d.clone().into_tensor_arg(1),
topk_i.clone().into_tensor_arg(1),
);
}
(topk_d.read(&client), topk_i.read(&client))
}
fn run_coalesced_reduce(
candidate_dists: &[f32],
candidate_indices: &[u32],
candidates_per_query: &[u32],
n_queries: usize,
max_candidates: usize,
k: usize,
device: &WgpuDevice,
) -> (Vec<f32>, Vec<u32>) {
let client = WgpuRuntime::client(device);
let cd_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(
candidate_dists,
vec![n_queries, max_candidates],
&client,
);
let ci_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(
candidate_indices,
vec![n_queries, max_candidates],
&client,
);
let cpq_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(
candidates_per_query,
vec![n_queries],
&client,
);
let topk_d = GpuTensor::<WgpuRuntime, f32>::empty(vec![n_queries, k], &client);
let topk_i = GpuTensor::<WgpuRuntime, u32>::empty(vec![n_queries, k], &client);
let (rgx, rgy) = grid_2d(n_queries as u32);
unsafe {
let _ = reduce_ivf_topk_coalesced::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(rgx, rgy, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
cd_gpu.into_tensor_arg(1),
ci_gpu.into_tensor_arg(1),
cpq_gpu.into_tensor_arg(1),
topk_d.clone().into_tensor_arg(1),
topk_i.clone().into_tensor_arg(1),
ScalarArg { elem: k as u32 },
k,
);
}
(topk_d.read(&client), topk_i.read(&client))
}
#[test]
fn test_coalesced_reduce_trivial() {
let Some(device) = try_device() else { return };
let k = 3usize;
let n_queries = 1usize;
let max_candidates = 5usize;
let candidate_dists = vec![50.0f32, 40.0, 30.0, 20.0, 10.0];
let candidate_indices = vec![100u32, 101, 102, 103, 104];
let candidates_per_query = vec![5u32];
let (dists, indices) = run_coalesced_reduce(
&candidate_dists,
&candidate_indices,
&candidates_per_query,
n_queries,
max_candidates,
k,
&device,
);
println!("Trivial: dists={:?}, indices={:?}", dists, indices);
assert_eq!(indices, vec![104, 103, 102], "Wrong indices");
assert_eq!(dists, vec![10.0, 20.0, 30.0], "Wrong distances");
}
#[test]
fn test_coalesced_reduce_single_stride() {
let Some(device) = try_device() else { return };
let k = 3usize;
let n_queries = 1usize;
let max_candidates = 20usize;
let candidate_dists: Vec<f32> = (0..20).rev().map(|i| (i + 1) as f32).collect();
let candidate_indices: Vec<u32> = (0..20).map(|i| i as u32 + 200).collect();
let candidates_per_query = vec![20u32];
let (dists, indices) = run_coalesced_reduce(
&candidate_dists,
&candidate_indices,
&candidates_per_query,
n_queries,
max_candidates,
k,
&device,
);
println!("Single stride: dists={:?}, indices={:?}", dists, indices);
assert_eq!(dists, vec![1.0, 2.0, 3.0]);
assert_eq!(indices, vec![219, 218, 217]);
}
#[test]
fn test_coalesced_reduce_needs_merge() {
let Some(device) = try_device() else { return };
let k = 5usize;
let n_queries = 1usize;
let max_candidates = 200usize;
let mut candidate_dists = vec![999.0f32; max_candidates];
let candidate_indices: Vec<u32> = (0..max_candidates as u32).collect();
candidate_dists[3] = 1.0; candidate_dists[35] = 2.0; candidate_dists[64] = 3.0; candidate_dists[100] = 4.0; candidate_dists[129] = 5.0;
let candidates_per_query = vec![max_candidates as u32];
let (dists, indices) = run_coalesced_reduce(
&candidate_dists,
&candidate_indices,
&candidates_per_query,
n_queries,
max_candidates,
k,
&device,
);
println!("Merge test: dists={:?}, indices={:?}", dists, indices);
assert_eq!(dists, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(indices, vec![3, 35, 64, 100, 129]);
}
#[test]
fn test_coalesced_vs_serial_random() {
let Some(device) = try_device() else { return };
let k = 10usize;
let n_queries = 8usize;
let max_candidates = 500usize;
let mut candidate_dists = vec![0.0f32; n_queries * max_candidates];
let mut candidate_indices = vec![0u32; n_queries * max_candidates];
for q in 0..n_queries {
for c in 0..max_candidates {
let idx = q * max_candidates + c;
candidate_dists[idx] = ((idx * 17 + 31) % 9973) as f32 * 0.1;
candidate_indices[idx] = (q * 10000 + c) as u32;
}
}
let candidates_per_query = vec![max_candidates as u32; n_queries];
let (serial_d, serial_i) = run_serial_reduce(
&candidate_dists,
&candidate_indices,
&candidates_per_query,
n_queries,
max_candidates,
k,
&device,
);
let (coal_d, coal_i) = run_coalesced_reduce(
&candidate_dists,
&candidate_indices,
&candidates_per_query,
n_queries,
max_candidates,
k,
&device,
);
for q in 0..n_queries {
let s = q * k;
let e = s + k;
let sd = &serial_d[s..e];
let cd = &coal_d[s..e];
let si = &serial_i[s..e];
let ci = &coal_i[s..e];
for i in 0..k {
assert!(
(sd[i] - cd[i]).abs() < 1e-4,
"Query {} rank {}: serial dist {} != coalesced dist {}",
q,
i,
sd[i],
cd[i]
);
assert_eq!(
si[i], ci[i],
"Query {} rank {}: serial idx {} != coalesced idx {}",
q, i, si[i], ci[i]
);
}
}
}
#[test]
fn test_coalesced_reduce_multi_query() {
let Some(device) = try_device() else { return };
let k = 3usize;
let n_queries = 4usize;
let max_candidates = 10usize;
let mut candidate_dists = vec![f32::MAX; n_queries * max_candidates];
let mut candidate_indices = vec![0u32; n_queries * max_candidates];
for q in 0..n_queries {
let base = q * max_candidates;
for c in 0..max_candidates {
candidate_dists[base + c] = 100.0 + c as f32;
candidate_indices[base + c] = (q * 1000 + c) as u32;
}
candidate_dists[base + q + 1] = 0.5 + q as f32 * 0.1;
}
let candidates_per_query = vec![max_candidates as u32; n_queries];
let (dists, indices) = run_coalesced_reduce(
&candidate_dists,
&candidate_indices,
&candidates_per_query,
n_queries,
max_candidates,
k,
&device,
);
for q in 0..n_queries {
let s = q * k;
let best_dist = dists[s];
let best_idx = indices[s];
let expected_dist = 0.5 + q as f32 * 0.1;
let expected_idx = (q * 1000 + q + 1) as u32;
assert!(
(best_dist - expected_dist).abs() < 1e-4,
"Query {}: best dist {} != expected {}",
q,
best_dist,
expected_dist
);
assert_eq!(
best_idx, expected_idx,
"Query {}: best idx {} != expected {}",
q, best_idx, expected_idx
);
}
}
#[test]
fn test_coalesced_reduce_fewer_than_k() {
let Some(device) = try_device() else { return };
let k = 5usize;
let n_queries = 1usize;
let max_candidates = 10usize;
let mut candidate_dists = vec![f32::MAX; max_candidates];
let mut candidate_indices = vec![0u32; max_candidates];
candidate_dists[0] = 5.0;
candidate_dists[1] = 3.0;
candidate_indices[0] = 42;
candidate_indices[1] = 99;
let candidates_per_query = vec![2u32];
let (dists, indices) = run_coalesced_reduce(
&candidate_dists,
&candidate_indices,
&candidates_per_query,
n_queries,
max_candidates,
k,
&device,
);
println!("Fewer than k: dists={:?}, indices={:?}", dists, indices);
assert!((dists[0] - 3.0).abs() < 1e-4, "First should be 3.0");
assert!((dists[1] - 5.0).abs() < 1e-4, "Second should be 5.0");
assert_eq!(indices[0], 99);
assert_eq!(indices[1], 42);
assert!(dists[2] >= f32::MAX / 2.0, "Slot 2 should be sentinel");
}
#[allow(clippy::too_many_arguments)]
fn run_mega_euclidean(
queries: &[f32],
db: &[f32],
tasks: &[(u32, u32, u32, u32)],
n_queries: usize,
n_db: usize,
dim: usize,
max_candidates: usize,
device: &WgpuDevice,
use_cached: bool,
) -> (Vec<f32>, Vec<u32>) {
let client = WgpuRuntime::client(device);
let vec_size = LINE_SIZE as usize;
let dim_lines = dim / vec_size;
let n_tasks = tasks.len();
let q_gpu =
GpuTensor::<WgpuRuntime, f32>::from_slice(queries, vec![n_queries, dim], &client);
let db_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(db, vec![n_db, dim], &client);
let task_q: Vec<u32> = tasks.iter().map(|t| t.0).collect();
let task_db_s: Vec<u32> = tasks.iter().map(|t| t.1).collect();
let task_wo: Vec<u32> = tasks.iter().map(|t| t.2).collect();
let task_dc: Vec<u32> = tasks.iter().map(|t| t.3).collect();
let tq_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(&task_q, vec![n_tasks], &client);
let tds_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(&task_db_s, vec![n_tasks], &client);
let two_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(&task_wo, vec![n_tasks], &client);
let tdc_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(&task_dc, vec![n_tasks], &client);
let out_d = GpuTensor::<WgpuRuntime, f32>::empty(vec![n_queries, max_candidates], &client);
let out_i = GpuTensor::<WgpuRuntime, u32>::empty(vec![n_queries, max_candidates], &client);
let max_db_count = tasks.iter().map(|t| t.3).max().unwrap_or(0);
let gx = max_db_count.div_ceil(WORKGROUP_SIZE_X).max(1);
let (gy, gz) = grid_2d((n_tasks as u32).div_ceil(WORKGROUP_SIZE_Y));
if use_cached {
unsafe {
let _ = compute_ivf_mega_euclidean_cached::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(gx, gy, gz),
CubeDim::new_2d(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y),
q_gpu.into_tensor_arg(vec_size),
db_gpu.into_tensor_arg(vec_size),
tq_gpu.into_tensor_arg(1),
tds_gpu.into_tensor_arg(1),
two_gpu.into_tensor_arg(1),
tdc_gpu.into_tensor_arg(1),
out_d.clone().into_tensor_arg(1),
out_i.clone().into_tensor_arg(1),
ScalarArg {
elem: n_tasks as u32,
},
dim_lines,
);
}
} else {
unsafe {
let _ = compute_ivf_mega_euclidean::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(gx, gy, gz),
CubeDim::new_2d(WORKGROUP_SIZE_X, WORKGROUP_SIZE_Y),
q_gpu.into_tensor_arg(vec_size),
db_gpu.into_tensor_arg(vec_size),
tq_gpu.into_tensor_arg(1),
tds_gpu.into_tensor_arg(1),
two_gpu.into_tensor_arg(1),
tdc_gpu.into_tensor_arg(1),
out_d.clone().into_tensor_arg(1),
out_i.clone().into_tensor_arg(1),
);
}
}
(out_d.read(&client), out_i.read(&client))
}
fn cpu_sq_euclidean(queries: &[f32], db: &[f32], q: usize, d: usize, dim: usize) -> f32 {
let mut sum = 0.0f32;
for j in 0..dim {
let diff = queries[q * dim + j] - db[d * dim + j];
sum += diff * diff;
}
sum
}
#[test]
fn test_mega_cached_known_answer() {
let Some(device) = try_device() else { return };
let dim = 4usize;
let n_queries = 2usize;
let n_db = 5usize;
let max_candidates = 5usize;
let queries = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let db = vec![
1.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 2.0, 0.0, 0.0, 0.0,
0.0, 0.0, 3.0,
];
let tasks = vec![
(0u32, 0u32, 0u32, 3u32), (0u32, 3u32, 3u32, 2u32), (1u32, 0u32, 0u32, 3u32), ];
let (dists, indices) = run_mega_euclidean(
&queries,
&db,
&tasks,
n_queries,
n_db,
dim,
max_candidates,
&device,
true,
);
for c in 0..3 {
let expected = cpu_sq_euclidean(&queries, &db, 0, c, dim);
let got = dists[c];
assert!(
(got - expected).abs() < 1e-4,
"q0 vs db{}: got {} expected {}",
c,
got,
expected,
);
assert_eq!(indices[c], c as u32);
}
for c in 0..2 {
let db_idx = 3 + c;
let expected = cpu_sq_euclidean(&queries, &db, 0, db_idx, dim);
let got = dists[3 + c];
assert!(
(got - expected).abs() < 1e-4,
"q0 vs db{}: got {} expected {}",
db_idx,
got,
expected,
);
assert_eq!(indices[3 + c], db_idx as u32);
}
for c in 0..3 {
let expected = cpu_sq_euclidean(&queries, &db, 1, c, dim);
let got = dists[max_candidates + c];
assert!(
(got - expected).abs() < 1e-4,
"q1 vs db{}: got {} expected {}",
c,
got,
expected,
);
}
println!("Known-answer: PASSED");
}
#[test]
fn test_mega_cached_known_answer_2() {
let Some(device) = try_device() else { return };
let dim = 4usize;
let n_queries = 2usize;
let n_db = 5usize;
let max_candidates = 5usize;
let queries = vec![1.0f32, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0];
let db = vec![
1.0, 1.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 2.0, 0.0, 0.0, 0.0,
0.0, 0.0, 3.0,
];
let tasks = vec![
(0u32, 0u32, 0u32, 3u32), (0u32, 3u32, 3u32, 2u32), (1u32, 0u32, 0u32, 3u32), ];
let (dists, indices) = run_mega_euclidean(
&queries,
&db,
&tasks,
n_queries,
n_db,
dim,
max_candidates,
&device,
true,
);
for c in 0..3 {
let expected = cpu_sq_euclidean(&queries, &db, 0, c, dim);
let got = dists[c];
assert!(
(got - expected).abs() < 1e-4,
"q0 vs db{}: got {} expected {}",
c,
got,
expected,
);
assert_eq!(indices[c], c as u32);
}
for c in 0..2 {
let db_idx = 3 + c;
let expected = cpu_sq_euclidean(&queries, &db, 0, db_idx, dim);
let got = dists[3 + c];
assert!(
(got - expected).abs() < 1e-4,
"q0 vs db{}: got {} expected {}",
db_idx,
got,
expected,
);
assert_eq!(indices[3 + c], db_idx as u32);
}
for c in 0..3 {
let expected = cpu_sq_euclidean(&queries, &db, 1, c, dim);
let got = dists[max_candidates + c];
assert!(
(got - expected).abs() < 1e-4,
"q1 vs db{}: got {} expected {}",
c,
got,
expected,
);
}
println!("Known-answer: PASSED");
}
}