hnsw-stable 0.10.1

Pure-Rust port of hnswlib (HNSW approximate nearest neighbors)
Documentation
use crate::kernels::cosine_distance;
use crate::kernels::cosine_distance_qi8;
use crate::kernels::inner_product_distance;
use crate::kernels::inner_product_distance_qi8;
use crate::kernels::l2_sq_qi8;
use crate::kernels::Kernel;
use crate::scalar::Scalar;
use crate::vector::Dense;
use crate::vector::Qi8;
use crate::vector::Qi8Ref;
use crate::vector::VectorFamily;
use std::marker::PhantomData;

pub trait Metric: Clone + Send + Sync + 'static
{
  type Family: VectorFamily;
  fn distance<'a, 'b>(
    &self,
    a: <Self::Family as VectorFamily>::Ref<'a>,
    b: <Self::Family as VectorFamily>::Ref<'b>,
  ) -> f32;
}

#[derive(Clone, Debug, Default)]
pub struct L2<S: Scalar = f32>(PhantomData<S>);

impl<S: Scalar> L2<S>
{
  pub fn new() -> Self
  {
    Self(PhantomData)
  }
}

impl<S: Scalar> Metric for L2<S>
{
  type Family = Dense<S>;

  fn distance(&self, a: &[S], b: &[S]) -> f32
  {
    debug_assert_eq!(a.len(), b.len());
    <S as Kernel>::l2_sq(a, b)
  }
}

#[derive(Clone, Debug, Default)]
pub struct InnerProduct<S: Scalar = f32>(PhantomData<S>);

impl<S: Scalar> InnerProduct<S>
{
  pub fn new() -> Self
  {
    Self(PhantomData)
  }
}

impl<S: Scalar> Metric for InnerProduct<S>
{
  type Family = Dense<S>;

  fn distance(&self, a: &[S], b: &[S]) -> f32
  {
    debug_assert_eq!(a.len(), b.len());
    inner_product_distance::<S>(a, b)
  }
}

#[derive(Clone, Debug, Default)]
pub struct Cosine<S: Scalar = f32>(PhantomData<S>);

impl<S: Scalar> Cosine<S>
{
  pub fn new() -> Self
  {
    Self(PhantomData)
  }
}

impl<S: Scalar> Metric for Cosine<S>
{
  type Family = Dense<S>;

  fn distance(&self, a: &[S], b: &[S]) -> f32
  {
    debug_assert_eq!(a.len(), b.len());
    cosine_distance::<S>(a, b)
  }
}

#[derive(Clone, Debug, Default)]
pub struct L2Qi8;

impl L2Qi8
{
  pub fn new() -> Self
  {
    Self
  }
}

impl Metric for L2Qi8
{
  type Family = Qi8;

  fn distance<'a, 'b>(&self, a: Qi8Ref<'a>, b: Qi8Ref<'b>) -> f32
  {
    l2_sq_qi8(a.data, a.scale, a.zero_point, b.data, b.scale, b.zero_point)
  }
}

#[derive(Clone, Debug, Default)]
pub struct InnerProductQi8;

impl InnerProductQi8
{
  pub fn new() -> Self
  {
    Self
  }
}

impl Metric for InnerProductQi8
{
  type Family = Qi8;

  fn distance<'a, 'b>(&self, a: Qi8Ref<'a>, b: Qi8Ref<'b>) -> f32
  {
    inner_product_distance_qi8(a.data, a.scale, a.zero_point, b.data, b.scale, b.zero_point)
  }
}

#[derive(Clone, Debug, Default)]
pub struct CosineQi8;

impl CosineQi8
{
  pub fn new() -> Self
  {
    Self
  }
}

impl Metric for CosineQi8
{
  type Family = Qi8;

  fn distance<'a, 'b>(&self, a: Qi8Ref<'a>, b: Qi8Ref<'b>) -> f32
  {
    cosine_distance_qi8(a.data, a.scale, a.zero_point, b.data, b.scale, b.zero_point)
  }
}

pub fn normalize_cosine_in_place<S: Scalar>(vector: &mut [S])
{
  let mut norm_sq = 0.0_f32;
  for &v in vector.iter()
  {
    let v = v.to_f32();
    norm_sq += v * v;
  }
  if norm_sq == 0.0
  {
    return;
  }
  let inv_norm = norm_sq.sqrt().recip();
  for v in vector.iter_mut()
  {
    *v = S::from_f32(v.to_f32() * inv_norm);
  }
}

#[cfg(test)]
mod tests
{
  use super::*;
  use approx::assert_relative_eq;
  use rand::rngs::StdRng;
  use rand::Rng;
  use rand::SeedableRng;

  fn l2_ref(a: &[f32], b: &[f32]) -> f32
  {
    a.iter()
      .zip(b.iter())
      .map(|(a, b)| {
        let d = a - b;
        d * d
      })
      .sum()
  }

  fn ip_ref(a: &[f32], b: &[f32]) -> f32
  {
    1.0 - a.iter().zip(b.iter()).map(|(a, b)| a * b).sum::<f32>()
  }

  #[test]
  fn l2_matches_scalar_across_dims()
  {
    let mut rng = StdRng::seed_from_u64(123);
    let dims = [
      1usize, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255,
    ];
    let metric = L2::<f32>::new();
    for &dim in &dims
    {
      for _ in 0..100
      {
        let a: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
        let b: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
        assert_relative_eq!(
          metric.distance(&a, &b),
          l2_ref(&a, &b),
          epsilon = 1e-3,
          max_relative = 1e-3
        );
      }
    }
  }

  #[test]
  fn inner_product_matches_scalar_across_dims()
  {
    let mut rng = StdRng::seed_from_u64(456);
    let dims = [
      1usize, 2, 3, 4, 5, 7, 8, 9, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129, 255,
    ];
    let metric = InnerProduct::<f32>::new();
    for &dim in &dims
    {
      for _ in 0..100
      {
        let a: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
        let b: Vec<f32> = (0..dim).map(|_| rng.gen_range(-1.0..1.0)).collect();
        assert_relative_eq!(
          metric.distance(&a, &b),
          ip_ref(&a, &b),
          epsilon = 1e-3,
          max_relative = 1e-3
        );
      }
    }
  }
}