#![warn(missing_docs)]
#![warn(rustdoc::missing_crate_level_docs)]
#[cfg(feature = "wasm")]
#[global_allocator]
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
pub mod distance;
pub mod e8;
pub mod error;
pub mod filter;
pub mod h4;
pub mod hnsw;
pub mod metadata;
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
pub mod persistence;
pub mod quantization;
pub mod storage;
#[cfg(feature = "python")]
pub mod python;
pub use distance::Distance;
pub use e8::{E8Codec, HadamardTransform};
pub use error::{EmbedVecError, Result};
pub use filter::FilterExpr;
pub use h4::{H4Codec, hadamard4_inplace};
pub use hnsw::HnswIndex;
pub use metadata::Metadata;
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
pub use persistence::{BackendConfig, BackendType, PersistenceBackend};
pub use quantization::Quantization;
pub use storage::VectorStorage;
use ordered_float::OrderedFloat;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Hit {
pub id: usize,
pub score: f32,
pub payload: Metadata,
}
impl Hit {
pub fn new(id: usize, score: f32, payload: Metadata) -> Self {
Self { id, score, payload }
}
}
#[derive(Debug, Clone)]
pub struct EmbedVecBuilder {
dimension: usize,
distance: Distance,
m: usize,
ef_construction: usize,
quantization: Quantization,
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
persistence_config: Option<persistence::BackendConfig>,
}
impl EmbedVecBuilder {
pub fn new(dimension: usize) -> Self {
Self {
dimension,
distance: Distance::Cosine,
m: 16,
ef_construction: 200,
quantization: Quantization::None,
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
persistence_config: None,
}
}
pub fn dimension(mut self, dim: usize) -> Self {
self.dimension = dim;
self
}
pub fn metric(mut self, distance: Distance) -> Self {
self.distance = distance;
self
}
pub fn m(mut self, m: usize) -> Self {
self.m = m;
self
}
pub fn ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
pub fn quantization(mut self, quant: Quantization) -> Self {
self.quantization = quant;
self
}
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
pub fn persistence(mut self, path: impl Into<String>) -> Self {
self.persistence_config = Some(persistence::BackendConfig::new(path));
self
}
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
pub fn persistence_config(mut self, config: persistence::BackendConfig) -> Self {
self.persistence_config = Some(config);
self
}
#[cfg(feature = "async")]
pub async fn build(self) -> Result<EmbedVec> {
EmbedVec::from_builder(self).await
}
#[cfg(not(feature = "async"))]
pub fn build(self) -> Result<EmbedVec> {
EmbedVec::new_internal(
self.dimension,
self.distance,
self.m,
self.ef_construction,
self.quantization,
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
self.persistence_config,
)
}
}
pub struct EmbedVec {
dimension: usize,
distance: Distance,
pub index: Arc<RwLock<HnswIndex>>,
pub storage: Arc<RwLock<VectorStorage>>,
pub metadata: Arc<RwLock<Vec<Metadata>>>,
quantization: Quantization,
e8_codec: Option<E8Codec>,
h4_codec: Option<H4Codec>,
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
backend: Option<Box<dyn persistence::PersistenceBackend>>,
}
impl EmbedVec {
#[cfg(all(feature = "async", any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
pub async fn new(
dim: usize,
distance: Distance,
m: usize,
ef_construction: usize,
) -> Result<Self> {
Self::new_internal(dim, distance, m, ef_construction, Quantization::None, None)
}
#[cfg(all(feature = "async", not(any(feature = "persistence-sled", feature = "persistence-rocksdb"))))]
pub async fn new(
dim: usize,
distance: Distance,
m: usize,
ef_construction: usize,
) -> Result<Self> {
Self::new_internal(dim, distance, m, ef_construction, Quantization::None)
}
#[cfg(all(feature = "async", any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
pub async fn with_persistence(
path: impl AsRef<std::path::Path>,
dim: usize,
distance: Distance,
m: usize,
ef_construction: usize,
) -> Result<Self> {
let path_str = path.as_ref().to_string_lossy().to_string();
let config = persistence::BackendConfig::new(path_str);
Self::new_internal(
dim,
distance,
m,
ef_construction,
Quantization::None,
Some(config),
)
}
#[cfg(all(feature = "async", any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
pub async fn with_backend(
config: persistence::BackendConfig,
dim: usize,
distance: Distance,
m: usize,
ef_construction: usize,
) -> Result<Self> {
Self::new_internal(
dim,
distance,
m,
ef_construction,
Quantization::None,
Some(config),
)
}
#[cfg(feature = "async")]
async fn from_builder(builder: EmbedVecBuilder) -> Result<Self> {
Self::new_internal(
builder.dimension,
builder.distance,
builder.m,
builder.ef_construction,
builder.quantization,
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
builder.persistence_config,
)
}
#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
pub fn new_internal(
dim: usize,
distance: Distance,
m: usize,
ef_construction: usize,
quantization: Quantization,
persistence_config: Option<persistence::BackendConfig>,
) -> Result<Self> {
if dim == 0 {
return Err(EmbedVecError::InvalidDimension(dim));
}
let index = HnswIndex::new(m, ef_construction, distance);
let storage = VectorStorage::new(dim, quantization.clone());
let e8_codec = match &quantization {
Quantization::E8 { bits_per_block, use_hadamard, random_seed } =>
Some(E8Codec::new(dim, *bits_per_block, *use_hadamard, *random_seed)),
_ => None,
};
let h4_codec = match &quantization {
Quantization::H4 { use_hadamard, random_seed } =>
Some(H4Codec::new(dim, *use_hadamard, *random_seed)),
_ => None,
};
let backend = if let Some(config) = persistence_config {
Some(persistence::create_backend(&config)?)
} else {
None
};
Ok(Self {
dimension: dim,
distance,
index: Arc::new(RwLock::new(index)),
storage: Arc::new(RwLock::new(storage)),
metadata: Arc::new(RwLock::new(Vec::new())),
quantization,
e8_codec,
h4_codec,
backend,
})
}
#[cfg(not(any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
pub fn new_internal(
dim: usize,
distance: Distance,
m: usize,
ef_construction: usize,
quantization: Quantization,
) -> Result<Self> {
if dim == 0 {
return Err(EmbedVecError::InvalidDimension(dim));
}
let index = HnswIndex::new(m, ef_construction, distance);
let storage = VectorStorage::new(dim, quantization.clone());
let e8_codec = match &quantization {
Quantization::E8 { bits_per_block, use_hadamard, random_seed } =>
Some(E8Codec::new(dim, *bits_per_block, *use_hadamard, *random_seed)),
_ => None,
};
let h4_codec = match &quantization {
Quantization::H4 { use_hadamard, random_seed } =>
Some(H4Codec::new(dim, *use_hadamard, *random_seed)),
_ => None,
};
Ok(Self {
dimension: dim,
distance,
index: Arc::new(RwLock::new(index)),
storage: Arc::new(RwLock::new(storage)),
metadata: Arc::new(RwLock::new(Vec::new())),
quantization,
e8_codec,
h4_codec,
})
}
pub fn builder() -> EmbedVecBuilder {
EmbedVecBuilder::new(768) }
#[cfg(feature = "async")]
pub async fn add(&mut self, vector: &[f32], payload: impl Into<Metadata>) -> Result<usize> {
self.add_internal(vector, payload.into())
}
#[cfg(feature = "async")]
pub async fn add_many(
&mut self,
vectors: &[Vec<f32>],
payloads: Vec<impl Into<Metadata>>,
) -> Result<()> {
if vectors.len() != payloads.len() {
return Err(EmbedVecError::MismatchedLengths {
vectors: vectors.len(),
payloads: payloads.len(),
});
}
for (vector, payload) in vectors.iter().zip(payloads.into_iter()) {
self.add_internal(vector, payload.into())?;
}
Ok(())
}
pub fn add_internal(&mut self, vector: &[f32], payload: Metadata) -> Result<usize> {
if vector.len() != self.dimension {
return Err(EmbedVecError::DimensionMismatch {
expected: self.dimension,
got: vector.len(),
});
}
let processed_vector = if self.distance == Distance::Cosine {
normalize_vector(vector)
} else {
vector.to_vec()
};
let id = {
let mut storage = self.storage.write();
storage.add(&processed_vector, self.e8_codec.as_ref(), self.h4_codec.as_ref())?
};
{
let mut meta = self.metadata.write();
if id >= meta.len() {
meta.resize(id + 1, Metadata::default());
}
meta[id] = payload;
}
{
let mut index = self.index.write();
let storage = self.storage.read();
index.insert(id, &processed_vector, &storage, self.e8_codec.as_ref())?;
}
Ok(id)
}
#[cfg(feature = "async")]
pub async fn search(
&self,
query: &[f32],
k: usize,
ef_search: usize,
filter: Option<FilterExpr>,
) -> Result<Vec<Hit>> {
self.search_internal(query, k, ef_search, filter)
}
pub fn search_internal(
&self,
query: &[f32],
k: usize,
ef_search: usize,
filter: Option<FilterExpr>,
) -> Result<Vec<Hit>> {
if query.len() != self.dimension {
return Err(EmbedVecError::DimensionMismatch {
expected: self.dimension,
got: query.len(),
});
}
let processed_query = if self.distance == Distance::Cosine {
normalize_vector(query)
} else {
query.to_vec()
};
let candidates = {
let index = self.index.read();
let storage = self.storage.read();
index.search(
&processed_query,
k,
ef_search,
&storage,
self.e8_codec.as_ref(),
)?
};
let metadata = self.metadata.read();
let mut results: Vec<Hit> = candidates
.into_iter()
.filter_map(|(id, score)| {
let payload = metadata.get(id)?.clone();
if let Some(ref f) = filter {
if !f.matches(&payload) {
return None;
}
}
Some(Hit::new(id, score, payload))
})
.take(k)
.collect();
results.sort_by_key(|h| OrderedFloat(h.score));
Ok(results)
}
#[cfg(feature = "async")]
pub async fn len(&self) -> usize {
self.storage.read().len()
}
#[cfg(feature = "async")]
pub async fn is_empty(&self) -> bool {
self.storage.read().is_empty()
}
#[cfg(feature = "async")]
pub async fn clear(&mut self) -> Result<()> {
{
let mut storage = self.storage.write();
storage.clear();
}
{
let mut metadata = self.metadata.write();
metadata.clear();
}
{
let mut index = self.index.write();
index.clear();
}
Ok(())
}
#[cfg(all(feature = "async", feature = "persistence"))]
pub async fn flush(&mut self) -> Result<()> {
if let Some(ref db) = self.db {
db.flush()
.map_err(|e| EmbedVecError::PersistenceError(e.to_string()))?;
}
Ok(())
}
pub fn quantization(&self) -> &Quantization {
&self.quantization
}
#[cfg(feature = "async")]
pub async fn set_quantization(&mut self, quant: Quantization) -> Result<()> {
self.e8_codec = match &quant {
Quantization::E8 { bits_per_block, use_hadamard, random_seed } =>
Some(E8Codec::new(self.dimension, *bits_per_block, *use_hadamard, *random_seed)),
_ => None,
};
self.h4_codec = match &quant {
Quantization::H4 { use_hadamard, random_seed } =>
Some(H4Codec::new(self.dimension, *use_hadamard, *random_seed)),
_ => None,
};
self.quantization = quant.clone();
let mut storage = self.storage.write();
storage.set_quantization(quant, self.e8_codec.as_ref(), self.h4_codec.as_ref())?;
Ok(())
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn distance(&self) -> Distance {
self.distance
}
}
fn normalize_vector(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-10 {
v.iter().map(|x| x / norm).collect()
} else {
v.to_vec()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_basic_operations() {
let mut db = EmbedVec::new(4, Distance::Cosine, 16, 100).await.unwrap();
let id = db
.add(&[1.0, 0.0, 0.0, 0.0], serde_json::json!({"test": "value"}))
.await
.unwrap();
assert_eq!(id, 0);
let results = db.search(&[1.0, 0.0, 0.0, 0.0], 1, 50, None).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 0);
}
#[tokio::test]
async fn test_dimension_mismatch() {
let mut db = EmbedVec::new(4, Distance::Cosine, 16, 100).await.unwrap();
let result = db
.add(&[1.0, 0.0, 0.0], serde_json::json!({}))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_h4_quantization_end_to_end() {
let mut db = EmbedVec::builder()
.dimension(8)
.metric(Distance::Cosine)
.quantization(Quantization::h4_default())
.build()
.await
.unwrap();
let id = db
.add(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], serde_json::json!({"lattice": "h4"}))
.await
.unwrap();
assert_eq!(id, 0);
let results = db
.search(&[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 1, 50, None)
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].id, 0);
}
}