sochdb_query/
unified_fusion.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Unified Hybrid Fusion with Mandatory Pre-Filtering (Task 7)
16//!
17//! This module implements hybrid retrieval (vector + BM25) that **never**
18//! post-filters. The key insight is:
19//!
20//! > Both vector and BM25 executors receive the **same** AllowedSet,
21//! > produce candidates **guaranteed** within it, then fusion merges by doc_id.
22//!
23//! ## Anti-Pattern (What We Avoid)
24//!
25//! ```text
26//! BAD: vector_search() → candidates → filter → too few
27//!      bm25_search() → candidates → filter → inconsistent
28//!      fusion(unfiltered_v, unfiltered_b) → filter at end → broken!
29//! ```
30//!
31//! ## Correct Pattern
32//!
33//! ```text
34//! GOOD: compute AllowedSet from FilterIR
35//!       vector_search(query, allowed_set) → filtered_v
36//!       bm25_search(query, allowed_set) → filtered_b
37//!       fusion(filtered_v, filtered_b) → already correct!
38//! ```
39//!
40//! ## Fusion Cost
41//!
42//! With pre-filtered candidates:
43//! - Fusion is O(k_v + k_b) with hash-join or two-pointer merge
44//! - Total work is proportional to constrained candidate sizes
45//! - No wasted scoring on disallowed documents
46
47use std::collections::HashMap;
48use std::sync::Arc;
49
50use crate::candidate_gate::AllowedSet;
51use crate::filter_ir::{AuthScope, FilterIR};
52use crate::filtered_vector_search::ScoredResult;
53use crate::namespace::NamespaceScope;
54
55// ============================================================================
56// Fusion Configuration
57// ============================================================================
58
59/// Fusion method
60#[derive(Debug, Clone, Copy, PartialEq)]
61pub enum FusionMethod {
62    /// Reciprocal Rank Fusion: score = Σ w_i / (k + rank_i)
63    Rrf { k: f32 },
64    
65    /// Linear combination of normalized scores
66    Linear { vector_weight: f32, bm25_weight: f32 },
67    
68    /// Take max score across modalities
69    Max,
70    
71    /// Cascade: use one modality to filter, other to rank
72    Cascade { primary: Modality },
73}
74
75/// Search modality
76#[derive(Debug, Clone, Copy, PartialEq)]
77pub enum Modality {
78    Vector,
79    Bm25,
80}
81
82impl Default for FusionMethod {
83    fn default() -> Self {
84        Self::Rrf { k: 60.0 }
85    }
86}
87
88/// Configuration for hybrid fusion
89#[derive(Debug, Clone)]
90pub struct FusionConfig {
91    /// Fusion method
92    pub method: FusionMethod,
93    
94    /// Number of candidates to retrieve from each modality
95    pub candidates_per_modality: usize,
96    
97    /// Final result limit
98    pub final_k: usize,
99    
100    /// Minimum score threshold (after fusion)
101    pub min_score: Option<f32>,
102}
103
104impl Default for FusionConfig {
105    fn default() -> Self {
106        Self {
107            method: FusionMethod::default(),
108            candidates_per_modality: 100,
109            final_k: 10,
110            min_score: None,
111        }
112    }
113}
114
115// ============================================================================
116// Unified Hybrid Query
117// ============================================================================
118
119/// A hybrid query that enforces pre-filtering
120#[derive(Debug, Clone)]
121pub struct UnifiedHybridQuery {
122    /// Namespace scope (mandatory)
123    pub namespace: NamespaceScope,
124    
125    /// Vector query (optional)
126    pub vector_query: Option<VectorQuerySpec>,
127    
128    /// BM25 query (optional)
129    pub bm25_query: Option<Bm25QuerySpec>,
130    
131    /// User-provided filter
132    pub filter: FilterIR,
133    
134    /// Fusion configuration
135    pub fusion_config: FusionConfig,
136}
137
138/// Vector query specification
139#[derive(Debug, Clone)]
140pub struct VectorQuerySpec {
141    /// Query embedding
142    pub embedding: Vec<f32>,
143    /// ef_search for HNSW
144    pub ef_search: usize,
145}
146
147/// BM25 query specification
148#[derive(Debug, Clone)]
149pub struct Bm25QuerySpec {
150    /// Query text (will be tokenized)
151    pub text: String,
152    /// Fields to search
153    pub fields: Vec<String>,
154}
155
156impl UnifiedHybridQuery {
157    /// Create a new hybrid query (namespace is mandatory)
158    pub fn new(namespace: NamespaceScope) -> Self {
159        Self {
160            namespace,
161            vector_query: None,
162            bm25_query: None,
163            filter: FilterIR::all(),
164            fusion_config: FusionConfig::default(),
165        }
166    }
167    
168    /// Add vector search
169    pub fn with_vector(mut self, embedding: Vec<f32>) -> Self {
170        self.vector_query = Some(VectorQuerySpec {
171            embedding,
172            ef_search: 100,
173        });
174        self
175    }
176    
177    /// Add BM25 search
178    pub fn with_bm25(mut self, text: impl Into<String>) -> Self {
179        self.bm25_query = Some(Bm25QuerySpec {
180            text: text.into(),
181            fields: vec!["content".to_string()],
182        });
183        self
184    }
185    
186    /// Add filter
187    pub fn with_filter(mut self, filter: FilterIR) -> Self {
188        self.filter = filter;
189        self
190    }
191    
192    /// Set fusion config
193    pub fn with_fusion(mut self, config: FusionConfig) -> Self {
194        self.fusion_config = config;
195        self
196    }
197    
198    /// Compute the complete effective filter
199    ///
200    /// This combines namespace scope + user filter. Auth scope is added later.
201    pub fn effective_filter(&self) -> FilterIR {
202        self.namespace.to_filter_ir().and(self.filter.clone())
203    }
204}
205
206// ============================================================================
207// Filtered Candidates
208// ============================================================================
209
210/// Candidates from a single modality (already filtered)
211#[derive(Debug)]
212pub struct FilteredCandidates {
213    /// Modality source
214    pub modality: Modality,
215    /// Scored results (doc_id, score)
216    pub results: Vec<ScoredResult>,
217    /// Whether the allowed set was applied
218    pub filtered: bool,
219}
220
221impl FilteredCandidates {
222    /// Create from vector search results
223    pub fn from_vector(results: Vec<ScoredResult>) -> Self {
224        Self {
225            modality: Modality::Vector,
226            results,
227            filtered: true,
228        }
229    }
230    
231    /// Create from BM25 results
232    pub fn from_bm25(results: Vec<ScoredResult>) -> Self {
233        Self {
234            modality: Modality::Bm25,
235            results,
236            filtered: true,
237        }
238    }
239}
240
241// ============================================================================
242// Fusion Engine
243// ============================================================================
244
245/// The fusion engine that combines candidates from multiple modalities
246pub struct FusionEngine {
247    config: FusionConfig,
248}
249
250impl FusionEngine {
251    /// Create a new fusion engine
252    pub fn new(config: FusionConfig) -> Self {
253        Self { config }
254    }
255    
256    /// Fuse candidates from vector and BM25 search
257    ///
258    /// INVARIANT: Both candidate sets are already filtered to AllowedSet.
259    /// This function does NOT apply any additional filtering.
260    pub fn fuse(
261        &self,
262        vector_candidates: Option<FilteredCandidates>,
263        bm25_candidates: Option<FilteredCandidates>,
264    ) -> FusionResult {
265        // Validate that candidates are pre-filtered
266        if let Some(ref vc) = vector_candidates {
267            debug_assert!(vc.filtered, "Vector candidates must be pre-filtered!");
268        }
269        if let Some(ref bc) = bm25_candidates {
270            debug_assert!(bc.filtered, "BM25 candidates must be pre-filtered!");
271        }
272        
273        match self.config.method {
274            FusionMethod::Rrf { k } => self.fuse_rrf(vector_candidates, bm25_candidates, k),
275            FusionMethod::Linear { vector_weight, bm25_weight } => {
276                self.fuse_linear(vector_candidates, bm25_candidates, vector_weight, bm25_weight)
277            }
278            FusionMethod::Max => self.fuse_max(vector_candidates, bm25_candidates),
279            FusionMethod::Cascade { primary } => {
280                self.fuse_cascade(vector_candidates, bm25_candidates, primary)
281            }
282        }
283    }
284    
285    /// Reciprocal Rank Fusion
286    ///
287    /// score(d) = Σ w_i / (k + rank_i(d))
288    fn fuse_rrf(
289        &self,
290        vector: Option<FilteredCandidates>,
291        bm25: Option<FilteredCandidates>,
292        k: f32,
293    ) -> FusionResult {
294        let mut scores: HashMap<u64, f32> = HashMap::new();
295        
296        // Add vector ranks
297        if let Some(vc) = vector {
298            for (rank, result) in vc.results.iter().enumerate() {
299                let rrf_score = 1.0 / (k + rank as f32 + 1.0);
300                *scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
301            }
302        }
303        
304        // Add BM25 ranks
305        if let Some(bc) = bm25 {
306            for (rank, result) in bc.results.iter().enumerate() {
307                let rrf_score = 1.0 / (k + rank as f32 + 1.0);
308                *scores.entry(result.doc_id).or_insert(0.0) += rrf_score;
309            }
310        }
311        
312        self.collect_top_k(scores)
313    }
314    
315    /// Linear combination fusion
316    fn fuse_linear(
317        &self,
318        vector: Option<FilteredCandidates>,
319        bm25: Option<FilteredCandidates>,
320        vector_weight: f32,
321        bm25_weight: f32,
322    ) -> FusionResult {
323        let mut scores: HashMap<u64, f32> = HashMap::new();
324        
325        // Normalize and add vector scores
326        if let Some(vc) = vector {
327            let normalized = self.normalize_scores(&vc.results);
328            for (doc_id, score) in normalized {
329                *scores.entry(doc_id).or_insert(0.0) += score * vector_weight;
330            }
331        }
332        
333        // Normalize and add BM25 scores
334        if let Some(bc) = bm25 {
335            let normalized = self.normalize_scores(&bc.results);
336            for (doc_id, score) in normalized {
337                *scores.entry(doc_id).or_insert(0.0) += score * bm25_weight;
338            }
339        }
340        
341        self.collect_top_k(scores)
342    }
343    
344    /// Max-score fusion
345    fn fuse_max(
346        &self,
347        vector: Option<FilteredCandidates>,
348        bm25: Option<FilteredCandidates>,
349    ) -> FusionResult {
350        let mut scores: HashMap<u64, f32> = HashMap::new();
351        
352        if let Some(vc) = vector {
353            let normalized = self.normalize_scores(&vc.results);
354            for (doc_id, score) in normalized {
355                let entry = scores.entry(doc_id).or_insert(0.0);
356                *entry = entry.max(score);
357            }
358        }
359        
360        if let Some(bc) = bm25 {
361            let normalized = self.normalize_scores(&bc.results);
362            for (doc_id, score) in normalized {
363                let entry = scores.entry(doc_id).or_insert(0.0);
364                *entry = entry.max(score);
365            }
366        }
367        
368        self.collect_top_k(scores)
369    }
370    
371    /// Cascade fusion: use primary modality to filter, secondary to rank
372    fn fuse_cascade(
373        &self,
374        vector: Option<FilteredCandidates>,
375        bm25: Option<FilteredCandidates>,
376        primary: Modality,
377    ) -> FusionResult {
378        let (primary_candidates, secondary_candidates) = match primary {
379            Modality::Vector => (vector, bm25),
380            Modality::Bm25 => (bm25, vector),
381        };
382        
383        // Get primary doc IDs
384        let primary_ids: std::collections::HashSet<u64> = primary_candidates
385            .as_ref()
386            .map(|c| c.results.iter().map(|r| r.doc_id).collect())
387            .unwrap_or_default();
388        
389        // Score by secondary, but only docs in primary
390        let mut scores: HashMap<u64, f32> = HashMap::new();
391        
392        if let Some(sc) = secondary_candidates {
393            for result in &sc.results {
394                if primary_ids.contains(&result.doc_id) {
395                    scores.insert(result.doc_id, result.score);
396                }
397            }
398        }
399        
400        // If secondary doesn't score some docs, use primary order
401        if let Some(pc) = primary_candidates {
402            for (rank, result) in pc.results.iter().enumerate() {
403                scores.entry(result.doc_id).or_insert(-(rank as f32));
404            }
405        }
406        
407        self.collect_top_k(scores)
408    }
409    
410    /// Normalize scores to [0, 1] using min-max normalization
411    fn normalize_scores(&self, results: &[ScoredResult]) -> Vec<(u64, f32)> {
412        if results.is_empty() {
413            return vec![];
414        }
415        
416        let min = results.iter().map(|r| r.score).fold(f32::INFINITY, f32::min);
417        let max = results.iter().map(|r| r.score).fold(f32::NEG_INFINITY, f32::max);
418        let range = max - min;
419        
420        if range == 0.0 {
421            return results.iter().map(|r| (r.doc_id, 1.0)).collect();
422        }
423        
424        results.iter()
425            .map(|r| (r.doc_id, (r.score - min) / range))
426            .collect()
427    }
428    
429    /// Collect top-k results from score map
430    fn collect_top_k(&self, scores: HashMap<u64, f32>) -> FusionResult {
431        let mut results: Vec<ScoredResult> = scores
432            .into_iter()
433            .map(|(doc_id, score)| ScoredResult::new(doc_id, score))
434            .collect();
435        
436        // Sort by score descending
437        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
438        
439        // Apply min_score filter
440        if let Some(min) = self.config.min_score {
441            results.retain(|r| r.score >= min);
442        }
443        
444        // Truncate to k
445        results.truncate(self.config.final_k);
446        
447        FusionResult {
448            results,
449            method: self.config.method,
450        }
451    }
452}
453
454/// Result of fusion
455#[derive(Debug)]
456pub struct FusionResult {
457    /// Final ranked results
458    pub results: Vec<ScoredResult>,
459    /// Method used
460    pub method: FusionMethod,
461}
462
463// ============================================================================
464// Unified Hybrid Executor
465// ============================================================================
466
467/// Trait for vector search executor
468pub trait VectorExecutor {
469    fn search(&self, query: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
470}
471
472/// Trait for BM25 executor
473pub trait Bm25Executor {
474    fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
475}
476
477/// The unified hybrid executor
478///
479/// This is the main entry point that enforces the "no post-filtering" contract.
480pub struct UnifiedHybridExecutor<V: VectorExecutor, B: Bm25Executor> {
481    vector_executor: Arc<V>,
482    bm25_executor: Arc<B>,
483    fusion_engine: FusionEngine,
484}
485
486impl<V: VectorExecutor, B: Bm25Executor> UnifiedHybridExecutor<V, B> {
487    /// Create a new executor
488    pub fn new(
489        vector_executor: Arc<V>,
490        bm25_executor: Arc<B>,
491        fusion_config: FusionConfig,
492    ) -> Self {
493        Self {
494            vector_executor,
495            bm25_executor,
496            fusion_engine: FusionEngine::new(fusion_config),
497        }
498    }
499    
500    /// Execute a hybrid query with mandatory pre-filtering
501    ///
502    /// # Contract
503    ///
504    /// 1. Computes `effective_filter = auth_scope ∧ query_filter`
505    /// 2. Converts to `AllowedSet` (via metadata index)
506    /// 3. Passes SAME `AllowedSet` to BOTH vector and BM25 executors
507    /// 4. Fuses already-filtered results
508    ///
509    /// NO POST-FILTERING occurs in this function.
510    pub fn execute(
511        &self,
512        query: &UnifiedHybridQuery,
513        _auth_scope: &AuthScope,
514        allowed_set: &AllowedSet, // Pre-computed from FilterIR + AuthScope
515    ) -> FusionResult {
516        // Short-circuit if empty
517        if allowed_set.is_empty() {
518            return FusionResult {
519                results: vec![],
520                method: self.fusion_engine.config.method,
521            };
522        }
523        
524        let k = self.fusion_engine.config.candidates_per_modality;
525        
526        // Vector search (with AllowedSet)
527        let vector_candidates = query.vector_query.as_ref().map(|vq| {
528            let results = self.vector_executor.search(&vq.embedding, k, allowed_set);
529            FilteredCandidates::from_vector(results)
530        });
531        
532        // BM25 search (with SAME AllowedSet)
533        let bm25_candidates = query.bm25_query.as_ref().map(|bq| {
534            let results = self.bm25_executor.search(&bq.text, k, allowed_set);
535            FilteredCandidates::from_bm25(results)
536        });
537        
538        // Fuse (both are already filtered - no post-filtering!)
539        self.fusion_engine.fuse(vector_candidates, bm25_candidates)
540    }
541}
542
543// ============================================================================
544// Tests
545// ============================================================================
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550    
551    #[test]
552    fn test_rrf_fusion() {
553        let config = FusionConfig {
554            method: FusionMethod::Rrf { k: 60.0 },
555            candidates_per_modality: 10,
556            final_k: 5,
557            min_score: None,
558        };
559        
560        let engine = FusionEngine::new(config);
561        
562        let vector = FilteredCandidates::from_vector(vec![
563            ScoredResult::new(1, 0.9),
564            ScoredResult::new(2, 0.8),
565            ScoredResult::new(3, 0.7),
566        ]);
567        
568        let bm25 = FilteredCandidates::from_bm25(vec![
569            ScoredResult::new(2, 5.0), // doc 2 is in both
570            ScoredResult::new(4, 4.0),
571            ScoredResult::new(1, 3.0), // doc 1 is in both
572        ]);
573        
574        let result = engine.fuse(Some(vector), Some(bm25));
575        
576        // Doc 2 should score highest (rank 2 in vector, rank 1 in BM25)
577        // Doc 1 should also score well (rank 1 in vector, rank 3 in BM25)
578        assert!(!result.results.is_empty());
579        
580        // Docs 1 and 2 should be near the top
581        let top_ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
582        assert!(top_ids.contains(&1));
583        assert!(top_ids.contains(&2));
584    }
585    
586    #[test]
587    fn test_linear_fusion() {
588        let config = FusionConfig {
589            method: FusionMethod::Linear { 
590                vector_weight: 0.6, 
591                bm25_weight: 0.4 
592            },
593            candidates_per_modality: 10,
594            final_k: 5,
595            min_score: None,
596        };
597        
598        let engine = FusionEngine::new(config);
599        
600        let vector = FilteredCandidates::from_vector(vec![
601            ScoredResult::new(1, 1.0),
602            ScoredResult::new(2, 0.5),
603        ]);
604        
605        let bm25 = FilteredCandidates::from_bm25(vec![
606            ScoredResult::new(2, 10.0), // Different scale
607            ScoredResult::new(3, 5.0),
608        ]);
609        
610        let result = engine.fuse(Some(vector), Some(bm25));
611        
612        // After normalization, doc 2 should benefit from both
613        assert!(!result.results.is_empty());
614    }
615    
616    #[test]
617    fn test_empty_allowed_set() {
618        let config = FusionConfig::default();
619        let engine = FusionEngine::new(config);
620        
621        // No candidates = empty result
622        let result = engine.fuse(None, None);
623        assert!(result.results.is_empty());
624    }
625    
626    #[test]
627    fn test_score_normalization() {
628        let config = FusionConfig::default();
629        let engine = FusionEngine::new(config);
630        
631        let results = vec![
632            ScoredResult::new(1, 100.0),
633            ScoredResult::new(2, 50.0),
634            ScoredResult::new(3, 0.0),
635        ];
636        
637        let normalized = engine.normalize_scores(&results);
638        
639        // Should be normalized to [0, 1]
640        assert_eq!(normalized.len(), 3);
641        let scores: HashMap<u64, f32> = normalized.into_iter().collect();
642        assert!((scores[&1] - 1.0).abs() < 0.001);
643        assert!((scores[&2] - 0.5).abs() < 0.001);
644        assert!((scores[&3] - 0.0).abs() < 0.001);
645    }
646    
647    #[test]
648    fn test_no_post_filter_invariant() {
649        // This test verifies the core invariant:
650        // result-set ⊆ allowed-set
651        //
652        // If this invariant is violated, it indicates a security issue.
653        
654        let allowed: std::collections::HashSet<u64> = [1, 2, 3, 5, 8].into_iter().collect();
655        let allowed_set = AllowedSet::from_iter(allowed.iter().copied());
656        
657        // Simulate filtered candidates (these should already respect AllowedSet)
658        let vector = FilteredCandidates::from_vector(vec![
659            ScoredResult::new(1, 0.9),  // in allowed set
660            ScoredResult::new(2, 0.8),  // in allowed set
661            ScoredResult::new(5, 0.7),  // in allowed set
662        ]);
663        
664        let bm25 = FilteredCandidates::from_bm25(vec![
665            ScoredResult::new(2, 5.0),  // in allowed set
666            ScoredResult::new(3, 4.0),  // in allowed set
667            ScoredResult::new(8, 3.0),  // in allowed set
668        ]);
669        
670        let config = FusionConfig::default();
671        let engine = FusionEngine::new(config);
672        let result = engine.fuse(Some(vector), Some(bm25));
673        
674        // INVARIANT: Every result doc_id must be in the allowed set
675        for doc in &result.results {
676            assert!(
677                allowed_set.contains(doc.doc_id),
678                "INVARIANT VIOLATION: doc_id {} not in allowed set",
679                doc.doc_id
680            );
681        }
682    }
683}
684
685// ============================================================================
686// Invariant Verification
687// ============================================================================
688
689/// Verify that a fusion result respects the no-post-filtering invariant
690/// 
691/// This function should be used in tests and optionally in debug builds
692/// to verify that the security invariant holds.
693///
694/// # Invariant
695///
696/// `∀ doc ∈ result: doc.id ∈ allowed_set`
697///
698/// This is the "monotone property" from the architecture document.
699pub fn verify_no_post_filter_invariant(
700    result: &FusionResult,
701    allowed_set: &AllowedSet,
702) -> InvariantVerification {
703    let mut violations = Vec::new();
704    
705    for doc in &result.results {
706        if !allowed_set.contains(doc.doc_id) {
707            violations.push(doc.doc_id);
708        }
709    }
710    
711    if violations.is_empty() {
712        InvariantVerification::Valid
713    } else {
714        InvariantVerification::Violated { doc_ids: violations }
715    }
716}
717
718/// Result of invariant verification
719#[derive(Debug, Clone, PartialEq, Eq)]
720pub enum InvariantVerification {
721    /// Invariant holds
722    Valid,
723    /// Invariant violated - these doc IDs should not be in results
724    Violated { doc_ids: Vec<u64> },
725}
726
727impl InvariantVerification {
728    /// Check if the invariant holds
729    pub fn is_valid(&self) -> bool {
730        matches!(self, Self::Valid)
731    }
732    
733    /// Panic if the invariant is violated (for testing)
734    pub fn assert_valid(&self) {
735        match self {
736            Self::Valid => {}
737            Self::Violated { doc_ids } => {
738                panic!(
739                    "NO-POST-FILTER INVARIANT VIOLATED: {} docs not in allowed set: {:?}",
740                    doc_ids.len(),
741                    doc_ids
742                );
743            }
744        }
745    }
746}