use std::collections::HashMap;
use std::sync::Arc;
use crate::candidate_gate::AllowedSet;
use crate::filter_ir::{AuthScope, FilterIR};
use crate::filtered_vector_search::ScoredResult;
use crate::namespace::NamespaceScope;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum FusionMethod {
Rrf { k: f32 },
Linear { vector_weight: f32, bm25_weight: f32 },
Max,
Cascade { primary: Modality },
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Modality {
Vector,
Bm25,
}
impl Default for FusionMethod {
fn default() -> Self {
Self::Rrf { k: 60.0 }
}
}
#[derive(Debug, Clone)]
pub struct FusionConfig {
pub method: FusionMethod,
pub candidates_per_modality: usize,
pub final_k: usize,
pub min_score: Option<f32>,
}
impl Default for FusionConfig {
fn default() -> Self {
Self {
method: FusionMethod::default(),
candidates_per_modality: 100,
final_k: 10,
min_score: None,
}
}
}
#[derive(Debug, Clone)]
pub struct UnifiedHybridQuery {
pub namespace: NamespaceScope,
pub vector_query: Option<VectorQuerySpec>,
pub bm25_query: Option<Bm25QuerySpec>,
pub filter: FilterIR,
pub fusion_config: FusionConfig,
}
#[derive(Debug, Clone)]
pub struct VectorQuerySpec {
pub embedding: Vec<f32>,
pub ef_search: usize,
}
#[derive(Debug, Clone)]
pub struct Bm25QuerySpec {
pub text: String,
pub fields: Vec<String>,
}
impl UnifiedHybridQuery {
pub fn new(namespace: NamespaceScope) -> Self {
Self {
namespace,
vector_query: None,
bm25_query: None,
filter: FilterIR::all(),
fusion_config: FusionConfig::default(),
}
}
pub fn with_vector(mut self, embedding: Vec<f32>) -> Self {
self.vector_query = Some(VectorQuerySpec {
embedding,
ef_search: 100,
});
self
}
pub fn with_bm25(mut self, text: impl Into<String>) -> Self {
self.bm25_query = Some(Bm25QuerySpec {
text: text.into(),
fields: vec!["content".to_string()],
});
self
}
pub fn with_filter(mut self, filter: FilterIR) -> Self {
self.filter = filter;
self
}
pub fn with_fusion(mut self, config: FusionConfig) -> Self {
self.fusion_config = config;
self
}
pub fn effective_filter(&self) -> FilterIR {
self.namespace.to_filter_ir().and(self.filter.clone())
}
}
#[derive(Debug)]
pub struct FilteredCandidates {
pub modality: Modality,
pub results: Vec<ScoredResult>,
pub filtered: bool,
}
impl FilteredCandidates {
pub fn from_vector(results: Vec<ScoredResult>) -> Self {
Self {
modality: Modality::Vector,
results,
filtered: true,
}
}
pub fn from_bm25(results: Vec<ScoredResult>) -> Self {
Self {
modality: Modality::Bm25,
results,
filtered: true,
}
}
}
pub struct FusionEngine {
config: FusionConfig,
}
impl FusionEngine {
pub fn new(config: FusionConfig) -> Self {
Self { config }
}
pub fn fuse(
&self,
vector_candidates: Option<FilteredCandidates>,
bm25_candidates: Option<FilteredCandidates>,
) -> FusionResult {
if let Some(ref vc) = vector_candidates {
debug_assert!(vc.filtered, "Vector candidates must be pre-filtered!");
}
if let Some(ref bc) = bm25_candidates {
debug_assert!(bc.filtered, "BM25 candidates must be pre-filtered!");
}
match self.config.method {
FusionMethod::Rrf { k } => self.fuse_rrf(vector_candidates, bm25_candidates, k),
FusionMethod::Linear { vector_weight, bm25_weight } => {
self.fuse_linear(vector_candidates, bm25_candidates, vector_weight, bm25_weight)
}
FusionMethod::Max => self.fuse_max(vector_candidates, bm25_candidates),
FusionMethod::Cascade { primary } => {
self.fuse_cascade(vector_candidates, bm25_candidates, primary)
}
}
}
fn fuse_rrf(
&self,
vector: Option<FilteredCandidates>,
bm25: Option<FilteredCandidates>,
k: f32,
) -> FusionResult {
let mut scores: HashMap<u64, f32> = HashMap::new();
if let Some(vc) = vector {
for (rank, result) in vc.results.iter().enumerate() {
let rrf_score = 1.0 / (k + rank as f32 + 1.0);
*scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
}
}
if let Some(bc) = bm25 {
for (rank, result) in bc.results.iter().enumerate() {
let rrf_score = 1.0 / (k + rank as f32 + 1.0);
*scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
}
}
self.collect_top_k(scores)
}
fn fuse_linear(
&self,
vector: Option<FilteredCandidates>,
bm25: Option<FilteredCandidates>,
vector_weight: f32,
bm25_weight: f32,
) -> FusionResult {
let mut scores: HashMap<u64, f32> = HashMap::new();
if let Some(vc) = vector {
let normalized = self.normalize_scores(&vc.results);
for (doc_id, score) in normalized {
*scores.entry(doc_id).or_insert(0.0) += score * vector_weight;
}
}
if let Some(bc) = bm25 {
let normalized = self.normalize_scores(&bc.results);
for (doc_id, score) in normalized {
*scores.entry(doc_id).or_insert(0.0) += score * bm25_weight;
}
}
self.collect_top_k(scores)
}
fn fuse_max(
&self,
vector: Option<FilteredCandidates>,
bm25: Option<FilteredCandidates>,
) -> FusionResult {
let mut scores: HashMap<u64, f32> = HashMap::new();
if let Some(vc) = vector {
let normalized = self.normalize_scores(&vc.results);
for (doc_id, score) in normalized {
let entry = scores.entry(doc_id).or_insert(0.0);
*entry = entry.max(score);
}
}
if let Some(bc) = bm25 {
let normalized = self.normalize_scores(&bc.results);
for (doc_id, score) in normalized {
let entry = scores.entry(doc_id).or_insert(0.0);
*entry = entry.max(score);
}
}
self.collect_top_k(scores)
}
fn fuse_cascade(
&self,
vector: Option<FilteredCandidates>,
bm25: Option<FilteredCandidates>,
primary: Modality,
) -> FusionResult {
let (primary_candidates, secondary_candidates) = match primary {
Modality::Vector => (vector, bm25),
Modality::Bm25 => (bm25, vector),
};
let primary_ids: std::collections::HashSet<u64> = primary_candidates
.as_ref()
.map(|c| c.results.iter().map(|r| r.doc_id).collect())
.unwrap_or_default();
let mut scores: HashMap<u64, f32> = HashMap::new();
if let Some(sc) = secondary_candidates {
for result in &sc.results {
if primary_ids.contains(&result.doc_id) {
scores.insert(result.doc_id, result.score);
}
}
}
if let Some(pc) = primary_candidates {
for (rank, result) in pc.results.iter().enumerate() {
scores.entry(result.doc_id).or_insert(-(rank as f32));
}
}
self.collect_top_k(scores)
}
fn normalize_scores(&self, results: &[ScoredResult]) -> Vec<(u64, f32)> {
if results.is_empty() {
return vec![];
}
let min = results.iter().map(|r| r.score).fold(f32::INFINITY, f32::min);
let max = results.iter().map(|r| r.score).fold(f32::NEG_INFINITY, f32::max);
let range = max - min;
if range == 0.0 {
return results.iter().map(|r| (r.doc_id, 1.0)).collect();
}
results.iter()
.map(|r| (r.doc_id, (r.score - min) / range))
.collect()
}
fn collect_top_k(&self, scores: HashMap<u64, f32>) -> FusionResult {
let mut results: Vec<ScoredResult> = scores
.into_iter()
.map(|(doc_id, score)| ScoredResult::new(doc_id, score))
.collect();
results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
if let Some(min) = self.config.min_score {
results.retain(|r| r.score >= min);
}
results.truncate(self.config.final_k);
FusionResult {
results,
method: self.config.method,
}
}
}
#[derive(Debug)]
pub struct FusionResult {
pub results: Vec<ScoredResult>,
pub method: FusionMethod,
}
pub trait VectorExecutor {
fn search(&self, query: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
}
pub trait Bm25Executor {
fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
}
pub struct UnifiedHybridExecutor<V: VectorExecutor, B: Bm25Executor> {
vector_executor: Arc<V>,
bm25_executor: Arc<B>,
fusion_engine: FusionEngine,
}
impl<V: VectorExecutor, B: Bm25Executor> UnifiedHybridExecutor<V, B> {
pub fn new(
vector_executor: Arc<V>,
bm25_executor: Arc<B>,
fusion_config: FusionConfig,
) -> Self {
Self {
vector_executor,
bm25_executor,
fusion_engine: FusionEngine::new(fusion_config),
}
}
pub fn execute(
&self,
query: &UnifiedHybridQuery,
_auth_scope: &AuthScope,
allowed_set: &AllowedSet, ) -> FusionResult {
if allowed_set.is_empty() {
return FusionResult {
results: vec![],
method: self.fusion_engine.config.method,
};
}
let k = self.fusion_engine.config.candidates_per_modality;
let vector_candidates = query.vector_query.as_ref().map(|vq| {
let results = self.vector_executor.search(&vq.embedding, k, allowed_set);
FilteredCandidates::from_vector(results)
});
let bm25_candidates = query.bm25_query.as_ref().map(|bq| {
let results = self.bm25_executor.search(&bq.text, k, allowed_set);
FilteredCandidates::from_bm25(results)
});
self.fusion_engine.fuse(vector_candidates, bm25_candidates)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rrf_fusion() {
let config = FusionConfig {
method: FusionMethod::Rrf { k: 60.0 },
candidates_per_modality: 10,
final_k: 5,
min_score: None,
};
let engine = FusionEngine::new(config);
let vector = FilteredCandidates::from_vector(vec![
ScoredResult::new(1, 0.9),
ScoredResult::new(2, 0.8),
ScoredResult::new(3, 0.7),
]);
let bm25 = FilteredCandidates::from_bm25(vec![
ScoredResult::new(2, 5.0), ScoredResult::new(4, 4.0),
ScoredResult::new(1, 3.0), ]);
let result = engine.fuse(Some(vector), Some(bm25));
assert!(!result.results.is_empty());
let top_ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
assert!(top_ids.contains(&1));
assert!(top_ids.contains(&2));
}
#[test]
fn test_linear_fusion() {
let config = FusionConfig {
method: FusionMethod::Linear {
vector_weight: 0.6,
bm25_weight: 0.4
},
candidates_per_modality: 10,
final_k: 5,
min_score: None,
};
let engine = FusionEngine::new(config);
let vector = FilteredCandidates::from_vector(vec![
ScoredResult::new(1, 1.0),
ScoredResult::new(2, 0.5),
]);
let bm25 = FilteredCandidates::from_bm25(vec![
ScoredResult::new(2, 10.0), ScoredResult::new(3, 5.0),
]);
let result = engine.fuse(Some(vector), Some(bm25));
assert!(!result.results.is_empty());
}
#[test]
fn test_empty_allowed_set() {
let config = FusionConfig::default();
let engine = FusionEngine::new(config);
let result = engine.fuse(None, None);
assert!(result.results.is_empty());
}
#[test]
fn test_score_normalization() {
let config = FusionConfig::default();
let engine = FusionEngine::new(config);
let results = vec![
ScoredResult::new(1, 100.0),
ScoredResult::new(2, 50.0),
ScoredResult::new(3, 0.0),
];
let normalized = engine.normalize_scores(&results);
assert_eq!(normalized.len(), 3);
let scores: HashMap<u64, f32> = normalized.into_iter().collect();
assert!((scores[&1] - 1.0).abs() < 0.001);
assert!((scores[&2] - 0.5).abs() < 0.001);
assert!((scores[&3] - 0.0).abs() < 0.001);
}
#[test]
fn test_no_post_filter_invariant() {
let allowed: std::collections::HashSet<u64> = [1, 2, 3, 5, 8].into_iter().collect();
let allowed_set = AllowedSet::from_iter(allowed.iter().copied());
let vector = FilteredCandidates::from_vector(vec![
ScoredResult::new(1, 0.9), ScoredResult::new(2, 0.8), ScoredResult::new(5, 0.7), ]);
let bm25 = FilteredCandidates::from_bm25(vec![
ScoredResult::new(2, 5.0), ScoredResult::new(3, 4.0), ScoredResult::new(8, 3.0), ]);
let config = FusionConfig::default();
let engine = FusionEngine::new(config);
let result = engine.fuse(Some(vector), Some(bm25));
for doc in &result.results {
assert!(
allowed_set.contains(doc.doc_id),
"INVARIANT VIOLATION: doc_id {} not in allowed set",
doc.doc_id
);
}
}
}
pub fn verify_no_post_filter_invariant(
result: &FusionResult,
allowed_set: &AllowedSet,
) -> InvariantVerification {
let mut violations = Vec::new();
for doc in &result.results {
if !allowed_set.contains(doc.doc_id) {
violations.push(doc.doc_id);
}
}
if violations.is_empty() {
InvariantVerification::Valid
} else {
InvariantVerification::Violated { doc_ids: violations }
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum InvariantVerification {
Valid,
Violated { doc_ids: Vec<u64> },
}
impl InvariantVerification {
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
pub fn assert_valid(&self) {
match self {
Self::Valid => {}
Self::Violated { doc_ids } => {
panic!(
"NO-POST-FILTER INVARIANT VIOLATED: {} docs not in allowed set: {:?}",
doc_ids.len(),
doc_ids
);
}
}
}
}