use std::collections::{HashMap, HashSet};
use std::cmp::Ordering;
use std::sync::Arc;
use crate::context_query::VectorIndex;
use crate::soch_ql::SochValue;
#[derive(Debug, Clone)]
pub struct HybridQuery {
pub collection: String,
pub vector: Option<VectorQueryComponent>,
pub lexical: Option<LexicalQueryComponent>,
pub filters: Vec<MetadataFilter>,
pub fusion: FusionConfig,
pub rerank: Option<RerankConfig>,
pub limit: usize,
pub min_score: Option<f32>,
}
impl HybridQuery {
pub fn new(collection: &str) -> Self {
Self {
collection: collection.to_string(),
vector: None,
lexical: None,
filters: Vec::new(),
fusion: FusionConfig::default(),
rerank: None,
limit: 10,
min_score: None,
}
}
pub fn with_vector(mut self, embedding: Vec<f32>, weight: f32) -> Self {
self.vector = Some(VectorQueryComponent {
embedding,
weight,
ef_search: 100,
});
self
}
pub fn with_vector_text(mut self, text: String, weight: f32) -> Self {
self.vector = Some(VectorQueryComponent {
embedding: Vec::new(), weight,
ef_search: 100,
});
self.lexical = self.lexical.or(Some(LexicalQueryComponent {
query: text,
weight: 0.0, fields: vec!["content".to_string()],
}));
self
}
pub fn with_lexical(mut self, query: &str, weight: f32) -> Self {
self.lexical = Some(LexicalQueryComponent {
query: query.to_string(),
weight,
fields: vec!["content".to_string()],
});
self
}
pub fn with_lexical_fields(mut self, query: &str, weight: f32, fields: Vec<String>) -> Self {
self.lexical = Some(LexicalQueryComponent {
query: query.to_string(),
weight,
fields,
});
self
}
pub fn filter(mut self, field: &str, op: FilterOp, value: SochValue) -> Self {
self.filters.push(MetadataFilter {
field: field.to_string(),
op,
value,
});
self
}
pub fn filter_eq(self, field: &str, value: impl Into<SochValue>) -> Self {
self.filter(field, FilterOp::Eq, value.into())
}
pub fn filter_range(mut self, field: &str, min: Option<SochValue>, max: Option<SochValue>) -> Self {
if let Some(min_val) = min {
self.filters.push(MetadataFilter {
field: field.to_string(),
op: FilterOp::Gte,
value: min_val,
});
}
if let Some(max_val) = max {
self.filters.push(MetadataFilter {
field: field.to_string(),
op: FilterOp::Lte,
value: max_val,
});
}
self
}
pub fn with_fusion(mut self, method: FusionMethod) -> Self {
self.fusion.method = method;
self
}
pub fn with_rrf_k(mut self, k: f32) -> Self {
self.fusion.rrf_k = k;
self
}
pub fn with_rerank(mut self, model: &str, top_n: usize) -> Self {
self.rerank = Some(RerankConfig {
model: model.to_string(),
top_n,
batch_size: 32,
});
self
}
pub fn limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn min_score(mut self, score: f32) -> Self {
self.min_score = Some(score);
self
}
}
#[derive(Debug, Clone)]
pub struct VectorQueryComponent {
pub embedding: Vec<f32>,
pub weight: f32,
pub ef_search: usize,
}
#[derive(Debug, Clone)]
pub struct LexicalQueryComponent {
pub query: String,
pub weight: f32,
pub fields: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct MetadataFilter {
pub field: String,
pub op: FilterOp,
pub value: SochValue,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FilterOp {
Eq,
Ne,
Gt,
Gte,
Lt,
Lte,
Contains,
In,
}
#[derive(Debug, Clone)]
pub struct FusionConfig {
pub method: FusionMethod,
pub rrf_k: f32,
pub normalize: bool,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
method: FusionMethod::Rrf,
rrf_k: 60.0,
normalize: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FusionMethod {
Rrf,
WeightedSum,
Max,
Rsf,
}
#[derive(Debug, Clone)]
pub struct RerankConfig {
pub model: String,
pub top_n: usize,
pub batch_size: usize,
}
#[derive(Debug, Clone)]
pub struct HybridExecutionPlan {
pub query: HybridQuery,
pub steps: Vec<ExecutionStep>,
pub estimated_cost: f64,
}
#[derive(Debug, Clone)]
pub enum ExecutionStep {
VectorSearch {
collection: String,
ef_search: usize,
weight: f32,
},
LexicalSearch {
collection: String,
query: String,
fields: Vec<String>,
weight: f32,
},
PreFilter {
filters: Vec<MetadataFilter>,
},
Fusion {
method: FusionMethod,
rrf_k: f32,
},
Rerank {
model: String,
top_n: usize,
},
Limit {
count: usize,
min_score: Option<f32>,
},
Redact {
fields: Vec<String>,
method: RedactionMethod,
},
}
#[derive(Debug, Clone)]
pub enum RedactionMethod {
Replace(String),
Mask,
Remove,
Hash,
}
pub struct HybridQueryExecutor<V: VectorIndex> {
vector_index: Arc<V>,
lexical_index: Arc<LexicalIndex>,
}
impl<V: VectorIndex> HybridQueryExecutor<V> {
pub fn new(vector_index: Arc<V>, lexical_index: Arc<LexicalIndex>) -> Self {
Self {
vector_index,
lexical_index,
}
}
pub fn execute(&self, query: &HybridQuery) -> Result<HybridQueryResult, HybridQueryError> {
let mut candidates: HashMap<String, CandidateDoc> = HashMap::new();
let overfetch = (query.limit * 3).max(100);
if let Some(vector) = &query.vector {
if !vector.embedding.is_empty() {
let results = self.vector_index
.search_by_embedding(&query.collection, &vector.embedding, overfetch, None)
.map_err(HybridQueryError::VectorSearchError)?;
for (rank, result) in results.iter().enumerate() {
let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
CandidateDoc {
id: result.id.clone(),
content: result.content.clone(),
metadata: result.metadata.clone(),
vector_rank: None,
vector_score: None,
lexical_rank: None,
lexical_score: None,
fused_score: 0.0,
}
});
entry.vector_rank = Some(rank);
entry.vector_score = Some(result.score);
}
}
}
if let Some(lexical) = &query.lexical {
if lexical.weight > 0.0 {
let results = self.lexical_index.search(
&query.collection,
&lexical.query,
&lexical.fields,
overfetch,
)?;
for (rank, result) in results.iter().enumerate() {
let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
CandidateDoc {
id: result.id.clone(),
content: result.content.clone(),
metadata: HashMap::new(),
vector_rank: None,
vector_score: None,
lexical_rank: None,
lexical_score: None,
fused_score: 0.0,
}
});
entry.lexical_rank = Some(rank);
entry.lexical_score = Some(result.score);
}
}
}
let filtered: Vec<CandidateDoc> = candidates
.into_values()
.filter(|doc| self.matches_filters(doc, &query.filters))
.collect();
let mut fused = self.fuse_scores(filtered, query)?;
fused.sort_by(|a, b| b.fused_score.partial_cmp(&a.fused_score).unwrap_or(Ordering::Equal));
if let Some(rerank) = &query.rerank {
fused = self.rerank(&fused, &query.lexical.as_ref().map(|l| l.query.clone()).unwrap_or_default(), rerank)?;
}
if let Some(min) = query.min_score {
fused.retain(|doc| doc.fused_score >= min);
}
fused.truncate(query.limit);
let results: Vec<HybridSearchResult> = fused
.into_iter()
.map(|doc| HybridSearchResult {
id: doc.id,
score: doc.fused_score,
content: doc.content,
metadata: doc.metadata,
vector_score: doc.vector_score,
lexical_score: doc.lexical_score,
})
.collect();
Ok(HybridQueryResult {
results,
query: query.clone(),
stats: HybridQueryStats {
vector_candidates: 0, lexical_candidates: 0,
filtered_candidates: 0,
fusion_time_us: 0,
rerank_time_us: 0,
},
})
}
fn matches_filters(&self, doc: &CandidateDoc, filters: &[MetadataFilter]) -> bool {
for filter in filters {
if let Some(value) = doc.metadata.get(&filter.field) {
if !self.match_filter(value, &filter.op, &filter.value) {
return false;
}
} else {
return false;
}
}
true
}
fn match_filter(&self, doc_value: &SochValue, op: &FilterOp, filter_value: &SochValue) -> bool {
match op {
FilterOp::Eq => doc_value == filter_value,
FilterOp::Ne => doc_value != filter_value,
FilterOp::Gt => self.compare_values(doc_value, filter_value) == Some(Ordering::Greater),
FilterOp::Gte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Greater | Ordering::Equal)),
FilterOp::Lt => self.compare_values(doc_value, filter_value) == Some(Ordering::Less),
FilterOp::Lte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Less | Ordering::Equal)),
FilterOp::Contains => self.value_contains(doc_value, filter_value),
FilterOp::In => self.value_in_set(doc_value, filter_value),
}
}
fn compare_values(&self, a: &SochValue, b: &SochValue) -> Option<Ordering> {
match (a, b) {
(SochValue::Int(a), SochValue::Int(b)) => Some(a.cmp(b)),
(SochValue::UInt(a), SochValue::UInt(b)) => Some(a.cmp(b)),
(SochValue::Float(a), SochValue::Float(b)) => a.partial_cmp(b),
(SochValue::Text(a), SochValue::Text(b)) => Some(a.cmp(b)),
_ => None,
}
}
fn value_contains(&self, doc_value: &SochValue, search_value: &SochValue) -> bool {
match (doc_value, search_value) {
(SochValue::Text(text), SochValue::Text(search)) => text.contains(search.as_str()),
(SochValue::Array(arr), _) => arr.contains(search_value),
_ => false,
}
}
fn value_in_set(&self, doc_value: &SochValue, set_value: &SochValue) -> bool {
if let SochValue::Array(arr) = set_value {
arr.contains(doc_value)
} else {
false
}
}
fn fuse_scores(
&self,
candidates: Vec<CandidateDoc>,
query: &HybridQuery,
) -> Result<Vec<CandidateDoc>, HybridQueryError> {
let vector_weight = query.vector.as_ref().map(|v| v.weight).unwrap_or(0.0);
let lexical_weight = query.lexical.as_ref().map(|l| l.weight).unwrap_or(0.0);
let mut fused = candidates;
match query.fusion.method {
FusionMethod::Rrf => {
for doc in &mut fused {
let mut score = 0.0;
if let Some(rank) = doc.vector_rank {
score += vector_weight / (query.fusion.rrf_k + rank as f32);
}
if let Some(rank) = doc.lexical_rank {
score += lexical_weight / (query.fusion.rrf_k + rank as f32);
}
doc.fused_score = score;
}
}
FusionMethod::WeightedSum => {
for doc in &mut fused {
let mut score = 0.0;
if let Some(s) = doc.vector_score {
score += vector_weight * s;
}
if let Some(s) = doc.lexical_score {
score += lexical_weight * s;
}
doc.fused_score = score;
}
}
FusionMethod::Max => {
for doc in &mut fused {
let v_score = doc.vector_score.map(|s| vector_weight * s).unwrap_or(0.0);
let l_score = doc.lexical_score.map(|s| lexical_weight * s).unwrap_or(0.0);
doc.fused_score = v_score.max(l_score);
}
}
FusionMethod::Rsf => {
for doc in &mut fused {
let mut score = 0.0;
let mut count = 0;
if let Some(s) = doc.vector_score {
score += s;
count += 1;
}
if let Some(s) = doc.lexical_score {
score += s;
count += 1;
}
doc.fused_score = if count > 0 { score / count as f32 } else { 0.0 };
}
}
}
Ok(fused)
}
fn rerank(
&self,
candidates: &[CandidateDoc],
query: &str,
config: &RerankConfig,
) -> Result<Vec<CandidateDoc>, HybridQueryError> {
let to_rerank: Vec<_> = candidates.iter().take(config.top_n).cloned().collect();
let mut reranked = to_rerank;
let query_terms: HashSet<&str> = query.split_whitespace().collect();
for doc in &mut reranked {
let content_terms: HashSet<&str> = doc.content.split_whitespace().collect();
let overlap = query_terms.intersection(&content_terms).count();
doc.fused_score += (overlap as f32) * 0.01;
}
reranked.extend(candidates.iter().skip(config.top_n).cloned());
Ok(reranked)
}
}
#[derive(Debug, Clone)]
struct CandidateDoc {
id: String,
content: String,
metadata: HashMap<String, SochValue>,
vector_rank: Option<usize>,
vector_score: Option<f32>,
lexical_rank: Option<usize>,
lexical_score: Option<f32>,
fused_score: f32,
}
pub struct LexicalIndex {
collections: std::sync::RwLock<HashMap<String, InvertedIndex>>,
}
struct InvertedIndex {
postings: HashMap<String, Vec<(String, u32)>>,
doc_lengths: HashMap<String, u32>,
documents: HashMap<String, String>,
avg_doc_len: f32,
k1: f32,
b: f32,
}
#[derive(Debug, Clone)]
pub struct LexicalSearchResult {
pub id: String,
pub score: f32,
pub content: String,
}
impl LexicalIndex {
pub fn new() -> Self {
Self {
collections: std::sync::RwLock::new(HashMap::new()),
}
}
pub fn create_collection(&self, name: &str) {
let mut collections = self.collections.write().unwrap();
collections.insert(name.to_string(), InvertedIndex {
postings: HashMap::new(),
doc_lengths: HashMap::new(),
documents: HashMap::new(),
avg_doc_len: 0.0,
k1: 1.2,
b: 0.75,
});
}
pub fn index_document(&self, collection: &str, id: &str, content: &str) -> Result<(), HybridQueryError> {
let mut collections = self.collections.write().unwrap();
let index = collections.get_mut(collection)
.ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
let tokens: Vec<String> = content
.split_whitespace()
.map(|t| t.to_lowercase())
.collect();
let doc_len = tokens.len() as u32;
index.doc_lengths.insert(id.to_string(), doc_len);
index.documents.insert(id.to_string(), content.to_string());
let total_len: u32 = index.doc_lengths.values().sum();
index.avg_doc_len = total_len as f32 / index.doc_lengths.len() as f32;
let mut term_freqs: HashMap<String, u32> = HashMap::new();
for token in &tokens {
*term_freqs.entry(token.clone()).or_insert(0) += 1;
}
for (term, freq) in term_freqs {
index.postings
.entry(term)
.or_insert_with(Vec::new)
.push((id.to_string(), freq));
}
Ok(())
}
pub fn search(
&self,
collection: &str,
query: &str,
_fields: &[String],
limit: usize,
) -> Result<Vec<LexicalSearchResult>, HybridQueryError> {
let collections = self.collections.read().unwrap();
let index = collections.get(collection)
.ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
let query_terms: Vec<String> = query
.split_whitespace()
.map(|t| t.to_lowercase())
.collect();
let n = index.doc_lengths.len() as f32;
let mut scores: HashMap<String, f32> = HashMap::new();
for term in &query_terms {
if let Some(postings) = index.postings.get(term) {
let df = postings.len() as f32;
let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
for (doc_id, tf) in postings {
let doc_len = *index.doc_lengths.get(doc_id).unwrap_or(&1) as f32;
let tf = *tf as f32;
let score = idf * (tf * (index.k1 + 1.0)) /
(tf + index.k1 * (1.0 - index.b + index.b * doc_len / index.avg_doc_len));
*scores.entry(doc_id.clone()).or_insert(0.0) += score;
}
}
}
let mut results: Vec<_> = scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
let results: Vec<LexicalSearchResult> = results
.into_iter()
.take(limit)
.map(|(id, score)| {
let content = index.documents.get(&id).cloned().unwrap_or_default();
LexicalSearchResult { id, score, content }
})
.collect();
Ok(results)
}
}
impl Default for LexicalIndex {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct HybridSearchResult {
pub id: String,
pub score: f32,
pub content: String,
pub metadata: HashMap<String, SochValue>,
pub vector_score: Option<f32>,
pub lexical_score: Option<f32>,
}
#[derive(Debug, Clone)]
pub struct HybridQueryResult {
pub results: Vec<HybridSearchResult>,
pub query: HybridQuery,
pub stats: HybridQueryStats,
}
#[derive(Debug, Clone, Default)]
pub struct HybridQueryStats {
pub vector_candidates: usize,
pub lexical_candidates: usize,
pub filtered_candidates: usize,
pub fusion_time_us: u64,
pub rerank_time_us: u64,
}
#[derive(Debug, Clone)]
pub enum HybridQueryError {
CollectionNotFound(String),
VectorSearchError(String),
LexicalSearchError(String),
FilterError(String),
RerankError(String),
}
impl std::fmt::Display for HybridQueryError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::CollectionNotFound(name) => write!(f, "Collection not found: {}", name),
Self::VectorSearchError(msg) => write!(f, "Vector search error: {}", msg),
Self::LexicalSearchError(msg) => write!(f, "Lexical search error: {}", msg),
Self::FilterError(msg) => write!(f, "Filter error: {}", msg),
Self::RerankError(msg) => write!(f, "Rerank error: {}", msg),
}
}
}
impl std::error::Error for HybridQueryError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hybrid_query_builder() {
let query = HybridQuery::new("documents")
.with_vector(vec![0.1, 0.2, 0.3], 0.7)
.with_lexical("search query", 0.3)
.filter_eq("category", SochValue::Text("tech".to_string()))
.with_fusion(FusionMethod::Rrf)
.with_rerank("cross-encoder", 20)
.limit(10);
assert_eq!(query.collection, "documents");
assert!(query.vector.is_some());
assert!(query.lexical.is_some());
assert_eq!(query.filters.len(), 1);
assert_eq!(query.limit, 10);
}
#[test]
fn test_lexical_index_bm25() {
let index = LexicalIndex::new();
index.create_collection("test");
index.index_document("test", "doc1", "the quick brown fox").unwrap();
index.index_document("test", "doc2", "the lazy dog sleeps").unwrap();
index.index_document("test", "doc3", "quick fox jumps over the lazy dog").unwrap();
let results = index.search("test", "quick fox", &[], 10).unwrap();
assert!(!results.is_empty());
let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
assert!(ids.contains(&"doc1") || ids.contains(&"doc3"));
assert!(!ids.contains(&"doc2"));
}
#[test]
fn test_rrf_fusion() {
let k = 60.0;
let vector_weight = 0.7;
let lexical_weight = 0.3;
let score = vector_weight / (k + 0.0) + lexical_weight / (k + 5.0);
assert!(score > 0.01 && score < 0.02);
}
#[test]
fn test_filter_matching() {
let filters = vec![
MetadataFilter {
field: "status".to_string(),
op: FilterOp::Eq,
value: SochValue::Text("active".to_string()),
},
MetadataFilter {
field: "count".to_string(),
op: FilterOp::Gte,
value: SochValue::Int(10),
},
];
let mut metadata = HashMap::new();
metadata.insert("status".to_string(), SochValue::Text("active".to_string()));
metadata.insert("count".to_string(), SochValue::Int(15));
let doc = CandidateDoc {
id: "test".to_string(),
content: "test content".to_string(),
metadata,
vector_rank: None,
vector_score: None,
lexical_rank: None,
lexical_score: None,
fused_score: 0.0,
};
assert!(doc.metadata.get("status") == Some(&SochValue::Text("active".to_string())));
if let Some(SochValue::Int(count)) = doc.metadata.get("count") {
assert!(*count >= 10);
}
}
}