use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use parking_lot::RwLock;
use crate::storage::VectorIndexManager;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticSearchConfig {
pub default_store: String,
pub embedding_model: Option<String>,
pub dimensions: usize,
pub metric: DistanceMetric,
#[serde(default)]
pub hybrid_enabled: bool,
pub reranker_model: Option<String>,
#[serde(default)]
pub query_expansion: bool,
#[serde(default = "default_true")]
pub cache_embeddings: bool,
#[serde(default = "default_bm25_k1")]
pub bm25_k1: f32,
#[serde(default = "default_bm25_b")]
pub bm25_b: f32,
}
fn default_true() -> bool {
true
}
fn default_bm25_k1() -> f32 {
1.2
}
fn default_bm25_b() -> f32 {
0.75
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum DistanceMetric {
Cosine,
Euclidean,
DotProduct,
Manhattan,
}
impl Default for DistanceMetric {
fn default() -> Self {
Self::Cosine
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticSearchRequest {
pub query: SearchQuery,
pub stores: Option<Vec<String>>,
#[serde(default = "default_top_k")]
pub top_k: usize,
pub min_score: Option<f32>,
pub filters: Option<Vec<MetadataFilter>>,
#[serde(default)]
pub mode: SearchMode,
pub alpha: Option<f32>,
#[serde(default)]
pub include_vectors: bool,
#[serde(default = "default_true")]
pub include_metadata: bool,
#[serde(default)]
pub highlight: bool,
pub namespace: Option<String>,
#[serde(default)]
pub rerank: bool,
#[serde(default)]
pub expand_query: bool,
pub group_by: Option<String>,
pub distinct_by: Option<String>,
}
fn default_top_k() -> usize {
10
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum SearchQuery {
Text(String),
Vector(Vec<f32>),
MultiQuery(Vec<String>),
Image { image: String, alt_text: Option<String> },
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum SearchMode {
#[default]
Semantic,
Keyword,
Hybrid,
MultiModal,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetadataFilter {
pub field: String,
pub operator: FilterOperator,
pub value: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum FilterOperator {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
In,
NotIn,
Contains,
StartsWith,
EndsWith,
Exists,
IsNull,
IsNotNull,
Between,
Regex,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SemanticSearchResponse {
pub results: Vec<SearchResult>,
pub total: usize,
pub query_time_ms: u64,
pub embedding_time_ms: Option<u64>,
pub rerank_time_ms: Option<u64>,
pub expanded_queries: Option<Vec<String>>,
pub facets: Option<HashMap<String, Vec<FacetValue>>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SearchResult {
pub id: String,
pub score: f32,
pub content: Option<String>,
pub vector: Option<Vec<f32>>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
pub highlights: Option<Vec<Highlight>>,
pub store: String,
pub namespace: Option<String>,
pub rerank_score: Option<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Highlight {
pub field: String,
pub text: String,
pub positions: Vec<(usize, usize)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FacetValue {
pub value: String,
pub count: usize,
}
#[derive(Debug, Clone)]
pub struct IndexedDocument {
pub id: String,
pub content: String,
pub vector: Option<Vec<f32>>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
pub namespace: Option<String>,
pub store: String,
}
pub struct Bm25Index {
doc_terms: HashMap<String, HashMap<String, usize>>,
doc_lengths: HashMap<String, usize>,
idf: HashMap<String, f32>,
avg_doc_length: f32,
num_docs: usize,
doc_content: HashMap<String, String>,
doc_metadata: HashMap<String, HashMap<String, serde_json::Value>>,
k1: f32,
b: f32,
}
impl Bm25Index {
pub fn new(k1: f32, b: f32) -> Self {
Self {
doc_terms: HashMap::new(),
doc_lengths: HashMap::new(),
idf: HashMap::new(),
avg_doc_length: 0.0,
num_docs: 0,
doc_content: HashMap::new(),
doc_metadata: HashMap::new(),
k1,
b,
}
}
fn tokenize(text: &str) -> Vec<String> {
text.to_lowercase()
.split(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty() && s.len() > 1)
.map(|s| s.to_string())
.collect()
}
pub fn add_document(&mut self, doc: &IndexedDocument) {
let terms = Self::tokenize(&doc.content);
let doc_len = terms.len();
let mut term_counts: HashMap<String, usize> = HashMap::new();
for term in &terms {
*term_counts.entry(term.clone()).or_insert(0) += 1;
}
for term in term_counts.keys() {
let df = self.idf.entry(term.clone()).or_insert(0.0);
*df += 1.0;
}
self.doc_terms.insert(doc.id.clone(), term_counts);
self.doc_lengths.insert(doc.id.clone(), doc_len);
self.doc_content.insert(doc.id.clone(), doc.content.clone());
if let Some(ref meta) = doc.metadata {
self.doc_metadata.insert(doc.id.clone(), meta.clone());
}
self.num_docs += 1;
let total_length: usize = self.doc_lengths.values().sum();
self.avg_doc_length = total_length as f32 / self.num_docs as f32;
self.recalculate_idf();
}
fn recalculate_idf(&mut self) {
let n = self.num_docs as f32;
for (term, df) in self.idf.iter_mut() {
let doc_freq = self.doc_terms.values()
.filter(|terms| terms.contains_key(term))
.count() as f32;
*df = ((n - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
}
}
pub fn search(&self, query: &str, top_k: usize) -> Vec<(String, f32, Option<String>)> {
let query_terms = Self::tokenize(query);
let mut scores: HashMap<String, f32> = HashMap::new();
for (doc_id, term_freqs) in &self.doc_terms {
let doc_len = *self.doc_lengths.get(doc_id).unwrap_or(&1) as f32;
let mut score = 0.0;
for term in &query_terms {
if let Some(&tf) = term_freqs.get(term) {
let idf = *self.idf.get(term).unwrap_or(&0.0);
let tf_norm = (tf as f32 * (self.k1 + 1.0))
/ (tf as f32 + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length));
score += idf * tf_norm;
}
}
if score > 0.0 {
scores.insert(doc_id.clone(), score);
}
}
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(top_k);
results.into_iter()
.map(|(id, score)| {
let content = self.doc_content.get(&id).cloned();
(id, score, content)
})
.collect()
}
pub fn remove_document(&mut self, doc_id: &str) {
self.doc_terms.remove(doc_id);
self.doc_lengths.remove(doc_id);
self.doc_content.remove(doc_id);
self.doc_metadata.remove(doc_id);
if self.num_docs > 0 {
self.num_docs -= 1;
if self.num_docs > 0 {
let total_length: usize = self.doc_lengths.values().sum();
self.avg_doc_length = total_length as f32 / self.num_docs as f32;
}
self.recalculate_idf();
}
}
pub fn len(&self) -> usize {
self.num_docs
}
pub fn is_empty(&self) -> bool {
self.num_docs == 0
}
}
struct CacheEntry {
vector: Vec<f32>,
timestamp: std::time::Instant,
}
struct EmbeddingCache {
entries: HashMap<String, CacheEntry>,
max_entries: usize,
ttl_seconds: u64,
}
impl EmbeddingCache {
fn new(max_entries: usize) -> Self {
Self {
entries: HashMap::new(),
max_entries,
ttl_seconds: 3600, }
}
fn get(&self, key: &str) -> Option<Vec<f32>> {
self.entries.get(key).and_then(|entry| {
if entry.timestamp.elapsed().as_secs() < self.ttl_seconds {
Some(entry.vector.clone())
} else {
None
}
})
}
fn set(&mut self, key: String, vector: Vec<f32>) {
if self.entries.len() >= self.max_entries {
let oldest_key = self.entries.iter()
.min_by_key(|(_, v)| v.timestamp)
.map(|(k, _)| k.clone());
if let Some(key) = oldest_key {
self.entries.remove(&key);
}
}
self.entries.insert(key, CacheEntry {
vector,
timestamp: std::time::Instant::now(),
});
}
fn clear(&mut self) {
self.entries.clear();
}
}
pub struct SemanticSearch {
config: SemanticSearchConfig,
embedding_cache: RwLock<Option<EmbeddingCache>>,
bm25_index: RwLock<Bm25Index>,
vector_index: Option<Arc<VectorIndexManager>>,
document_store: RwLock<HashMap<String, IndexedDocument>>,
}
impl SemanticSearch {
pub fn new(config: SemanticSearchConfig) -> Self {
let embedding_cache = if config.cache_embeddings {
Some(EmbeddingCache::new(10000))
} else {
None
};
let bm25_index = Bm25Index::new(config.bm25_k1, config.bm25_b);
Self {
config,
embedding_cache: RwLock::new(embedding_cache),
bm25_index: RwLock::new(bm25_index),
vector_index: None,
document_store: RwLock::new(HashMap::new()),
}
}
pub fn with_vector_index(mut self, index: Arc<VectorIndexManager>) -> Self {
self.vector_index = Some(index);
self
}
pub fn index_document(&self, doc: IndexedDocument) -> Result<(), SearchError> {
{
let mut bm25 = self.bm25_index.write();
bm25.add_document(&doc);
}
if let (Some(ref index), Some(ref vector)) = (&self.vector_index, &doc.vector) {
let row_id = hash_string_to_u64(&doc.id);
let store_name = format!("{}_{}", doc.store, "vectors");
if let Err(_) = index.insert_vector(&store_name, row_id, vector) {
}
}
{
let mut store = self.document_store.write();
store.insert(doc.id.clone(), doc);
}
Ok(())
}
pub fn remove_document(&self, doc_id: &str) -> Result<(), SearchError> {
{
let mut bm25 = self.bm25_index.write();
bm25.remove_document(doc_id);
}
if let Some(ref index) = self.vector_index {
let row_id = hash_string_to_u64(doc_id);
let _ = index.delete_vector(&self.config.default_store, row_id);
}
{
let mut store = self.document_store.write();
store.remove(doc_id);
}
Ok(())
}
pub async fn search(&self, request: SemanticSearchRequest) -> Result<SemanticSearchResponse, SearchError> {
let start = std::time::Instant::now();
let queries = if request.expand_query {
self.expand_query(&request.query).await?
} else {
vec![request.query.clone()]
};
let embed_start = std::time::Instant::now();
let query_vectors = self.embed_queries(&queries).await?;
let embedding_time = embed_start.elapsed().as_millis() as u64;
let mut results = match request.mode {
SearchMode::Semantic => {
self.vector_search(&query_vectors, &request).await?
}
SearchMode::Keyword => {
self.keyword_search(&queries, &request).await?
}
SearchMode::Hybrid => {
let alpha = request.alpha.unwrap_or(0.5);
self.hybrid_search(&queries, &query_vectors, alpha, &request).await?
}
SearchMode::MultiModal => {
self.multimodal_search(&request.query, &request).await?
}
};
if let Some(ref filters) = request.filters {
results = self.apply_filters(results, filters);
}
if let Some(min_score) = request.min_score {
results.retain(|r| r.score >= min_score);
}
let rerank_time = if request.rerank {
let rerank_start = std::time::Instant::now();
results = self.rerank_results(&queries, results).await?;
Some(rerank_start.elapsed().as_millis() as u64)
} else {
None
};
if let Some(ref group_by) = request.group_by {
results = self.group_results(results, group_by);
}
if let Some(ref distinct_by) = request.distinct_by {
results = self.distinct_results(results, distinct_by);
}
if request.highlight {
results = self.add_highlights(results, &queries);
}
results.truncate(request.top_k);
let total = results.len();
Ok(SemanticSearchResponse {
results,
total,
query_time_ms: start.elapsed().as_millis() as u64,
embedding_time_ms: Some(embedding_time),
rerank_time_ms: rerank_time,
expanded_queries: if request.expand_query {
Some(queries.iter().filter_map(|q| {
if let SearchQuery::Text(t) = q {
Some(t.clone())
} else {
None
}
}).collect())
} else {
None
},
facets: None,
})
}
async fn expand_query(&self, query: &SearchQuery) -> Result<Vec<SearchQuery>, SearchError> {
match query {
SearchQuery::Text(text) => {
let mut expanded = vec![SearchQuery::Text(text.clone())];
if text.contains(" or ") {
let parts: Vec<&str> = text.split(" or ").collect();
for part in parts {
expanded.push(SearchQuery::Text(part.trim().to_string()));
}
}
let words: Vec<&str> = text.split_whitespace().collect();
if words.len() > 1 {
for word in &words {
if word.len() > 4 { expanded.push(SearchQuery::Text(word.to_string()));
}
}
}
Ok(expanded)
}
SearchQuery::MultiQuery(texts) => {
Ok(texts.iter().map(|t| SearchQuery::Text(t.clone())).collect())
}
_ => Ok(vec![query.clone()]),
}
}
async fn embed_queries(&self, queries: &[SearchQuery]) -> Result<Vec<Vec<f32>>, SearchError> {
let mut vectors = Vec::new();
for query in queries {
match query {
SearchQuery::Text(text) => {
{
let cache = self.embedding_cache.read();
if let Some(ref cache) = *cache {
if let Some(vec) = cache.get(text) {
vectors.push(vec);
continue;
}
}
}
let embedding = self.generate_embedding(text);
{
let mut cache = self.embedding_cache.write();
if let Some(ref mut cache) = *cache {
cache.set(text.clone(), embedding.clone());
}
}
vectors.push(embedding);
}
SearchQuery::Vector(vec) => {
vectors.push(vec.clone());
}
SearchQuery::MultiQuery(texts) => {
let mut avg = vec![0.0f32; self.config.dimensions];
for text in texts {
let emb = self.generate_embedding(text);
for (i, v) in emb.iter().enumerate() {
if let Some(slot) = avg.get_mut(i) {
*slot += v / texts.len() as f32;
}
}
}
vectors.push(avg);
}
SearchQuery::Image { image, alt_text } => {
let text = alt_text.as_ref().map(|s| s.as_str()).unwrap_or("image");
let embedding = self.generate_embedding(text);
vectors.push(embedding);
}
}
}
Ok(vectors)
}
fn generate_embedding(&self, text: &str) -> Vec<f32> {
let mut embedding = vec![0.0f32; self.config.dimensions];
let tokens = Bm25Index::tokenize(text);
for (i, token) in tokens.iter().enumerate() {
let hash = hash_string_to_u64(token);
let idx = (hash as usize) % self.config.dimensions;
if let Some(slot) = embedding.get_mut(idx) {
*slot += 1.0 / (i + 1) as f32;
}
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut embedding {
*v /= norm;
}
}
embedding
}
async fn vector_search(
&self,
vectors: &[Vec<f32>],
request: &SemanticSearchRequest,
) -> Result<Vec<SearchResult>, SearchError> {
let mut all_results = Vec::new();
if let Some(ref index) = self.vector_index {
let stores = request.stores.clone()
.unwrap_or_else(|| vec![self.config.default_store.clone()]);
for store in stores {
for vector in vectors {
let store_name = format!("{}_vectors", store);
if let Ok(results) = index.search(&store_name, vector, request.top_k * 2) {
for (row_id, distance) in results {
let score = match self.config.metric {
DistanceMetric::Cosine => 1.0 - distance,
DistanceMetric::DotProduct => distance,
DistanceMetric::Euclidean => 1.0 / (1.0 + distance),
DistanceMetric::Manhattan => 1.0 / (1.0 + distance),
};
let doc_id = format!("doc_{}", row_id);
let doc_store = self.document_store.read();
let content = doc_store.get(&doc_id).map(|d| d.content.clone());
let metadata = doc_store.get(&doc_id).and_then(|d| d.metadata.clone());
all_results.push(SearchResult {
id: doc_id,
score,
content,
vector: if request.include_vectors { Some(vector.clone()) } else { None },
metadata,
highlights: None,
store: store.clone(),
namespace: request.namespace.clone(),
rerank_score: None,
});
}
}
}
}
}
if let (true, Some(first_vec)) = (all_results.is_empty(), vectors.first()) {
let doc_store = self.document_store.read();
for (doc_id, doc) in doc_store.iter() {
if let Some(ref doc_vec) = doc.vector {
let score = cosine_similarity(first_vec, doc_vec);
all_results.push(SearchResult {
id: doc_id.clone(),
score,
content: Some(doc.content.clone()),
vector: if request.include_vectors { Some(doc_vec.clone()) } else { None },
metadata: doc.metadata.clone(),
highlights: None,
store: doc.store.clone(),
namespace: doc.namespace.clone(),
rerank_score: None,
});
}
}
}
all_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
all_results.dedup_by(|a, b| a.id == b.id);
all_results.truncate(request.top_k);
Ok(all_results)
}
async fn keyword_search(
&self,
queries: &[SearchQuery],
request: &SemanticSearchRequest,
) -> Result<Vec<SearchResult>, SearchError> {
let bm25 = self.bm25_index.read();
let mut all_results = Vec::new();
for query in queries {
if let SearchQuery::Text(text) = query {
let results = bm25.search(text, request.top_k * 2);
for (id, score, content) in results {
let doc_store = self.document_store.read();
let metadata = doc_store.get(&id).and_then(|d| d.metadata.clone());
let store = doc_store.get(&id).map(|d| d.store.clone()).unwrap_or_else(|| self.config.default_store.clone());
let namespace = doc_store.get(&id).and_then(|d| d.namespace.clone());
all_results.push(SearchResult {
id,
score,
content,
vector: None,
metadata,
highlights: None,
store,
namespace,
rerank_score: None,
});
}
}
}
all_results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
all_results.dedup_by(|a, b| a.id == b.id);
all_results.truncate(request.top_k);
Ok(all_results)
}
async fn hybrid_search(
&self,
queries: &[SearchQuery],
vectors: &[Vec<f32>],
alpha: f32,
request: &SemanticSearchRequest,
) -> Result<Vec<SearchResult>, SearchError> {
let vector_results = self.vector_search(vectors, request).await?;
let keyword_results = self.keyword_search(queries, request).await?;
let merged = self.reciprocal_rank_fusion(
&[vector_results, keyword_results],
&[alpha, 1.0 - alpha],
);
Ok(merged)
}
async fn multimodal_search(
&self,
query: &SearchQuery,
request: &SemanticSearchRequest,
) -> Result<Vec<SearchResult>, SearchError> {
let vectors = self.embed_queries(std::slice::from_ref(query)).await?;
self.vector_search(&vectors, request).await
}
fn apply_filters(&self, results: Vec<SearchResult>, filters: &[MetadataFilter]) -> Vec<SearchResult> {
results.into_iter()
.filter(|result| {
if let Some(ref metadata) = result.metadata {
filters.iter().all(|filter| {
self.evaluate_filter(metadata, filter)
})
} else {
filters.is_empty()
}
})
.collect()
}
fn evaluate_filter(&self, metadata: &HashMap<String, serde_json::Value>, filter: &MetadataFilter) -> bool {
let value = match metadata.get(&filter.field) {
Some(v) => v,
None => return matches!(filter.operator, FilterOperator::IsNull | FilterOperator::Exists),
};
match filter.operator {
FilterOperator::Eq => value == &filter.value,
FilterOperator::Ne => value != &filter.value,
FilterOperator::Gt => compare_json_values(value, &filter.value) == Some(std::cmp::Ordering::Greater),
FilterOperator::Gte => matches!(compare_json_values(value, &filter.value), Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Equal)),
FilterOperator::Lt => compare_json_values(value, &filter.value) == Some(std::cmp::Ordering::Less),
FilterOperator::Lte => matches!(compare_json_values(value, &filter.value), Some(std::cmp::Ordering::Less | std::cmp::Ordering::Equal)),
FilterOperator::In => {
if let serde_json::Value::Array(arr) = &filter.value {
arr.contains(value)
} else {
false
}
}
FilterOperator::NotIn => {
if let serde_json::Value::Array(arr) = &filter.value {
!arr.contains(value)
} else {
true
}
}
FilterOperator::Contains => {
if let (serde_json::Value::String(s), serde_json::Value::String(pattern)) = (value, &filter.value) {
s.contains(pattern.as_str())
} else {
false
}
}
FilterOperator::StartsWith => {
if let (serde_json::Value::String(s), serde_json::Value::String(pattern)) = (value, &filter.value) {
s.starts_with(pattern.as_str())
} else {
false
}
}
FilterOperator::EndsWith => {
if let (serde_json::Value::String(s), serde_json::Value::String(pattern)) = (value, &filter.value) {
s.ends_with(pattern.as_str())
} else {
false
}
}
FilterOperator::Exists => true,
FilterOperator::IsNull => value.is_null(),
FilterOperator::IsNotNull => !value.is_null(),
FilterOperator::Between => {
if let serde_json::Value::Array(arr) = &filter.value {
if let (Some(low), Some(high)) = (arr.first(), arr.get(1)) {
let gte = matches!(compare_json_values(value, low), Some(std::cmp::Ordering::Greater | std::cmp::Ordering::Equal));
let lte = matches!(compare_json_values(value, high), Some(std::cmp::Ordering::Less | std::cmp::Ordering::Equal));
gte && lte
} else {
false
}
} else {
false
}
}
FilterOperator::Regex => {
if let (serde_json::Value::String(s), serde_json::Value::String(pattern)) = (value, &filter.value) {
regex::Regex::new(pattern)
.map(|re| re.is_match(s))
.unwrap_or(false)
} else {
false
}
}
}
}
async fn rerank_results(
&self,
queries: &[SearchQuery],
mut results: Vec<SearchResult>,
) -> Result<Vec<SearchResult>, SearchError> {
let query_terms: Vec<String> = queries.iter()
.filter_map(|q| {
if let SearchQuery::Text(t) = q {
Some(Bm25Index::tokenize(t))
} else {
None
}
})
.flatten()
.collect();
for result in &mut results {
let mut boost = 0.0;
if let Some(ref content) = result.content {
let content_lower = content.to_lowercase();
for term in &query_terms {
if content_lower.contains(term) {
boost += 0.1;
}
if let Some(SearchQuery::Text(query)) = queries.first() {
if content_lower.contains(&query.to_lowercase()) {
boost += 0.3;
}
}
}
}
result.rerank_score = Some(result.score * (1.0 + boost));
}
results.sort_by(|a, b| {
b.rerank_score.unwrap_or(b.score)
.partial_cmp(&a.rerank_score.unwrap_or(a.score))
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
fn reciprocal_rank_fusion(
&self,
result_sets: &[Vec<SearchResult>],
weights: &[f32],
) -> Vec<SearchResult> {
let k = 60.0; let mut scores: HashMap<String, (f32, SearchResult)> = HashMap::new();
for (results, weight) in result_sets.iter().zip(weights.iter()) {
for (rank, result) in results.iter().enumerate() {
let rrf_score = weight / (k + rank as f32 + 1.0);
scores.entry(result.id.clone())
.and_modify(|(score, _)| *score += rrf_score)
.or_insert((rrf_score, result.clone()));
}
}
let mut merged: Vec<SearchResult> = scores.into_values()
.map(|(score, mut result)| {
result.score = score;
result
})
.collect();
merged.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
merged
}
fn group_results(&self, results: Vec<SearchResult>, field: &str) -> Vec<SearchResult> {
let mut groups: HashMap<String, Vec<SearchResult>> = HashMap::new();
for result in results {
let key = result.metadata.as_ref()
.and_then(|m| m.get(field))
.map(|v| v.to_string())
.unwrap_or_else(|| "_none_".to_string());
groups.entry(key).or_default().push(result);
}
groups.into_values()
.filter_map(|mut group| {
group.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
group.into_iter().next()
})
.collect()
}
fn distinct_results(&self, results: Vec<SearchResult>, field: &str) -> Vec<SearchResult> {
let mut seen: HashMap<String, bool> = HashMap::new();
results.into_iter()
.filter(|result| {
let key = result.metadata.as_ref()
.and_then(|m| m.get(field))
.map(|v| v.to_string())
.unwrap_or_else(|| result.id.clone());
if seen.contains_key(&key) {
false
} else {
seen.insert(key, true);
true
}
})
.collect()
}
fn add_highlights(&self, mut results: Vec<SearchResult>, queries: &[SearchQuery]) -> Vec<SearchResult> {
let query_terms: Vec<String> = queries.iter()
.filter_map(|q| {
if let SearchQuery::Text(t) = q {
Some(Bm25Index::tokenize(t))
} else {
None
}
})
.flatten()
.collect();
for result in &mut results {
if let Some(ref content) = result.content {
let highlights = generate_highlights(content, &query_terms);
if !highlights.is_empty() {
result.highlights = Some(highlights);
}
}
}
results
}
pub fn clear_cache(&self) {
let mut cache = self.embedding_cache.write();
if let Some(ref mut cache) = *cache {
cache.clear();
}
}
pub fn stats(&self) -> SearchStats {
let cache_size = {
let cache = self.embedding_cache.read();
cache.as_ref().map(|c| c.entries.len()).unwrap_or(0)
};
let bm25_docs = self.bm25_index.read().len();
let doc_store_size = self.document_store.read().len();
SearchStats {
cached_embeddings: cache_size,
indexed_documents: bm25_docs,
document_store_size: doc_store_size,
}
}
}
#[derive(Debug, Clone)]
pub struct SearchStats {
pub cached_embeddings: usize,
pub indexed_documents: usize,
pub document_store_size: usize,
}
#[derive(Debug, thiserror::Error)]
pub enum SearchError {
#[error("Embedding error: {0}")]
Embedding(String),
#[error("Index error: {0}")]
Index(String),
#[error("Invalid query: {0}")]
InvalidQuery(String),
#[error("Store not found: {0}")]
StoreNotFound(String),
#[error("Filter error: {0}")]
Filter(String),
}
impl Default for SemanticSearchConfig {
fn default() -> Self {
Self {
default_store: "default".to_string(),
embedding_model: None,
dimensions: 1536,
metric: DistanceMetric::Cosine,
hybrid_enabled: true,
reranker_model: None,
query_expansion: false,
cache_embeddings: true,
bm25_k1: 1.2,
bm25_b: 0.75,
}
}
}
fn hash_string_to_u64(s: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
s.hash(&mut hasher);
hasher.finish()
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
fn compare_json_values(a: &serde_json::Value, b: &serde_json::Value) -> Option<std::cmp::Ordering> {
match (a, b) {
(serde_json::Value::Number(n1), serde_json::Value::Number(n2)) => {
n1.as_f64().partial_cmp(&n2.as_f64())
}
(serde_json::Value::String(s1), serde_json::Value::String(s2)) => {
Some(s1.cmp(s2))
}
_ => None,
}
}
fn generate_highlights(content: &str, query_terms: &[String]) -> Vec<Highlight> {
let mut highlights = Vec::new();
let content_lower = content.to_lowercase();
for term in query_terms {
let mut positions = Vec::new();
let mut start = 0;
while let Some(pos) = content_lower[start..].find(term) {
let abs_pos = start + pos;
positions.push((abs_pos, abs_pos + term.len()));
start = abs_pos + term.len();
}
if !positions.is_empty() {
let mut highlighted = String::new();
let mut last_end = 0;
for (pos_start, pos_end) in &positions {
let context_start = pos_start.saturating_sub(30);
if context_start > last_end {
highlighted.push_str("...");
}
let actual_start = context_start.max(last_end);
highlighted.push_str(&content[actual_start..*pos_start]);
highlighted.push_str("<mark>");
highlighted.push_str(&content[*pos_start..*pos_end]);
highlighted.push_str("</mark>");
let context_end = (*pos_end + 30).min(content.len());
highlighted.push_str(&content[*pos_end..context_end]);
last_end = context_end;
}
if last_end < content.len() {
highlighted.push_str("...");
}
highlights.push(Highlight {
field: "content".to_string(),
text: highlighted,
positions,
});
}
}
highlights
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bm25_index() {
let mut index = Bm25Index::new(1.2, 0.75);
index.add_document(&IndexedDocument {
id: "doc1".to_string(),
content: "The quick brown fox jumps over the lazy dog".to_string(),
vector: None,
metadata: None,
namespace: None,
store: "default".to_string(),
});
index.add_document(&IndexedDocument {
id: "doc2".to_string(),
content: "A quick brown dog runs in the park".to_string(),
vector: None,
metadata: None,
namespace: None,
store: "default".to_string(),
});
let results = index.search("quick brown", 10);
assert_eq!(results.len(), 2);
assert!(results[0].1 > 0.0);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
let c = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &c).abs() < 0.001);
}
#[test]
fn test_highlights() {
let content = "The quick brown fox jumps over the lazy dog";
let terms = vec!["quick".to_string(), "fox".to_string()];
let highlights = generate_highlights(content, &terms);
assert!(!highlights.is_empty());
assert!(highlights[0].text.contains("<mark>"));
}
#[tokio::test]
async fn test_semantic_search() {
let config = SemanticSearchConfig::default();
let search = SemanticSearch::new(config);
search.index_document(IndexedDocument {
id: "doc1".to_string(),
content: "Machine learning is a subset of artificial intelligence".to_string(),
vector: None,
metadata: Some(HashMap::from([
("category".to_string(), serde_json::json!("tech")),
])),
namespace: None,
store: "default".to_string(),
}).unwrap();
search.index_document(IndexedDocument {
id: "doc2".to_string(),
content: "Deep learning uses neural networks for pattern recognition".to_string(),
vector: None,
metadata: Some(HashMap::from([
("category".to_string(), serde_json::json!("tech")),
])),
namespace: None,
store: "default".to_string(),
}).unwrap();
let request = SemanticSearchRequest {
query: SearchQuery::Text("machine learning AI".to_string()),
stores: None,
top_k: 10,
min_score: None,
filters: None,
mode: SearchMode::Keyword,
alpha: None,
include_vectors: false,
include_metadata: true,
highlight: true,
namespace: None,
rerank: false,
expand_query: false,
group_by: None,
distinct_by: None,
};
let response = search.search(request).await.unwrap();
assert!(!response.results.is_empty());
}
#[test]
fn test_filter_evaluation() {
let config = SemanticSearchConfig::default();
let search = SemanticSearch::new(config);
let metadata: HashMap<String, serde_json::Value> = HashMap::from([
("count".to_string(), serde_json::json!(10)),
("name".to_string(), serde_json::json!("test")),
]);
let filter = MetadataFilter {
field: "count".to_string(),
operator: FilterOperator::Eq,
value: serde_json::json!(10),
};
assert!(search.evaluate_filter(&metadata, &filter));
let filter = MetadataFilter {
field: "count".to_string(),
operator: FilterOperator::Gt,
value: serde_json::json!(5),
};
assert!(search.evaluate_filter(&metadata, &filter));
let filter = MetadataFilter {
field: "name".to_string(),
operator: FilterOperator::Contains,
value: serde_json::json!("es"),
};
assert!(search.evaluate_filter(&metadata, &filter));
}
}