#![allow(missing_docs)]
use cubecl::frontend::{Atomic, CubePrimitive, Float, SharedMemory};
use cubecl::prelude::*;
use faer::{MatRef, RowRef};
use fixedbitset::FixedBitSet;
use rayon::prelude::*;
use std::time::Instant;
use std::{cell::RefCell, cmp::Reverse, collections::BinaryHeap};
use thousands::*;
use crate::cpu::nndescent::*;
use crate::cpu::vamana::compute_medoid;
use crate::gpu::cagra_gpu_search::*;
use crate::gpu::forest_gpu::*;
use crate::gpu::tensor::*;
use crate::gpu::*;
use crate::prelude::*;
use crate::utils::*;
pub const MAX_PROPOSALS: usize = 128;
const DEFAULT_MAX_ITERS: usize = 15;
const DEFAULT_DELTA: f32 = 0.001;
const DEFAULT_RHO: f32 = 0.5;
#[cube]
fn xorshift(state: u32) -> u32 {
let mut x = state;
x ^= x << 13;
x ^= x >> 17;
x ^= x << 5;
x
}
#[cube]
fn entry_hash(node: u32, entry: u32, seed: u32) -> u32 {
xorshift(node ^ (entry * 2654435769u32) ^ seed)
}
#[cube]
fn dist_sq_euclidean<F: Float + CubePrimitive>(
vectors: &Tensor<Line<F>>,
a: u32,
b: u32,
#[comptime] dim_lines: usize,
) -> F {
let off_a = a as usize * dim_lines;
let off_b = b as usize * dim_lines;
let mut sum = F::new(0.0);
for i in 0..dim_lines {
let va = vectors[off_a + i];
let vb = vectors[off_b + i];
let diff = va - vb;
let sq = diff * diff;
sum += sq[0];
sum += sq[1];
sum += sq[2];
sum += sq[3];
}
sum
}
#[cube]
fn dist_cosine<F: Float>(
vectors: &Tensor<Line<F>>,
norms: &Tensor<F>,
a: u32,
b: u32,
#[comptime] dim_lines: usize,
) -> F {
let off_a = a as usize * dim_lines;
let off_b = b as usize * dim_lines;
let mut dot = F::new(0.0);
for i in 0..dim_lines {
let va = vectors[off_a + i];
let vb = vectors[off_b + i];
let prod = va * vb;
dot += prod[0];
dot += prod[1];
dot += prod[2];
dot += prod[3];
}
F::new(1.0) - dot / (norms[a as usize] * norms[b as usize])
}
#[cube(launch_unchecked)]
fn init_random_graph<F: Float>(
vectors: &Tensor<Line<F>>,
norms: &Tensor<F>,
graph_idx: &mut Tensor<u32>,
graph_dist: &mut Tensor<F>,
n: u32,
seed: u32,
#[comptime] use_cosine: bool,
#[comptime] dim_lines: usize,
) {
let node = (CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) * WORKGROUP_SIZE_X + UNIT_POS_X;
if node >= n {
terminate!();
}
let k = graph_idx.shape(1);
let is_new_bit = 1u32 << 31;
let base = node as usize * k;
let mut rng = xorshift(node ^ seed ^ 0xDEADBEEFu32);
for slot in 0..k {
rng = xorshift(rng);
let mut pid = rng % n;
if pid == node {
pid = (pid + 1u32) % n;
}
let dist = if use_cosine {
dist_cosine(vectors, norms, node, pid, dim_lines)
} else {
dist_sq_euclidean(vectors, node, pid, dim_lines)
};
let mut insert_pos = slot;
for j in 0..slot {
if dist < graph_dist[base + j] && insert_pos == slot {
insert_pos = j;
}
}
for j in 0..slot {
let src = slot - 1 - j;
let dst = slot - j;
if src >= insert_pos {
graph_idx[base + dst] = graph_idx[base + src];
graph_dist[base + dst] = graph_dist[base + src];
}
}
graph_idx[base + insert_pos] = pid | is_new_bit;
graph_dist[base + insert_pos] = dist;
}
}
#[cube(launch_unchecked)]
pub fn reset_proposals(prop_count: &mut Tensor<u32>, update_counter: &mut Tensor<u32>, n: u32) {
let idx = (CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) * WORKGROUP_SIZE_X + UNIT_POS_X;
if idx < n {
prop_count[idx as usize] = 0u32;
}
if idx == 0u32 {
update_counter[0usize] = 0u32;
}
}
#[cube(launch_unchecked)]
pub fn build_reverse_candidates(
graph_idx: &Tensor<u32>,
reverse_idx: &mut Tensor<u32>,
reverse_count: &Tensor<Atomic<u32>>,
n: u32,
#[comptime] build_k: u32,
) {
let node = (CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) * WORKGROUP_SIZE_X + UNIT_POS_X;
if node >= n {
terminate!();
}
let pid_mask = 0x7FFFFFFFu32;
let base = node as usize * build_k as usize;
let mut i = 0usize;
while i < build_k as usize {
let target_raw = graph_idx[base + i];
let target = target_raw & pid_mask;
if target < n && target != node {
let pos = reverse_count[target as usize].fetch_add(1u32);
if pos < build_k {
let rev_base = target as usize * build_k as usize;
let is_new_bit = target_raw & (1u32 << 31);
reverse_idx[rev_base + pos as usize] = node | is_new_bit;
}
}
i += 1usize;
}
}
#[cube(launch_unchecked)]
pub fn local_join_shared<F: Float>(
vectors: &Tensor<Line<F>>,
norms: &Tensor<F>,
graph_idx: &Tensor<u32>,
graph_dist: &Tensor<F>,
reverse_idx: &Tensor<u32>,
reverse_count: &Tensor<u32>,
prop_idx: &mut Tensor<u32>,
prop_dist: &mut Tensor<F>,
prop_count: &Tensor<Atomic<u32>>,
n: u32,
rho_thresh: u32,
iter_seed: u32,
#[comptime] max_proposals: u32,
#[comptime] use_cosine: bool,
#[comptime] dim_lines: usize,
#[comptime] build_k: usize,
) {
let node = CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X;
if node >= n {
terminate!();
}
let tx = UNIT_POS_X;
let k = graph_idx.shape(1usize) as u32;
let pid_mask = 0x7FFFFFFFu32;
let is_new_bit = 1u32 << 31;
let max_cands_comp = build_k * 2usize;
let dim_scalars = dim_lines * 4usize;
let mut shared_vecs = SharedMemory::<F>::new(max_cands_comp * dim_scalars);
let mut shared_pids = SharedMemory::<u32>::new(max_cands_comp);
let mut shared_is_new = SharedMemory::<u32>::new(max_cands_comp);
let mut shared_norms = SharedMemory::<F>::new(max_cands_comp);
let mut shared_compact = SharedMemory::<u32>::new(2usize);
let mut shared_rev_count = SharedMemory::<u32>::new(1usize);
if tx == 0u32 {
let rc = reverse_count[node as usize];
shared_rev_count[0usize] = if rc > k { k } else { rc };
}
sync_cube();
let rev_k = shared_rev_count[0usize];
let raw_total = k + rev_k;
let mut i_load = tx;
while i_load < raw_total {
let entry = if i_load < k {
graph_idx[(node * k + i_load) as usize]
} else {
reverse_idx[(node * k + i_load - k) as usize]
};
shared_pids[i_load as usize] = entry & pid_mask;
shared_is_new[i_load as usize] = if entry >= is_new_bit {
#[allow(clippy::useless_conversion)]
1u32.into()
} else {
#[allow(clippy::useless_conversion)]
0u32.into()
};
i_load += WORKGROUP_SIZE_X;
}
sync_cube();
if tx == 0u32 {
let mut write = 0u32;
let mut has_new = 0u32;
let mut read = 0u32;
while read < raw_total {
let hash = entry_hash(node, read, iter_seed);
if (hash & 0xFFFFu32) < rho_thresh {
shared_pids[write as usize] = shared_pids[read as usize];
shared_is_new[write as usize] = shared_is_new[read as usize];
if shared_is_new[read as usize] != 0u32 {
has_new = 1u32;
}
write += 1u32;
}
read += 1u32;
}
shared_compact[0usize] = write;
shared_compact[1usize] = has_new;
}
sync_cube();
let total_cands = shared_compact[0usize];
let has_new = shared_compact[1usize];
if total_cands < 2u32 || has_new == 0u32 {
terminate!();
}
if use_cosine {
let mut i_norm = tx;
while i_norm < total_cands {
shared_norms[i_norm as usize] = norms[shared_pids[i_norm as usize] as usize];
i_norm += WORKGROUP_SIZE_X;
}
sync_cube();
}
let total_scalars = total_cands as usize * dim_scalars;
let mut idx_load = tx as usize;
while idx_load < total_scalars {
let n_idx = idx_load / dim_scalars;
let s_idx = idx_load % dim_scalars;
let line_idx = s_idx / 4usize;
let lane = s_idx % 4usize;
let pid = shared_pids[n_idx];
if pid < n {
let vec_offset = pid as usize * dim_lines + line_idx;
let line_val = vectors[vec_offset];
shared_vecs[idx_load] = line_val[lane];
}
idx_load += WORKGROUP_SIZE_X as usize;
}
sync_cube();
let num_pairs = (total_cands * (total_cands - 1u32)) / 2u32;
let mut pair_idx = tx as usize;
while pair_idx < num_pairs as usize {
let mut rem = pair_idx;
let mut i = 0usize;
let mut step = total_cands as usize - 1usize;
while rem >= step {
rem -= step;
i += 1usize;
step = total_cands as usize - 1usize - i;
}
let j = i + 1usize + rem;
let is_new_i = shared_is_new[i] != 0u32;
let is_new_j = shared_is_new[j] != 0u32;
let pid_i = shared_pids[i];
let pid_j = shared_pids[j];
if (is_new_i || is_new_j) && pid_i != pid_j {
let mut sum = F::new(0.0);
let mut s = 0usize;
while s < dim_scalars {
let va = shared_vecs[i * dim_scalars + s];
let vb = shared_vecs[j * dim_scalars + s];
if use_cosine {
sum += va * vb;
} else {
let diff = va - vb;
sum += diff * diff;
}
s += 1usize;
}
let dist = if use_cosine {
F::new(1.0) - (sum / (shared_norms[i] * shared_norms[j]))
} else {
sum
};
let thresh_i = graph_dist[pid_i as usize * k as usize + k as usize - 1usize];
let thresh_j = graph_dist[pid_j as usize * k as usize + k as usize - 1usize];
if dist < thresh_i {
let slot_i = prop_count[pid_i as usize].fetch_add(1u32);
if slot_i < max_proposals {
let off = pid_i as usize * max_proposals as usize + slot_i as usize;
prop_idx[off] = pid_j;
prop_dist[off] = dist;
}
}
if dist < thresh_j {
let slot_j = prop_count[pid_j as usize].fetch_add(1u32);
if slot_j < max_proposals {
let off = pid_j as usize * max_proposals as usize + slot_j as usize;
prop_idx[off] = pid_i;
prop_dist[off] = dist;
}
}
}
pair_idx += WORKGROUP_SIZE_X as usize;
}
}
#[cube(launch_unchecked)]
pub fn merge_proposals<F: Float>(
graph_idx: &mut Tensor<u32>,
graph_dist: &mut Tensor<F>,
prop_idx: &Tensor<u32>,
prop_dist: &Tensor<F>,
prop_count: &Tensor<u32>,
update_counter: &Tensor<Atomic<u32>>,
n: u32,
#[comptime] max_proposals: u32,
) {
let node = (CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) * WORKGROUP_SIZE_X + UNIT_POS_X;
if node >= n {
terminate!();
}
let k = graph_idx.shape(1);
let pid_mask = 0x7FFFFFFFu32;
let is_new_bit = 1u32 << 31;
let base = node as usize * k;
for j in 0..k {
graph_idx[base + j] = graph_idx[base + j] & pid_mask;
}
let raw_count = prop_count[node as usize];
let prop_base = node as usize * max_proposals as usize;
let mut improvements = 0u32;
for p in 0..max_proposals {
if p < raw_count {
let candidate = prop_idx[prop_base + p as usize];
let dist = prop_dist[prop_base + p as usize];
if dist < graph_dist[base + k - 1] {
let mut exists: u32 = 0u32;
for j in 0..k {
if (graph_idx[base + j] & pid_mask) == candidate {
exists = 1u32;
}
}
if exists == 0u32 && candidate != node {
let mut insert_pos = k - 1;
for j in 0..k {
if dist < graph_dist[base + j] && insert_pos == k - 1 {
insert_pos = j;
}
}
for j in 0..k - 1 {
let src = k - 2 - j;
let dst = k - 1 - j;
if src >= insert_pos {
graph_idx[base + dst] = graph_idx[base + src];
graph_dist[base + dst] = graph_dist[base + src];
}
}
graph_idx[base + insert_pos] = candidate | is_new_bit;
graph_dist[base + insert_pos] = dist;
improvements += 1u32;
}
}
}
}
if improvements > 0u32 {
update_counter[0usize].fetch_add(improvements);
}
}
#[cube(launch_unchecked)]
pub fn two_hop_refinement<F: Float>(
vectors: &Tensor<Line<F>>,
norms: &Tensor<F>,
graph_idx: &Tensor<u32>,
graph_dist: &Tensor<F>,
prop_idx: &mut Tensor<u32>,
prop_dist: &mut Tensor<F>,
prop_count: &Tensor<Atomic<u32>>,
n: u32,
#[comptime] max_proposals: u32,
#[comptime] use_cosine: bool,
#[comptime] dim_lines: usize,
) {
let node = CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X;
if node >= n {
terminate!();
}
let tx = UNIT_POS_X;
let k = graph_idx.shape(1usize);
let pid_mask = 0x7FFFFFFFu32;
let graph_base = node as usize * k;
let dim_scalars = dim_lines * 4usize;
let mut shared_source = SharedMemory::<F>::new(dim_scalars);
let mut shared_worst_dist = SharedMemory::<F>::new(1usize);
let mut idx_load = tx as usize;
while idx_load < dim_scalars {
let line_idx = idx_load / 4usize;
let lane = idx_load % 4usize;
let vec_offset = node as usize * dim_lines + line_idx;
let line_val = vectors[vec_offset];
shared_source[idx_load] = line_val[lane];
idx_load += WORKGROUP_SIZE_X as usize;
}
if tx == 0u32 {
shared_worst_dist[0usize] = graph_dist[graph_base + k - 1usize];
}
sync_cube();
let worst_dist = shared_worst_dist[0usize];
let num_candidates = k * k;
let mut cand_idx = tx as usize;
while cand_idx < num_candidates {
let n1_idx = cand_idx / k;
let n2_idx = cand_idx % k;
let n1_raw = graph_idx[graph_base + n1_idx];
let n1_pid = n1_raw & pid_mask;
if n1_pid < n {
let n2_raw = graph_idx[n1_pid as usize * k + n2_idx];
let cand_pid = n2_raw & pid_mask;
if cand_pid < n && cand_pid != node {
let mut is_dup: bool = false;
let mut scan_idx = 0usize;
while scan_idx < k {
if (graph_idx[graph_base + scan_idx] & pid_mask) == cand_pid {
is_dup = true;
}
scan_idx += 1usize;
}
if !is_dup {
let mut sum = F::new(0.0);
let mut s = 0usize;
while s < dim_scalars {
let va = shared_source[s];
let line_idx = s / 4usize;
let lane = s % 4usize;
let line_val = vectors[cand_pid as usize * dim_lines + line_idx];
let vb = line_val[lane];
if use_cosine {
sum += va * vb;
} else {
let diff = va - vb;
sum += diff * diff;
}
s += 1usize;
}
let dist = if use_cosine {
F::new(1.0) - (sum / (norms[node as usize] * norms[cand_pid as usize]))
} else {
sum
};
if dist < worst_dist {
let slot = prop_count[node as usize].fetch_add(1u32);
if slot < max_proposals {
let off = node as usize * max_proposals as usize + slot as usize;
prop_idx[off] = cand_pid;
prop_dist[off] = dist;
} else {
let rand_val = xorshift(node ^ slot ^ cand_pid) % (slot + 1u32);
if rand_val < max_proposals {
let off =
node as usize * max_proposals as usize + rand_val as usize;
prop_idx[off] = cand_pid;
prop_dist[off] = dist;
}
}
}
}
}
}
cand_idx += WORKGROUP_SIZE_X as usize;
}
}
#[cube(launch_unchecked)]
pub fn cagra_rank_prune_shared(
graph_idx: &Tensor<u32>,
pruned_idx: &mut Tensor<u32>,
n: u32,
#[comptime] k: usize,
#[comptime] d: usize,
) {
let node = CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X;
if node >= n {
terminate!();
}
let tx = UNIT_POS_X;
let k_u32 = k as u32;
let d_u32 = d as u32;
let graph_base = node * k_u32;
let mut shared_neighbors = SharedMemory::<u32>::new(k);
let mut shared_detours = SharedMemory::<u32>::new(k);
let mut i = tx;
while i < k_u32 {
shared_neighbors[i as usize] = graph_idx[(graph_base + i) as usize] & 0x7FFFFFFFu32;
shared_detours[i as usize] = 0u32;
i += WORKGROUP_SIZE_X;
}
sync_cube();
i = tx;
while i < k_u32 {
let y = shared_neighbors[i as usize];
let mut detours = 0u32;
let mut j = 0u32;
while j < i {
let z = shared_neighbors[j as usize];
let z_base = z * k_u32;
let mut found: bool = false;
let mut m = 0u32;
while m < i {
let z_neighbor = graph_idx[(z_base + m) as usize] & 0x7FFFFFFFu32;
if z_neighbor == y {
found = true;
}
m += 1u32;
}
if found {
detours += 1u32;
}
j += 1u32;
}
shared_detours[i as usize] = (detours << 16) | i;
i += WORKGROUP_SIZE_X;
}
sync_cube();
if tx == 0u32 {
let mut step = 0u32;
while step < d_u32 {
let mut min_val = 0xFFFFFFFFu32;
let mut min_idx = 0u32;
let mut scan = step;
while scan < k_u32 {
let val = shared_detours[scan as usize];
if val < min_val {
min_val = val;
min_idx = scan;
}
scan += 1u32;
}
let temp = shared_detours[step as usize];
shared_detours[step as usize] = shared_detours[min_idx as usize];
shared_detours[min_idx as usize] = temp;
let original_rank = min_val & 0xFFFFu32;
pruned_idx[(node * d_u32 + step) as usize] = shared_neighbors[original_rank as usize];
step += 1u32;
}
}
}
#[cube(launch_unchecked)]
pub fn cagra_build_reverse(
pruned_idx: &Tensor<u32>,
reverse_idx: &mut Tensor<u32>,
reverse_counts: &Tensor<Atomic<u32>>,
n: u32,
#[comptime] d: usize,
) {
let node = (CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) * WORKGROUP_SIZE_X + UNIT_POS_X;
if node >= n {
terminate!();
}
let d_u32 = d as u32;
let mut i = 0u32;
while i < d_u32 {
let target = pruned_idx[(node * d_u32 + i) as usize];
if target < n {
let pos = reverse_counts[target as usize].fetch_add(1u32);
if pos < d_u32 {
reverse_idx[(target * d_u32 + pos) as usize] = node;
}
}
i += 1u32;
}
}
#[cube(launch_unchecked)]
pub fn cagra_merge_graphs(
pruned_idx: &Tensor<u32>,
reverse_idx: &Tensor<u32>,
reverse_counts: &Tensor<u32>,
final_idx: &mut Tensor<u32>,
n: u32,
#[comptime] d: usize,
) {
let node = (CUBE_POS_Y * CUBE_COUNT_X + CUBE_POS_X) * WORKGROUP_SIZE_X + UNIT_POS_X;
if node >= n {
terminate!();
}
let d_u32 = d as u32;
let half_d = d_u32 / 2u32;
let rev_count = reverse_counts[node as usize];
let mut take_rev = rev_count;
if take_rev > half_d {
take_rev = half_d;
}
let mut final_count = 0u32;
let mut i = 0u32;
while i < take_rev {
final_idx[(node * d_u32 + final_count) as usize] = reverse_idx[(node * d_u32 + i) as usize];
final_count += 1u32;
i += 1u32;
}
let mut j = 0u32;
while j < d_u32 {
if final_count < d_u32 {
let candidate = pruned_idx[(node * d_u32 + j) as usize];
let mut is_dup: bool = false;
let mut c = 0u32;
while c < final_count {
if final_idx[(node * d_u32 + c) as usize] == candidate {
is_dup = true;
}
c += 1u32;
}
if !is_dup {
final_idx[(node * d_u32 + final_count) as usize] = candidate;
final_count += 1u32;
}
}
j += 1u32;
}
while final_count < d_u32 {
final_idx[(node * d_u32 + final_count) as usize] = 0x7FFFFFFFu32;
final_count += 1u32;
}
}
fn pad_vectors<T: Float>(flat: &[T], n: usize, dim: usize, dim_padded: usize) -> Vec<T> {
let mut padded = vec![T::zero(); n * dim_padded];
for i in 0..n {
let src = &flat[i * dim..(i + 1) * dim];
let dst = &mut padded[i * dim_padded..i * dim_padded + dim];
dst.copy_from_slice(src);
}
padded
}
thread_local! {
static QUERY_VISITED: RefCell<FixedBitSet> = const { RefCell::new(FixedBitSet::new()) };
static QUERY_CANDIDATES_F32: QueryCandF32 =
const { RefCell::new(BinaryHeap::new()) };
static QUERY_CANDIDATES_F64: QueryCandF64 =
const { RefCell::new(BinaryHeap::new()) };
static QUERY_RESULTS_F32: RefCell<BinaryHeap<(OrderedFloat<f32>, usize)>> =
const { RefCell::new(BinaryHeap::new()) };
static QUERY_RESULTS_F64: RefCell<BinaryHeap<(OrderedFloat<f64>, usize)>> =
const { RefCell::new(BinaryHeap::new()) };
}
macro_rules! impl_nndescent_gpu_query {
($float:ty, $cand_tls:ident, $res_tls:ident) => {
impl<R: Runtime> NNDescentQuery<$float> for NNDescentGpu<$float, R> {
fn query_internal(
&self,
query_vec: &[$float],
query_norm: $float,
k: usize,
ef: usize,
) -> (Vec<usize>, Vec<$float>) {
QUERY_VISITED.with(|visited_cell| {
$cand_tls.with(|cand_cell| {
$res_tls.with(|res_cell| {
let mut visited = visited_cell.borrow_mut();
let mut candidates = cand_cell.borrow_mut();
let mut results = res_cell.borrow_mut();
visited.clear();
visited.grow(self.n);
candidates.clear();
results.clear();
match self.metric {
Dist::Euclidean => self.query_euclidean(
query_vec,
k,
ef,
&mut visited,
&mut candidates,
&mut results,
),
Dist::Cosine => self.query_cosine(
query_vec,
query_norm,
k,
ef,
&mut visited,
&mut candidates,
&mut results,
),
}
})
})
})
}
#[inline(always)]
fn query_euclidean(
&self,
query_vec: &[$float],
k: usize,
ef: usize,
visited: &mut FixedBitSet,
candidates: &mut BinaryHeap<Reverse<(OrderedFloat<$float>, usize)>>,
results: &mut BinaryHeap<(OrderedFloat<$float>, usize)>,
) -> (Vec<usize>, Vec<$float>) {
let init_indices = self
.router
.find_entry_points(query_vec, (ef / 2).max(2 * k).min(self.n));
for &entry_idx in &init_indices {
if entry_idx >= self.n || visited.contains(entry_idx) {
continue;
}
visited.insert(entry_idx);
let dist = self.euclidean_distance_to_query(entry_idx, query_vec);
candidates.push(Reverse((OrderedFloat(dist), entry_idx)));
results.push((OrderedFloat(dist), entry_idx));
}
while results.len() > ef {
results.pop();
}
let mut lower_bound = if results.len() >= ef {
results.peek().unwrap().0 .0
} else {
<$float>::MAX
};
while let Some(Reverse((OrderedFloat(curr_dist), curr_idx))) = candidates.pop() {
if curr_dist > lower_bound {
break;
}
for &(nbr_idx, _) in self.graph_neighbours(curr_idx) {
if nbr_idx == SENTINEL_PID || visited.contains(nbr_idx) {
continue;
}
visited.insert(nbr_idx);
let dist = self.euclidean_distance_to_query(nbr_idx, query_vec);
if dist < lower_bound || results.len() < ef {
candidates.push(Reverse((OrderedFloat(dist), nbr_idx)));
if results.len() < ef {
results.push((OrderedFloat(dist), nbr_idx));
if results.len() == ef {
lower_bound = results.peek().unwrap().0 .0;
}
} else if dist < lower_bound {
results.pop();
results.push((OrderedFloat(dist), nbr_idx));
lower_bound = results.peek().unwrap().0 .0;
}
}
}
}
let mut final_results: Vec<_> = results.drain().collect();
final_results.sort_unstable_by(|a, b| a.0.cmp(&b.0));
final_results.truncate(k);
final_results
.into_iter()
.map(|(OrderedFloat(d), i)| (i, d))
.unzip()
}
#[inline(always)]
fn query_cosine(
&self,
query_vec: &[$float],
query_norm: $float,
k: usize,
ef: usize,
visited: &mut FixedBitSet,
candidates: &mut BinaryHeap<Reverse<(OrderedFloat<$float>, usize)>>,
results: &mut BinaryHeap<(OrderedFloat<$float>, usize)>,
) -> (Vec<usize>, Vec<$float>) {
let init_indices = self
.router
.find_entry_points(query_vec, (ef / 2).max(2 * k).min(self.n));
for &entry_idx in &init_indices {
if entry_idx >= self.n || visited.contains(entry_idx) {
continue;
}
visited.insert(entry_idx);
let dist = self.cosine_distance_to_query(entry_idx, query_vec, query_norm);
candidates.push(Reverse((OrderedFloat(dist), entry_idx)));
results.push((OrderedFloat(dist), entry_idx));
}
while results.len() > ef {
results.pop();
}
let mut lower_bound = if results.len() >= ef {
results.peek().unwrap().0 .0
} else {
<$float>::MAX
};
while let Some(Reverse((OrderedFloat(curr_dist), curr_idx))) = candidates.pop() {
if curr_dist > lower_bound {
break;
}
for &(nbr_idx, _) in self.graph_neighbours(curr_idx) {
if nbr_idx == SENTINEL_PID || visited.contains(nbr_idx) {
continue;
}
visited.insert(nbr_idx);
let dist = self.cosine_distance_to_query(nbr_idx, query_vec, query_norm);
if dist < lower_bound || results.len() < ef {
candidates.push(Reverse((OrderedFloat(dist), nbr_idx)));
if results.len() < ef {
results.push((OrderedFloat(dist), nbr_idx));
if results.len() == ef {
lower_bound = results.peek().unwrap().0 .0;
}
} else if dist < lower_bound {
results.pop();
results.push((OrderedFloat(dist), nbr_idx));
lower_bound = results.peek().unwrap().0 .0;
}
}
}
}
let mut final_results: Vec<_> = results.drain().collect();
final_results.sort_unstable_by(|a, b| a.0.cmp(&b.0));
final_results.truncate(k);
final_results
.into_iter()
.map(|(OrderedFloat(d), i)| (i, d))
.unzip()
}
}
};
}
impl_nndescent_gpu_query!(f32, QUERY_CANDIDATES_F32, QUERY_RESULTS_F32);
impl_nndescent_gpu_query!(f64, QUERY_CANDIDATES_F64, QUERY_RESULTS_F64);
pub struct NNDescentGpu<T: AnnSearchFloat + AnnSearchGpuFloat, R: Runtime> {
pub vectors_flat: Vec<T>,
pub dim: usize,
pub n: usize,
pub k: usize,
pub norms: Vec<T>,
pub metric: Dist,
pub medoid: u32,
knn_graph: Vec<(usize, T)>,
nav_graph: Vec<(usize, T)>,
converged: bool,
router: ForestRouter<T>,
_device: R::Device,
dim_padded: usize,
nav_graph_gpu: Option<GpuTensor<R, u32>>,
vectors_gpu: Option<GpuTensor<R, T>>,
norms_gpu: Option<GpuTensor<R, T>>,
}
impl<T, R> VectorDistance<T> for NNDescentGpu<T, R>
where
T: AnnSearchFloat + AnnSearchGpuFloat,
R: Runtime,
{
fn vectors_flat(&self) -> &[T] {
&self.vectors_flat
}
fn dim(&self) -> usize {
self.dim
}
fn norms(&self) -> &[T] {
&self.norms
}
}
impl<T, R> NNDescentGpu<T, R>
where
R: Runtime,
T: AnnSearchFloat + cubecl::frontend::Float + cubecl::CubeElement,
Self: NNDescentQuery<T>,
{
#[allow(clippy::too_many_arguments)]
pub fn build(
data: MatRef<T>,
metric: Dist,
k: Option<usize>,
build_k: Option<usize>,
max_iters: Option<usize>,
n_trees: Option<usize>,
delta: Option<f32>,
rho: Option<f32>,
refine_knn: Option<usize>,
seed: usize,
verbose: bool,
retain_gpu: bool,
device: R::Device,
) -> Self {
let (vectors_flat, n, dim) = matrix_to_flat(data);
let k = k.unwrap_or(30);
let build_k = build_k.unwrap_or((1.5 * k as f32) as usize).max(k);
let max_iters = max_iters.unwrap_or(DEFAULT_MAX_ITERS);
let delta = delta.unwrap_or(DEFAULT_DELTA);
let rho = rho.unwrap_or(DEFAULT_RHO);
let rho_thresh = (rho * 65535.0) as u32;
let refine_knn = refine_knn.unwrap_or(0);
let medoid = compute_medoid(&vectors_flat, n, dim, metric);
let line = LINE_SIZE as usize;
let dim_padded = dim.next_multiple_of(line);
let dim_vec = dim_padded / line;
let vectors_padded = if dim_padded != dim {
pad_vectors(&vectors_flat, n, dim, dim_padded)
} else {
vectors_flat.clone()
};
let norms = if metric == Dist::Cosine {
(0..n)
.into_par_iter()
.map(|i| T::calculate_l2_norm(&vectors_flat[i * dim..(i + 1) * dim]))
.collect()
} else {
Vec::new()
};
if verbose {
println!(
"NNDescent-GPU: {} vectors, dim={} (padded to {}), k={}, build_k={}",
n.separate_with_underscores(),
dim,
dim_padded,
k,
build_k,
);
}
let start = Instant::now();
let n_trees_forest = n_trees.unwrap_or_else(|| {
let calculated = 5 + ((n as f64).powf(0.25)).round() as usize;
calculated.min(20)
});
let client = R::client(&device);
let use_cosine = metric == Dist::Cosine;
let vectors_gpu =
GpuTensor::<R, T>::from_slice(&vectors_padded, vec![n, dim_padded], &client);
let norms_gpu = if use_cosine {
GpuTensor::<R, T>::from_slice(&norms, vec![n], &client)
} else {
GpuTensor::<R, T>::from_slice(&[T::zero()], vec![1], &client)
};
let graph_idx_gpu = GpuTensor::<R, u32>::from_slice(
&vec![0x7FFFFFFFu32; n * build_k],
vec![n, build_k],
&client,
);
let graph_dist_gpu = GpuTensor::<R, T>::from_slice(
&vec![<T as num_traits::Float>::max_value(); n * build_k],
vec![n, build_k],
&client,
);
let max_prop = MAX_PROPOSALS;
let prop_idx_gpu = GpuTensor::<R, u32>::empty(vec![n, max_prop], &client);
let prop_dist_gpu = GpuTensor::<R, T>::empty(vec![n, max_prop], &client);
let prop_count_gpu = GpuTensor::<R, u32>::empty(vec![n], &client);
let update_counter_gpu = GpuTensor::<R, u32>::empty(vec![1], &client);
let (grid_n_x, grid_n_y) = grid_2d((n as u32).div_ceil(WORKGROUP_SIZE_X));
if verbose {
println!(" Random graph initialisation...");
}
unsafe {
let _ = init_random_graph::launch_unchecked::<T, R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
vectors_gpu.clone().into_tensor_arg(line),
norms_gpu.clone().into_tensor_arg(1),
graph_idx_gpu.clone().into_tensor_arg(1),
graph_dist_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
ScalarArg { elem: seed as u32 },
use_cosine,
dim_vec,
);
}
let router = gpu_forest_init(
&vectors_gpu,
&norms_gpu,
&graph_idx_gpu,
&graph_dist_gpu,
&prop_idx_gpu,
&prop_dist_gpu,
&prop_count_gpu,
&update_counter_gpu,
n,
dim,
dim_padded,
n_trees_forest,
seed,
use_cosine,
verbose,
&client,
);
let total_entries = (n * build_k) as u32;
let mark_grid_flat = total_entries.div_ceil(WORKGROUP_SIZE_X);
let mark_cubes_x = mark_grid_flat.min(65535);
let mark_cubes_y = mark_grid_flat.div_ceil(mark_cubes_x);
unsafe {
let _ = mark_all_new::launch_unchecked::<R>(
&client,
CubeCount::Static(mark_cubes_x, mark_cubes_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
graph_idx_gpu.clone().into_tensor_arg(1),
ScalarArg {
elem: total_entries,
},
);
}
let iter_start = Instant::now();
let mut converged = false;
let reverse_idx_gpu = GpuTensor::<R, u32>::empty(vec![n, build_k], &client);
let reverse_count_gpu = GpuTensor::<R, u32>::empty(vec![n], &client);
for iter in 0..max_iters {
let cubes_x = 65535u32;
let cubes_y = (n as u32).div_ceil(cubes_x);
unsafe {
let _ = reset_proposals::launch_unchecked::<R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
prop_count_gpu.clone().into_tensor_arg(1),
update_counter_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
);
let _ = reset_proposals::launch_unchecked::<R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
reverse_count_gpu.clone().into_tensor_arg(1),
update_counter_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
);
}
unsafe {
let _ = build_reverse_candidates::launch_unchecked::<R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
graph_idx_gpu.clone().into_tensor_arg(1),
reverse_idx_gpu.clone().into_tensor_arg(1),
reverse_count_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
build_k as u32,
);
}
let iter_seed = seed as u32 ^ (iter as u32).wrapping_mul(0x9E3779B9u32);
unsafe {
let _ = local_join_shared::launch_unchecked::<T, R>(
&client,
CubeCount::Static(cubes_x, cubes_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
vectors_gpu.clone().into_tensor_arg(line),
norms_gpu.clone().into_tensor_arg(1),
graph_idx_gpu.clone().into_tensor_arg(1),
graph_dist_gpu.clone().into_tensor_arg(1),
reverse_idx_gpu.clone().into_tensor_arg(1),
reverse_count_gpu.clone().into_tensor_arg(1),
prop_idx_gpu.clone().into_tensor_arg(1),
prop_dist_gpu.clone().into_tensor_arg(1),
prop_count_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
ScalarArg { elem: rho_thresh },
ScalarArg { elem: iter_seed },
MAX_PROPOSALS as u32,
use_cosine,
dim_vec,
build_k,
);
}
unsafe {
let _ = merge_proposals::launch_unchecked::<T, R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
graph_idx_gpu.clone().into_tensor_arg(1),
graph_dist_gpu.clone().into_tensor_arg(1),
prop_idx_gpu.clone().into_tensor_arg(1),
prop_dist_gpu.clone().into_tensor_arg(1),
prop_count_gpu.clone().into_tensor_arg(1),
update_counter_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
MAX_PROPOSALS as u32,
);
}
let counter_data = update_counter_gpu.clone().read(&client);
let updates = counter_data[0] as f64;
let rate = updates / (n * build_k) as f64;
if verbose {
println!(
" Iter {}: {} updates (rate={:.6})",
iter + 1,
(updates as usize).separate_with_underscores(),
rate
);
}
if rate < delta as f64 {
if verbose {
println!(" Converged after {} iterations", iter + 1);
}
converged = true;
break;
}
}
if verbose {
println!(" NNDescent iterations: {:.2?}", iter_start.elapsed());
}
if verbose && refine_knn > 0 {
println!(" Running 2-Hop Refinement Sweep...");
}
let refinement_start = Instant::now();
let cubes_x = 65535u32;
let cubes_y = (n as u32).div_ceil(cubes_x);
for sweep in 0..refine_knn {
unsafe {
let _ = reset_proposals::launch_unchecked::<R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
prop_count_gpu.clone().into_tensor_arg(1),
update_counter_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
);
}
unsafe {
let _ = two_hop_refinement::launch_unchecked::<T, R>(
&client,
CubeCount::Static(cubes_x, cubes_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
vectors_gpu.clone().into_tensor_arg(line),
norms_gpu.clone().into_tensor_arg(1),
graph_idx_gpu.clone().into_tensor_arg(1),
graph_dist_gpu.clone().into_tensor_arg(1),
prop_idx_gpu.clone().into_tensor_arg(1),
prop_dist_gpu.clone().into_tensor_arg(1),
prop_count_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
MAX_PROPOSALS as u32,
use_cosine,
dim_vec,
);
}
unsafe {
let _ = merge_proposals::launch_unchecked::<T, R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
graph_idx_gpu.clone().into_tensor_arg(1),
graph_dist_gpu.clone().into_tensor_arg(1),
prop_idx_gpu.clone().into_tensor_arg(1),
prop_dist_gpu.clone().into_tensor_arg(1),
prop_count_gpu.clone().into_tensor_arg(1),
update_counter_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
MAX_PROPOSALS as u32,
);
}
if verbose {
let counter_data = update_counter_gpu.clone().read(&client);
println!(" 2-Hop sweep {}: {} updates", sweep + 1, counter_data[0]);
}
let refinement_stop = refinement_start.elapsed();
if verbose {
println!(" NNDescent refinement done in: {:.2?}", refinement_stop);
}
}
let nndescent_idx = graph_idx_gpu.clone().read(&client);
let nndescent_dist = graph_dist_gpu.clone().read(&client);
let pid_mask = 0x7FFFFFFFu32;
let sentinel = 0x7FFFFFFFusize;
let mut knn_graph = vec![(sentinel, <T as num_traits::Float>::max_value()); n * k];
knn_graph
.par_chunks_mut(k)
.enumerate()
.for_each(|(i, slot)| {
let mut written = 0;
for j in 0..build_k {
if written >= k {
break;
}
let raw = nndescent_idx[i * build_k + j];
let pid = (raw & pid_mask) as usize;
if pid < n && pid != i && pid != sentinel {
let dist = nndescent_dist[i * build_k + j];
slot[written] = (pid, dist);
written += 1;
}
}
});
let cagra_start = Instant::now();
let pruned_idx_gpu = GpuTensor::<R, u32>::empty(vec![n, k], &client);
let reverse_idx_gpu = GpuTensor::<R, u32>::empty(vec![n, k], &client);
let reverse_counts_gpu = GpuTensor::<R, u32>::from_slice(&vec![0u32; n], vec![n], &client);
let final_idx_gpu = GpuTensor::<R, u32>::empty(vec![n, k], &client);
let cubes_x = 65535u32;
let cubes_y = (n as u32).div_ceil(cubes_x);
unsafe {
let _ = cagra_rank_prune_shared::launch_unchecked::<R>(
&client,
CubeCount::Static(cubes_x, cubes_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
graph_idx_gpu.into_tensor_arg(1),
pruned_idx_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
build_k,
k,
);
let _ = cagra_build_reverse::launch_unchecked::<R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
pruned_idx_gpu.clone().into_tensor_arg(1),
reverse_idx_gpu.clone().into_tensor_arg(1),
reverse_counts_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
k,
);
let _ = cagra_merge_graphs::launch_unchecked::<R>(
&client,
CubeCount::Static(grid_n_x, grid_n_y, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
pruned_idx_gpu.into_tensor_arg(1),
reverse_idx_gpu.into_tensor_arg(1),
reverse_counts_gpu.into_tensor_arg(1),
final_idx_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
k,
);
}
if verbose {
println!(" CAGRA optimisation: {:.2?}", cagra_start.elapsed());
}
let final_idx = final_idx_gpu.clone().read(&client);
let pid_mask = 0x7FFFFFFFu32;
let sentinel = 0x7FFFFFFFusize;
let mut cagra_graph = vec![(sentinel, <T as num_traits::Float>::max_value()); n * k];
cagra_graph
.par_chunks_mut(k)
.enumerate()
.for_each(|(i, slot)| {
for j in 0..k {
let raw = final_idx[i * k + j];
let pid = (raw & pid_mask) as usize;
if pid < n && pid != sentinel {
let a = &vectors_flat[i * dim..(i + 1) * dim];
let b = &vectors_flat[pid * dim..(pid + 1) * dim];
let dist = match metric {
Dist::Euclidean => T::euclidean_simd(a, b),
Dist::Cosine => {
let dot = T::dot_simd(a, b);
T::one() - dot / (norms[i] * norms[pid])
}
};
slot[j] = (pid, dist);
}
}
slot.sort_unstable_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
});
});
if verbose {
println!(" Total build time: {:.2?}", start.elapsed());
}
let (nav_graph_gpu, vectors_gpu, norms_gpu) = if retain_gpu {
(Some(final_idx_gpu), Some(vectors_gpu), Some(norms_gpu))
} else {
(None, None, None)
};
Self {
vectors_flat,
dim,
dim_padded,
n,
k,
medoid,
norms,
metric,
router,
knn_graph,
nav_graph: cagra_graph,
converged,
nav_graph_gpu,
vectors_gpu,
norms_gpu,
_device: device,
}
}
pub fn query(
&self,
query_vec: &[T],
k: usize,
ef_search: Option<usize>,
) -> (Vec<usize>, Vec<T>) {
let k = k.min(self.n);
let ef = ef_search.unwrap_or_else(|| (k * 2).clamp(50, 200)).max(k);
let query_norm = if self.metric == Dist::Cosine {
num_traits::Float::sqrt(query_vec.iter().map(|x| *x * *x).sum::<T>())
} else {
T::one()
};
self.query_internal(query_vec, query_norm, k, ef)
}
#[inline]
pub fn query_row(
&self,
query_row: RowRef<T>,
k: usize,
ef_search: Option<usize>,
) -> (Vec<usize>, Vec<T>) {
if query_row.col_stride() == 1 {
let slice =
unsafe { std::slice::from_raw_parts(query_row.as_ptr(), query_row.ncols()) };
return self.query(slice, k, ef_search);
}
let query_vec: Vec<T> = query_row.iter().cloned().collect();
self.query(&query_vec, k, ef_search)
}
pub fn query_batch_gpu(
&mut self,
queries_flat: &[T],
n_queries: usize,
query_params: Option<CagraGpuSearchParams>,
k: usize,
seed: usize,
) -> (Vec<Vec<usize>>, Vec<Vec<T>>)
where
T: AnnSearchGpuFloat + num_traits::Float,
{
assert_eq!(
queries_flat.len(),
n_queries * self.dim,
"queries_flat length must be n_queries * dim"
);
let query_params =
query_params.unwrap_or_else(|| CagraGpuSearchParams::from_graph(k, self.k));
let n_entry = query_params.get_n_entry();
self.ensure_gpu_tensors();
let client = R::client(&self._device);
let use_cosine = self.metric == Dist::Cosine;
let medoid = self.medoid;
let entry_flat: Vec<u32> = (0..n_queries)
.into_par_iter()
.flat_map_iter(|i| {
let query = &queries_flat[i * self.dim..(i + 1) * self.dim];
let mut candidates = self.router.find_entry_points(query, n_entry * 4);
candidates.sort_unstable_by(|&a, &b| {
let dist_a = match self.metric {
Dist::Euclidean => {
let va = &self.vectors_flat[a * self.dim..(a + 1) * self.dim];
T::euclidean_simd(query, va)
}
Dist::Cosine => {
let va = &self.vectors_flat[a * self.dim..(a + 1) * self.dim];
let dot = T::dot_simd(query, va);
let q_norm = T::calculate_l2_norm(query);
T::one() - dot / (q_norm * self.norms[a])
}
};
let dist_b = match self.metric {
Dist::Euclidean => {
let vb = &self.vectors_flat[b * self.dim..(b + 1) * self.dim];
T::euclidean_simd(query, vb)
}
Dist::Cosine => {
let vb = &self.vectors_flat[b * self.dim..(b + 1) * self.dim];
let dot = T::dot_simd(query, vb);
let q_norm = T::calculate_l2_norm(query);
T::one() - dot / (q_norm * self.norms[b])
}
};
dist_a
.partial_cmp(&dist_b)
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.retain(|&e| e != medoid as usize);
candidates.truncate(n_entry - 1);
let mut final_entries = Vec::with_capacity(n_entry);
final_entries.push(medoid);
final_entries.extend(candidates.into_iter().map(|idx| idx as u32));
final_entries.resize(n_entry, 0);
final_entries.into_iter()
})
.collect();
let result = cagra_search_batch_gpu(
queries_flat,
n_queries,
self.dim,
self.vectors_gpu.as_ref().unwrap(),
self.norms_gpu.as_ref().unwrap(),
self.nav_graph_gpu.as_ref().unwrap(),
self.n,
self.k,
k,
use_cosine,
seed,
&query_params,
Some(&entry_flat),
&client,
);
result
}
fn graph_neighbours(&self, idx: usize) -> &[(usize, T)] {
&self.nav_graph[idx * self.k..(idx + 1) * self.k]
}
pub fn converged(&self) -> bool {
self.converged
}
pub fn memory_usage_bytes(&self) -> usize {
std::mem::size_of_val(self)
+ self.vectors_flat.capacity() * std::mem::size_of::<T>()
+ self.norms.capacity() * std::mem::size_of::<T>()
+ self.nav_graph.capacity() * std::mem::size_of::<(usize, T)>()
+ self.knn_graph.capacity() * std::mem::size_of::<(usize, T)>()
}
pub fn extract_knn(&self, return_dist: bool) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>) {
let sentinel = 0x7FFFFFFFusize;
let indices: Vec<Vec<usize>> = (0..self.n)
.map(|i| {
self.knn_graph[i * self.k..(i + 1) * self.k]
.iter()
.filter(|&&(pid, _)| pid != sentinel)
.map(|&(pid, _)| pid)
.collect()
})
.collect();
let distances = if return_dist {
Some(
(0..self.n)
.map(|i| {
self.knn_graph[i * self.k..(i + 1) * self.k]
.iter()
.filter(|&&(pid, _)| pid != sentinel)
.map(|&(_, dist)| dist)
.collect()
})
.collect(),
)
} else {
None
};
(indices, distances)
}
pub fn self_query_gpu(
&mut self,
k: usize,
query_params: Option<CagraGpuSearchParams>,
seed: usize,
) -> (Vec<Vec<usize>>, Vec<Vec<T>>)
where
T: AnnSearchGpuFloat + AnnSearchFloat,
{
self.ensure_gpu_tensors();
let query_params =
query_params.unwrap_or_else(|| CagraGpuSearchParams::from_graph(k, self.k));
let n_entry = query_params.get_n_entry();
let client = R::client(&self._device);
let use_cosine = self.metric == Dist::Cosine;
let entry_flat: Vec<u32> = (0..self.n)
.flat_map(|i| {
let row = &self.knn_graph[i * self.k..(i + 1) * self.k];
let mut entries: Vec<u32> = row
.iter()
.filter(|&&(pid, _)| pid != SENTINEL_PID)
.take(n_entry)
.map(|&(pid, _)| pid as u32)
.collect();
let mut rng_val = (i as u32) ^ (seed as u32);
while entries.len() < n_entry {
rng_val = rng_val.wrapping_mul(1664525).wrapping_add(1013904223);
entries.push(rng_val % self.n as u32);
}
entries
})
.collect();
let queries_flat = self.vectors_flat.clone();
cagra_search_batch_gpu(
&queries_flat,
self.n,
self.dim,
self.vectors_gpu.as_ref().unwrap(),
self.norms_gpu.as_ref().unwrap(),
self.nav_graph_gpu.as_ref().unwrap(),
self.n,
self.k,
k,
use_cosine,
seed,
&query_params,
Some(&entry_flat),
&client,
)
}
fn ensure_gpu_tensors(&mut self) {
if self.nav_graph_gpu.is_some() {
return;
}
let client = R::client(&self._device);
let dim_padded = self.dim_padded;
let vectors_padded = if dim_padded != self.dim {
pad_vectors(&self.vectors_flat, self.n, self.dim, dim_padded)
} else {
self.vectors_flat.clone()
};
self.vectors_gpu = Some(GpuTensor::<R, T>::from_slice(
&vectors_padded,
vec![self.n, dim_padded],
&client,
));
self.norms_gpu = Some(if self.metric == Dist::Cosine {
GpuTensor::<R, T>::from_slice(&self.norms, vec![self.n], &client)
} else {
GpuTensor::<R, T>::from_slice(&[T::zero()], vec![1], &client)
});
let nav_flat: Vec<u32> = self.nav_graph.iter().map(|&(pid, _)| pid as u32).collect();
self.nav_graph_gpu = Some(GpuTensor::<R, u32>::from_slice(
&nav_flat,
vec![self.n, self.k],
&client,
));
}
}
#[cfg(test)]
mod tests {
use super::*;
use cubecl::wgpu::WgpuDevice;
use cubecl::wgpu::WgpuRuntime;
use faer::Mat;
fn try_device() -> Option<WgpuDevice> {
let device = WgpuDevice::DefaultDevice;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
cubecl::wgpu::WgpuRuntime::client(&device);
}));
result.ok().map(|_| device)
}
#[test]
fn test_nndescent_gpu_basic() {
let Some(device) = try_device() else {
eprintln!("Skipping test: no wgpu backend available");
return;
};
let data = Mat::from_fn(20, 4, |i, j| ((i * 3 + j) as f32) / 10.0);
let index = NNDescentGpu::<f32, WgpuRuntime>::build(
data.as_ref(),
Dist::Euclidean,
Some(5),
None,
Some(10),
None,
Some(0.001),
Some(0.5),
None,
42,
false,
false,
device,
);
assert_eq!(index.nav_graph.len(), 20 * 5);
for i in 0..20 {
let nbrs = index.graph_neighbours(i);
assert_eq!(nbrs.len(), 5);
for w in nbrs.windows(2) {
assert!(w[1].1 >= w[0].1);
}
}
}
#[test]
fn test_nndescent_gpu_cosine() {
let Some(device) = try_device() else {
eprintln!("Skipping test: no wgpu backend available");
return;
};
let data = Mat::from_fn(16, 4, |i, _| (i as f32) + 1.0);
let index = NNDescentGpu::<f32, WgpuRuntime>::build(
data.as_ref(),
Dist::Cosine,
Some(3),
None,
Some(10),
None,
Some(0.001),
Some(0.5),
None,
42,
false,
false,
device,
);
assert_eq!(index.nav_graph.len(), 16 * 3);
assert!(!index.norms.is_empty());
}
#[test]
fn test_nndescent_gpu_padded_dim() {
let Some(device) = try_device() else {
eprintln!("Skipping test: no wgpu backend available");
return;
};
let data = Mat::from_fn(12, 3, |i, j| (i + j) as f32);
let index = NNDescentGpu::<f32, WgpuRuntime>::build(
data.as_ref(),
Dist::Euclidean,
Some(3),
None,
Some(10),
None,
Some(0.001),
Some(0.5),
None,
42,
false,
false,
device,
);
assert_eq!(index.dim, 3);
assert_eq!(index.nav_graph.len(), 12 * 3);
}
#[test]
fn test_extract_knn() {
let Some(device) = try_device() else {
eprintln!("Skipping test: no wgpu backend available");
return;
};
let data = Mat::from_fn(20, 4, |i, j| ((i * 3 + j) as f32) / 10.0);
let index = NNDescentGpu::<f32, WgpuRuntime>::build(
data.as_ref(),
Dist::Euclidean,
Some(5),
None,
Some(10),
None,
Some(0.001),
Some(0.5),
None,
42,
false,
false,
device,
);
let (indices, Some(distances)) = index.extract_knn(true) else {
panic!("Expected distances");
};
assert_eq!(indices.len(), 20);
assert_eq!(distances.len(), 20);
for i in 0..20 {
assert_eq!(indices[i].len(), 5);
assert_eq!(distances[i].len(), 5);
assert!(!indices[i].contains(&i));
}
let (indices, dists) = index.extract_knn(false);
assert_eq!(indices.len(), 20);
assert!(dists.is_none());
}
}
#[cfg(test)]
#[cfg(feature = "gpu-tests")]
mod kernel_tests {
use super::*;
use cubecl::wgpu::WgpuDevice;
use cubecl::wgpu::WgpuRuntime;
use faer::Mat;
fn try_device() -> Option<WgpuDevice> {
let device = WgpuDevice::DefaultDevice;
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
cubecl::wgpu::WgpuRuntime::client(&device);
}));
result.ok().map(|_| device)
}
#[cube(launch_unchecked)]
fn probe_stride<F: Float>(vectors: &Tensor<Line<F>>, out: &mut Tensor<u32>) {
if ABSOLUTE_POS_X == 0u32 {
out[0usize] = vectors.stride(0) as u32;
out[1usize] = vectors.shape(1) as u32;
out[2usize] = vectors.stride(1) as u32;
out[3usize] = vectors.shape(0) as u32;
}
}
#[test]
fn test_stride_probe() {
let Some(device) = try_device() else {
eprintln!("Skipping: no wgpu backend");
return;
};
let client = WgpuRuntime::client(&device);
let line: usize = LINE_SIZE as usize;
let n = 8usize;
let dim_padded = 32usize;
let dim_vec = dim_padded / line;
let data: Vec<f32> = (0..n * dim_padded).map(|i| i as f32).collect();
let vectors_gpu =
GpuTensor::<WgpuRuntime, f32>::from_slice(&data, vec![n, dim_padded], &client);
let out_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(&[0u32; 4], vec![4], &client);
unsafe {
let _ = probe_stride::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(1, 1),
vectors_gpu.into_tensor_arg(line),
out_gpu.clone().into_tensor_arg(1),
);
}
let result = out_gpu.read(&client);
let stride_0 = result[0];
let shape_1 = result[1];
let stride_1 = result[2];
let shape_0 = result[3];
println!("Tensor [n={n}, dim_padded={dim_padded}] with line_size={line}:");
println!(" stride(0) = {stride_0} (expected {dim_vec} in Line units, or {dim_padded} in f32 units)");
println!(" shape(1) = {shape_1} (expected {dim_vec} in Line units)");
println!(" stride(1) = {stride_1} (expected 1)");
println!(" shape(0) = {shape_0} (expected {n})");
assert_eq!(shape_0, n as u32, "shape(0) should be n");
assert_eq!(
stride_0, dim_padded as u32,
"stride(0) should be dim_padded (f32 units)"
);
}
#[cube(launch_unchecked)]
fn read_vector_via_stride<F: Float>(
vectors: &Tensor<Line<F>>,
row_idx: u32,
out: &mut Tensor<F>,
#[comptime] dim_lines: usize,
) {
if ABSOLUTE_POS_X == 0u32 {
let off = row_idx as usize * dim_lines;
let mut d = 0usize;
while d < dim_lines {
let line_val = vectors[off + d];
out[d * 4usize] = line_val[0usize];
out[d * 4usize + 1usize] = line_val[1usize];
out[d * 4usize + 2usize] = line_val[2usize];
out[d * 4usize + 3usize] = line_val[3usize];
d += 1usize;
}
}
}
#[test]
fn test_vector_roundtrip_line() {
let Some(device) = try_device() else {
eprintln!("Skipping: no wgpu backend");
return;
};
let client = WgpuRuntime::client(&device);
let line: usize = LINE_SIZE as usize;
let n = 4usize;
let dim = 8usize; let dim_vec = dim / line;
let mut data = vec![0.0f32; n * dim];
for i in 0..n {
for j in 0..dim {
data[i * dim + j] = (i * 100 + j) as f32;
}
}
let vectors_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&data, vec![n, dim], &client);
for row in 0..n {
let out_gpu =
GpuTensor::<WgpuRuntime, f32>::from_slice(&vec![-1.0f32; dim], vec![dim], &client);
unsafe {
let _ = read_vector_via_stride::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(1, 1),
vectors_gpu.clone().into_tensor_arg(line),
ScalarArg { elem: row as u32 },
out_gpu.clone().into_tensor_arg(1),
dim_vec,
);
}
let result = out_gpu.read(&client);
let expected: Vec<f32> = (0..dim).map(|j| (row * 100 + j) as f32).collect();
println!("Row {row}: got {:?}", &result[..dim]);
println!(" exp {:?}", &expected);
for j in 0..dim {
if (result[j] - expected[j]).abs() > 1e-6 {
eprintln!(
"*** MISMATCH at row={row}, col={j}: got {}, expected {} ***",
result[j], expected[j]
);
}
}
assert_eq!(&result[..dim], &expected[..], "Row {row} data mismatch");
}
}
#[cube(launch_unchecked)]
fn compute_pairwise_dist<F: Float>(
vectors: &Tensor<Line<F>>,
norms: &Tensor<F>,
out_sq_euclid: &mut Tensor<F>,
out_cosine: &mut Tensor<F>,
n: u32,
#[comptime] use_cosine: bool,
#[comptime] dim_lines: usize,
) {
let idx = ABSOLUTE_POS_X;
let n_pairs = n * (n - 1u32) / 2u32;
if idx >= n_pairs {
terminate!();
}
let mut rem = idx;
let mut i = 0u32;
let mut step = n - 1u32;
while rem >= step {
rem -= step;
i += 1u32;
step = n - 1u32 - i;
}
let j = i + 1u32 + rem;
out_sq_euclid[idx as usize] = dist_sq_euclidean(vectors, i, j, dim_lines);
if use_cosine {
out_cosine[idx as usize] = dist_cosine(vectors, norms, i, j, dim_lines);
}
}
#[test]
fn test_gpu_distances_euclidean() {
let Some(device) = try_device() else {
eprintln!("Skipping: no wgpu backend");
return;
};
let client = WgpuRuntime::client(&device);
let line: usize = LINE_SIZE as usize;
let n = 4usize;
let dim = 8usize;
let dim_vec = dim / line;
let mut data = vec![0.0f32; n * dim];
data[0..dim].copy_from_slice(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); data[dim..2 * dim].copy_from_slice(&[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); data[2 * dim..3 * dim].copy_from_slice(&[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]); data[3 * dim..4 * dim].copy_from_slice(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]);
let n_pairs = n * (n - 1) / 2;
let vectors_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&data, vec![n, dim], &client);
let norms_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&[0.0f32], vec![1], &client);
let out_euclid = GpuTensor::<WgpuRuntime, f32>::from_slice(
&vec![0.0f32; n_pairs],
vec![n_pairs],
&client,
);
let out_cosine = GpuTensor::<WgpuRuntime, f32>::from_slice(
&vec![0.0f32; n_pairs],
vec![n_pairs],
&client,
);
unsafe {
let _ = compute_pairwise_dist::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
vectors_gpu.into_tensor_arg(line),
norms_gpu.into_tensor_arg(1),
out_euclid.clone().into_tensor_arg(1),
out_cosine.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
false,
dim_vec,
);
}
let euclid = out_euclid.read(&client);
let expected = [2.0f32, 1.0, 2.0, 1.0, 2.0, 3.0];
println!("Squared Euclidean distances:");
let pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)];
for (k, &(i, j)) in pairs.iter().enumerate() {
println!(
" ({i},{j}): gpu={:.4} expected={:.4} match={}",
euclid[k],
expected[k],
(euclid[k] - expected[k]).abs() < 1e-4
);
}
for (k, &exp) in expected.iter().enumerate() {
assert!(
(euclid[k] - exp).abs() < 1e-4,
"Pair {:?}: gpu={}, expected={}",
pairs[k],
euclid[k],
exp
);
}
}
#[test]
fn test_gpu_distances_cosine() {
let Some(device) = try_device() else {
eprintln!("Skipping: no wgpu backend");
return;
};
let client = WgpuRuntime::client(&device);
let line: usize = LINE_SIZE as usize;
let n = 4usize;
let dim = 8usize;
let dim_vec = dim / line;
let mut data = vec![0.0f32; n * dim];
data[0..dim].copy_from_slice(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
data[dim..2 * dim].copy_from_slice(&[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
data[2 * dim..3 * dim].copy_from_slice(&[1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]);
data[3 * dim..4 * dim].copy_from_slice(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]);
let norms: Vec<f32> = (0..n)
.map(|i| {
let row = &data[i * dim..(i + 1) * dim];
row.iter().map(|x| x * x).sum::<f32>().sqrt()
})
.collect();
let n_pairs = n * (n - 1) / 2;
let vectors_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&data, vec![n, dim], &client);
let norms_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&norms, vec![n], &client);
let out_euclid = GpuTensor::<WgpuRuntime, f32>::from_slice(
&vec![0.0f32; n_pairs],
vec![n_pairs],
&client,
);
let out_cosine = GpuTensor::<WgpuRuntime, f32>::from_slice(
&vec![0.0f32; n_pairs],
vec![n_pairs],
&client,
);
unsafe {
let _ = compute_pairwise_dist::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
vectors_gpu.into_tensor_arg(line),
norms_gpu.into_tensor_arg(1),
out_euclid.clone().into_tensor_arg(1),
out_cosine.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
true,
dim_vec,
);
}
let cosine = out_cosine.read(&client);
let sqrt2 = 2.0f32.sqrt();
let expected = [1.0, 1.0 - 1.0 / sqrt2, 1.0, 1.0 - 1.0 / sqrt2, 1.0, 1.0];
println!("Cosine distances:");
let pairs = [(0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3)];
for (k, &(i, j)) in pairs.iter().enumerate() {
println!(
" ({i},{j}): gpu={:.6} expected={:.6} match={}",
cosine[k],
expected[k],
(cosine[k] - expected[k]).abs() < 1e-4
);
if cosine[k] < -1e-6 {
eprintln!(" *** NEGATIVE cosine distance: {} ***", cosine[k]);
}
}
for (k, &exp) in expected.iter().enumerate() {
assert!(
(cosine[k] - exp).abs() < 1e-3,
"Pair {:?}: gpu={}, expected={}",
pairs[k],
cosine[k],
exp
);
}
}
#[test]
fn test_local_join_distances() {
let Some(device) = try_device() else {
eprintln!("Skipping: no wgpu backend");
return;
};
let client = WgpuRuntime::client(&device);
let line: usize = LINE_SIZE as usize;
let n = 8usize;
let dim = 8usize;
let dim_vec = dim / line;
let build_k = 4usize;
let mut data = vec![0.0f32; n * dim];
for i in 0..n {
data[i * dim + (i % dim)] = 1.0;
}
let norms = vec![1.0f32; n];
let is_new_bit = 1u32 << 31;
let mut graph_idx = vec![0u32; n * build_k];
let mut graph_dist = vec![0.0f32; n * build_k];
for i in 0..n {
for j in 0..build_k {
let nbr = (i + j + 1) % n;
graph_idx[i * build_k + j] = (nbr as u32) | is_new_bit;
let a = &data[i * dim..(i + 1) * dim];
let b = &data[nbr * dim..(nbr + 1) * dim];
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
graph_dist[i * build_k + j] = 1.0 - dot; }
let base = i * build_k;
let mut pairs: Vec<(u32, f32)> = (0..build_k)
.map(|j| (graph_idx[base + j], graph_dist[base + j]))
.collect();
pairs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
for (j, (idx, dist)) in pairs.into_iter().enumerate() {
graph_idx[base + j] = idx;
graph_dist[base + j] = dist;
}
}
let vectors_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&data, vec![n, dim], &client);
let norms_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&norms, vec![n], &client);
let graph_idx_gpu =
GpuTensor::<WgpuRuntime, u32>::from_slice(&graph_idx, vec![n, build_k], &client);
let graph_dist_gpu =
GpuTensor::<WgpuRuntime, f32>::from_slice(&graph_dist, vec![n, build_k], &client);
let reverse_idx_gpu = GpuTensor::<WgpuRuntime, u32>::from_slice(
&vec![0u32; n * build_k],
vec![n, build_k],
&client,
);
let reverse_count_gpu =
GpuTensor::<WgpuRuntime, u32>::from_slice(&vec![0u32; n], vec![n], &client);
let max_prop = MAX_PROPOSALS;
let prop_idx_gpu = GpuTensor::<WgpuRuntime, u32>::empty(vec![n, max_prop], &client);
let prop_dist_gpu = GpuTensor::<WgpuRuntime, f32>::empty(vec![n, max_prop], &client);
let prop_count_gpu =
GpuTensor::<WgpuRuntime, u32>::from_slice(&vec![0u32; n], vec![n], &client);
let rho_thresh = 65535u32;
unsafe {
let _ = local_join_shared::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(n as u32, 1, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
vectors_gpu.into_tensor_arg(line),
norms_gpu.into_tensor_arg(1),
graph_idx_gpu.into_tensor_arg(1),
graph_dist_gpu.into_tensor_arg(1),
reverse_idx_gpu.into_tensor_arg(1),
reverse_count_gpu.into_tensor_arg(1),
prop_idx_gpu.clone().into_tensor_arg(1),
prop_dist_gpu.clone().into_tensor_arg(1),
prop_count_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
ScalarArg { elem: rho_thresh },
ScalarArg { elem: 42u32 },
MAX_PROPOSALS as u32,
true, dim_vec,
build_k,
);
}
let p_idx = prop_idx_gpu.read(&client);
let p_dist = prop_dist_gpu.read(&client);
let p_count = prop_count_gpu.read(&client);
println!("Local join proposals (n={n}, build_k={build_k}, cosine):");
let mut any_negative = false;
let mut any_mismatch = false;
for node in 0..n {
let count = (p_count[node] as usize).min(max_prop);
if count == 0 {
continue;
}
println!(" node {node}: {count} proposals");
for p in 0..count.min(5) {
let cand = p_idx[node * max_prop + p] as usize;
let gpu_dist = p_dist[node * max_prop + p];
let a = &data[node * dim..(node + 1) * dim];
let b = &data[cand * dim..(cand + 1) * dim];
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let cpu_dist = 1.0 - dot / (norms[node] * norms[cand]);
let ok = (gpu_dist - cpu_dist).abs() < 1e-4;
println!(
" -> cand {cand}: gpu={:.6e} cpu={:.6e} match={ok}",
gpu_dist, cpu_dist
);
if gpu_dist < -1e-6 {
any_negative = true;
eprintln!(" *** NEGATIVE distance: {gpu_dist} ***");
}
if !ok {
any_mismatch = true;
}
}
}
assert!(
!any_negative,
"Negative cosine distances found in local_join proposals"
);
assert!(
!any_mismatch,
"Distance mismatches found in local_join proposals"
);
}
#[test]
fn test_merge_proposals() {
let Some(device) = try_device() else {
eprintln!("Skipping: no wgpu backend");
return;
};
let client = WgpuRuntime::client(&device);
let n = 4usize;
let k = 3usize;
let pid_mask = 0x7FFFFFFFu32;
let graph_idx_data: Vec<u32> = vec![
1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, ];
let graph_dist_data: Vec<f32> = vec![
0.1, 0.5, 0.9, 0.1, 0.3, 0.8, 0.2, 0.3, 0.7, 0.2, 0.4, 0.6, ];
let mut prop_idx = vec![0u32; n * MAX_PROPOSALS];
let mut prop_dist = vec![0.0f32; n * MAX_PROPOSALS];
let mut prop_count = vec![0u32; n];
prop_idx[0] = 2; prop_dist[0] = 0.05; prop_idx[1] = 1; prop_dist[1] = 0.08;
prop_count[0] = 2;
let graph_idx_gpu =
GpuTensor::<WgpuRuntime, u32>::from_slice(&graph_idx_data, vec![n, k], &client);
let graph_dist_gpu =
GpuTensor::<WgpuRuntime, f32>::from_slice(&graph_dist_data, vec![n, k], &client);
let prop_idx_gpu =
GpuTensor::<WgpuRuntime, u32>::from_slice(&prop_idx, vec![n, MAX_PROPOSALS], &client);
let prop_dist_gpu =
GpuTensor::<WgpuRuntime, f32>::from_slice(&prop_dist, vec![n, MAX_PROPOSALS], &client);
let prop_count_gpu =
GpuTensor::<WgpuRuntime, u32>::from_slice(&prop_count, vec![n], &client);
let update_counter = GpuTensor::<WgpuRuntime, u32>::from_slice(&[0u32], vec![1], &client);
let grid_n = (n as u32).div_ceil(WORKGROUP_SIZE_X);
unsafe {
let _ = merge_proposals::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(grid_n, 1, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
graph_idx_gpu.clone().into_tensor_arg(1),
graph_dist_gpu.clone().into_tensor_arg(1),
prop_idx_gpu.into_tensor_arg(1),
prop_dist_gpu.into_tensor_arg(1),
prop_count_gpu.into_tensor_arg(1),
update_counter.clone().into_tensor_arg(1),
ScalarArg { elem: n as u32 },
MAX_PROPOSALS as u32,
);
}
let result_idx = graph_idx_gpu.read(&client);
let result_dist = graph_dist_gpu.read(&client);
let updates = update_counter.read(&client);
println!("Merge proposals result:");
println!(" Total updates: {}", updates[0]);
for node in 0..n {
let base = node * k;
print!(" Node {node}:");
for j in 0..k {
let pid = result_idx[base + j] & pid_mask;
let is_new = result_idx[base + j] & (1u32 << 31) != 0;
let dist = result_dist[base + j];
print!(" ({pid}, {dist:.4}{}) ", if is_new { "*" } else { "" });
}
println!();
}
let base = 0;
assert_eq!(
result_idx[base] & pid_mask,
1,
"Node 0, slot 0 should be pid=1"
);
assert_eq!(
result_idx[base + 1] & pid_mask,
2,
"Node 0, slot 1 should be pid=2"
);
assert_eq!(
result_idx[base + 2] & pid_mask,
3,
"Node 0, slot 2 should be pid=3"
);
#[allow(unused_assignments)]
let mut prop_idx2 = vec![0u32; n * MAX_PROPOSALS];
#[allow(unused_assignments)]
let mut prop_dist2 = vec![0.0f32; n * MAX_PROPOSALS];
#[allow(unused_assignments)]
let mut prop_count2 = vec![0u32; n];
let n2 = 5usize;
let k2 = 3usize;
let graph_idx_data2: Vec<u32> = vec![
1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 0, 1, 2, ];
let graph_dist_data2: Vec<f32> = vec![
0.1, 0.5, 0.9, 0.1, 0.3, 0.8, 0.2, 0.3, 0.7, 0.2, 0.4, 0.6, 0.1, 0.2, 0.3, ];
prop_idx2 = vec![0u32; n2 * MAX_PROPOSALS];
prop_dist2 = vec![0.0f32; n2 * MAX_PROPOSALS];
prop_count2 = vec![0u32; n2];
prop_idx2[0] = 4; prop_dist2[0] = 0.3;
prop_count2[0] = 1;
let graph_idx_gpu2 =
GpuTensor::<WgpuRuntime, u32>::from_slice(&graph_idx_data2, vec![n2, k2], &client);
let graph_dist_gpu2 =
GpuTensor::<WgpuRuntime, f32>::from_slice(&graph_dist_data2, vec![n2, k2], &client);
let prop_idx_gpu2 =
GpuTensor::<WgpuRuntime, u32>::from_slice(&prop_idx2, vec![n2, MAX_PROPOSALS], &client);
let prop_dist_gpu2 = GpuTensor::<WgpuRuntime, f32>::from_slice(
&prop_dist2,
vec![n2, MAX_PROPOSALS],
&client,
);
let prop_count_gpu2 =
GpuTensor::<WgpuRuntime, u32>::from_slice(&prop_count2, vec![n2], &client);
let update_counter2 = GpuTensor::<WgpuRuntime, u32>::from_slice(&[0u32], vec![1], &client);
let grid_n2 = (n2 as u32).div_ceil(WORKGROUP_SIZE_X);
unsafe {
let _ = merge_proposals::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(grid_n2, 1, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
graph_idx_gpu2.clone().into_tensor_arg(1),
graph_dist_gpu2.clone().into_tensor_arg(1),
prop_idx_gpu2.into_tensor_arg(1),
prop_dist_gpu2.into_tensor_arg(1),
prop_count_gpu2.into_tensor_arg(1),
update_counter2.clone().into_tensor_arg(1),
ScalarArg { elem: n2 as u32 },
MAX_PROPOSALS as u32,
);
}
let r_idx = graph_idx_gpu2.read(&client);
let r_dist = graph_dist_gpu2.read(&client);
let r_updates = update_counter2.read(&client);
println!("\nMerge with new candidate:");
println!(" Updates: {}", r_updates[0]);
let base = 0;
for j in 0..k2 {
let pid = r_idx[base + j] & pid_mask;
let is_new = r_idx[base + j] & (1u32 << 31) != 0;
let dist = r_dist[base + j];
println!(" Node 0 slot {j}: pid={pid} dist={dist:.4} new={is_new}");
}
assert_eq!(r_updates[0], 1, "Should have exactly 1 update");
assert_eq!(r_idx[base] & pid_mask, 1, "Slot 0: pid=1 (unchanged)");
assert_eq!(r_idx[base + 1] & pid_mask, 4, "Slot 1: pid=4 (new)");
assert!(
r_idx[base + 1] & (1u32 << 31) != 0,
"Slot 1 should be flagged new"
);
assert_eq!(r_idx[base + 2] & pid_mask, 2, "Slot 2: pid=2 (shifted)");
assert!((r_dist[base] - 0.1).abs() < 1e-6);
assert!((r_dist[base + 1] - 0.3).abs() < 1e-6);
assert!((r_dist[base + 2] - 0.5).abs() < 1e-6);
for j in 0..k2 {
assert_ne!(
r_idx[base + j] & pid_mask,
3,
"pid=3 should have been evicted"
);
}
}
#[test]
fn test_end_to_end_quality() {
let Some(device) = try_device() else {
eprintln!("Skipping: no wgpu backend");
return;
};
let n = 100;
let dim = 8;
let k = 5;
let data_flat: Vec<f32> = (0..n * dim)
.map(|idx| {
let i = idx / dim;
let j = idx % dim;
let cluster = (i / 10) as f32 * 5.0;
cluster + (i % 10) as f32 * 0.1 + j as f32 * 0.01
})
.collect();
let data = Mat::from_fn(n, dim, |i, j| data_flat[i * dim + j]);
let mut ground_truth: Vec<Vec<usize>> = Vec::with_capacity(n);
for i in 0..n {
let mut dists: Vec<(usize, f32)> = (0..n)
.filter(|&j| j != i)
.map(|j| {
let a = &data_flat[i * dim..(i + 1) * dim];
let b = &data_flat[j * dim..(j + 1) * dim];
let d: f32 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
(j, d)
})
.collect();
dists.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
ground_truth.push(dists.iter().take(k).map(|&(j, _)| j).collect());
}
let index = NNDescentGpu::<f32, WgpuRuntime>::build(
data.as_ref(),
Dist::Euclidean,
Some(k),
None,
Some(15),
None,
Some(0.001),
Some(0.5),
None,
42,
true,
false,
device,
);
let (knn_indices, _) = index.extract_knn(false);
let mut total_hits = 0;
let total_possible = n * k;
for i in 0..n {
let gt_set: std::collections::HashSet<usize> =
ground_truth[i].iter().copied().collect();
for &idx in &knn_indices[i] {
if gt_set.contains(&idx) {
total_hits += 1;
}
}
}
let recall = total_hits as f64 / total_possible as f64;
println!("End-to-end extract recall@{k}: {recall:.4} ({total_hits}/{total_possible})");
assert!(recall > 0.7, "End-to-end recall too low: {recall:.4}");
}
#[test]
fn test_distances_dim32() {
let Some(device) = try_device() else {
return;
};
let client = WgpuRuntime::client(&device);
let line = LINE_SIZE as usize;
let n = 16usize;
let dim = 32usize;
let dim_vec = dim / line;
let data: Vec<f32> = (0..n * dim)
.map(|i| ((i % 7) as f32) * 0.1 + (i / dim) as f32)
.collect();
let vectors_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&data, vec![n, dim], &client);
let norms_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&[0.0f32], vec![1], &client);
let n_pairs = n * (n - 1) / 2;
let out_euclid = GpuTensor::<WgpuRuntime, f32>::from_slice(
&vec![0.0f32; n_pairs],
vec![n_pairs],
&client,
);
let out_cos = GpuTensor::<WgpuRuntime, f32>::from_slice(
&vec![0.0f32; n_pairs],
vec![n_pairs],
&client,
);
unsafe {
let _ = compute_pairwise_dist::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
vectors_gpu.into_tensor_arg(line),
norms_gpu.into_tensor_arg(1),
out_euclid.clone().into_tensor_arg(1),
out_cos.into_tensor_arg(1),
ScalarArg { elem: n as u32 },
false,
dim_vec,
);
}
let euclid = out_euclid.read(&client);
let a = &data[0..dim];
let b = &data[dim..2 * dim];
let expected: f32 = a.iter().zip(b).map(|(x, y)| (x - y) * (x - y)).sum();
println!("dim=32 dist(0,1): gpu={:.6} cpu={:.6}", euclid[0], expected);
assert!(
(euclid[0] - expected).abs() < 1e-3,
"dim=32 distance mismatch: gpu={}, cpu={}",
euclid[0],
expected
);
}
#[cube(launch_unchecked)]
fn debug_shared_mem_dist<F: Float>(
vectors: &Tensor<Line<F>>,
norms: &Tensor<F>,
pid_a: u32,
pid_b: u32,
out_dist: &mut Tensor<F>,
out_raw: &mut Tensor<F>,
#[comptime] _max_proposals: u32,
#[comptime] use_cosine: bool,
#[comptime] dim_lines: usize,
#[comptime] build_k: usize,
) {
let tx = UNIT_POS_X;
let max_cands_comp = build_k * 2usize;
let dim_scalars = dim_lines * 4usize;
let mut shared_vecs = SharedMemory::<F>::new(max_cands_comp * dim_scalars);
let mut shared_pids = SharedMemory::<u32>::new(max_cands_comp);
let mut shared_norms = SharedMemory::<F>::new(max_cands_comp);
if tx == 0u32 {
shared_pids[0usize] = pid_a;
shared_pids[1usize] = pid_b;
if use_cosine {
shared_norms[0usize] = norms[pid_a as usize];
shared_norms[1usize] = norms[pid_b as usize];
}
}
sync_cube();
let total_scalars = 2usize * dim_scalars;
let mut idx_load = tx as usize;
while idx_load < total_scalars {
let n_idx = idx_load / dim_scalars;
let s_idx = idx_load % dim_scalars;
let line_idx = s_idx / 4usize;
let lane = s_idx % 4usize;
let pid = shared_pids[n_idx];
let vec_offset = pid as usize * dim_lines + line_idx;
let line_val = vectors[vec_offset];
shared_vecs[idx_load] = line_val[lane];
idx_load += WORKGROUP_SIZE_X as usize;
}
sync_cube();
if tx == 0u32 {
let mut sum = F::new(0.0);
let mut s = 0usize;
while s < dim_scalars {
let va = shared_vecs[s];
let vb = shared_vecs[dim_scalars + s];
if use_cosine {
sum += va * vb;
} else {
let diff = va - vb;
sum += diff * diff;
}
s += 1usize;
}
let dist = if use_cosine {
F::new(1.0) - (sum / (shared_norms[0usize] * shared_norms[1usize]))
} else {
sum
};
out_dist[0usize] = dist;
out_dist[1usize] = sum;
if use_cosine {
out_dist[2usize] = shared_norms[0usize];
out_dist[3usize] = shared_norms[1usize];
}
let mut i = 0usize;
while i < total_scalars {
out_raw[i] = shared_vecs[i];
i += 1usize;
}
}
}
#[test]
fn test_shared_mem_local_join_pattern() {
let Some(device) = try_device() else {
eprintln!("Skipping: no wgpu backend");
return;
};
let client = WgpuRuntime::client(&device);
let line = LINE_SIZE as usize;
let n = 100usize;
let dim = 32usize;
let dim_vec = dim / line; let build_k = 30usize;
let mut data = vec![0.0f32; n * dim];
for i in 0..n {
for j in 0..dim {
data[i * dim + j] = (i * 1000 + j) as f32;
}
}
let norms: Vec<f32> = (0..n)
.map(|i| {
let row = &data[i * dim..(i + 1) * dim];
row.iter().map(|x| x * x).sum::<f32>().sqrt()
})
.collect();
let vectors_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&data, vec![n, dim], &client);
let norms_gpu = GpuTensor::<WgpuRuntime, f32>::from_slice(&norms, vec![n], &client);
let out_dist = GpuTensor::<WgpuRuntime, f32>::from_slice(&[0.0f32; 4], vec![4], &client);
let out_raw = GpuTensor::<WgpuRuntime, f32>::from_slice(
&vec![0.0f32; 2 * dim],
vec![2 * dim],
&client,
);
let pid_a = 0u32;
let pid_b = 1u32;
unsafe {
let _ = debug_shared_mem_dist::launch_unchecked::<f32, WgpuRuntime>(
&client,
CubeCount::Static(1, 1, 1),
CubeDim::new_2d(WORKGROUP_SIZE_X, 1),
vectors_gpu.clone().into_tensor_arg(line),
norms_gpu.clone().into_tensor_arg(1),
ScalarArg { elem: pid_a },
ScalarArg { elem: pid_b },
out_dist.clone().into_tensor_arg(1),
out_raw.clone().into_tensor_arg(1),
MAX_PROPOSALS as u32, false, dim_vec, build_k, );
}
let dist_result = out_dist.read(&client);
let raw = out_raw.read(&client);
let expected_a: Vec<f32> = (0..dim).map(|j| j as f32).collect();
let expected_b: Vec<f32> = (0..dim).map(|j| (1000 + j) as f32).collect();
println!("Shared mem vec A (first 8): {:?}", &raw[..8]);
println!("Expected vec A (first 8): {:?}", &expected_a[..8]);
println!("Shared mem vec B (first 8): {:?}", &raw[dim..dim + 8]);
println!("Expected vec B (first 8): {:?}", &expected_b[..8]);
let vec_a_ok = (0..dim).all(|j| (raw[j] - expected_a[j]).abs() < 1e-4);
let vec_b_ok = (0..dim).all(|j| (raw[dim + j] - expected_b[j]).abs() < 1e-4);
println!("Vec A correct: {vec_a_ok}");
println!("Vec B correct: {vec_b_ok}");
let cpu_dist: f32 = expected_a
.iter()
.zip(&expected_b)
.map(|(a, b)| (a - b) * (a - b))
.sum();
let gpu_dist = dist_result[0];
println!(
"GPU dist: {gpu_dist:.4} CPU dist: {cpu_dist:.4} match: {}",
(gpu_dist - cpu_dist).abs() < 1e-2
);
assert!(vec_a_ok, "Vector A in shared memory is wrong");
assert!(vec_b_ok, "Vector B in shared memory is wrong");
assert!(
(gpu_dist - cpu_dist).abs() < 1e-2,
"Distance mismatch: gpu={gpu_dist}, cpu={cpu_dist}"
);
}
}