use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::embedding::embedder::{EmbedInput, EmbedInputType, Embedder};
use crate::error::Result;
use crate::vector::core::distance::DistanceMetric;
use crate::vector::core::quantization;
use crate::vector::core::vector::Vector;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum VectorNormalization {
None,
L2,
L1,
MinMax,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum VectorValidationError {
DimensionMismatch { expected: usize, actual: usize },
InvalidValues,
Empty,
Custom(String),
}
impl std::fmt::Display for VectorValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
VectorValidationError::DimensionMismatch { expected, actual } => {
write!(
f,
"Vector dimension mismatch: expected {expected}, got {actual}"
)
}
VectorValidationError::InvalidValues => {
write!(f, "Vector contains invalid values (NaN or infinity)")
}
VectorValidationError::Empty => {
write!(f, "Vector is empty")
}
VectorValidationError::Custom(msg) => write!(f, "Custom validation error: {msg}"),
}
}
}
impl std::error::Error for VectorValidationError {}
pub mod utils {
use super::*;
use crate::vector::core::distance::DistanceMetric;
pub fn validate_vector(vector: &Vector, expected_dimension: Option<usize>) -> Result<()> {
if vector.data.is_empty() {
return Err(crate::error::LaurusError::InvalidOperation(
VectorValidationError::Empty.to_string(),
));
}
if let Some(expected_dim) = expected_dimension
&& vector.data.len() != expected_dim
{
return Err(crate::error::LaurusError::InvalidOperation(
VectorValidationError::DimensionMismatch {
expected: expected_dim,
actual: vector.data.len(),
}
.to_string(),
));
}
if !vector.is_valid() {
return Err(crate::error::LaurusError::InvalidOperation(
VectorValidationError::InvalidValues.to_string(),
));
}
Ok(())
}
pub fn normalize_vectors_parallel(vectors: &mut [Vector], method: VectorNormalization) {
use rayon::prelude::*;
match method {
VectorNormalization::None => {
}
VectorNormalization::L2 => {
vectors.par_iter_mut().for_each(|vector| {
vector.normalize();
});
}
VectorNormalization::L1 => {
vectors.par_iter_mut().for_each(|vector| {
let l1_norm: f32 = vector.data.iter().map(|x| x.abs()).sum();
if l1_norm > 0.0 {
for value in &mut vector.data {
*value /= l1_norm;
}
}
});
}
VectorNormalization::MinMax => {
vectors.par_iter_mut().for_each(|vector| {
if let (Some(&min_val), Some(&max_val)) =
(
vector.data.iter().min_by(|a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
}),
vector.data.iter().max_by(|a, b| {
a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
}),
)
{
let range = max_val - min_val;
if range > 0.0 {
for value in &mut vector.data {
*value = (*value - min_val) / range;
}
}
}
});
}
}
}
pub fn batch_similarities(
query: &Vector,
vectors: &[Vector],
metric: DistanceMetric,
) -> Result<Vec<f32>> {
vectors
.iter()
.map(|vector| metric.similarity(&query.data, &vector.data))
.collect()
}
pub fn batch_distances(
query: &Vector,
vectors: &[Vector],
metric: DistanceMetric,
) -> Result<Vec<f32>> {
vectors
.iter()
.map(|vector| metric.distance(&query.data, &vector.data))
.collect()
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
#[derive(Default)]
pub enum IndexLoadingMode {
#[default]
InMemory,
Mmap,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum VectorIndexTypeConfig {
Flat(FlatIndexConfig),
HNSW(HnswIndexConfig),
IVF(IvfIndexConfig),
}
impl Default for VectorIndexTypeConfig {
fn default() -> Self {
VectorIndexTypeConfig::Flat(FlatIndexConfig::default())
}
}
impl VectorIndexTypeConfig {
pub fn index_type_name(&self) -> &'static str {
match self {
VectorIndexTypeConfig::Flat(_) => "Flat",
VectorIndexTypeConfig::HNSW(_) => "HNSW",
VectorIndexTypeConfig::IVF(_) => "IVF",
}
}
pub fn dimension(&self) -> usize {
match self {
VectorIndexTypeConfig::Flat(config) => config.dimension,
VectorIndexTypeConfig::HNSW(config) => config.dimension,
VectorIndexTypeConfig::IVF(config) => config.dimension,
}
}
pub fn distance_metric(&self) -> DistanceMetric {
match self {
VectorIndexTypeConfig::Flat(config) => config.distance_metric,
VectorIndexTypeConfig::HNSW(config) => config.distance_metric,
VectorIndexTypeConfig::IVF(config) => config.distance_metric,
}
}
pub fn max_vectors_per_segment(&self) -> u64 {
match self {
VectorIndexTypeConfig::Flat(config) => config.max_vectors_per_segment,
VectorIndexTypeConfig::HNSW(config) => config.max_vectors_per_segment,
VectorIndexTypeConfig::IVF(config) => config.max_vectors_per_segment,
}
}
pub fn merge_factor(&self) -> u32 {
match self {
VectorIndexTypeConfig::Flat(config) => config.merge_factor,
VectorIndexTypeConfig::HNSW(config) => config.merge_factor,
VectorIndexTypeConfig::IVF(config) => config.merge_factor,
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct FlatIndexConfig {
pub dimension: usize,
#[serde(default)]
pub loading_mode: IndexLoadingMode,
pub distance_metric: DistanceMetric,
pub normalize_vectors: bool,
pub max_vectors_per_segment: u64,
pub write_buffer_size: usize,
pub use_quantization: bool,
pub quantization_method: quantization::QuantizationMethod,
pub merge_factor: u32,
pub max_segments: u32,
#[serde(skip)]
#[serde(default = "default_embedder")]
pub embedder: Arc<dyn Embedder>,
}
fn default_embedder() -> Arc<dyn Embedder> {
use async_trait::async_trait;
#[derive(Debug)]
struct MockEmbedder;
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed(&self, _input: &EmbedInput<'_>) -> Result<Vector> {
Ok(Vector::new(vec![0.0; 384]))
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn name(&self) -> &str {
"MockEmbedder"
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
Arc::new(MockEmbedder)
}
impl Default for FlatIndexConfig {
fn default() -> Self {
Self {
dimension: 128,
loading_mode: IndexLoadingMode::default(),
distance_metric: DistanceMetric::Cosine,
normalize_vectors: true,
max_vectors_per_segment: 1000000,
write_buffer_size: 1024 * 1024, use_quantization: false,
quantization_method: quantization::QuantizationMethod::None,
merge_factor: 10,
max_segments: 100,
embedder: default_embedder(),
}
}
}
impl std::fmt::Debug for FlatIndexConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FlatIndexConfig")
.field("dimension", &self.dimension)
.field("loading_mode", &self.loading_mode)
.field("distance_metric", &self.distance_metric)
.field("normalize_vectors", &self.normalize_vectors)
.field("max_vectors_per_segment", &self.max_vectors_per_segment)
.field("write_buffer_size", &self.write_buffer_size)
.field("use_quantization", &self.use_quantization)
.field("quantization_method", &self.quantization_method)
.field("merge_factor", &self.merge_factor)
.field("max_segments", &self.max_segments)
.field("embedder", &self.embedder.name())
.finish()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct HnswIndexConfig {
pub dimension: usize,
#[serde(default)]
pub loading_mode: IndexLoadingMode,
pub distance_metric: DistanceMetric,
pub normalize_vectors: bool,
pub m: usize,
pub ef_construction: usize,
pub max_vectors_per_segment: u64,
pub write_buffer_size: usize,
pub use_quantization: bool,
pub quantization_method: quantization::QuantizationMethod,
pub merge_factor: u32,
pub max_segments: u32,
#[serde(skip)]
#[serde(default = "default_embedder")]
pub embedder: Arc<dyn Embedder>,
}
impl Default for HnswIndexConfig {
fn default() -> Self {
Self {
dimension: 128,
loading_mode: IndexLoadingMode::default(),
distance_metric: DistanceMetric::Cosine,
normalize_vectors: true,
m: 16,
ef_construction: 200,
max_vectors_per_segment: 1000000,
write_buffer_size: 1024 * 1024, use_quantization: false,
quantization_method: quantization::QuantizationMethod::None,
merge_factor: 10,
max_segments: 100,
embedder: default_embedder(),
}
}
}
impl std::fmt::Debug for HnswIndexConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HnswIndexConfig")
.field("dimension", &self.dimension)
.field("loading_mode", &self.loading_mode)
.field("distance_metric", &self.distance_metric)
.field("normalize_vectors", &self.normalize_vectors)
.field("m", &self.m)
.field("ef_construction", &self.ef_construction)
.field("max_vectors_per_segment", &self.max_vectors_per_segment)
.field("write_buffer_size", &self.write_buffer_size)
.field("use_quantization", &self.use_quantization)
.field("quantization_method", &self.quantization_method)
.field("merge_factor", &self.merge_factor)
.field("max_segments", &self.max_segments)
.field("embedder", &self.embedder.name())
.finish()
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct IvfIndexConfig {
pub dimension: usize,
#[serde(default)]
pub loading_mode: IndexLoadingMode,
pub distance_metric: DistanceMetric,
pub normalize_vectors: bool,
pub n_clusters: usize,
pub n_probe: usize,
pub max_vectors_per_segment: u64,
pub write_buffer_size: usize,
pub use_quantization: bool,
pub quantization_method: quantization::QuantizationMethod,
pub merge_factor: u32,
pub max_segments: u32,
#[serde(skip)]
#[serde(default = "default_embedder")]
pub embedder: Arc<dyn Embedder>,
}
impl Default for IvfIndexConfig {
fn default() -> Self {
Self {
dimension: 128,
loading_mode: IndexLoadingMode::default(),
distance_metric: DistanceMetric::Cosine,
normalize_vectors: true,
n_clusters: 100,
n_probe: 1,
max_vectors_per_segment: 1000000,
write_buffer_size: 1024 * 1024, use_quantization: false,
quantization_method: quantization::QuantizationMethod::None,
merge_factor: 10,
max_segments: 100,
embedder: default_embedder(),
}
}
}
impl std::fmt::Debug for IvfIndexConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("IvfIndexConfig")
.field("dimension", &self.dimension)
.field("loading_mode", &self.loading_mode)
.field("distance_metric", &self.distance_metric)
.field("normalize_vectors", &self.normalize_vectors)
.field("n_clusters", &self.n_clusters)
.field("n_probe", &self.n_probe)
.field("max_vectors_per_segment", &self.max_vectors_per_segment)
.field("write_buffer_size", &self.write_buffer_size)
.field("use_quantization", &self.use_quantization)
.field("quantization_method", &self.quantization_method)
.field("merge_factor", &self.merge_factor)
.field("max_segments", &self.max_segments)
.field("embedder", &self.embedder.name())
.finish()
}
}