use alloc::vec;
use alloc::vec::Vec;
use crate::vector_ops::DistanceMetric;
pub(crate) struct Xorshift64(u64);
impl Xorshift64 {
pub fn new(seed: u64) -> Self {
Self(if seed == 0 {
0x5EED_DEAD_BEEF_CAFE
} else {
seed
})
}
pub fn from_data(data: &[f32]) -> Self {
let mut h: u64 = 0x517c_c1b7_2722_0a95;
for &x in data {
let bits = u64::from(x.to_bits());
h ^= bits;
h = h.wrapping_mul(0x9e37_79b9_7f4a_7c15);
h ^= h >> 30;
}
Self::new(h)
}
#[inline]
pub fn next_u64(&mut self) -> u64 {
let mut x = self.0;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.0 = x;
x
}
#[inline]
#[allow(clippy::cast_precision_loss)]
pub fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / ((1u64 << 53) as f64)
}
#[inline]
#[allow(clippy::cast_possible_truncation)]
pub fn next_usize(&mut self, n: usize) -> usize {
if n == 0 {
return 0;
}
let n = n as u64;
let mut x = self.next_u64();
let mut hi = ((u128::from(x) * u128::from(n)) >> 64) as u64;
let mut lo = x.wrapping_mul(n);
if lo < n {
let threshold = n.wrapping_neg() % n; while lo < threshold {
x = self.next_u64();
let full = u128::from(x) * u128::from(n);
hi = (full >> 64) as u64;
lo = full as u64;
}
}
hi as usize
}
}
fn kmeans_pp_init(
flat_vectors: &[f32],
dim: usize,
k: usize,
rng: &mut Xorshift64,
metric: DistanceMetric,
) -> Vec<f32> {
assert!(dim > 0, "kmeans: dimension must be positive");
let n = flat_vectors.len() / dim;
debug_assert!(k > 0 && k <= n);
let mut centroids = Vec::with_capacity(k * dim);
let first = rng.next_usize(n);
centroids.extend_from_slice(&flat_vectors[first * dim..(first + 1) * dim]);
let mut min_dists = vec![f32::INFINITY; n];
for c in 1..k {
let prev_start = (c - 1) * dim;
let prev_centroid = ¢roids[prev_start..prev_start + dim];
let mut total: f64 = 0.0;
for i in 0..n {
let v = &flat_vectors[i * dim..(i + 1) * dim];
let d = metric.compute(v, prev_centroid);
if d < min_dists[i] {
min_dists[i] = d;
}
total += f64::from(min_dists[i]);
}
if total <= 0.0 {
centroids.extend_from_slice(&flat_vectors[rng.next_usize(n) * dim..][..dim]);
continue;
}
let threshold = rng.next_f64() * total;
let mut cumulative: f64 = 0.0;
let mut chosen = n - 1;
for (i, dist) in min_dists.iter().enumerate().take(n) {
cumulative += f64::from(*dist);
if cumulative >= threshold {
chosen = i;
break;
}
}
centroids.extend_from_slice(&flat_vectors[chosen * dim..(chosen + 1) * dim]);
}
centroids
}
fn assign_all(
flat_vectors: &[f32],
dim: usize,
centroids: &[f32],
k: usize,
assignments: &mut Vec<u32>,
metric: DistanceMetric,
) -> f64 {
let n = flat_vectors.len() / dim;
assignments.clear();
assignments.reserve(n);
let mut total_dist: f64 = 0.0;
for i in 0..n {
let v = &flat_vectors[i * dim..(i + 1) * dim];
let mut best_c = 0u32;
let mut best_d = f32::INFINITY;
for c in 0..k {
let centroid = ¢roids[c * dim..(c + 1) * dim];
let d = metric.compute(v, centroid);
if d < best_d {
best_d = d;
#[allow(clippy::cast_possible_truncation)]
{
best_c = c as u32;
}
}
}
assignments.push(best_c);
total_dist += f64::from(best_d);
}
total_dist
}
fn recompute_centroids(
flat_vectors: &[f32],
dim: usize,
k: usize,
assignments: &[u32],
centroids: &mut [f32],
) {
let n = flat_vectors.len() / dim;
let mut accum = vec![0.0f64; k * dim];
let mut counts = vec![0u64; k];
for i in 0..n {
let c = assignments[i] as usize;
counts[c] += 1;
let v = &flat_vectors[i * dim..(i + 1) * dim];
let dest = &mut accum[c * dim..(c + 1) * dim];
for (d, &s) in dest.iter_mut().zip(v.iter()) {
*d += f64::from(s);
}
}
for c in 0..k {
if counts[c] > 0 {
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / counts[c] as f64;
let src = &accum[c * dim..(c + 1) * dim];
let dest = &mut centroids[c * dim..(c + 1) * dim];
#[allow(clippy::cast_possible_truncation)]
for (d, &s) in dest.iter_mut().zip(src.iter()) {
*d = (s * scale) as f32;
}
} else {
centroids[c * dim..(c + 1) * dim].fill(0.0);
}
}
}
pub fn kmeans(
flat_vectors: &[f32],
dim: usize,
k: usize,
max_iter: usize,
metric: DistanceMetric,
) -> Vec<f32> {
let n = flat_vectors.len() / dim.max(1);
if n == 0 || dim == 0 {
return Vec::new();
}
let k = k.min(n);
if k == 0 {
return Vec::new();
}
let mut rng = Xorshift64::from_data(&flat_vectors[..dim.min(flat_vectors.len())]);
let mut centroids = kmeans_pp_init(flat_vectors, dim, k, &mut rng, metric);
let mut assignments = Vec::with_capacity(n);
let mut prev_dist = f64::INFINITY;
let mut old_centroids = vec![0.0f32; k * dim];
let mut counts = vec![0u32; k];
for iter in 0..max_iter {
let total_dist = assign_all(flat_vectors, dim, ¢roids, k, &mut assignments, metric);
let improvement = (prev_dist - total_dist) / prev_dist.max(1e-12);
if improvement < 0.001 && iter > 0 {
break;
}
prev_dist = total_dist;
old_centroids.copy_from_slice(¢roids);
recompute_centroids(flat_vectors, dim, k, &assignments, &mut centroids);
counts.fill(0);
for &a in &assignments {
counts[a as usize] += 1;
}
for c in 0..k {
if counts[c] == 0 {
centroids[c * dim..(c + 1) * dim]
.copy_from_slice(&old_centroids[c * dim..(c + 1) * dim]);
}
}
}
centroids
}
pub fn assign_nearest(
vector: &[f32],
centroids: &[f32],
dim: usize,
num_clusters: usize,
metric: DistanceMetric,
) -> (u32, f32) {
let mut best_c = 0u32;
let mut best_d = f32::INFINITY;
for c in 0..num_clusters {
let start = c * dim;
let end = start + dim;
if end > centroids.len() {
break;
}
let centroid = ¢roids[start..end];
let d = metric.compute(vector, centroid);
if d < best_d {
best_d = d;
#[allow(clippy::cast_possible_truncation)]
{
best_c = c as u32;
}
}
}
(best_c, best_d)
}
pub fn nearest_clusters(
query: &[f32],
centroids: &[f32],
dim: usize,
num_clusters: usize,
nprobe: usize,
metric: DistanceMetric,
diversity: crate::probe_select::DiversityConfig,
) -> Vec<(u32, f32)> {
if num_clusters == 0 {
return Vec::new();
}
let nprobe = nprobe.min(num_clusters).max(1);
#[allow(clippy::cast_possible_truncation)]
let mut dists: Vec<(u32, f32)> = (0..num_clusters)
.filter_map(|c| {
let start = c * dim;
let end = start + dim;
if end > centroids.len() {
return None;
}
let centroid = ¢roids[start..end];
Some((c as u32, metric.compute(query, centroid)))
})
.collect();
if dists.is_empty() {
return Vec::new();
}
if !diversity.enabled() {
let select_idx = (nprobe - 1).min(dists.len() - 1);
dists.select_nth_unstable_by(select_idx, |a, b| a.1.total_cmp(&b.1));
dists.truncate(nprobe.min(dists.len()));
dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
return dists;
}
let shortlist_size = (nprobe.saturating_mul(2)).min(dists.len());
dists.select_nth_unstable_by(shortlist_size - 1, |a, b| a.1.total_cmp(&b.1));
dists.truncate(shortlist_size);
dists.sort_unstable_by(|a, b| a.1.total_cmp(&b.1));
let mut shortlist_centroids: Vec<f32> = Vec::with_capacity(shortlist_size * dim);
for &(cid, _) in &dists {
let start = cid as usize * dim;
let end = start + dim;
if end > centroids.len() {
continue;
}
shortlist_centroids.extend_from_slice(¢roids[start..end]);
}
crate::probe_select::select_diverse_probes(
&dists,
&shortlist_centroids,
dim,
nprobe,
diversity,
metric,
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn xorshift_deterministic() {
let mut rng1 = Xorshift64::new(42);
let mut rng2 = Xorshift64::new(42);
for _ in 0..100 {
assert_eq!(rng1.next_u64(), rng2.next_u64());
}
}
#[test]
fn xorshift_f64_range() {
let mut rng = Xorshift64::new(12345);
for _ in 0..10_000 {
let v = rng.next_f64();
assert!((0.0..1.0).contains(&v));
}
}
#[test]
fn kmeans_basic_2d() {
#[rustfmt::skip]
let data: Vec<f32> = vec![
0.1, 0.1,
-0.1, 0.2,
0.0, -0.1,
0.2, 0.0,
10.0, 10.1,
9.9, 10.0,
10.1, 9.9,
10.0, 10.0,
];
let centroids = kmeans(&data, 2, 2, 25, DistanceMetric::EuclideanSq);
assert_eq!(centroids.len(), 4);
let c0 = ¢roids[0..2];
let c1 = ¢roids[2..4];
let near_origin = |c: &[f32]| c[0].abs() < 1.0 && c[1].abs() < 1.0;
let near_ten = |c: &[f32]| (c[0] - 10.0).abs() < 1.0 && (c[1] - 10.0).abs() < 1.0;
assert!(
(near_origin(c0) && near_ten(c1)) || (near_origin(c1) && near_ten(c0)),
"centroids did not converge: c0={c0:?}, c1={c1:?}"
);
}
#[test]
fn assign_nearest_basic() {
let centroids = vec![0.0, 0.0, 10.0, 10.0];
let (c, _d) = assign_nearest(&[9.0, 9.0], ¢roids, 2, 2, DistanceMetric::EuclideanSq);
assert_eq!(c, 1);
}
#[test]
fn nearest_clusters_ordering() {
let centroids = vec![0.0, 0.0, 5.0, 5.0, 10.0, 10.0];
let result = nearest_clusters(
&[4.0, 4.0],
¢roids,
2,
3,
2,
DistanceMetric::EuclideanSq,
crate::probe_select::DiversityConfig { lambda: 0.0 },
);
assert_eq!(result.len(), 2);
assert_eq!(result[0].0, 1); assert_eq!(result[1].0, 0); }
}