1use std::collections::HashMap;
51use std::sync::Arc;
52
53use crate::candidate_gate::AllowedSet;
54use crate::filter_ir::{AuthScope, FilterIR};
55use crate::filtered_vector_search::ScoredResult;
56use crate::grep_executor::GrepMode;
57use crate::namespace::NamespaceScope;
58
59#[derive(Debug, Clone, Copy, PartialEq)]
65pub enum FusionMethod {
66 Rrf {
68 k: f32,
69 vector_weight: f32,
70 bm25_weight: f32,
71 },
72
73 Linear {
75 vector_weight: f32,
76 bm25_weight: f32,
77 },
78
79 Max,
81
82 Cascade { primary: Modality },
84}
85
86#[derive(Debug, Clone, Copy, PartialEq)]
88pub enum Modality {
89 Vector,
90 Bm25,
91 Grep,
93}
94
95impl Default for FusionMethod {
96 fn default() -> Self {
97 Self::Rrf {
98 k: 60.0,
99 vector_weight: 1.0,
100 bm25_weight: 1.0,
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
107pub struct FusionConfig {
108 pub method: FusionMethod,
110
111 pub candidates_per_modality: usize,
113
114 pub final_k: usize,
116
117 pub min_score: Option<f32>,
119}
120
121impl Default for FusionConfig {
122 fn default() -> Self {
123 Self {
124 method: FusionMethod::default(),
125 candidates_per_modality: 100,
126 final_k: 10,
127 min_score: None,
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
138pub struct UnifiedHybridQuery {
139 pub namespace: NamespaceScope,
141
142 pub vector_query: Option<VectorQuerySpec>,
144
145 pub bm25_query: Option<Bm25QuerySpec>,
147
148 pub grep_query: Option<GrepQuerySpec>,
150
151 pub filter: FilterIR,
153
154 pub fusion_config: FusionConfig,
156}
157
158#[derive(Debug, Clone)]
160pub struct VectorQuerySpec {
161 pub embedding: Vec<f32>,
163 pub ef_search: usize,
165}
166
167#[derive(Debug, Clone)]
169pub struct Bm25QuerySpec {
170 pub text: String,
172 pub fields: Vec<String>,
174}
175
176#[derive(Debug, Clone)]
183pub struct GrepQuerySpec {
184 pub pattern: String,
186 pub mode: GrepMode,
188 pub weight: f32,
190}
191
192impl UnifiedHybridQuery {
193 pub fn new(namespace: NamespaceScope) -> Self {
195 Self {
196 namespace,
197 vector_query: None,
198 bm25_query: None,
199 grep_query: None,
200 filter: FilterIR::all(),
201 fusion_config: FusionConfig::default(),
202 }
203 }
204
205 pub fn with_vector(mut self, embedding: Vec<f32>) -> Self {
207 self.vector_query = Some(VectorQuerySpec {
208 embedding,
209 ef_search: 100,
210 });
211 self
212 }
213
214 pub fn with_bm25(mut self, text: impl Into<String>) -> Self {
216 self.bm25_query = Some(Bm25QuerySpec {
217 text: text.into(),
218 fields: vec!["content".to_string()],
219 });
220 self
221 }
222
223 pub fn with_grep(mut self, pattern: impl Into<String>, mode: GrepMode) -> Self {
225 self.grep_query = Some(GrepQuerySpec {
226 pattern: pattern.into(),
227 mode,
228 weight: 1.0,
229 });
230 self
231 }
232
233 pub fn with_grep_weighted(
235 mut self,
236 pattern: impl Into<String>,
237 mode: GrepMode,
238 weight: f32,
239 ) -> Self {
240 self.grep_query = Some(GrepQuerySpec {
241 pattern: pattern.into(),
242 mode,
243 weight,
244 });
245 self
246 }
247
248 pub fn with_filter(mut self, filter: FilterIR) -> Self {
250 self.filter = filter;
251 self
252 }
253
254 pub fn with_fusion(mut self, config: FusionConfig) -> Self {
256 self.fusion_config = config;
257 self
258 }
259
260 pub fn effective_filter(&self) -> FilterIR {
264 self.namespace.to_filter_ir().and(self.filter.clone())
265 }
266}
267
268#[derive(Debug)]
274pub struct FilteredCandidates {
275 pub modality: Modality,
277 pub results: Vec<ScoredResult>,
279 pub filtered: bool,
281}
282
283impl FilteredCandidates {
284 pub fn from_vector(results: Vec<ScoredResult>) -> Self {
286 Self {
287 modality: Modality::Vector,
288 results,
289 filtered: true,
290 }
291 }
292
293 pub fn from_bm25(results: Vec<ScoredResult>) -> Self {
295 Self {
296 modality: Modality::Bm25,
297 results,
298 filtered: true,
299 }
300 }
301
302 pub fn from_grep(results: Vec<ScoredResult>) -> Self {
304 Self {
305 modality: Modality::Grep,
306 results,
307 filtered: true,
308 }
309 }
310}
311
312#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
344pub struct DocId(pub u64);
345
346impl DocId {
347 #[inline]
349 pub const fn get(self) -> u64 {
350 self.0
351 }
352}
353
354impl From<u64> for DocId {
355 fn from(id: u64) -> Self {
356 DocId(id)
357 }
358}
359
360impl From<DocId> for u64 {
361 fn from(d: DocId) -> Self {
362 d.0
363 }
364}
365
366pub struct RankedList<'a> {
371 pub results: &'a [ScoredResult],
373 pub weight: f32,
375}
376
377pub struct WeightedLane {
382 pub candidates: FilteredCandidates,
384 pub weight: f32,
386}
387
388pub fn fuse_rrf_weighted(lists: &[RankedList<'_>], k: f32) -> HashMap<DocId, f32> {
400 let mut scores: HashMap<DocId, f32> = HashMap::new();
401 for list in lists {
402 for (rank, result) in list.results.iter().enumerate() {
403 let contribution = list.weight / (k + (rank as f32 + 1.0));
404 *scores.entry(DocId(result.doc_id)).or_insert(0.0) += contribution;
405 }
406 }
407 scores
408}
409
410pub struct FusionEngine {
416 config: FusionConfig,
417}
418
419impl FusionEngine {
420 pub fn new(config: FusionConfig) -> Self {
422 Self { config }
423 }
424
425 pub fn fuse(
434 &self,
435 vector_candidates: Option<FilteredCandidates>,
436 bm25_candidates: Option<FilteredCandidates>,
437 ) -> FusionResult {
438 if let Some(ref vc) = vector_candidates {
440 debug_assert!(vc.filtered, "Vector candidates must be pre-filtered!");
441 }
442 if let Some(ref bc) = bm25_candidates {
443 debug_assert!(bc.filtered, "BM25 candidates must be pre-filtered!");
444 }
445
446 if let FusionMethod::Cascade { primary } = self.config.method {
449 return self.fuse_cascade(vector_candidates, bm25_candidates, primary);
450 }
451
452 let (vector_weight, bm25_weight) = self.method_weights();
453 let mut lanes: Vec<WeightedLane> = Vec::with_capacity(2);
454 if let Some(vc) = vector_candidates {
455 lanes.push(WeightedLane {
456 candidates: vc,
457 weight: vector_weight,
458 });
459 }
460 if let Some(bc) = bm25_candidates {
461 lanes.push(WeightedLane {
462 candidates: bc,
463 weight: bm25_weight,
464 });
465 }
466 self.fuse_multi(lanes)
467 }
468
469 pub(crate) fn method_weights(&self) -> (f32, f32) {
474 match self.config.method {
475 FusionMethod::Rrf {
476 vector_weight,
477 bm25_weight,
478 ..
479 } => (vector_weight, bm25_weight),
480 FusionMethod::Linear {
481 vector_weight,
482 bm25_weight,
483 } => (vector_weight, bm25_weight),
484 FusionMethod::Max | FusionMethod::Cascade { .. } => (1.0, 1.0),
485 }
486 }
487
488 pub fn fuse_multi(&self, lanes: Vec<WeightedLane>) -> FusionResult {
498 for lane in &lanes {
499 debug_assert!(
500 lane.candidates.filtered,
501 "Fusion lanes must be pre-filtered!"
502 );
503 }
504
505 match self.config.method {
506 FusionMethod::Rrf { k, .. } => {
507 let ranked: Vec<RankedList<'_>> = lanes
508 .iter()
509 .map(|lane| RankedList {
510 results: &lane.candidates.results,
511 weight: lane.weight,
512 })
513 .collect();
514 let scores = fuse_rrf_weighted(&ranked, k)
515 .into_iter()
516 .map(|(doc, score)| (doc.0, score))
517 .collect();
518 self.collect_top_k(scores)
519 }
520 FusionMethod::Linear { .. } => {
521 let mut scores: HashMap<u64, f32> = HashMap::new();
522 for lane in &lanes {
523 for (doc_id, score) in self.normalize_scores(&lane.candidates.results) {
524 *scores.entry(doc_id).or_insert(0.0) += score * lane.weight;
525 }
526 }
527 self.collect_top_k(scores)
528 }
529 FusionMethod::Max => {
530 let mut scores: HashMap<u64, f32> = HashMap::new();
531 for lane in &lanes {
532 for (doc_id, score) in self.normalize_scores(&lane.candidates.results) {
533 let entry = scores.entry(doc_id).or_insert(0.0);
534 *entry = entry.max(score);
535 }
536 }
537 self.collect_top_k(scores)
538 }
539 FusionMethod::Cascade { primary } => {
540 let mut vector = None;
545 let mut bm25 = None;
546 for lane in lanes {
547 match lane.candidates.modality {
548 Modality::Vector => vector = Some(lane.candidates),
549 Modality::Bm25 => bm25 = Some(lane.candidates),
550 Modality::Grep => {}
551 }
552 }
553 self.fuse_cascade(vector, bm25, primary)
554 }
555 }
556 }
557
558 fn fuse_cascade(
560 &self,
561 vector: Option<FilteredCandidates>,
562 bm25: Option<FilteredCandidates>,
563 primary: Modality,
564 ) -> FusionResult {
565 let (primary_candidates, secondary_candidates) = match primary {
566 Modality::Vector => (vector, bm25),
567 Modality::Bm25 => (bm25, vector),
568 Modality::Grep => (vector, bm25),
572 };
573
574 let primary_ids: std::collections::HashSet<u64> = primary_candidates
576 .as_ref()
577 .map(|c| c.results.iter().map(|r| r.doc_id).collect())
578 .unwrap_or_default();
579
580 let mut scores: HashMap<u64, f32> = HashMap::new();
582
583 if let Some(sc) = secondary_candidates {
584 for result in &sc.results {
585 if primary_ids.contains(&result.doc_id) {
586 scores.insert(result.doc_id, result.score);
587 }
588 }
589 }
590
591 if let Some(pc) = primary_candidates {
593 for (rank, result) in pc.results.iter().enumerate() {
594 scores.entry(result.doc_id).or_insert(-(rank as f32));
595 }
596 }
597
598 self.collect_top_k(scores)
599 }
600
601 fn normalize_scores(&self, results: &[ScoredResult]) -> Vec<(u64, f32)> {
603 if results.is_empty() {
604 return vec![];
605 }
606
607 let min = results
608 .iter()
609 .map(|r| r.score)
610 .fold(f32::INFINITY, f32::min);
611 let max = results
612 .iter()
613 .map(|r| r.score)
614 .fold(f32::NEG_INFINITY, f32::max);
615 let range = max - min;
616
617 if range == 0.0 {
618 return results.iter().map(|r| (r.doc_id, 1.0)).collect();
619 }
620
621 results
622 .iter()
623 .map(|r| (r.doc_id, (r.score - min) / range))
624 .collect()
625 }
626
627 fn collect_top_k(&self, scores: HashMap<u64, f32>) -> FusionResult {
629 let mut results: Vec<ScoredResult> = scores
630 .into_iter()
631 .map(|(doc_id, score)| ScoredResult::new(doc_id, score))
632 .collect();
633
634 results.sort_by(|a, b| {
636 b.score
637 .partial_cmp(&a.score)
638 .unwrap_or(std::cmp::Ordering::Equal)
639 });
640
641 if let Some(min) = self.config.min_score {
643 results.retain(|r| r.score >= min);
644 }
645
646 results.truncate(self.config.final_k);
648
649 FusionResult {
650 results,
651 method: self.config.method,
652 }
653 }
654}
655
656#[derive(Debug)]
658pub struct FusionResult {
659 pub results: Vec<ScoredResult>,
661 pub method: FusionMethod,
663}
664
665pub trait VectorExecutor {
671 fn search(&self, query: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
672}
673
674pub trait Bm25Executor {
676 fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult>;
677}
678
679pub trait GrepLaneExecutor {
687 fn grep(
688 &self,
689 pattern: &str,
690 k: usize,
691 allowed: &AllowedSet,
692 mode: GrepMode,
693 ) -> Vec<ScoredResult>;
694}
695
696pub struct UnifiedHybridExecutor<V: VectorExecutor, B: Bm25Executor> {
700 vector_executor: Arc<V>,
701 bm25_executor: Arc<B>,
702 grep_executor: Option<Arc<dyn GrepLaneExecutor>>,
703 fusion_engine: FusionEngine,
704}
705
706impl<V: VectorExecutor, B: Bm25Executor> UnifiedHybridExecutor<V, B> {
707 pub fn new(
709 vector_executor: Arc<V>,
710 bm25_executor: Arc<B>,
711 fusion_config: FusionConfig,
712 ) -> Self {
713 Self {
714 vector_executor,
715 bm25_executor,
716 grep_executor: None,
717 fusion_engine: FusionEngine::new(fusion_config),
718 }
719 }
720
721 pub fn with_grep_executor(mut self, grep_executor: Arc<dyn GrepLaneExecutor>) -> Self {
726 self.grep_executor = Some(grep_executor);
727 self
728 }
729
730 pub fn execute(
742 &self,
743 query: &UnifiedHybridQuery,
744 _auth_scope: &AuthScope,
745 allowed_set: &AllowedSet, ) -> FusionResult {
747 if allowed_set.is_empty() {
749 return FusionResult {
750 results: vec![],
751 method: self.fusion_engine.config.method,
752 };
753 }
754
755 let k = self.fusion_engine.config.candidates_per_modality;
756
757 let mut grep_rank: Option<FilteredCandidates> = None;
762 let mut grep_weight = 1.0_f32;
763 let mut gated: Option<AllowedSet> = None;
764 if let (Some(gq), Some(grep)) = (query.grep_query.as_ref(), self.grep_executor.as_ref()) {
765 match gq.mode {
766 GrepMode::Gate => {
767 let hits = grep.grep(&gq.pattern, 0, allowed_set, GrepMode::Gate);
769 gated = Some(AllowedSet::from_iter(hits.into_iter().map(|r| r.doc_id)));
770 }
771 GrepMode::Rank => {
772 let hits = grep.grep(&gq.pattern, k, allowed_set, GrepMode::Rank);
773 grep_rank = Some(FilteredCandidates::from_grep(hits));
774 grep_weight = gq.weight;
775 }
776 }
777 }
778
779 let effective_allowed: &AllowedSet = gated.as_ref().unwrap_or(allowed_set);
781 if effective_allowed.is_empty() {
782 return FusionResult {
783 results: vec![],
784 method: self.fusion_engine.config.method,
785 };
786 }
787
788 let vector_candidates = query.vector_query.as_ref().map(|vq| {
790 let results = self
791 .vector_executor
792 .search(&vq.embedding, k, effective_allowed);
793 FilteredCandidates::from_vector(results)
794 });
795
796 let bm25_candidates = query.bm25_query.as_ref().map(|bq| {
798 let results = self.bm25_executor.search(&bq.text, k, effective_allowed);
799 FilteredCandidates::from_bm25(results)
800 });
801
802 let (vector_weight, bm25_weight) = self.fusion_engine.method_weights();
804 let mut lanes: Vec<WeightedLane> = Vec::with_capacity(3);
805 if let Some(vc) = vector_candidates {
806 lanes.push(WeightedLane {
807 candidates: vc,
808 weight: vector_weight,
809 });
810 }
811 if let Some(bc) = bm25_candidates {
812 lanes.push(WeightedLane {
813 candidates: bc,
814 weight: bm25_weight,
815 });
816 }
817 if let Some(gc) = grep_rank {
818 lanes.push(WeightedLane {
819 candidates: gc,
820 weight: grep_weight,
821 });
822 }
823
824 self.fusion_engine.fuse_multi(lanes)
825 }
826}
827
828#[cfg(test)]
833mod tests {
834 use super::*;
835
836 #[test]
837 fn test_rrf_fusion() {
838 let config = FusionConfig {
839 method: FusionMethod::Rrf {
840 k: 60.0,
841 vector_weight: 1.0,
842 bm25_weight: 1.0,
843 },
844 candidates_per_modality: 10,
845 final_k: 5,
846 min_score: None,
847 };
848
849 let engine = FusionEngine::new(config);
850
851 let vector = FilteredCandidates::from_vector(vec![
852 ScoredResult::new(1, 0.9),
853 ScoredResult::new(2, 0.8),
854 ScoredResult::new(3, 0.7),
855 ]);
856
857 let bm25 = FilteredCandidates::from_bm25(vec![
858 ScoredResult::new(2, 5.0), ScoredResult::new(4, 4.0),
860 ScoredResult::new(1, 3.0), ]);
862
863 let result = engine.fuse(Some(vector), Some(bm25));
864
865 assert!(!result.results.is_empty());
868
869 let top_ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
871 assert!(top_ids.contains(&1));
872 assert!(top_ids.contains(&2));
873 }
874
875 #[test]
876 fn test_fuse_rrf_weighted_is_1_indexed_and_weighted() {
877 let k = 60.0_f32;
881 let docs = [ScoredResult::new(7, 0.9), ScoredResult::new(8, 0.5)];
882 let scores = fuse_rrf_weighted(
883 &[RankedList {
884 results: &docs,
885 weight: 2.0,
886 }],
887 k,
888 );
889
890 let s7 = scores[&DocId(7)];
891 let s8 = scores[&DocId(8)];
892 assert!(
893 (s7 - 2.0 / (k + 1.0)).abs() < 1e-6,
894 "rank-1 must use 1-indexed weighted score"
895 );
896 assert!(
897 (s8 - 2.0 / (k + 2.0)).abs() < 1e-6,
898 "rank-2 must use 1-indexed weighted score"
899 );
900 assert!(s7 > s8, "earlier rank must score higher");
901
902 let list_a = [ScoredResult::new(1, 0.0)];
904 let list_b = [ScoredResult::new(1, 0.0)];
905 let merged = fuse_rrf_weighted(
906 &[
907 RankedList {
908 results: &list_a,
909 weight: 1.0,
910 },
911 RankedList {
912 results: &list_b,
913 weight: 3.0,
914 },
915 ],
916 k,
917 );
918 let expected = 1.0 / (k + 1.0) + 3.0 / (k + 1.0);
919 assert!(
920 (merged[&DocId(1)] - expected).abs() < 1e-6,
921 "weights must sum across lists"
922 );
923 }
924
925 #[test]
926 fn test_linear_fusion() {
927 let config = FusionConfig {
928 method: FusionMethod::Linear {
929 vector_weight: 0.6,
930 bm25_weight: 0.4,
931 },
932 candidates_per_modality: 10,
933 final_k: 5,
934 min_score: None,
935 };
936
937 let engine = FusionEngine::new(config);
938
939 let vector = FilteredCandidates::from_vector(vec![
940 ScoredResult::new(1, 1.0),
941 ScoredResult::new(2, 0.5),
942 ]);
943
944 let bm25 = FilteredCandidates::from_bm25(vec![
945 ScoredResult::new(2, 10.0), ScoredResult::new(3, 5.0),
947 ]);
948
949 let result = engine.fuse(Some(vector), Some(bm25));
950
951 assert!(!result.results.is_empty());
953 }
954
955 #[test]
956 fn test_empty_allowed_set() {
957 let config = FusionConfig::default();
958 let engine = FusionEngine::new(config);
959
960 let result = engine.fuse(None, None);
962 assert!(result.results.is_empty());
963 }
964
965 #[test]
966 fn test_score_normalization() {
967 let config = FusionConfig::default();
968 let engine = FusionEngine::new(config);
969
970 let results = vec![
971 ScoredResult::new(1, 100.0),
972 ScoredResult::new(2, 50.0),
973 ScoredResult::new(3, 0.0),
974 ];
975
976 let normalized = engine.normalize_scores(&results);
977
978 assert_eq!(normalized.len(), 3);
980 let scores: HashMap<u64, f32> = normalized.into_iter().collect();
981 assert!((scores[&1] - 1.0).abs() < 0.001);
982 assert!((scores[&2] - 0.5).abs() < 0.001);
983 assert!((scores[&3] - 0.0).abs() < 0.001);
984 }
985
986 #[test]
987 fn test_no_post_filter_invariant() {
988 let allowed: std::collections::HashSet<u64> = [1, 2, 3, 5, 8].into_iter().collect();
994 let allowed_set = AllowedSet::from_iter(allowed.iter().copied());
995
996 let vector = FilteredCandidates::from_vector(vec![
998 ScoredResult::new(1, 0.9), ScoredResult::new(2, 0.8), ScoredResult::new(5, 0.7), ]);
1002
1003 let bm25 = FilteredCandidates::from_bm25(vec![
1004 ScoredResult::new(2, 5.0), ScoredResult::new(3, 4.0), ScoredResult::new(8, 3.0), ]);
1008
1009 let config = FusionConfig::default();
1010 let engine = FusionEngine::new(config);
1011 let result = engine.fuse(Some(vector), Some(bm25));
1012
1013 for doc in &result.results {
1015 assert!(
1016 allowed_set.contains(doc.doc_id),
1017 "INVARIANT VIOLATION: doc_id {} not in allowed set",
1018 doc.doc_id
1019 );
1020 }
1021 }
1022
1023 use crate::grep_executor::GrepMode;
1026 use crate::namespace::Namespace;
1027 use crate::trigram_index::TrigramIndex;
1028
1029 struct MockVector(Vec<ScoredResult>);
1030 impl VectorExecutor for MockVector {
1031 fn search(&self, _q: &[f32], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
1032 self.0
1033 .iter()
1034 .filter(|r| allowed.contains(r.doc_id))
1035 .take(k)
1036 .cloned()
1037 .collect()
1038 }
1039 }
1040
1041 struct MockBm25(Vec<ScoredResult>);
1042 impl Bm25Executor for MockBm25 {
1043 fn search(&self, _q: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
1044 self.0
1045 .iter()
1046 .filter(|r| allowed.contains(r.doc_id))
1047 .take(k)
1048 .cloned()
1049 .collect()
1050 }
1051 }
1052
1053 struct RealGrep {
1056 index: TrigramIndex,
1057 }
1058 impl GrepLaneExecutor for RealGrep {
1059 fn grep(
1060 &self,
1061 pattern: &str,
1062 k: usize,
1063 allowed: &AllowedSet,
1064 mode: GrepMode,
1065 ) -> Vec<ScoredResult> {
1066 let exec = crate::grep_executor::GrepExecutor::new(&self.index);
1067 match exec.search(pattern, allowed, k, mode) {
1068 Ok(results) => results
1069 .hits
1070 .into_iter()
1071 .map(|h| ScoredResult::new(h.doc_id, h.score))
1072 .collect(),
1073 Err(_) => Vec::new(),
1074 }
1075 }
1076 }
1077
1078 fn test_query() -> UnifiedHybridQuery {
1079 UnifiedHybridQuery::new(NamespaceScope::single(Namespace::new("test").unwrap()))
1080 }
1081
1082 fn grep_index() -> TrigramIndex {
1083 let mut idx = TrigramIndex::new();
1084 idx.insert(1, "fn alpha() { compute_idf() }");
1085 idx.insert(2, "fn beta() { unrelated helper }");
1086 idx.insert(3, "fn gamma() { compute_idf() twice compute_idf() }");
1087 idx.insert(4, "struct Config { compute_idf: bool }");
1088 idx
1089 }
1090
1091 #[test]
1092 fn test_three_lane_rank_fusion_respects_allowed_set() {
1093 let vector = MockVector(vec![
1097 ScoredResult::new(2, 0.9),
1098 ScoredResult::new(1, 0.8),
1099 ScoredResult::new(4, 0.2),
1100 ]);
1101 let bm25 = MockBm25(vec![ScoredResult::new(2, 5.0), ScoredResult::new(1, 3.0)]);
1102 let grep = RealGrep {
1103 index: grep_index(),
1104 };
1105
1106 let allowed = AllowedSet::from_iter([1, 2, 3, 4]);
1107 let executor =
1108 UnifiedHybridExecutor::new(Arc::new(vector), Arc::new(bm25), FusionConfig::default())
1109 .with_grep_executor(Arc::new(grep));
1110
1111 let query = test_query()
1112 .with_vector(vec![0.0; 4])
1113 .with_bm25("anything")
1114 .with_grep("compute_idf", GrepMode::Rank);
1115
1116 let result = executor.execute(&query, &AuthScope::for_namespace("test"), &allowed);
1117
1118 assert!(!result.results.is_empty());
1119 for r in &result.results {
1120 assert!(
1121 allowed.contains(r.doc_id),
1122 "result {} escaped allowed set",
1123 r.doc_id
1124 );
1125 }
1126 let ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
1127 assert!(
1128 ids.contains(&3),
1129 "grep-only doc 3 should appear via the third lane, got {ids:?}"
1130 );
1131 }
1132
1133 #[test]
1134 fn test_grep_gate_narrows_before_other_lanes() {
1135 let vector = MockVector(vec![
1138 ScoredResult::new(2, 0.9),
1139 ScoredResult::new(1, 0.8),
1140 ScoredResult::new(4, 0.7),
1141 ScoredResult::new(3, 0.6),
1142 ]);
1143 let bm25 = MockBm25(vec![ScoredResult::new(2, 5.0)]);
1144 let grep = RealGrep {
1145 index: grep_index(),
1146 };
1147
1148 let allowed = AllowedSet::from_iter([1, 2, 3, 4]);
1149 let executor =
1150 UnifiedHybridExecutor::new(Arc::new(vector), Arc::new(bm25), FusionConfig::default())
1151 .with_grep_executor(Arc::new(grep));
1152
1153 let query = test_query()
1154 .with_vector(vec![0.0; 4])
1155 .with_bm25("anything")
1156 .with_grep("compute_idf", GrepMode::Gate);
1157
1158 let result = executor.execute(&query, &AuthScope::for_namespace("test"), &allowed);
1159
1160 assert!(!result.results.is_empty());
1161 let gate: std::collections::HashSet<u64> = [1, 3, 4].into_iter().collect();
1162 for r in &result.results {
1163 assert!(
1164 gate.contains(&r.doc_id),
1165 "doc {} not in grep gate {{1,3,4}}",
1166 r.doc_id
1167 );
1168 }
1169 assert!(
1170 !result.results.iter().any(|r| r.doc_id == 2),
1171 "doc 2 (no compute_idf) must be gated out"
1172 );
1173 }
1174
1175 #[test]
1176 fn test_grep_query_ignored_without_grep_executor() {
1177 let vector = MockVector(vec![ScoredResult::new(1, 0.9)]);
1180 let bm25 = MockBm25(vec![ScoredResult::new(2, 5.0)]);
1181 let allowed = AllowedSet::from_iter([1, 2, 3, 4]);
1182 let executor =
1183 UnifiedHybridExecutor::new(Arc::new(vector), Arc::new(bm25), FusionConfig::default());
1184
1185 let query = test_query()
1186 .with_vector(vec![0.0; 4])
1187 .with_bm25("anything")
1188 .with_grep("compute_idf", GrepMode::Gate);
1189
1190 let result = executor.execute(&query, &AuthScope::for_namespace("test"), &allowed);
1191 let ids: Vec<u64> = result.results.iter().map(|r| r.doc_id).collect();
1192 assert!(
1193 ids.contains(&1) && ids.contains(&2),
1194 "without a grep executor both lanes survive, got {ids:?}"
1195 );
1196 }
1197}
1198
1199pub fn verify_no_post_filter_invariant(
1214 result: &FusionResult,
1215 allowed_set: &AllowedSet,
1216) -> InvariantVerification {
1217 let mut violations = Vec::new();
1218
1219 for doc in &result.results {
1220 if !allowed_set.contains(doc.doc_id) {
1221 violations.push(doc.doc_id);
1222 }
1223 }
1224
1225 if violations.is_empty() {
1226 InvariantVerification::Valid
1227 } else {
1228 InvariantVerification::Violated {
1229 doc_ids: violations,
1230 }
1231 }
1232}
1233
1234#[derive(Debug, Clone, PartialEq, Eq)]
1236pub enum InvariantVerification {
1237 Valid,
1239 Violated { doc_ids: Vec<u64> },
1241}
1242
1243impl InvariantVerification {
1244 pub fn is_valid(&self) -> bool {
1246 matches!(self, Self::Valid)
1247 }
1248
1249 pub fn assert_valid(&self) {
1251 match self {
1252 Self::Valid => {}
1253 Self::Violated { doc_ids } => {
1254 panic!(
1255 "NO-POST-FILTER INVARIANT VIOLATED: {} docs not in allowed set: {:?}",
1256 doc_ids.len(),
1257 doc_ids
1258 );
1259 }
1260 }
1261 }
1262}