1#![warn(missing_docs)]
124#![warn(rustdoc::missing_crate_level_docs)]
125
126#[cfg(feature = "wasm")]
128#[global_allocator]
129static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
130
131pub mod distance;
132pub mod e8;
133pub mod error;
134pub mod filter;
135pub mod hnsw;
136pub mod metadata;
137#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
138pub mod persistence;
139pub mod quantization;
140pub mod storage;
141
142#[cfg(feature = "python")]
143pub mod python;
144
145pub use distance::Distance;
147pub use e8::{E8Codec, HadamardTransform};
148pub use error::{EmbedVecError, Result};
149pub use filter::FilterExpr;
150pub use hnsw::HnswIndex;
151pub use metadata::Metadata;
152#[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
153pub use persistence::{BackendConfig, BackendType, PersistenceBackend};
154pub use quantization::Quantization;
155pub use storage::VectorStorage;
156
157use ordered_float::OrderedFloat;
158use parking_lot::RwLock;
159use serde::{Deserialize, Serialize};
160use std::sync::Arc;
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct Hit {
165 pub id: usize,
167 pub score: f32,
169 pub payload: Metadata,
171}
172
173impl Hit {
174 pub fn new(id: usize, score: f32, payload: Metadata) -> Self {
176 Self { id, score, payload }
177 }
178}
179
180#[derive(Debug, Clone)]
182pub struct EmbedVecBuilder {
183 dimension: usize,
184 distance: Distance,
185 m: usize,
186 ef_construction: usize,
187 quantization: Quantization,
188 #[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
189 persistence_config: Option<persistence::BackendConfig>,
190}
191
192impl EmbedVecBuilder {
193 pub fn new(dimension: usize) -> Self {
195 Self {
196 dimension,
197 distance: Distance::Cosine,
198 m: 16,
199 ef_construction: 200,
200 quantization: Quantization::None,
201 #[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
202 persistence_config: None,
203 }
204 }
205
206 pub fn dimension(mut self, dim: usize) -> Self {
208 self.dimension = dim;
209 self
210 }
211
212 pub fn metric(mut self, distance: Distance) -> Self {
214 self.distance = distance;
215 self
216 }
217
218 pub fn m(mut self, m: usize) -> Self {
220 self.m = m;
221 self
222 }
223
224 pub fn ef_construction(mut self, ef: usize) -> Self {
226 self.ef_construction = ef;
227 self
228 }
229
230 pub fn quantization(mut self, quant: Quantization) -> Self {
232 self.quantization = quant;
233 self
234 }
235
236 #[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
238 pub fn persistence(mut self, path: impl Into<String>) -> Self {
239 self.persistence_config = Some(persistence::BackendConfig::new(path));
240 self
241 }
242
243 #[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
245 pub fn persistence_config(mut self, config: persistence::BackendConfig) -> Self {
246 self.persistence_config = Some(config);
247 self
248 }
249
250 #[cfg(feature = "async")]
252 pub async fn build(self) -> Result<EmbedVec> {
253 EmbedVec::from_builder(self).await
254 }
255
256 #[cfg(not(feature = "async"))]
258 pub fn build(self) -> Result<EmbedVec> {
259 EmbedVec::new_internal(
260 self.dimension,
261 self.distance,
262 self.m,
263 self.ef_construction,
264 self.quantization,
265 #[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
266 self.persistence_config,
267 )
268 }
269}
270
271pub struct EmbedVec {
276 dimension: usize,
278 distance: Distance,
280 pub index: Arc<RwLock<HnswIndex>>,
282 pub storage: Arc<RwLock<VectorStorage>>,
284 pub metadata: Arc<RwLock<Vec<Metadata>>>,
286 quantization: Quantization,
288 e8_codec: Option<E8Codec>,
290 #[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
292 backend: Option<Box<dyn persistence::PersistenceBackend>>,
293}
294
295impl EmbedVec {
296 #[cfg(all(feature = "async", any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
314 pub async fn new(
315 dim: usize,
316 distance: Distance,
317 m: usize,
318 ef_construction: usize,
319 ) -> Result<Self> {
320 Self::new_internal(dim, distance, m, ef_construction, Quantization::None, None)
321 }
322
323 #[cfg(all(feature = "async", not(any(feature = "persistence-sled", feature = "persistence-rocksdb"))))]
324 pub async fn new(
325 dim: usize,
326 distance: Distance,
327 m: usize,
328 ef_construction: usize,
329 ) -> Result<Self> {
330 Self::new_internal(dim, distance, m, ef_construction, Quantization::None)
331 }
332
333 #[cfg(all(feature = "async", any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
335 pub async fn with_persistence(
336 path: impl AsRef<std::path::Path>,
337 dim: usize,
338 distance: Distance,
339 m: usize,
340 ef_construction: usize,
341 ) -> Result<Self> {
342 let path_str = path.as_ref().to_string_lossy().to_string();
343 let config = persistence::BackendConfig::new(path_str);
344 Self::new_internal(
345 dim,
346 distance,
347 m,
348 ef_construction,
349 Quantization::None,
350 Some(config),
351 )
352 }
353
354 #[cfg(all(feature = "async", any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
356 pub async fn with_backend(
357 config: persistence::BackendConfig,
358 dim: usize,
359 distance: Distance,
360 m: usize,
361 ef_construction: usize,
362 ) -> Result<Self> {
363 Self::new_internal(
364 dim,
365 distance,
366 m,
367 ef_construction,
368 Quantization::None,
369 Some(config),
370 )
371 }
372
373 #[cfg(feature = "async")]
375 async fn from_builder(builder: EmbedVecBuilder) -> Result<Self> {
376 Self::new_internal(
377 builder.dimension,
378 builder.distance,
379 builder.m,
380 builder.ef_construction,
381 builder.quantization,
382 #[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
383 builder.persistence_config,
384 )
385 }
386
387 #[cfg(any(feature = "persistence-sled", feature = "persistence-rocksdb"))]
389 pub fn new_internal(
390 dim: usize,
391 distance: Distance,
392 m: usize,
393 ef_construction: usize,
394 quantization: Quantization,
395 persistence_config: Option<persistence::BackendConfig>,
396 ) -> Result<Self> {
397 if dim == 0 {
398 return Err(EmbedVecError::InvalidDimension(dim));
399 }
400
401 let index = HnswIndex::new(m, ef_construction, distance);
402 let storage = VectorStorage::new(dim, quantization.clone());
403
404 let e8_codec = match &quantization {
405 Quantization::None => None,
406 Quantization::E8 {
407 bits_per_block,
408 use_hadamard,
409 random_seed,
410 } => Some(E8Codec::new(dim, *bits_per_block, *use_hadamard, *random_seed)),
411 };
412
413 let backend = if let Some(config) = persistence_config {
414 Some(persistence::create_backend(&config)?)
415 } else {
416 None
417 };
418
419 Ok(Self {
420 dimension: dim,
421 distance,
422 index: Arc::new(RwLock::new(index)),
423 storage: Arc::new(RwLock::new(storage)),
424 metadata: Arc::new(RwLock::new(Vec::new())),
425 quantization,
426 e8_codec,
427 backend,
428 })
429 }
430
431 #[cfg(not(any(feature = "persistence-sled", feature = "persistence-rocksdb")))]
433 pub fn new_internal(
434 dim: usize,
435 distance: Distance,
436 m: usize,
437 ef_construction: usize,
438 quantization: Quantization,
439 ) -> Result<Self> {
440 if dim == 0 {
441 return Err(EmbedVecError::InvalidDimension(dim));
442 }
443
444 let index = HnswIndex::new(m, ef_construction, distance);
445 let storage = VectorStorage::new(dim, quantization.clone());
446
447 let e8_codec = match &quantization {
448 Quantization::None => None,
449 Quantization::E8 {
450 bits_per_block,
451 use_hadamard,
452 random_seed,
453 } => Some(E8Codec::new(dim, *bits_per_block, *use_hadamard, *random_seed)),
454 };
455
456 Ok(Self {
457 dimension: dim,
458 distance,
459 index: Arc::new(RwLock::new(index)),
460 storage: Arc::new(RwLock::new(storage)),
461 metadata: Arc::new(RwLock::new(Vec::new())),
462 quantization,
463 e8_codec,
464 })
465 }
466
467 pub fn builder() -> EmbedVecBuilder {
469 EmbedVecBuilder::new(768) }
471
472 #[cfg(feature = "async")]
481 pub async fn add(&mut self, vector: &[f32], payload: impl Into<Metadata>) -> Result<usize> {
482 self.add_internal(vector, payload.into())
483 }
484
485 #[cfg(feature = "async")]
491 pub async fn add_many(
492 &mut self,
493 vectors: &[Vec<f32>],
494 payloads: Vec<impl Into<Metadata>>,
495 ) -> Result<()> {
496 if vectors.len() != payloads.len() {
497 return Err(EmbedVecError::MismatchedLengths {
498 vectors: vectors.len(),
499 payloads: payloads.len(),
500 });
501 }
502
503 for (vector, payload) in vectors.iter().zip(payloads.into_iter()) {
504 self.add_internal(vector, payload.into())?;
505 }
506
507 Ok(())
508 }
509
510 pub fn add_internal(&mut self, vector: &[f32], payload: Metadata) -> Result<usize> {
512 if vector.len() != self.dimension {
513 return Err(EmbedVecError::DimensionMismatch {
514 expected: self.dimension,
515 got: vector.len(),
516 });
517 }
518
519 let processed_vector = if self.distance == Distance::Cosine {
521 normalize_vector(vector)
522 } else {
523 vector.to_vec()
524 };
525
526 let id = {
528 let mut storage = self.storage.write();
529 storage.add(&processed_vector, self.e8_codec.as_ref())?
530 };
531
532 {
534 let mut meta = self.metadata.write();
535 if id >= meta.len() {
536 meta.resize(id + 1, Metadata::default());
537 }
538 meta[id] = payload;
539 }
540
541 {
543 let mut index = self.index.write();
544 let storage = self.storage.read();
545 index.insert(id, &processed_vector, &storage, self.e8_codec.as_ref())?;
546 }
547
548 Ok(id)
549 }
550
551 #[cfg(feature = "async")]
562 pub async fn search(
563 &self,
564 query: &[f32],
565 k: usize,
566 ef_search: usize,
567 filter: Option<FilterExpr>,
568 ) -> Result<Vec<Hit>> {
569 self.search_internal(query, k, ef_search, filter)
570 }
571
572 pub fn search_internal(
574 &self,
575 query: &[f32],
576 k: usize,
577 ef_search: usize,
578 filter: Option<FilterExpr>,
579 ) -> Result<Vec<Hit>> {
580 if query.len() != self.dimension {
581 return Err(EmbedVecError::DimensionMismatch {
582 expected: self.dimension,
583 got: query.len(),
584 });
585 }
586
587 let processed_query = if self.distance == Distance::Cosine {
589 normalize_vector(query)
590 } else {
591 query.to_vec()
592 };
593
594 let candidates = {
596 let index = self.index.read();
597 let storage = self.storage.read();
598 index.search(
599 &processed_query,
600 k,
601 ef_search,
602 &storage,
603 self.e8_codec.as_ref(),
604 )?
605 };
606
607 let metadata = self.metadata.read();
609 let mut results: Vec<Hit> = candidates
610 .into_iter()
611 .filter_map(|(id, score)| {
612 let payload = metadata.get(id)?.clone();
613
614 if let Some(ref f) = filter {
616 if !f.matches(&payload) {
617 return None;
618 }
619 }
620
621 Some(Hit::new(id, score, payload))
622 })
623 .take(k)
624 .collect();
625
626 results.sort_by_key(|h| OrderedFloat(h.score));
628
629 Ok(results)
630 }
631
632 #[cfg(feature = "async")]
634 pub async fn len(&self) -> usize {
635 self.storage.read().len()
636 }
637
638 #[cfg(feature = "async")]
640 pub async fn is_empty(&self) -> bool {
641 self.storage.read().is_empty()
642 }
643
644 #[cfg(feature = "async")]
646 pub async fn clear(&mut self) -> Result<()> {
647 {
648 let mut storage = self.storage.write();
649 storage.clear();
650 }
651 {
652 let mut metadata = self.metadata.write();
653 metadata.clear();
654 }
655 {
656 let mut index = self.index.write();
657 index.clear();
658 }
659 Ok(())
660 }
661
662 #[cfg(all(feature = "async", feature = "persistence"))]
664 pub async fn flush(&mut self) -> Result<()> {
665 if let Some(ref db) = self.db {
666 db.flush()
667 .map_err(|e| EmbedVecError::PersistenceError(e.to_string()))?;
668 }
669 Ok(())
670 }
671
672 pub fn quantization(&self) -> &Quantization {
674 &self.quantization
675 }
676
677 #[cfg(feature = "async")]
679 pub async fn set_quantization(&mut self, quant: Quantization) -> Result<()> {
680 self.quantization = quant.clone();
681 self.e8_codec = match &quant {
682 Quantization::None => None,
683 Quantization::E8 {
684 bits_per_block,
685 use_hadamard,
686 random_seed,
687 } => Some(E8Codec::new(
688 self.dimension,
689 *bits_per_block,
690 *use_hadamard,
691 *random_seed,
692 )),
693 };
694
695 let mut storage = self.storage.write();
697 storage.set_quantization(quant, self.e8_codec.as_ref())?;
698
699 Ok(())
700 }
701
702 pub fn dimension(&self) -> usize {
704 self.dimension
705 }
706
707 pub fn distance(&self) -> Distance {
709 self.distance
710 }
711}
712
713fn normalize_vector(v: &[f32]) -> Vec<f32> {
715 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
716 if norm > 1e-10 {
717 v.iter().map(|x| x / norm).collect()
718 } else {
719 v.to_vec()
720 }
721}
722
723#[cfg(test)]
724mod tests {
725 use super::*;
726
727 #[tokio::test]
728 async fn test_basic_operations() {
729 let mut db = EmbedVec::new(4, Distance::Cosine, 16, 100).await.unwrap();
730
731 let id = db
732 .add(&[1.0, 0.0, 0.0, 0.0], serde_json::json!({"test": "value"}))
733 .await
734 .unwrap();
735 assert_eq!(id, 0);
736
737 let results = db.search(&[1.0, 0.0, 0.0, 0.0], 1, 50, None).await.unwrap();
738 assert_eq!(results.len(), 1);
739 assert_eq!(results[0].id, 0);
740 }
741
742 #[tokio::test]
743 async fn test_dimension_mismatch() {
744 let mut db = EmbedVec::new(4, Distance::Cosine, 16, 100).await.unwrap();
745
746 let result = db
747 .add(&[1.0, 0.0, 0.0], serde_json::json!({}))
748 .await;
749 assert!(result.is_err());
750 }
751}