use serde::{Deserialize, Serialize};
use std::str::FromStr;
use crate::{Error, Result};
#[cfg(feature = "async")]
pub mod async_vector;
pub mod columnar;
pub mod flat;
pub mod hnsw;
pub mod simd;
#[cfg(feature = "tokio")]
pub use async_vector::AsyncVectorStoreAdapter;
#[cfg(feature = "async")]
pub use async_vector::{AsyncHnswIndex, AsyncVectorStore};
pub use columnar::{
key_layout as vector_key_layout, AppendResult, SearchStats, VectorSearchParams,
VectorSearchResult, VectorSegment, VectorStoreConfig, VectorStoreManager,
};
pub use hnsw::{HnswConfig, HnswIndex, HnswSearchResult, HnswStats};
pub use simd::{select_kernel, DistanceKernel, ScalarKernel};
#[cfg(all(test, not(target_arch = "wasm32")))]
mod disk;
#[cfg(all(test, not(target_arch = "wasm32")))]
mod integration;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct DeleteResult {
pub vectors_deleted: u64,
pub segments_modified: Vec<u64>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct CompactionResult {
pub old_segment_id: u64,
pub new_segment_id: Option<u64>,
pub vectors_removed: u64,
pub space_reclaimed: u64,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum Metric {
Cosine,
L2,
InnerProduct,
}
impl Metric {
pub fn as_str(&self) -> &'static str {
match self {
Metric::Cosine => "cosine",
Metric::L2 => "l2",
Metric::InnerProduct => "inner",
}
}
}
impl FromStr for Metric {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
match s.to_ascii_lowercase().as_str() {
"cosine" => Ok(Metric::Cosine),
"l2" => Ok(Metric::L2),
"inner" | "inner_product" | "innerproduct" => Ok(Metric::InnerProduct),
other => Err(Error::UnsupportedMetric {
metric: other.to_string(),
}),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct VectorType {
dim: usize,
metric: Metric,
}
impl VectorType {
pub fn new(dim: usize, metric: Metric) -> Self {
Self { dim, metric }
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn metric(&self) -> Metric {
self.metric
}
pub fn validate(&self, vector: &[f32]) -> Result<()> {
validate_dimensions(self.dim, vector.len())
}
pub fn score(&self, query: &[f32], item: &[f32]) -> Result<f32> {
self.validate(query)?;
self.validate(item)?;
score(self.metric, query, item)
}
}
pub fn validate_dimensions(expected: usize, actual: usize) -> Result<()> {
if expected != actual {
return Err(Error::DimensionMismatch { expected, actual });
}
Ok(())
}
pub fn score(metric: Metric, query: &[f32], item: &[f32]) -> Result<f32> {
validate_dimensions(query.len(), item.len())?;
match metric {
Metric::Cosine => {
let dot = query
.iter()
.zip(item.iter())
.map(|(a, b)| a * b)
.sum::<f32>();
let q_norm = query.iter().map(|v| v * v).sum::<f32>().sqrt();
let i_norm = item.iter().map(|v| v * v).sum::<f32>().sqrt();
if q_norm == 0.0 || i_norm == 0.0 {
return Ok(0.0);
}
Ok(dot / (q_norm * i_norm))
}
Metric::L2 => {
let dist = query
.iter()
.zip(item.iter())
.map(|(a, b)| {
let d = a - b;
d * d
})
.sum::<f32>()
.sqrt();
Ok(-dist)
}
Metric::InnerProduct => Ok(query
.iter()
.zip(item.iter())
.map(|(a, b)| a * b)
.sum::<f32>()),
}
}
#[cfg(all(test, not(target_arch = "wasm32")))]
mod tests {
use super::*;
#[test]
fn rejects_dimension_mismatch() {
let vt = VectorType::new(3, Metric::Cosine);
let err = vt.validate(&[1.0, 2.0]).unwrap_err();
assert!(matches!(
err,
Error::DimensionMismatch {
expected: 3,
actual: 2
}
));
let err = score(Metric::L2, &[1.0, 2.0], &[1.0]).unwrap_err();
assert!(matches!(
err,
Error::DimensionMismatch {
expected: 2,
actual: 1
}
));
}
#[test]
fn computes_cosine() {
let vt = VectorType::new(3, Metric::Cosine);
let s = vt.score(&[1.0, 0.0, 0.0], &[0.0, 1.0, 0.0]).unwrap();
assert_eq!(s, 0.0);
let s = vt.score(&[1.0, 1.0, 0.0], &[1.0, 1.0, 0.0]).unwrap();
assert!((s - 1.0).abs() < 1e-6);
}
#[test]
fn computes_l2_as_negative_distance() {
let s = score(Metric::L2, &[0.0, 0.0], &[3.0, 4.0]).unwrap();
assert!((s + 5.0).abs() < 1e-6);
}
#[test]
fn computes_inner_product() {
let s = score(Metric::InnerProduct, &[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]).unwrap();
assert_eq!(s, 32.0);
}
#[test]
fn parses_metric_from_str() {
assert_eq!(Metric::from_str("cosine").unwrap(), Metric::Cosine);
assert_eq!(Metric::from_str("L2").unwrap(), Metric::L2);
assert_eq!(
Metric::from_str("inner_product").unwrap(),
Metric::InnerProduct
);
let err = Metric::from_str("chebyshev").unwrap_err();
assert!(matches!(err, Error::UnsupportedMetric { .. }));
}
}