use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
use crate::data::DataValue;
use crate::embedding::embedder::{EmbedInput, Embedder};
use crate::embedding::per_field::PerFieldEmbedder;
use crate::error::Result;
use crate::vector::core::vector::StoredVector;
use crate::vector::core::vector::Vector;
use crate::vector::store::config::VectorFieldConfig;
use crate::vector::store::request::QueryVector;
use crate::vector::writer::VectorIndexWriter;
#[async_trait]
pub trait VectorField: Send + Sync + Debug {
fn name(&self) -> &str;
fn config(&self) -> &VectorFieldConfig;
fn writer(&self) -> &dyn VectorFieldWriter;
fn reader(&self) -> &dyn VectorFieldReader;
fn writer_handle(&self) -> Arc<dyn VectorFieldWriter>;
fn reader_handle(&self) -> Arc<dyn VectorFieldReader>;
fn as_any(&self) -> &dyn Any;
async fn optimize(&self) -> Result<()> {
self.writer().optimize().await
}
}
#[async_trait]
pub trait VectorFieldWriter: Send + Sync + Debug {
async fn add_stored_vector(
&self,
doc_id: u64,
vector: &StoredVector,
version: u64,
) -> Result<()>;
async fn add_value(&self, doc_id: u64, value: &DataValue, version: u64) -> Result<()> {
if let DataValue::Vector(v) = value {
let sv = StoredVector::new(v.clone());
self.add_stored_vector(doc_id, &sv, version).await
} else {
Err(crate::error::LaurusError::invalid_argument(
"add_value not supported for this field writer (needs embedding helper)",
))
}
}
async fn delete_document(&self, doc_id: u64, version: u64) -> Result<()>;
async fn has_storage(&self) -> bool;
async fn vectors(&self) -> Vec<(u64, String, Vector)>;
async fn rebuild(&self, vectors: Vec<(u64, String, Vector)>) -> Result<()>;
async fn flush(&self) -> Result<()>;
async fn optimize(&self) -> Result<()>;
}
pub trait VectorFieldReader: Send + Sync + Debug {
fn search(&self, request: FieldSearchInput) -> Result<FieldSearchResults>;
fn stats(&self) -> Result<VectorFieldStats>;
}
#[derive(Debug, Clone)]
pub struct FieldSearchInput {
pub field: String,
pub query_vectors: Vec<QueryVector>,
pub limit: usize,
pub allowed_ids: Option<std::collections::HashSet<u64>>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct FieldSearchResults {
#[serde(default)]
pub hits: Vec<FieldHit>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FieldHit {
pub doc_id: u64,
pub field: String,
pub score: f32,
pub distance: f32,
}
#[derive(Debug, Clone, Copy, Default)]
pub struct VectorFieldStats {
pub vector_count: usize,
pub dimension: usize,
}
pub struct LegacyVectorFieldWriter<W: VectorIndexWriter> {
field_name: String,
writer: Mutex<W>,
embedder: Option<Arc<dyn Embedder>>,
}
impl<W: VectorIndexWriter> std::fmt::Debug for LegacyVectorFieldWriter<W> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LegacyVectorFieldWriter")
.field("field_name", &self.field_name)
.field("writer", &self.writer)
.field(
"embedder",
&self.embedder.as_ref().map(|e| e.name().to_string()),
)
.finish()
}
}
impl<W: VectorIndexWriter> LegacyVectorFieldWriter<W> {
pub fn new(field_name: impl Into<String>, writer: W) -> Self {
Self {
field_name: field_name.into(),
writer: Mutex::new(writer),
embedder: None,
}
}
pub fn with_embedder(mut self, embedder: Arc<dyn Embedder>) -> Self {
self.embedder = Some(embedder);
self
}
pub fn field_name(&self) -> &str {
&self.field_name
}
fn to_legacy_vector(&self, doc_id: u64, stored: &StoredVector) -> (u64, String, Vector) {
let vector = Vector::new(stored.data.to_vec());
(doc_id, self.field_name.clone(), vector)
}
#[cfg(test)]
pub(crate) async fn pending_vectors(&self) -> Vec<(u64, String, Vector)> {
let guard = self.writer.lock().await;
guard.vectors().to_vec()
}
}
#[async_trait]
impl<W> VectorFieldWriter for LegacyVectorFieldWriter<W>
where
W: VectorIndexWriter,
{
async fn add_stored_vector(
&self,
doc_id: u64,
vector: &StoredVector,
_version: u64,
) -> Result<()> {
let mut guard = self.writer.lock().await;
let legacy = self.to_legacy_vector(doc_id, vector);
guard.add_vectors(vec![legacy])
}
async fn add_value(&self, doc_id: u64, value: &DataValue, _version: u64) -> Result<()> {
if let DataValue::Vector(v) = value {
let mut guard = self.writer.lock().await;
let legacy = (doc_id, self.field_name.clone(), Vector::new(v.clone()));
return guard.add_vectors(vec![legacy]);
}
if let Some(ref embedder) = self.embedder {
let input = match value {
DataValue::Text(t) => EmbedInput::Text(t),
DataValue::Bytes(b, m) => EmbedInput::Bytes(b, m.as_deref()),
_ => {
return Err(crate::error::LaurusError::invalid_argument(
"Unsupported data type for embedding",
));
}
};
let vector = if let Some(pf) = embedder.as_any().downcast_ref::<PerFieldEmbedder>() {
pf.embed_field(&self.field_name, &input).await?
} else {
embedder.embed(&input).await?
};
let mut guard = self.writer.lock().await;
return guard.add_vectors(vec![(doc_id, self.field_name.clone(), vector)]);
}
let mut guard = self.writer.lock().await;
VectorIndexWriter::add_value(&mut *guard, doc_id, self.field_name.clone(), value.clone())
.await
}
async fn has_storage(&self) -> bool {
self.writer.lock().await.has_storage()
}
async fn vectors(&self) -> Vec<(u64, String, Vector)> {
self.writer.lock().await.vectors().to_vec()
}
async fn rebuild(&self, vectors: Vec<(u64, String, Vector)>) -> Result<()> {
let mut guard = self.writer.lock().await;
guard.rollback()?;
guard.build(vectors)?;
guard.finalize()?;
Ok(())
}
async fn delete_document(&self, doc_id: u64, _version: u64) -> Result<()> {
let mut guard = self.writer.lock().await;
let _ = guard.delete_document(doc_id);
Ok(())
}
async fn flush(&self) -> Result<()> {
self.writer.lock().await.commit()?;
Ok(())
}
async fn optimize(&self) -> Result<()> {
let vectors = self.vectors().await;
self.rebuild(vectors).await?;
self.flush().await
}
}
#[derive(Debug)]
pub struct AdapterBackedVectorField {
name: String,
config: VectorFieldConfig,
writer: Arc<dyn VectorFieldWriter>,
reader: Arc<dyn VectorFieldReader>,
}
impl AdapterBackedVectorField {
pub fn new(
name: impl Into<String>,
config: VectorFieldConfig,
writer: Arc<dyn VectorFieldWriter>,
reader: Arc<dyn VectorFieldReader>,
) -> Self {
Self {
name: name.into(),
config,
writer,
reader,
}
}
pub fn writer_handle(&self) -> &Arc<dyn VectorFieldWriter> {
&self.writer
}
pub fn reader_handle(&self) -> &Arc<dyn VectorFieldReader> {
&self.reader
}
}
#[async_trait]
impl VectorField for AdapterBackedVectorField {
fn name(&self) -> &str {
&self.name
}
fn config(&self) -> &VectorFieldConfig {
&self.config
}
fn writer(&self) -> &dyn VectorFieldWriter {
self.writer.as_ref()
}
fn reader(&self) -> &dyn VectorFieldReader {
self.reader.as_ref()
}
fn writer_handle(&self) -> Arc<dyn VectorFieldWriter> {
self.writer.clone()
}
fn reader_handle(&self) -> Arc<dyn VectorFieldReader> {
self.reader.clone()
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::vector::core::vector::StoredVector;
use crate::vector::index::config::{FlatIndexConfig, HnswIndexConfig, IvfIndexConfig};
use crate::vector::index::flat::writer::FlatIndexWriter;
use crate::vector::index::hnsw::writer::HnswIndexWriter;
use crate::vector::index::ivf::writer::IvfIndexWriter;
use crate::vector::writer::VectorIndexWriterConfig;
fn sample_stored_vector() -> StoredVector {
StoredVector::new(vec![1.0, 0.0])
}
fn flat_writer() -> FlatIndexWriter {
let config = FlatIndexConfig {
dimension: 2,
normalize_vectors: false,
..Default::default()
};
FlatIndexWriter::new(config, VectorIndexWriterConfig::default(), "test_flat").unwrap()
}
fn hnsw_writer() -> HnswIndexWriter {
let config = HnswIndexConfig {
dimension: 2,
normalize_vectors: false,
..Default::default()
};
HnswIndexWriter::new(config, VectorIndexWriterConfig::default(), "test_hnsw").unwrap()
}
fn ivf_writer() -> IvfIndexWriter {
let config = IvfIndexConfig {
dimension: 2,
normalize_vectors: false,
..Default::default()
};
IvfIndexWriter::new(config, VectorIndexWriterConfig::default(), "test_ivf").unwrap()
}
#[tokio::test]
async fn test_adapter_flat() {
let adapter = LegacyVectorFieldWriter::new("body", flat_writer());
assert_eq!(adapter.field_name(), "body");
assert!(!adapter.has_storage().await);
}
#[tokio::test]
async fn test_adapter_hnsw() {
let adapter = LegacyVectorFieldWriter::new("body", hnsw_writer());
assert_eq!(adapter.field_name(), "body");
}
#[tokio::test]
async fn test_adapter_ivf() {
let adapter = LegacyVectorFieldWriter::new("body", ivf_writer());
assert_eq!(adapter.field_name(), "body");
}
#[tokio::test]
async fn test_adapter_deletion_error_handling() {
let adapter = LegacyVectorFieldWriter::new("body", flat_writer());
let vector = sample_stored_vector();
adapter.add_stored_vector(3, &vector, 1).await.unwrap();
let result = adapter.delete_document(3, 2).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn adapter_stores_vector_with_correct_doc_id() {
let adapter = LegacyVectorFieldWriter::new("body", flat_writer());
let vector = sample_stored_vector();
adapter.add_stored_vector(5, &vector, 1).await.unwrap();
let pending = adapter.pending_vectors().await;
assert_eq!(pending.len(), 1);
assert_eq!(pending[0].0, 5);
assert_eq!(pending[0].1, "body");
assert_eq!(pending[0].2.data.len(), 2);
}
}