use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use crate::embedding::embedder::{EmbedInput, EmbedInputType, Embedder};
use crate::embedding::precomputed::PrecomputedEmbedder;
use crate::error::{LaurusError, Result};
use crate::lexical::store::config::LexicalIndexConfig;
use crate::maintenance::deletion::DeletionConfig;
use crate::vector::core::distance::DistanceMetric;
use crate::vector::core::field::FieldOption;
use crate::vector::core::quantization;
use crate::vector::core::vector::Vector;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum IndexLoadingMode {
#[default]
InMemory,
Mmap,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum VectorIndexTypeConfig {
Flat(FlatIndexConfig),
HNSW(HnswIndexConfig),
IVF(IvfIndexConfig),
}
impl Default for VectorIndexTypeConfig {
fn default() -> Self {
VectorIndexTypeConfig::HNSW(HnswIndexConfig::default())
}
}
impl VectorIndexTypeConfig {
pub fn index_type_name(&self) -> &'static str {
match self {
VectorIndexTypeConfig::Flat(_) => "Flat",
VectorIndexTypeConfig::HNSW(_) => "HNSW",
VectorIndexTypeConfig::IVF(_) => "IVF",
}
}
pub fn dimension(&self) -> usize {
match self {
VectorIndexTypeConfig::Flat(config) => config.dimension,
VectorIndexTypeConfig::HNSW(config) => config.dimension,
VectorIndexTypeConfig::IVF(config) => config.dimension,
}
}
pub fn distance_metric(&self) -> DistanceMetric {
match self {
VectorIndexTypeConfig::Flat(config) => config.distance_metric,
VectorIndexTypeConfig::HNSW(config) => config.distance_metric,
VectorIndexTypeConfig::IVF(config) => config.distance_metric,
}
}
pub fn max_vectors_per_segment(&self) -> u64 {
match self {
VectorIndexTypeConfig::Flat(config) => config.max_vectors_per_segment,
VectorIndexTypeConfig::HNSW(config) => config.max_vectors_per_segment,
VectorIndexTypeConfig::IVF(config) => config.max_vectors_per_segment,
}
}
pub fn merge_factor(&self) -> u32 {
match self {
VectorIndexTypeConfig::Flat(config) => config.merge_factor,
VectorIndexTypeConfig::HNSW(config) => config.merge_factor,
VectorIndexTypeConfig::IVF(config) => config.merge_factor,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct FlatIndexConfig {
pub dimension: usize,
#[serde(default)]
pub loading_mode: IndexLoadingMode,
pub distance_metric: DistanceMetric,
pub normalize_vectors: bool,
pub max_vectors_per_segment: u64,
pub write_buffer_size: usize,
pub use_quantization: bool,
pub quantization_method: quantization::QuantizationMethod,
pub merge_factor: u32,
pub max_segments: u32,
#[serde(skip)]
#[serde(default = "default_embedder")]
pub embedder: Arc<dyn Embedder>,
}
fn default_embedder() -> Arc<dyn Embedder> {
use async_trait::async_trait;
#[derive(Debug)]
struct MockEmbedder;
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed(&self, _input: &EmbedInput<'_>) -> Result<Vector> {
Ok(Vector::new(vec![0.0; 384]))
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn name(&self) -> &str {
"MockEmbedder"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
Arc::new(MockEmbedder)
}
impl Default for FlatIndexConfig {
fn default() -> Self {
Self {
dimension: 128,
loading_mode: IndexLoadingMode::default(),
distance_metric: DistanceMetric::Cosine,
normalize_vectors: true,
max_vectors_per_segment: 1000000,
write_buffer_size: 1024 * 1024, use_quantization: false,
quantization_method: quantization::QuantizationMethod::None,
merge_factor: 10,
max_segments: 100,
embedder: default_embedder(),
}
}
}
impl std::fmt::Debug for FlatIndexConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlatIndexConfig")
.field("dimension", &self.dimension)
.field("dimension", &self.dimension)
.field("loading_mode", &self.loading_mode)
.field("distance_metric", &self.distance_metric)
.field("distance_metric", &self.distance_metric)
.field("normalize_vectors", &self.normalize_vectors)
.field("max_vectors_per_segment", &self.max_vectors_per_segment)
.field("write_buffer_size", &self.write_buffer_size)
.field("use_quantization", &self.use_quantization)
.field("quantization_method", &self.quantization_method)
.field("merge_factor", &self.merge_factor)
.field("max_segments", &self.max_segments)
.field("embedder", &self.embedder.name())
.finish()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct HnswIndexConfig {
pub dimension: usize,
#[serde(default)]
pub loading_mode: IndexLoadingMode,
pub distance_metric: DistanceMetric,
pub normalize_vectors: bool,
pub m: usize,
pub ef_construction: usize,
pub max_vectors_per_segment: u64,
pub write_buffer_size: usize,
pub use_quantization: bool,
pub quantization_method: quantization::QuantizationMethod,
pub merge_factor: u32,
pub max_segments: u32,
#[serde(skip)]
#[serde(default = "default_embedder")]
pub embedder: Arc<dyn Embedder>,
}
impl Default for HnswIndexConfig {
fn default() -> Self {
Self {
dimension: 128,
loading_mode: IndexLoadingMode::default(),
distance_metric: DistanceMetric::Cosine,
normalize_vectors: true,
m: 16,
ef_construction: 200,
max_vectors_per_segment: 1000000,
write_buffer_size: 1024 * 1024, use_quantization: false,
quantization_method: quantization::QuantizationMethod::None,
merge_factor: 10,
max_segments: 100,
embedder: default_embedder(),
}
}
}
impl std::fmt::Debug for HnswIndexConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HnswIndexConfig")
.field("dimension", &self.dimension)
.field("dimension", &self.dimension)
.field("loading_mode", &self.loading_mode)
.field("distance_metric", &self.distance_metric)
.field("distance_metric", &self.distance_metric)
.field("normalize_vectors", &self.normalize_vectors)
.field("m", &self.m)
.field("ef_construction", &self.ef_construction)
.field("max_vectors_per_segment", &self.max_vectors_per_segment)
.field("write_buffer_size", &self.write_buffer_size)
.field("use_quantization", &self.use_quantization)
.field("quantization_method", &self.quantization_method)
.field("merge_factor", &self.merge_factor)
.field("max_segments", &self.max_segments)
.field("embedder", &self.embedder.name())
.finish()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct IvfIndexConfig {
pub dimension: usize,
#[serde(default)]
pub loading_mode: IndexLoadingMode,
pub distance_metric: DistanceMetric,
pub normalize_vectors: bool,
pub n_clusters: usize,
pub n_probe: usize,
pub max_vectors_per_segment: u64,
pub write_buffer_size: usize,
pub use_quantization: bool,
pub quantization_method: quantization::QuantizationMethod,
pub merge_factor: u32,
pub max_segments: u32,
#[serde(skip)]
#[serde(default = "default_embedder")]
pub embedder: Arc<dyn Embedder>,
}
impl Default for IvfIndexConfig {
fn default() -> Self {
Self {
dimension: 128,
loading_mode: IndexLoadingMode::default(),
distance_metric: DistanceMetric::Cosine,
normalize_vectors: true,
n_clusters: 100,
n_probe: 1,
max_vectors_per_segment: 1000000,
write_buffer_size: 1024 * 1024, use_quantization: false,
quantization_method: quantization::QuantizationMethod::None,
merge_factor: 10,
max_segments: 100,
embedder: default_embedder(),
}
}
}
impl std::fmt::Debug for IvfIndexConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IvfIndexConfig")
.field("dimension", &self.dimension)
.field("loading_mode", &self.loading_mode)
.field("distance_metric", &self.distance_metric)
.field("normalize_vectors", &self.normalize_vectors)
.field("n_clusters", &self.n_clusters)
.field("n_probe", &self.n_probe)
.field("max_vectors_per_segment", &self.max_vectors_per_segment)
.field("write_buffer_size", &self.write_buffer_size)
.field("use_quantization", &self.use_quantization)
.field("quantization_method", &self.quantization_method)
.field("merge_factor", &self.merge_factor)
.field("max_segments", &self.max_segments)
.field("embedder", &self.embedder.name())
.finish()
}
}
#[derive(Clone)]
pub struct VectorIndexConfig {
pub fields: HashMap<String, VectorFieldConfig>,
pub default_fields: Vec<String>,
pub metadata: HashMap<String, serde_json::Value>,
pub embedder: Arc<dyn Embedder>,
pub deletion_config: DeletionConfig,
pub shard_id: u16,
pub metadata_config: LexicalIndexConfig,
}
impl std::fmt::Debug for VectorIndexConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VectorIndexConfig")
.field("fields", &self.fields)
.field("default_fields", &self.default_fields)
.field("metadata", &self.metadata)
.field("embedder", &format_args!("{:?}", self.embedder))
.field("deletion_config", &self.deletion_config)
.field("shard_id", &self.shard_id)
.field("metadata_config", &self.metadata_config)
.finish()
}
}
impl VectorIndexConfig {
pub fn builder() -> VectorIndexConfigBuilder {
VectorIndexConfigBuilder::new()
}
pub fn validate(&self) -> Result<()> {
for field in &self.default_fields {
if !self.fields.contains_key(field) {
return Err(LaurusError::invalid_config(format!(
"default field '{field}' is not defined"
)));
}
}
Ok(())
}
pub fn get_embedder(&self) -> &Arc<dyn Embedder> {
&self.embedder
}
}
impl Default for VectorIndexConfig {
fn default() -> Self {
Self::builder()
.build()
.expect("Default config should be valid")
}
}
pub struct VectorIndexConfigBuilder {
fields: HashMap<String, VectorFieldConfig>,
default_fields: Vec<String>,
metadata: HashMap<String, serde_json::Value>,
embedder: Option<Arc<dyn Embedder>>,
deletion_config: Option<DeletionConfig>,
shard_id: Option<u16>,
metadata_config: Option<LexicalIndexConfig>,
}
impl VectorIndexConfigBuilder {
pub fn new() -> Self {
Self {
fields: HashMap::new(),
default_fields: Vec::new(),
metadata: HashMap::new(),
embedder: None,
deletion_config: None,
shard_id: None,
metadata_config: None,
}
}
pub fn embedder(mut self, embedder: Arc<dyn Embedder>) -> Self {
self.embedder = Some(embedder);
self
}
pub fn field(mut self, name: impl Into<String>, config: VectorFieldConfig) -> Self {
let name = name.into();
if !self.default_fields.contains(&name) {
self.default_fields.push(name.clone());
}
self.fields.insert(name, config);
self
}
pub fn add_field(
mut self,
name: impl Into<String>,
option: impl Into<FieldOption>,
) -> Result<Self> {
let name = name.into();
let config = VectorFieldConfig {
vector: Some(option.into()),
lexical: None,
};
if !self.default_fields.contains(&name) {
self.default_fields.push(name.clone());
}
self.fields.insert(name, config);
Ok(self)
}
pub fn image_field(
self,
name: impl Into<String>,
option: impl Into<FieldOption>,
) -> Result<Self> {
self.add_field(name, option)
}
pub fn default_field(mut self, name: impl Into<String>) -> Self {
let name = name.into();
if !self.default_fields.contains(&name) {
self.default_fields.push(name);
}
self
}
pub fn default_fields(mut self, fields: Vec<String>) -> Self {
self.default_fields = fields;
self
}
pub fn metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn deletion_config(mut self, config: DeletionConfig) -> Self {
self.deletion_config = Some(config);
self
}
pub fn shard_id(mut self, shard_id: u16) -> Self {
self.shard_id = Some(shard_id);
self
}
pub fn metadata_config(mut self, config: LexicalIndexConfig) -> Self {
self.metadata_config = Some(config);
self
}
pub fn build(self) -> Result<VectorIndexConfig> {
let embedder = self
.embedder
.unwrap_or_else(|| Arc::new(PrecomputedEmbedder::new()));
let config = VectorIndexConfig {
fields: self.fields,
default_fields: self.default_fields,
metadata: self.metadata,
embedder,
deletion_config: self.deletion_config.unwrap_or_default(),
shard_id: self.shard_id.unwrap_or(0),
metadata_config: self.metadata_config.unwrap_or_default(),
};
config.validate()?;
Ok(config)
}
}
impl Default for VectorIndexConfigBuilder {
fn default() -> Self {
Self::new()
}
}
impl Serialize for VectorIndexConfig {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("VectorIndexConfig", 5)?;
state.serialize_field("fields", &self.fields)?;
state.serialize_field("default_fields", &self.default_fields)?;
state.serialize_field("metadata", &self.metadata)?;
state.serialize_field("deletion_config", &self.deletion_config)?;
state.serialize_field("shard_id", &self.shard_id)?;
state.serialize_field("metadata_config", &self.metadata_config)?;
state.end()
}
}
impl<'de> Deserialize<'de> for VectorIndexConfig {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
#[derive(Deserialize)]
struct VectorIndexConfigHelper {
fields: HashMap<String, VectorFieldConfig>,
default_fields: Vec<String>,
#[serde(default)]
metadata: HashMap<String, serde_json::Value>,
#[serde(default)]
deletion_config: DeletionConfig,
#[serde(default)]
shard_id: u16,
#[serde(default)]
metadata_config: LexicalIndexConfig,
}
let helper = VectorIndexConfigHelper::deserialize(deserializer)?;
Ok(VectorIndexConfig {
fields: helper.fields,
default_fields: helper.default_fields,
metadata: helper.metadata,
deletion_config: helper.deletion_config,
shard_id: helper.shard_id,
metadata_config: helper.metadata_config,
embedder: Arc::new(PrecomputedEmbedder::new()),
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorFieldConfig {
#[serde(default)]
pub vector: Option<FieldOption>,
pub lexical: Option<crate::lexical::core::field::FieldOption>,
}
impl Default for VectorFieldConfig {
fn default() -> Self {
Self {
vector: Some(FieldOption::default()),
lexical: Some(crate::lexical::core::field::FieldOption::default()),
}
}
}
impl VectorFieldConfig {
pub fn default_weight() -> f32 {
1.0
}
}