use super::DistanceMetric;
#[derive(Debug, Clone)]
pub struct HnswConfig {
pub dimensions: usize,
pub metric: DistanceMetric,
pub m: usize,
pub m_max: usize,
pub ef_construction: usize,
pub ef: usize,
pub ml: f64,
pub alpha: f32,
pub max_elements: Option<usize>,
}
impl HnswConfig {
#[must_use]
pub fn new(dimensions: usize, metric: DistanceMetric) -> Self {
let m = 16;
Self {
dimensions,
metric,
m,
m_max: m * 2,
ef_construction: 128,
ef: 50,
ml: 1.0 / (m as f64).ln(),
alpha: 1.0,
max_elements: None,
}
}
#[must_use]
pub fn high_recall(dimensions: usize, metric: DistanceMetric) -> Self {
let m = 32;
Self {
dimensions,
metric,
m,
m_max: m * 2,
ef_construction: 256,
ef: 100,
ml: 1.0 / (m as f64).ln(),
alpha: 1.2,
max_elements: None,
}
}
#[must_use]
pub fn fast(dimensions: usize, metric: DistanceMetric) -> Self {
let m = 12;
Self {
dimensions,
metric,
m,
m_max: m * 2,
ef_construction: 100,
ef: 32,
ml: 1.0 / (m as f64).ln(),
alpha: 1.0,
max_elements: None,
}
}
#[must_use]
pub fn with_m(mut self, m: usize) -> Self {
self.m = m;
self.m_max = m * 2;
self.ml = 1.0 / (m as f64).ln();
self
}
#[must_use]
pub fn with_m_max(mut self, m_max: usize) -> Self {
self.m_max = m_max;
self
}
#[must_use]
pub fn with_ef_construction(mut self, ef_construction: usize) -> Self {
self.ef_construction = ef_construction;
self
}
#[must_use]
pub fn with_ef(mut self, ef: usize) -> Self {
self.ef = ef;
self
}
#[must_use]
pub fn with_alpha(mut self, alpha: f32) -> Self {
self.alpha = alpha;
self
}
#[must_use]
pub fn with_max_elements(mut self, max: usize) -> Self {
self.max_elements = Some(max);
self
}
#[must_use]
pub fn expected_max_level(&self, n: usize) -> usize {
if n == 0 {
return 0;
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let result = ((n as f64).ln() * self.ml).ceil() as usize;
result
}
}
impl Default for HnswConfig {
fn default() -> Self {
Self::new(384, DistanceMetric::Cosine)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hnsw_config_new() {
let config = HnswConfig::new(768, DistanceMetric::Euclidean);
assert_eq!(config.dimensions, 768);
assert_eq!(config.metric, DistanceMetric::Euclidean);
assert_eq!(config.m, 16);
assert_eq!(config.m_max, 32);
assert_eq!(config.ef_construction, 128);
assert_eq!(config.ef, 50);
}
#[test]
fn test_hnsw_config_builder() {
let config = HnswConfig::new(1536, DistanceMetric::Cosine)
.with_m(24)
.with_ef_construction(200)
.with_ef(100);
assert_eq!(config.m, 24);
assert_eq!(config.m_max, 48);
assert_eq!(config.ef_construction, 200);
assert_eq!(config.ef, 100);
}
#[test]
fn test_hnsw_config_high_recall_preset() {
let config = HnswConfig::high_recall(384, DistanceMetric::Cosine);
assert_eq!(config.dimensions, 384);
assert_eq!(config.metric, DistanceMetric::Cosine);
assert_eq!(config.m, 32);
assert_eq!(config.m_max, 64); assert_eq!(config.ef_construction, 256);
assert_eq!(config.ef, 100);
assert!((config.ml - 1.0 / (32.0_f64).ln()).abs() < 1e-10);
}
#[test]
fn test_hnsw_config_fast_preset() {
let config = HnswConfig::fast(384, DistanceMetric::Euclidean);
assert_eq!(config.dimensions, 384);
assert_eq!(config.metric, DistanceMetric::Euclidean);
assert_eq!(config.m, 12);
assert_eq!(config.m_max, 24); assert_eq!(config.ef_construction, 100);
assert_eq!(config.ef, 32);
assert!((config.ml - 1.0 / (12.0_f64).ln()).abs() < 1e-10);
}
#[test]
fn test_preset_ordering() {
let default = HnswConfig::default();
let high_recall = HnswConfig::high_recall(384, DistanceMetric::Cosine);
assert!(high_recall.m > default.m);
assert!(high_recall.ef_construction > default.ef_construction);
assert!(high_recall.ef > default.ef);
let fast = HnswConfig::fast(384, DistanceMetric::Cosine);
assert!(fast.m < default.m);
assert!(fast.ef_construction < default.ef_construction);
assert!(fast.ef < default.ef);
}
#[test]
fn test_with_m_updates_derived_fields() {
let config = HnswConfig::new(384, DistanceMetric::Cosine).with_m(24);
assert_eq!(config.m, 24);
assert_eq!(config.m_max, 48); assert!((config.ml - 1.0 / (24.0_f64).ln()).abs() < 1e-10); }
#[test]
fn test_with_m_max_override() {
let config = HnswConfig::new(384, DistanceMetric::Cosine)
.with_m(16)
.with_m_max(64);
assert_eq!(config.m, 16);
assert_eq!(config.m_max, 64); }
#[test]
fn test_expected_max_level() {
let config = HnswConfig::new(384, DistanceMetric::Cosine);
assert_eq!(config.expected_max_level(0), 0);
let level_100 = config.expected_max_level(100);
assert!(level_100 > 0 && level_100 < 10);
let level_1m = config.expected_max_level(1_000_000);
assert!(level_1m > level_100);
}
#[test]
fn test_hnsw_config_max_elements() {
let config = HnswConfig::new(384, DistanceMetric::Cosine).with_max_elements(10_000);
assert_eq!(config.max_elements, Some(10_000));
}
#[test]
fn test_hnsw_config_no_max_elements_by_default() {
let config = HnswConfig::new(384, DistanceMetric::Cosine);
assert!(config.max_elements.is_none());
}
}