pub mod config;
pub mod embedding_writer;
pub mod memory;
pub mod request;
pub mod response;
use std::sync::Arc;
use tokio::sync::Mutex;
use crate::data::{DataValue, Document};
use crate::embedding::embedder::{EmbedInput, Embedder};
use crate::embedding::per_field::PerFieldEmbedder;
use crate::error::{LaurusError, Result};
use crate::storage::Storage;
use crate::vector::core::vector::Vector;
use crate::vector::index::VectorIndex;
use crate::vector::index::config::VectorIndexTypeConfig;
use crate::vector::index::factory::VectorIndexFactory;
use crate::vector::search::searcher::{VectorIndexQuery, VectorIndexSearcher};
use crate::vector::writer::VectorIndexWriter;
use self::config::VectorIndexConfig;
use self::request::{VectorScoreMode, VectorSearchRequest};
use self::response::{VectorHit, VectorSearchResults, VectorStats};
pub struct VectorStore {
index: Box<dyn VectorIndex>,
writer_cache: Mutex<Option<Box<dyn VectorIndexWriter>>>,
searcher_cache: parking_lot::RwLock<Option<Box<dyn VectorIndexSearcher>>>,
}
impl std::fmt::Debug for VectorStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VectorStore")
.field("index", &self.index)
.finish()
}
}
impl VectorStore {
pub fn new(storage: Arc<dyn Storage>, config: VectorIndexConfig) -> Result<Self> {
let index_type_config = Self::extract_index_type_config(&config);
Self::with_index_type_config(storage, index_type_config)
}
pub fn with_index_type_config(
storage: Arc<dyn Storage>,
config: VectorIndexTypeConfig,
) -> Result<Self> {
let index = VectorIndexFactory::open_or_create(storage, "vector_index", config)?;
Ok(Self {
index,
writer_cache: Mutex::new(None),
searcher_cache: parking_lot::RwLock::new(None),
})
}
fn extract_index_type_config(config: &VectorIndexConfig) -> VectorIndexTypeConfig {
use crate::vector::core::field::FieldOption;
use crate::vector::index::config::{FlatIndexConfig, HnswIndexConfig, IvfIndexConfig};
for field_config in config.fields.values() {
if let Some(ref vector_opt) = field_config.vector {
return match vector_opt {
FieldOption::Flat(opt) => VectorIndexTypeConfig::Flat(FlatIndexConfig {
dimension: opt.dimension,
distance_metric: opt.distance,
embedder: config.embedder.clone(),
..Default::default()
}),
FieldOption::Hnsw(opt) => VectorIndexTypeConfig::HNSW(HnswIndexConfig {
dimension: opt.dimension,
distance_metric: opt.distance,
m: opt.m,
ef_construction: opt.ef_construction,
embedder: config.embedder.clone(),
..Default::default()
}),
FieldOption::Ivf(opt) => VectorIndexTypeConfig::IVF(IvfIndexConfig {
dimension: opt.dimension,
distance_metric: opt.distance,
n_clusters: opt.n_clusters,
n_probe: opt.n_probe,
embedder: config.embedder.clone(),
..Default::default()
}),
};
}
}
VectorIndexTypeConfig::HNSW(HnswIndexConfig {
embedder: config.embedder.clone(),
..Default::default()
})
}
pub async fn upsert_document_by_internal_id(&self, doc_id: u64, doc: Document) -> Result<()> {
let embedder = self.index.embedder();
let mut embedded_vectors: Vec<(u64, String, Vector)> = Vec::new();
for (field_name, value) in &doc.fields {
let vector = match value {
DataValue::Vector(v) => Vector::new(v.clone()),
DataValue::Text(_) | DataValue::Bytes(_, _) => {
Self::embed_value(&*embedder, field_name, value).await?
}
_ => continue,
};
embedded_vectors.push((doc_id, field_name.clone(), vector));
}
let mut guard = self.writer_cache.lock().await;
if guard.is_none() {
*guard = Some(self.index.writer()?);
}
let writer = guard.as_mut().unwrap();
writer.delete_document(doc_id)?;
writer.add_vectors(embedded_vectors)?;
Ok(())
}
async fn embed_value(
embedder: &dyn Embedder,
field_name: &str,
value: &DataValue,
) -> Result<Vector> {
match value {
DataValue::Text(_) if !embedder.supports_text() => {
return Err(LaurusError::invalid_argument(format!(
"Embedder '{}' does not support text input",
embedder.name()
)));
}
DataValue::Bytes(_, mime) if !embedder.supports_image() => {
if mime.as_ref().is_some_and(|m| m.starts_with("image/")) {
return Err(LaurusError::invalid_argument(format!(
"Embedder '{}' does not support image input",
embedder.name()
)));
}
}
_ => {}
}
let (text_owned, bytes_owned, mime_owned) = match value {
DataValue::Text(t) => (Some(t.clone()), None, None),
DataValue::Bytes(b, m) => (None, Some(b.clone()), m.clone()),
_ => {
return Err(LaurusError::invalid_argument(
"Unsupported data type for embedding",
));
}
};
let input = if let Some(ref text) = text_owned {
EmbedInput::Text(text)
} else if let Some(ref bytes) = bytes_owned {
EmbedInput::Bytes(bytes, mime_owned.as_deref())
} else {
return Err(LaurusError::internal("Unreachable state in embed_value"));
};
if let Some(per_field) = embedder.as_any().downcast_ref::<PerFieldEmbedder>() {
per_field.embed_field(field_name, &input).await
} else {
embedder.embed(&input).await
}
}
pub async fn delete_document_by_internal_id(&self, doc_id: u64) -> Result<()> {
let mut guard = self.writer_cache.lock().await;
if guard.is_none() {
*guard = Some(self.index.writer()?);
}
let writer = guard.as_mut().unwrap();
writer.delete_document(doc_id)?;
Ok(())
}
pub async fn commit(&self) -> Result<()> {
if let Some(mut writer) = self.writer_cache.lock().await.take() {
writer.commit()?;
}
self.index.storage().sync()?;
self.index.refresh()?;
*self.searcher_cache.write() = None;
Ok(())
}
pub fn optimize(&self) -> Result<()> {
self.index.optimize()?;
*self.searcher_cache.write() = None;
Ok(())
}
pub fn refresh(&self) -> Result<()> {
*self.searcher_cache.write() = None;
Ok(())
}
fn acquire_searcher_guard(
&self,
) -> Result<parking_lot::RwLockReadGuard<'_, Option<Box<dyn VectorIndexSearcher>>>> {
{
let guard = self.searcher_cache.read();
if guard.is_some() {
return Ok(guard);
}
}
let mut guard = self.searcher_cache.write();
if guard.is_none() {
*guard = Some(self.index.searcher()?);
}
Ok(parking_lot::RwLockWriteGuard::downgrade(guard))
}
pub fn search_index(
&self,
request: &VectorIndexQuery,
) -> Result<crate::vector::search::searcher::VectorIndexQueryResults> {
let guard = self.acquire_searcher_guard()?;
guard.as_ref().unwrap().search(request)
}
pub fn search(&self, request: VectorSearchRequest) -> Result<VectorSearchResults> {
use crate::vector::search::searcher::VectorSearchQuery;
let query_vectors = match &request.query {
VectorSearchQuery::Vectors(vecs) => vecs,
VectorSearchQuery::Payloads(_) => {
return Err(crate::error::LaurusError::invalid_argument(
"VectorStore::search requires pre-embedded vectors; \
Payloads must be embedded before calling this method",
));
}
};
if query_vectors.is_empty() {
return Ok(VectorSearchResults::default());
}
let searcher_guard = self.acquire_searcher_guard()?;
let searcher = searcher_guard.as_ref().unwrap();
let mut all_hits: std::collections::HashMap<u64, f32> = std::collections::HashMap::new();
for qv in query_vectors {
let index_request = VectorIndexQuery::new(qv.vector.clone())
.top_k(request.params.limit.saturating_mul(2));
let results = searcher.search(&index_request)?;
for result in results.results {
if request
.params
.allowed_ids
.as_ref()
.is_some_and(|allowed| !allowed.contains(&result.doc_id))
{
continue;
}
if result.similarity < request.params.min_score {
continue;
}
let weighted_score = result.similarity * qv.weight;
let entry = all_hits.entry(result.doc_id).or_insert(0.0);
match request.params.score_mode {
VectorScoreMode::WeightedSum | VectorScoreMode::LateInteraction => {
*entry += weighted_score;
}
VectorScoreMode::MaxSim => {
if weighted_score > *entry {
*entry = weighted_score;
}
}
}
}
}
let mut hits: Vec<VectorHit> = all_hits
.into_iter()
.map(|(doc_id, score)| VectorHit {
doc_id,
score,
field_hits: vec![],
})
.collect();
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
if hits.len() > request.params.limit {
hits.truncate(request.params.limit);
}
Ok(VectorSearchResults { hits })
}
pub fn count(&self, request: VectorIndexQuery) -> Result<u64> {
let guard = self.acquire_searcher_guard()?;
guard.as_ref().unwrap().count(request)
}
pub fn stats(&self) -> Result<VectorStats> {
let reader = self.index.reader()?;
let doc_count = reader.vector_count();
let index_dimension = reader.dimension();
let mut fields = std::collections::HashMap::new();
if let Ok(field_names) = reader.field_names() {
for name in field_names {
let vectors = reader.get_vectors_by_field(&name).unwrap_or_default();
let vector_count = vectors.len();
let dimension = vectors
.first()
.map(|(_, v)| v.data.len())
.unwrap_or(index_dimension);
fields.insert(
name,
crate::vector::index::field::VectorFieldStats {
vector_count,
dimension,
},
);
}
}
Ok(VectorStats {
document_count: doc_count,
fields,
})
}
pub fn storage(&self) -> &Arc<dyn Storage> {
self.index.storage()
}
pub async fn close(&self) -> Result<()> {
*self.writer_cache.lock().await = None;
*self.searcher_cache.write() = None;
self.index.close()
}
pub fn is_closed(&self) -> bool {
self.index.is_closed()
}
pub fn embedder(&self) -> Arc<dyn Embedder> {
self.index.embedder()
}
pub fn last_wal_seq(&self) -> u64 {
self.index.last_wal_seq()
}
pub fn set_last_wal_seq(&self, seq: u64) {
let _ = self.index.set_last_wal_seq(seq);
}
pub async fn add_field(
&self,
name: &str,
embedder: Option<Arc<dyn crate::embedding::embedder::Embedder>>,
) {
if let Some(field_embedder) = embedder {
let index_embedder = self.index.embedder();
if let Some(pfe) = index_embedder
.as_any()
.downcast_ref::<crate::embedding::per_field::PerFieldEmbedder>()
{
pfe.add_embedder(name, field_embedder);
}
}
*self.writer_cache.lock().await = None;
*self.searcher_cache.write() = None;
}
pub async fn delete_field(&self, name: &str) {
let index_embedder = self.index.embedder();
if let Some(pfe) = index_embedder
.as_any()
.downcast_ref::<crate::embedding::per_field::PerFieldEmbedder>()
{
pfe.remove_embedder(name);
}
*self.writer_cache.lock().await = None;
*self.searcher_cache.write() = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::memory::{MemoryStorage, MemoryStorageConfig};
#[test]
fn test_vectorstore_creation() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let config = VectorIndexTypeConfig::default();
let store = VectorStore::with_index_type_config(storage, config).unwrap();
assert!(!store.is_closed());
}
#[tokio::test]
async fn test_vectorstore_close() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let config = VectorIndexTypeConfig::default();
let store = VectorStore::with_index_type_config(storage, config).unwrap();
assert!(!store.is_closed());
store.close().await.unwrap();
assert!(store.is_closed());
}
}