use crate::core::error::{Error, Result, VectorError};
use crate::core::property::MAX_VECTOR_DIMENSIONS;
use crate::index::vector::{CustomMetric, DistanceMetric, Quantization, StorageMode};
use std::io::{Read, Write};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct HnswConfig {
pub dimensions: usize,
pub metric: DistanceMetric,
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub capacity: usize,
pub quantization: Quantization,
pub storage: StorageMode,
pub custom_metric: Option<CustomMetric>,
}
impl PartialEq for HnswConfig {
fn eq(&self, other: &Self) -> bool {
self.dimensions == other.dimensions
&& self.metric == other.metric
&& self.m == other.m
&& self.ef_construction == other.ef_construction
&& self.ef_search == other.ef_search
&& self.capacity == other.capacity
&& self.quantization == other.quantization
&& self.storage == other.storage
&& self.custom_metric == other.custom_metric
}
}
impl Default for HnswConfig {
fn default() -> Self {
HnswConfig {
dimensions: 0,
metric: DistanceMetric::Cosine,
m: 16,
ef_construction: 128,
ef_search: 64,
capacity: 0,
quantization: Quantization::default(),
storage: StorageMode::default(),
custom_metric: None,
}
}
}
impl HnswConfig {
pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
HnswConfig {
dimensions,
metric,
..Default::default()
}
}
pub fn with_m(mut self, m: usize) -> Self {
self.m = m;
self
}
pub fn with_ef_construction(mut self, ef_construction: usize) -> Self {
self.ef_construction = ef_construction;
self
}
pub fn with_ef_search(mut self, ef_search: usize) -> Self {
self.ef_search = ef_search;
self
}
pub fn with_capacity(mut self, capacity: usize) -> Self {
self.capacity = capacity;
self
}
pub fn with_dimensions(mut self, dimensions: usize) -> Self {
self.dimensions = dimensions;
self
}
pub fn with_metric(mut self, metric: DistanceMetric) -> Self {
self.metric = metric;
self
}
pub fn with_quantization(mut self, quantization: Quantization) -> Self {
self.quantization = quantization;
self
}
pub fn with_storage(mut self, storage: StorageMode) -> Self {
self.storage = storage;
self
}
pub fn with_custom_metric<F>(mut self, name: &str, f: F) -> Self
where
F: Fn(&[f32], &[f32]) -> f32 + Send + Sync + 'static,
{
self.custom_metric = Some(CustomMetric {
name: name.to_string(),
distance_fn: Arc::new(f),
});
self
}
pub fn serialize_into<W: Write>(&self, writer: &mut W) -> Result<()> {
writer.write_all(&(self.dimensions as u64).to_le_bytes())?;
writer.write_all(&[self.metric.to_u8()])?;
writer.write_all(&(self.m as u64).to_le_bytes())?;
writer.write_all(&(self.ef_construction as u64).to_le_bytes())?;
writer.write_all(&(self.ef_search as u64).to_le_bytes())?;
writer.write_all(&(self.capacity as u64).to_le_bytes())?;
writer.write_all(&[self.quantization.to_u8()])?;
Ok(())
}
pub fn deserialize_from<R: Read>(reader: &mut R) -> Result<Self> {
let mut buf_u64 = [0u8; 8];
let mut buf_u8 = [0u8; 1];
reader.read_exact(&mut buf_u64)?;
let dimensions = u64::from_le_bytes(buf_u64) as usize;
if dimensions > MAX_VECTOR_DIMENSIONS {
return Err(Error::Vector(VectorError::InvalidVector {
reason: format!(
"dimensions {} exceeds maximum allowed {}",
dimensions, MAX_VECTOR_DIMENSIONS
),
}));
}
reader.read_exact(&mut buf_u8)?;
let metric = DistanceMetric::from_u8(buf_u8[0])?;
reader.read_exact(&mut buf_u64)?;
let m = u64::from_le_bytes(buf_u64) as usize;
reader.read_exact(&mut buf_u64)?;
let ef_construction = u64::from_le_bytes(buf_u64) as usize;
reader.read_exact(&mut buf_u64)?;
let ef_search = u64::from_le_bytes(buf_u64) as usize;
reader.read_exact(&mut buf_u64)?;
let capacity = u64::from_le_bytes(buf_u64) as usize;
let quantization = match reader.read_exact(&mut buf_u8) {
Ok(_) => Quantization::from_u8(buf_u8[0])?,
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
Quantization::default()
}
Err(e) => return Err(e.into()),
};
Ok(HnswConfig {
dimensions,
metric,
m,
ef_construction,
ef_search,
capacity,
quantization,
..Default::default()
})
}
}
pub struct HnswIndexBuilder {
pub(crate) config: HnswConfig,
}
impl HnswIndexBuilder {
pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
HnswIndexBuilder {
config: HnswConfig {
dimensions,
metric,
..Default::default()
},
}
}
pub fn from_config(config: &HnswConfig) -> Self {
HnswIndexBuilder {
config: config.clone(),
}
}
pub fn m(mut self, m: usize) -> Self {
self.config.m = m;
self
}
pub fn ef_construction(mut self, ef_construction: usize) -> Self {
self.config.ef_construction = ef_construction;
self
}
pub fn ef_search(mut self, ef_search: usize) -> Self {
self.config.ef_search = ef_search;
self
}
pub fn initial_capacity(mut self, capacity: usize) -> Self {
self.config.capacity = capacity;
self
}
pub fn quantization(mut self, quantization: Quantization) -> Self {
self.config.quantization = quantization;
self
}
pub fn storage(mut self, storage: StorageMode) -> Self {
self.config.storage = storage;
self
}
pub fn with_custom_metric<F>(mut self, name: &str, f: F) -> Self
where
F: Fn(&[f32], &[f32]) -> f32 + Send + Sync + 'static,
{
self.config = self.config.with_custom_metric(name, f);
self
}
pub fn build(self) -> Result<super::HnswIndex> {
super::HnswIndex::new_internal(self.config)
}
}