use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use laurus::Engine;
use laurus::LaurusError;
use laurus::lexical::TermQuery;
use laurus::lexical::TextOption;
use laurus::storage::memory::MemoryStorage;
use laurus::vector::FlatOption;
use laurus::vector::QueryVector;
use laurus::vector::Vector;
use laurus::{DataValue, Document};
use laurus::{EmbedInput, EmbedInputType, Embedder};
use laurus::{FieldOption, Schema};
use laurus::{FusionAlgorithm, LexicalSearchQuery, SearchRequestBuilder, VectorSearchQuery};
#[derive(Debug, Clone)]
struct MockEmbedder {
vectors: Arc<Mutex<HashMap<String, Vec<f32>>>>,
}
impl MockEmbedder {
fn new() -> Self {
Self {
vectors: Arc::new(Mutex::new(HashMap::new())),
}
}
fn add(&self, text: &str, vector: Vec<f32>) {
self.vectors
.lock()
.unwrap()
.insert(text.to_string(), vector);
}
}
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed(
&self,
input: &EmbedInput<'_>,
) -> std::result::Result<laurus::vector::core::vector::Vector, LaurusError> {
match input {
EmbedInput::Text(text) => {
let map = self.vectors.lock().unwrap();
if let Some(vec) = map.get(*text) {
Ok(laurus::vector::core::vector::Vector::new(vec.clone()))
} else {
Err(LaurusError::invalid_argument(format!(
"MockEmbedder: unknown text '{}'",
text
)))
}
}
_ => Err(LaurusError::invalid_argument(
"MockEmbedder only supports text",
)),
}
}
fn supported_input_types(&self) -> Vec<EmbedInputType> {
vec![EmbedInputType::Text]
}
fn supports_text(&self) -> bool {
true
}
fn supports_image(&self) -> bool {
false
}
fn name(&self) -> &str {
"MockEmbedder"
}
fn as_any(&self) -> &dyn Any {
self
}
}
async fn create_hybrid_engine() -> std::result::Result<Engine, Box<dyn std::error::Error>> {
let embedder = Arc::new(MockEmbedder::new());
embedder.add("apple", vec![0.9, 0.1, 0.0]);
embedder.add("banana", vec![0.0, 1.0, 0.0]);
let schema = Schema::builder()
.add_field("title", FieldOption::Text(TextOption::default()))
.add_field(
"title_vec",
FieldOption::Flat(FlatOption {
dimension: 3,
..Default::default()
}),
)
.build();
let storage = Arc::new(MemoryStorage::new(Default::default()));
let engine = Engine::builder(storage, schema)
.embedder(embedder)
.build()
.await?;
Ok(engine)
}
#[tokio::test(flavor = "multi_thread")]
async fn test_hybrid_search_unification() -> std::result::Result<(), Box<dyn std::error::Error>> {
let engine = create_hybrid_engine().await?;
let payload1 = Document::builder()
.add_field("title", DataValue::Text("apple".into()))
.add_field("title_vec", DataValue::Text("apple".into()))
.build();
engine.add_document("1", payload1).await?;
let payload2 = Document::builder()
.add_field("title", DataValue::Text("banana".into()))
.add_field("title_vec", DataValue::Text("banana".into()))
.build();
engine.add_document("2", payload2).await?;
engine.commit().await?;
let req_vector = SearchRequestBuilder::new()
.vector_query(VectorSearchQuery::Vectors(vec![QueryVector {
vector: Vector::new(vec![0.9, 0.1, 0.0]),
weight: 1.0,
fields: None,
}]))
.limit(10)
.build();
let res_vector = engine.search(req_vector).await?;
assert!(!res_vector.is_empty(), "Vector search should return hits");
let top_id = res_vector[0].id.clone();
let req_lexical = SearchRequestBuilder::new()
.lexical_query(LexicalSearchQuery::Obj(Box::new(TermQuery::new(
"title", "banana",
))))
.limit(10)
.build();
let res_lexical = engine.search(req_lexical).await?;
assert!(!res_lexical.is_empty(), "Lexical search should return hits");
assert_ne!(
res_lexical[0].id, top_id,
"Banana should be different from Apple"
);
let req_hybrid = SearchRequestBuilder::new()
.vector_query(VectorSearchQuery::Vectors(vec![QueryVector {
vector: Vector::new(vec![0.0, 1.0, 0.0]),
weight: 1.0,
fields: None,
}]))
.lexical_query(LexicalSearchQuery::Obj(Box::new(TermQuery::new(
"title", "banana",
))))
.fusion_algorithm(FusionAlgorithm::RRF { k: 60.0 })
.limit(10)
.build();
let res_hybrid = engine.search(req_hybrid).await?;
assert!(!res_hybrid.is_empty());
assert_eq!(
res_hybrid[0].id, res_lexical[0].id,
"Banana should win in hybrid search"
);
Ok(())
}