#[derive(Clone)]
pub struct FlatVectors {
pub data: Vec<f32>,
pub dim: usize,
pub count: usize,
}
impl FlatVectors {
pub fn new(dim: usize) -> Self {
Self {
data: Vec::new(),
dim,
count: 0,
}
}
pub fn with_capacity(dim: usize, n: usize) -> Self {
Self {
data: Vec::with_capacity(n * dim),
dim,
count: 0,
}
}
#[inline]
pub fn push(&mut self, vector: &[f32]) {
debug_assert_eq!(vector.len(), self.dim);
self.data.extend_from_slice(vector);
self.count += 1;
}
#[inline]
pub fn get(&self, idx: usize) -> &[f32] {
let start = idx * self.dim;
&self.data[start..start + self.dim]
}
#[inline]
pub fn zero_out(&mut self, idx: usize) {
let start = idx * self.dim;
for v in &mut self.data[start..start + self.dim] {
*v = f32::NAN;
}
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
}
#[inline]
pub fn l2_squared(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(feature = "simd")]
{
simd_l2_squared(a, b)
}
#[cfg(not(feature = "simd"))]
{
scalar_l2_squared(a, b)
}
}
#[inline]
pub fn scalar_l2_squared(a: &[f32], b: &[f32]) -> f32 {
let len = a.len();
let mut s0 = 0.0f32;
let mut s1 = 0.0f32;
let mut s2 = 0.0f32;
let mut s3 = 0.0f32;
let mut i = 0;
while i + 16 <= len {
for j in 0..4 {
let off = i + j * 4;
let d0 = a[off] - b[off];
let d1 = a[off + 1] - b[off + 1];
let d2 = a[off + 2] - b[off + 2];
let d3 = a[off + 3] - b[off + 3];
s0 += d0 * d0;
s1 += d1 * d1;
s2 += d2 * d2;
s3 += d3 * d3;
}
i += 16;
}
while i < len {
let d = a[i] - b[i];
s0 += d * d;
i += 1;
}
s0 + s1 + s2 + s3
}
#[cfg(feature = "simd")]
#[inline]
pub fn simd_l2_squared(a: &[f32], b: &[f32]) -> f32 {
simsimd::SpatialSimilarity::sqeuclidean(a, b)
.map(|d| d as f32)
.unwrap_or_else(|| scalar_l2_squared(a, b))
}
#[inline]
pub fn inner_product(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(feature = "simd")]
{
simsimd::SpatialSimilarity::inner(a, b)
.map(|d| -(d as f32))
.unwrap_or_else(|| scalar_inner_product(a, b))
}
#[cfg(not(feature = "simd"))]
{
scalar_inner_product(a, b)
}
}
#[inline]
fn scalar_inner_product(a: &[f32], b: &[f32]) -> f32 {
let mut s0 = 0.0f32;
let mut s1 = 0.0f32;
let mut s2 = 0.0f32;
let mut s3 = 0.0f32;
let len = a.len();
let mut i = 0;
while i + 16 <= len {
for j in 0..4 {
let off = i + j * 4;
s0 += a[off] * b[off];
s1 += a[off + 1] * b[off + 1];
s2 += a[off + 2] * b[off + 2];
s3 += a[off + 3] * b[off + 3];
}
i += 16;
}
while i < len {
s0 += a[i] * b[i];
i += 1;
}
-(s0 + s1 + s2 + s3)
}
#[inline]
pub fn pq_asymmetric_distance(codes: &[u8], table: &[f32], k: usize) -> f32 {
let mut dist = 0.0f32;
for (i, &code) in codes.iter().enumerate() {
dist += unsafe { *table.get_unchecked(i * k + code as usize) };
}
dist
}
pub struct VisitedSet {
bits: Vec<u64>,
generation: u64,
gens: Vec<u64>,
}
impl VisitedSet {
pub fn new(n: usize) -> Self {
Self {
bits: vec![0u64; (n + 63) / 64],
generation: 1,
gens: vec![0u64; n],
}
}
#[inline]
pub fn clear(&mut self) {
self.generation += 1;
}
#[inline]
pub fn insert(&mut self, id: u32) {
self.gens[id as usize] = self.generation;
}
#[inline]
pub fn contains(&self, id: u32) -> bool {
self.gens[id as usize] == self.generation
}
}
#[cfg(feature = "gpu")]
pub mod gpu {
use super::FlatVectors;
#[derive(Debug, Clone, Copy)]
pub enum GpuBackend {
Metal,
Cuda,
Vulkan,
}
pub struct GpuDistanceContext {
backend: GpuBackend,
batch_size: usize,
}
impl GpuDistanceContext {
pub fn new() -> Option<Self> {
#[cfg(target_os = "macos")]
let backend = GpuBackend::Metal;
#[cfg(not(target_os = "macos"))]
let backend = GpuBackend::Cuda;
Some(Self {
backend,
batch_size: 4096,
})
}
pub fn batch_l2_squared(
&self,
query: &[f32],
vectors: &FlatVectors,
k: usize,
) -> Vec<(u32, f32)> {
use rayon::prelude::*;
let mut dists: Vec<(u32, f32)> = (0..vectors.count as u32)
.into_par_iter()
.map(|i| {
let v = vectors.get(i as usize);
(i, super::scalar_l2_squared(query, v))
})
.collect();
dists.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
dists.truncate(k);
dists
}
pub fn backend(&self) -> GpuBackend {
self.backend
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_l2_squared() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((l2_squared(&a, &b) - 27.0).abs() < 1e-6);
}
#[test]
fn test_l2_identical() {
let a = vec![1.0; 128];
assert!(l2_squared(&a, &a) < 1e-10);
}
#[test]
fn test_inner_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((inner_product(&a, &b) - (-32.0)).abs() < 1e-6);
}
#[test]
fn test_flat_vectors() {
let mut fv = FlatVectors::new(3);
fv.push(&[1.0, 2.0, 3.0]);
fv.push(&[4.0, 5.0, 6.0]);
assert_eq!(fv.len(), 2);
assert_eq!(fv.get(0), &[1.0, 2.0, 3.0]);
assert_eq!(fv.get(1), &[4.0, 5.0, 6.0]);
}
#[test]
fn test_visited_set() {
let mut vs = VisitedSet::new(100);
vs.insert(42);
assert!(vs.contains(42));
assert!(!vs.contains(43));
vs.clear(); assert!(!vs.contains(42));
vs.insert(43);
assert!(vs.contains(43));
}
#[test]
fn test_pq_flat_table() {
let table = vec![
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, ];
let codes = vec![1u8, 2u8]; let dist = pq_asymmetric_distance(&codes, &table, 4);
assert!((dist - (0.2 + 0.7)).abs() < 1e-6);
}
}