use manifoldb_vector::distance::sparse::SparseDistanceMetric;
use manifoldb_vector::distance::DistanceMetric;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct VectorConfig {
pub vector_type: VectorType,
pub distance: DistanceType,
pub index: IndexConfig,
}
impl VectorConfig {
#[must_use]
pub fn dense(dimension: usize, distance: DistanceMetric) -> Self {
Self {
vector_type: VectorType::Dense { dimension },
distance: DistanceType::Dense(distance),
index: IndexConfig::hnsw_default(),
}
}
#[must_use]
pub fn sparse(max_dimension: u32) -> Self {
Self {
vector_type: VectorType::Sparse { max_dimension },
distance: DistanceType::Sparse(SparseDistanceMetric::DotProduct),
index: IndexConfig::inverted_default(),
}
}
#[must_use]
pub fn multi_vector(token_dim: usize) -> Self {
Self {
vector_type: VectorType::Multi { token_dim },
distance: DistanceType::Dense(DistanceMetric::DotProduct),
index: IndexConfig::hnsw_with_aggregation(AggregationMethod::MaxSim),
}
}
#[must_use]
pub fn binary(bits: usize) -> Self {
Self {
vector_type: VectorType::Binary { bits },
distance: DistanceType::Binary(BinaryDistanceType::Hamming),
index: IndexConfig::hnsw_default(),
}
}
#[must_use]
pub fn with_distance(mut self, distance: DistanceType) -> Self {
self.distance = distance;
self
}
#[must_use]
pub fn with_index(mut self, index: IndexConfig) -> Self {
self.index = index;
self
}
#[must_use]
pub fn dimension(&self) -> Option<usize> {
match &self.vector_type {
VectorType::Dense { dimension } => Some(*dimension),
VectorType::Sparse { .. } => None,
VectorType::Multi { token_dim } => Some(*token_dim),
VectorType::Binary { bits } => Some(*bits),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum VectorType {
Dense {
dimension: usize,
},
Sparse {
max_dimension: u32,
},
Multi {
token_dim: usize,
},
Binary {
bits: usize,
},
}
impl VectorType {
#[must_use]
pub fn is_dense(&self) -> bool {
matches!(self, Self::Dense { .. })
}
#[must_use]
pub fn is_sparse(&self) -> bool {
matches!(self, Self::Sparse { .. })
}
#[must_use]
pub fn is_multi(&self) -> bool {
matches!(self, Self::Multi { .. })
}
#[must_use]
pub fn is_binary(&self) -> bool {
matches!(self, Self::Binary { .. })
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DistanceType {
Dense(DistanceMetric),
Sparse(SparseDistanceMetric),
Binary(BinaryDistanceType),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BinaryDistanceType {
Hamming,
HammingNormalized,
Jaccard,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct IndexConfig {
pub method: IndexMethod,
pub aggregation: Option<AggregationMethod>,
}
impl IndexConfig {
#[must_use]
pub fn hnsw_default() -> Self {
Self { method: IndexMethod::Hnsw(HnswParams::default()), aggregation: None }
}
#[must_use]
pub fn hnsw(params: HnswParams) -> Self {
Self { method: IndexMethod::Hnsw(params), aggregation: None }
}
#[must_use]
pub fn hnsw_with_aggregation(aggregation: AggregationMethod) -> Self {
Self { method: IndexMethod::Hnsw(HnswParams::default()), aggregation: Some(aggregation) }
}
#[must_use]
pub fn inverted_default() -> Self {
Self { method: IndexMethod::Inverted(InvertedIndexParams::default()), aggregation: None }
}
#[must_use]
pub fn inverted(params: InvertedIndexParams) -> Self {
Self { method: IndexMethod::Inverted(params), aggregation: None }
}
#[must_use]
pub fn flat() -> Self {
Self { method: IndexMethod::Flat, aggregation: None }
}
#[must_use]
pub fn with_aggregation(mut self, aggregation: AggregationMethod) -> Self {
self.aggregation = Some(aggregation);
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum IndexMethod {
Hnsw(HnswParams),
Inverted(InvertedIndexParams),
Flat,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct HnswParams {
pub m: usize,
pub m_max0: usize,
pub ef_construction: usize,
pub ef_search: usize,
}
impl Default for HnswParams {
fn default() -> Self {
Self { m: 16, m_max0: 32, ef_construction: 200, ef_search: 50 }
}
}
impl HnswParams {
#[must_use]
pub fn new(m: usize) -> Self {
let m = m.max(2);
Self { m, m_max0: m * 2, ef_construction: 200, ef_search: 50 }
}
#[must_use]
pub const fn with_ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
#[must_use]
pub const fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct InvertedIndexParams {
pub use_idf: bool,
pub min_term_freq: u32,
}
impl Default for InvertedIndexParams {
fn default() -> Self {
Self { use_idf: true, min_term_freq: 1 }
}
}
impl InvertedIndexParams {
#[must_use]
pub const fn new() -> Self {
Self { use_idf: true, min_term_freq: 1 }
}
#[must_use]
pub const fn with_idf(mut self, use_idf: bool) -> Self {
self.use_idf = use_idf;
self
}
#[must_use]
pub const fn with_min_term_freq(mut self, min_freq: u32) -> Self {
self.min_term_freq = min_freq;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum AggregationMethod {
MaxSim,
Average,
Sum,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dense_config() {
let config = VectorConfig::dense(768, DistanceMetric::Cosine);
assert!(config.vector_type.is_dense());
assert_eq!(config.dimension(), Some(768));
assert!(matches!(config.distance, DistanceType::Dense(DistanceMetric::Cosine)));
assert!(matches!(config.index.method, IndexMethod::Hnsw(_)));
}
#[test]
fn test_sparse_config() {
let config = VectorConfig::sparse(30522);
assert!(config.vector_type.is_sparse());
assert_eq!(config.dimension(), None);
assert!(matches!(config.distance, DistanceType::Sparse(SparseDistanceMetric::DotProduct)));
assert!(matches!(config.index.method, IndexMethod::Inverted(_)));
}
#[test]
fn test_multi_vector_config() {
let config = VectorConfig::multi_vector(128);
assert!(config.vector_type.is_multi());
assert_eq!(config.dimension(), Some(128));
assert!(matches!(config.distance, DistanceType::Dense(DistanceMetric::DotProduct)));
assert_eq!(config.index.aggregation, Some(AggregationMethod::MaxSim));
}
#[test]
fn test_binary_config() {
let config = VectorConfig::binary(1024);
assert!(config.vector_type.is_binary());
assert_eq!(config.dimension(), Some(1024));
assert!(matches!(config.distance, DistanceType::Binary(BinaryDistanceType::Hamming)));
}
#[test]
fn test_hnsw_params_builder() {
let params = HnswParams::new(32).with_ef_construction(400).with_ef_search(100);
assert_eq!(params.m, 32);
assert_eq!(params.m_max0, 64);
assert_eq!(params.ef_construction, 400);
assert_eq!(params.ef_search, 100);
}
#[test]
fn test_custom_index_config() {
let config = VectorConfig::dense(768, DistanceMetric::Euclidean).with_index(IndexConfig {
method: IndexMethod::Hnsw(HnswParams::new(32).with_ef_construction(500)),
aggregation: None,
});
if let IndexMethod::Hnsw(params) = &config.index.method {
assert_eq!(params.m, 32);
assert_eq!(params.ef_construction, 500);
} else {
panic!("Expected HNSW index");
}
}
}