use manifoldb_storage::StorageEngine;
use manifoldb_vector::distance::sparse::SparseDistanceMetric;
use manifoldb_vector::distance::DistanceMetric;
use super::config::{
AggregationMethod, HnswParams, IndexConfig, InvertedIndexParams, VectorConfig, VectorType,
};
use super::error::{ApiError, ApiResult};
use super::handle::CollectionHandle;
use super::metadata::CollectionName;
use crate::collection::config::DistanceType;
pub struct CollectionBuilder<E: StorageEngine> {
pub(crate) engine: E,
pub(crate) name: CollectionName,
pub(crate) vectors: Vec<(String, VectorConfig)>,
}
impl<E: StorageEngine> CollectionBuilder<E> {
pub(crate) fn new(engine: E, name: CollectionName) -> Self {
Self { engine, name, vectors: Vec::new() }
}
#[must_use]
pub fn with_dense_vector(
mut self,
name: impl Into<String>,
dimension: usize,
distance: DistanceMetric,
) -> Self {
self.vectors.push((name.into(), VectorConfig::dense(dimension, distance)));
self
}
#[must_use]
pub fn with_dense_vector_hnsw(
mut self,
name: impl Into<String>,
dimension: usize,
distance: DistanceMetric,
hnsw_params: HnswParams,
) -> Self {
let config = VectorConfig {
vector_type: VectorType::Dense { dimension },
distance: DistanceType::Dense(distance),
index: IndexConfig::hnsw(hnsw_params),
};
self.vectors.push((name.into(), config));
self
}
#[must_use]
pub fn with_sparse_vector(self, name: impl Into<String>) -> Self {
self.with_sparse_vector_config(name, 30522, SparseDistanceMetric::DotProduct)
}
#[must_use]
pub fn with_sparse_vector_config(
mut self,
name: impl Into<String>,
max_dimension: u32,
distance: SparseDistanceMetric,
) -> Self {
let config = VectorConfig {
vector_type: VectorType::Sparse { max_dimension },
distance: DistanceType::Sparse(distance),
index: IndexConfig::inverted_default(),
};
self.vectors.push((name.into(), config));
self
}
#[must_use]
pub fn with_sparse_vector_inverted(
mut self,
name: impl Into<String>,
max_dimension: u32,
distance: SparseDistanceMetric,
inverted_params: InvertedIndexParams,
) -> Self {
let config = VectorConfig {
vector_type: VectorType::Sparse { max_dimension },
distance: DistanceType::Sparse(distance),
index: IndexConfig::inverted(inverted_params),
};
self.vectors.push((name.into(), config));
self
}
#[must_use]
pub fn with_multi_vector(mut self, name: impl Into<String>, token_dim: usize) -> Self {
self.vectors.push((name.into(), VectorConfig::multi_vector(token_dim)));
self
}
#[must_use]
pub fn with_multi_vector_aggregation(
mut self,
name: impl Into<String>,
token_dim: usize,
aggregation: AggregationMethod,
) -> Self {
let config = VectorConfig {
vector_type: VectorType::Multi { token_dim },
distance: DistanceType::Dense(DistanceMetric::DotProduct),
index: IndexConfig::hnsw_with_aggregation(aggregation),
};
self.vectors.push((name.into(), config));
self
}
#[must_use]
pub fn with_binary_vector(mut self, name: impl Into<String>, bits: usize) -> Self {
self.vectors.push((name.into(), VectorConfig::binary(bits)));
self
}
#[must_use]
pub fn with_vector(mut self, name: impl Into<String>, config: VectorConfig) -> Self {
self.vectors.push((name.into(), config));
self
}
pub fn build(self) -> ApiResult<CollectionHandle<E>> {
if self.vectors.is_empty() {
return Err(ApiError::InvalidFilter(
"collection must have at least one vector".to_string(),
));
}
CollectionHandle::create(self.engine, self.name, self.vectors)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_vector_config_dense() {
let config = VectorConfig::dense(768, DistanceMetric::Cosine);
assert!(config.vector_type.is_dense());
assert_eq!(config.dimension(), Some(768));
}
#[test]
fn test_vector_config_sparse() {
let config = VectorConfig::sparse(30522);
assert!(config.vector_type.is_sparse());
assert_eq!(config.dimension(), None);
}
#[test]
fn test_vector_config_multi() {
let config = VectorConfig::multi_vector(128);
assert!(config.vector_type.is_multi());
assert_eq!(config.dimension(), Some(128));
}
#[test]
fn test_hnsw_params() {
let params = HnswParams::new(32).with_ef_construction(400).with_ef_search(100);
assert_eq!(params.m, 32);
assert_eq!(params.m_max0, 64);
assert_eq!(params.ef_construction, 400);
assert_eq!(params.ef_search, 100);
}
}