use std::fmt;
use bytemuck::cast_slice;
use half::f16;
use crate::distance::{DistanceMetric, vtype_to_scalar_kind};
use crate::types::VectorType;
#[derive(Debug, Clone, Copy)]
pub struct HnswParams {
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
}
impl Default for HnswParams {
fn default() -> Self {
Self {
m: 16,
ef_construction: 200,
ef_search: 64,
}
}
}
#[derive(Debug)]
pub struct IndexError(pub String);
impl fmt::Display for IndexError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "index error: {}", self.0)
}
}
impl std::error::Error for IndexError {}
pub struct HnswIndex {
inner: usearch::Index,
_dim: usize,
vtype: VectorType,
}
impl HnswIndex {
pub fn new(
dim: usize,
vtype: VectorType,
metric: DistanceMetric,
params: Option<HnswParams>,
) -> Result<Self, IndexError> {
let p = params.unwrap_or_default();
let opts = usearch::IndexOptions {
dimensions: dim,
metric: metric.to_usearch(),
quantization: vtype_to_scalar_kind(vtype),
connectivity: p.m,
expansion_add: p.ef_construction,
expansion_search: p.ef_search,
multi: false,
};
let inner = usearch::Index::new(&opts).map_err(|e| IndexError(e.to_string()))?;
Ok(Self {
inner,
_dim: dim,
vtype,
})
}
pub fn len(&self) -> usize {
self.inner.size()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn add(&self, key: u64, blob: &[u8]) -> Result<(), IndexError> {
self.reserve_if_needed()?;
match self.vtype {
VectorType::Float4 => {
let v: &[f32] = cast_slice(blob);
self.inner
.add(key, v)
.map_err(|e| IndexError(e.to_string()))
}
VectorType::Float8 => {
let v: &[f64] = cast_slice(blob);
self.inner
.add(key, v)
.map_err(|e| IndexError(e.to_string()))
}
VectorType::Int1 => {
let v: &[i8] = cast_slice(blob);
self.inner
.add(key, v)
.map_err(|e| IndexError(e.to_string()))
}
VectorType::Float2 => {
let v: &[f16] = cast_slice(blob);
let f: Vec<f32> = v.iter().map(|x| x.to_f32()).collect();
self.inner
.add(key, &f)
.map_err(|e| IndexError(e.to_string()))
}
VectorType::Int2 => {
let v: &[i16] = cast_slice(blob);
let f: Vec<f32> = v.iter().map(|x| *x as f32).collect();
self.inner
.add(key, &f)
.map_err(|e| IndexError(e.to_string()))
}
VectorType::Int4 => {
let v: &[i32] = cast_slice(blob);
let f: Vec<f32> = v.iter().map(|x| *x as f32).collect();
self.inner
.add(key, &f)
.map_err(|e| IndexError(e.to_string()))
}
}
}
pub fn search(&self, query_blob: &[u8], k: usize) -> Result<Vec<(u64, f32)>, IndexError> {
if self.is_empty() {
return Ok(Vec::new());
}
let matches = match self.vtype {
VectorType::Float4 => {
let q: &[f32] = cast_slice(query_blob);
self.inner.search(q, k)
}
VectorType::Float8 => {
let q: &[f64] = cast_slice(query_blob);
self.inner.search(q, k)
}
VectorType::Int1 => {
let q: &[i8] = cast_slice(query_blob);
self.inner.search(q, k)
}
VectorType::Float2 => {
let q: &[f16] = cast_slice(query_blob);
let f: Vec<f32> = q.iter().map(|x| x.to_f32()).collect();
self.inner.search(&f, k)
}
VectorType::Int2 => {
let q: &[i16] = cast_slice(query_blob);
let f: Vec<f32> = q.iter().map(|x| *x as f32).collect();
self.inner.search(&f, k)
}
VectorType::Int4 => {
let q: &[i32] = cast_slice(query_blob);
let f: Vec<f32> = q.iter().map(|x| *x as f32).collect();
self.inner.search(&f, k)
}
}
.map_err(|e| IndexError(e.to_string()))?;
Ok(matches.keys.into_iter().zip(matches.distances).collect())
}
pub fn remove(&self, key: u64) -> Result<(), IndexError> {
self.inner
.remove(key)
.map(|_| ())
.map_err(|e| IndexError(e.to_string()))
}
pub fn save_to_buffer(&self) -> Result<Vec<u8>, IndexError> {
let len = self.inner.serialized_length();
let mut buf = vec![0u8; len];
self.inner
.save_to_buffer(&mut buf)
.map_err(|e| IndexError(e.to_string()))?;
Ok(buf)
}
pub fn load_from_buffer(&self, buf: &[u8]) -> Result<(), IndexError> {
self.inner
.load_from_buffer(buf)
.map_err(|e| IndexError(e.to_string()))
}
fn reserve_if_needed(&self) -> Result<(), IndexError> {
if self.inner.size() >= self.inner.capacity() {
let new_cap = (self.inner.capacity() * 2).max(64);
self.inner
.reserve(new_cap)
.map_err(|e| IndexError(e.to_string()))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytemuck::cast_slice;
fn f32_blob(values: &[f32]) -> Vec<u8> {
cast_slice(values).to_vec()
}
fn f64_blob(values: &[f64]) -> Vec<u8> {
cast_slice(values).to_vec()
}
#[test]
fn hnsw_params_default_values() {
let p = HnswParams::default();
assert_eq!(p.m, 16);
assert_eq!(p.ef_construction, 200);
assert_eq!(p.ef_search, 64);
}
#[test]
fn new_float4_l2_does_not_error() {
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None);
assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
}
#[test]
fn new_float8_cosine_does_not_error() {
let idx = HnswIndex::new(4, VectorType::Float8, DistanceMetric::Cosine, None);
assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
}
#[test]
fn new_with_custom_params_does_not_error() {
let params = HnswParams {
m: 8,
ef_construction: 64,
ef_search: 32,
};
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, Some(params));
assert!(idx.is_ok(), "expected Ok, got {:?}", idx.err());
}
#[test]
fn len_empty_index_is_zero() {
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
assert_eq!(idx.len(), 0);
assert!(idx.is_empty());
}
#[test]
fn len_increases_after_add() {
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
assert_eq!(idx.len(), 1);
assert!(!idx.is_empty());
idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
assert_eq!(idx.len(), 2);
idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
assert_eq!(idx.len(), 3);
}
#[test]
fn search_nearest_orthogonal_float4() {
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
let results = idx.search(&f32_blob(&[0.9, 0.1, 0.0]), 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(
results[0].0, 1,
"expected key 1 ([1,0,0]) as nearest, got key {}",
results[0].0
);
}
#[test]
fn search_returns_empty_on_empty_index() {
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
let results = idx.search(&f32_blob(&[1.0, 0.0, 0.0]), 5).unwrap();
assert!(results.is_empty());
}
#[test]
fn search_k_larger_than_index_returns_all_vectors() {
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
let results = idx.search(&f32_blob(&[1.0, 0.0, 0.0]), 10).unwrap();
assert_eq!(results.len(), 2);
}
#[test]
fn remove_decreases_len() {
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
idx.add(10, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
idx.add(20, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
idx.add(30, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
assert_eq!(idx.len(), 3);
idx.remove(20).unwrap();
assert_eq!(idx.len(), 2);
}
#[test]
fn remove_key_no_longer_returned_by_search() {
let idx = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
idx.remove(2).unwrap();
let results = idx.search(&f32_blob(&[0.0, 1.0, 0.0]), 3).unwrap();
let returned_keys: Vec<u64> = results.iter().map(|(k, _)| *k).collect();
assert!(
!returned_keys.contains(&2),
"removed key 2 should not appear in search results, got {:?}",
returned_keys
);
}
#[test]
fn save_load_roundtrip_float4() {
let src = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
src.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
src.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
src.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
let buf = src.save_to_buffer().unwrap();
assert!(!buf.is_empty(), "serialized buffer must not be empty");
let dst = HnswIndex::new(3, VectorType::Float4, DistanceMetric::L2, None).unwrap();
dst.load_from_buffer(&buf).unwrap();
assert_eq!(dst.len(), src.len());
let results = dst.search(&f32_blob(&[0.9, 0.1, 0.0]), 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(
results[0].0, 1,
"post-load search should return key 1, got {}",
results[0].0
);
}
#[test]
fn add_search_float8() {
let idx = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
idx.add(1, &f64_blob(&[1.0, 0.0, 0.0])).unwrap();
idx.add(2, &f64_blob(&[0.0, 1.0, 0.0])).unwrap();
idx.add(3, &f64_blob(&[0.0, 0.0, 1.0])).unwrap();
let results = idx.search(&f64_blob(&[0.1, 0.0, 0.9]), 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(
results[0].0, 3,
"expected key 3 ([0,0,1]) as nearest, got key {}",
results[0].0
);
}
#[test]
fn save_load_roundtrip_float8() {
let src = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
src.add(1, &f64_blob(&[1.0, 0.0, 0.0])).unwrap();
src.add(2, &f64_blob(&[0.0, 1.0, 0.0])).unwrap();
src.add(3, &f64_blob(&[0.0, 0.0, 1.0])).unwrap();
let buf = src.save_to_buffer().unwrap();
let dst = HnswIndex::new(3, VectorType::Float8, DistanceMetric::L2, None).unwrap();
dst.load_from_buffer(&buf).unwrap();
assert_eq!(dst.len(), 3);
let results = dst.search(&f64_blob(&[0.0, 0.9, 0.1]), 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(
results[0].0, 2,
"post-load search should return key 2, got {}",
results[0].0
);
}
#[test]
fn custom_params_index_behaves_correctly() {
let params = HnswParams {
m: 4,
ef_construction: 32,
ef_search: 16,
};
let idx =
HnswIndex::new(3, VectorType::Float4, DistanceMetric::Cosine, Some(params)).unwrap();
idx.add(1, &f32_blob(&[1.0, 0.0, 0.0])).unwrap();
idx.add(2, &f32_blob(&[0.0, 1.0, 0.0])).unwrap();
idx.add(3, &f32_blob(&[0.0, 0.0, 1.0])).unwrap();
assert_eq!(idx.len(), 3);
let results = idx.search(&f32_blob(&[0.0, 0.1, 0.9]), 1).unwrap();
assert_eq!(results.len(), 1);
assert_eq!(
results[0].0, 3,
"expected key 3 ([0,0,1]) as nearest under cosine, got {}",
results[0].0
);
}
}