1use std::collections::{HashMap, HashSet};
96use std::cmp::Ordering;
97use std::sync::Arc;
98
99use crate::context_query::VectorIndex;
100use crate::soch_ql::SochValue;
101
102#[derive(Debug, Clone)]
108pub struct HybridQuery {
109 pub collection: String,
111
112 pub vector: Option<VectorQueryComponent>,
114
115 pub lexical: Option<LexicalQueryComponent>,
117
118 pub filters: Vec<MetadataFilter>,
120
121 pub fusion: FusionConfig,
123
124 pub rerank: Option<RerankConfig>,
126
127 pub limit: usize,
129
130 pub min_score: Option<f32>,
132}
133
134impl HybridQuery {
135 pub fn new(collection: &str) -> Self {
137 Self {
138 collection: collection.to_string(),
139 vector: None,
140 lexical: None,
141 filters: Vec::new(),
142 fusion: FusionConfig::default(),
143 rerank: None,
144 limit: 10,
145 min_score: None,
146 }
147 }
148
149 pub fn with_vector(mut self, embedding: Vec<f32>, weight: f32) -> Self {
151 self.vector = Some(VectorQueryComponent {
152 embedding,
153 weight,
154 ef_search: 100,
155 });
156 self
157 }
158
159 pub fn with_vector_text(mut self, text: String, weight: f32) -> Self {
161 self.vector = Some(VectorQueryComponent {
162 embedding: Vec::new(), weight,
164 ef_search: 100,
165 });
166 self.lexical = self.lexical.or(Some(LexicalQueryComponent {
168 query: text,
169 weight: 0.0, fields: vec!["content".to_string()],
171 }));
172 self
173 }
174
175 pub fn with_lexical(mut self, query: &str, weight: f32) -> Self {
177 self.lexical = Some(LexicalQueryComponent {
178 query: query.to_string(),
179 weight,
180 fields: vec!["content".to_string()],
181 });
182 self
183 }
184
185 pub fn with_lexical_fields(mut self, query: &str, weight: f32, fields: Vec<String>) -> Self {
187 self.lexical = Some(LexicalQueryComponent {
188 query: query.to_string(),
189 weight,
190 fields,
191 });
192 self
193 }
194
195 pub fn filter(mut self, field: &str, op: FilterOp, value: SochValue) -> Self {
197 self.filters.push(MetadataFilter {
198 field: field.to_string(),
199 op,
200 value,
201 });
202 self
203 }
204
205 pub fn filter_eq(self, field: &str, value: impl Into<SochValue>) -> Self {
207 self.filter(field, FilterOp::Eq, value.into())
208 }
209
210 pub fn filter_range(mut self, field: &str, min: Option<SochValue>, max: Option<SochValue>) -> Self {
212 if let Some(min_val) = min {
213 self.filters.push(MetadataFilter {
214 field: field.to_string(),
215 op: FilterOp::Gte,
216 value: min_val,
217 });
218 }
219 if let Some(max_val) = max {
220 self.filters.push(MetadataFilter {
221 field: field.to_string(),
222 op: FilterOp::Lte,
223 value: max_val,
224 });
225 }
226 self
227 }
228
229 pub fn with_fusion(mut self, method: FusionMethod) -> Self {
231 self.fusion.method = method;
232 self
233 }
234
235 pub fn with_rrf_k(mut self, k: f32) -> Self {
237 self.fusion.rrf_k = k;
238 self
239 }
240
241 pub fn with_rerank(mut self, model: &str, top_n: usize) -> Self {
243 self.rerank = Some(RerankConfig {
244 model: model.to_string(),
245 top_n,
246 batch_size: 32,
247 });
248 self
249 }
250
251 pub fn limit(mut self, limit: usize) -> Self {
253 self.limit = limit;
254 self
255 }
256
257 pub fn min_score(mut self, score: f32) -> Self {
259 self.min_score = Some(score);
260 self
261 }
262}
263
264#[derive(Debug, Clone)]
266pub struct VectorQueryComponent {
267 pub embedding: Vec<f32>,
269 pub weight: f32,
271 pub ef_search: usize,
273}
274
275#[derive(Debug, Clone)]
277pub struct LexicalQueryComponent {
278 pub query: String,
280 pub weight: f32,
282 pub fields: Vec<String>,
284}
285
286#[derive(Debug, Clone)]
288pub struct MetadataFilter {
289 pub field: String,
291 pub op: FilterOp,
293 pub value: SochValue,
295}
296
297#[derive(Debug, Clone, Copy, PartialEq, Eq)]
299pub enum FilterOp {
300 Eq,
302 Ne,
304 Gt,
306 Gte,
308 Lt,
310 Lte,
312 Contains,
314 In,
316}
317
318#[derive(Debug, Clone)]
320pub struct FusionConfig {
321 pub method: FusionMethod,
323 pub rrf_k: f32,
325 pub normalize: bool,
327}
328
329impl Default for FusionConfig {
330 fn default() -> Self {
331 Self {
332 method: FusionMethod::Rrf,
333 rrf_k: 60.0,
334 normalize: true,
335 }
336 }
337}
338
339#[derive(Debug, Clone, Copy, PartialEq, Eq)]
341pub enum FusionMethod {
342 Rrf,
344 WeightedSum,
346 Max,
348 Rsf,
350}
351
352#[derive(Debug, Clone)]
354pub struct RerankConfig {
355 pub model: String,
357 pub top_n: usize,
359 pub batch_size: usize,
361}
362
363#[derive(Debug, Clone)]
369pub struct HybridExecutionPlan {
370 pub query: HybridQuery,
372
373 pub steps: Vec<ExecutionStep>,
375
376 pub estimated_cost: f64,
378}
379
380#[derive(Debug, Clone)]
382pub enum ExecutionStep {
383 VectorSearch {
385 collection: String,
386 ef_search: usize,
387 weight: f32,
388 },
389
390 LexicalSearch {
392 collection: String,
393 query: String,
394 fields: Vec<String>,
395 weight: f32,
396 },
397
398 PreFilter {
403 filters: Vec<MetadataFilter>,
404 },
405
406 Fusion {
413 method: FusionMethod,
414 rrf_k: f32,
415 },
416
417 Rerank {
419 model: String,
420 top_n: usize,
421 },
422
423 Limit {
425 count: usize,
426 min_score: Option<f32>,
427 },
428
429 Redact {
435 fields: Vec<String>,
437 method: RedactionMethod,
439 },
440}
441
442#[derive(Debug, Clone)]
444pub enum RedactionMethod {
445 Replace(String),
447 Mask,
449 Remove,
451 Hash,
453}
454
455pub struct HybridQueryExecutor<V: VectorIndex> {
461 vector_index: Arc<V>,
463
464 lexical_index: Arc<LexicalIndex>,
466}
467
468impl<V: VectorIndex> HybridQueryExecutor<V> {
469 pub fn new(vector_index: Arc<V>, lexical_index: Arc<LexicalIndex>) -> Self {
471 Self {
472 vector_index,
473 lexical_index,
474 }
475 }
476
477 pub fn execute(&self, query: &HybridQuery) -> Result<HybridQueryResult, HybridQueryError> {
479 let mut candidates: HashMap<String, CandidateDoc> = HashMap::new();
480
481 let overfetch = (query.limit * 3).max(100);
483
484 if let Some(vector) = &query.vector {
486 if !vector.embedding.is_empty() {
487 let results = self.vector_index
488 .search_by_embedding(&query.collection, &vector.embedding, overfetch, None)
489 .map_err(HybridQueryError::VectorSearchError)?;
490
491 for (rank, result) in results.iter().enumerate() {
492 let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
493 CandidateDoc {
494 id: result.id.clone(),
495 content: result.content.clone(),
496 metadata: result.metadata.clone(),
497 vector_rank: None,
498 vector_score: None,
499 lexical_rank: None,
500 lexical_score: None,
501 fused_score: 0.0,
502 }
503 });
504 entry.vector_rank = Some(rank);
505 entry.vector_score = Some(result.score);
506 }
507 }
508 }
509
510 if let Some(lexical) = &query.lexical {
512 if lexical.weight > 0.0 {
513 let results = self.lexical_index.search(
514 &query.collection,
515 &lexical.query,
516 &lexical.fields,
517 overfetch,
518 )?;
519
520 for (rank, result) in results.iter().enumerate() {
521 let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
522 CandidateDoc {
523 id: result.id.clone(),
524 content: result.content.clone(),
525 metadata: HashMap::new(),
526 vector_rank: None,
527 vector_score: None,
528 lexical_rank: None,
529 lexical_score: None,
530 fused_score: 0.0,
531 }
532 });
533 entry.lexical_rank = Some(rank);
534 entry.lexical_score = Some(result.score);
535 }
536 }
537 }
538
539 let filtered: Vec<CandidateDoc> = candidates
541 .into_values()
542 .filter(|doc| self.matches_filters(doc, &query.filters))
543 .collect();
544
545 let mut fused = self.fuse_scores(filtered, query)?;
547
548 fused.sort_by(|a, b| b.fused_score.partial_cmp(&a.fused_score).unwrap_or(Ordering::Equal));
550
551 if let Some(rerank) = &query.rerank {
553 fused = self.rerank(&fused, &query.lexical.as_ref().map(|l| l.query.clone()).unwrap_or_default(), rerank)?;
554 }
555
556 if let Some(min) = query.min_score {
558 fused.retain(|doc| doc.fused_score >= min);
559 }
560
561 fused.truncate(query.limit);
563
564 let results: Vec<HybridSearchResult> = fused
566 .into_iter()
567 .map(|doc| HybridSearchResult {
568 id: doc.id,
569 score: doc.fused_score,
570 content: doc.content,
571 metadata: doc.metadata,
572 vector_score: doc.vector_score,
573 lexical_score: doc.lexical_score,
574 })
575 .collect();
576
577 Ok(HybridQueryResult {
578 results,
579 query: query.clone(),
580 stats: HybridQueryStats {
581 vector_candidates: 0, lexical_candidates: 0,
583 filtered_candidates: 0,
584 fusion_time_us: 0,
585 rerank_time_us: 0,
586 },
587 })
588 }
589
590 fn matches_filters(&self, doc: &CandidateDoc, filters: &[MetadataFilter]) -> bool {
592 for filter in filters {
593 if let Some(value) = doc.metadata.get(&filter.field) {
594 if !self.match_filter(value, &filter.op, &filter.value) {
595 return false;
596 }
597 } else {
598 return false;
600 }
601 }
602 true
603 }
604
605 fn match_filter(&self, doc_value: &SochValue, op: &FilterOp, filter_value: &SochValue) -> bool {
607 match op {
608 FilterOp::Eq => doc_value == filter_value,
609 FilterOp::Ne => doc_value != filter_value,
610 FilterOp::Gt => self.compare_values(doc_value, filter_value) == Some(Ordering::Greater),
611 FilterOp::Gte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Greater | Ordering::Equal)),
612 FilterOp::Lt => self.compare_values(doc_value, filter_value) == Some(Ordering::Less),
613 FilterOp::Lte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Less | Ordering::Equal)),
614 FilterOp::Contains => self.value_contains(doc_value, filter_value),
615 FilterOp::In => self.value_in_set(doc_value, filter_value),
616 }
617 }
618
619 fn compare_values(&self, a: &SochValue, b: &SochValue) -> Option<Ordering> {
621 match (a, b) {
622 (SochValue::Int(a), SochValue::Int(b)) => Some(a.cmp(b)),
623 (SochValue::UInt(a), SochValue::UInt(b)) => Some(a.cmp(b)),
624 (SochValue::Float(a), SochValue::Float(b)) => a.partial_cmp(b),
625 (SochValue::Text(a), SochValue::Text(b)) => Some(a.cmp(b)),
626 _ => None,
627 }
628 }
629
630 fn value_contains(&self, doc_value: &SochValue, search_value: &SochValue) -> bool {
632 match (doc_value, search_value) {
633 (SochValue::Text(text), SochValue::Text(search)) => text.contains(search.as_str()),
634 (SochValue::Array(arr), _) => arr.contains(search_value),
635 _ => false,
636 }
637 }
638
639 fn value_in_set(&self, doc_value: &SochValue, set_value: &SochValue) -> bool {
641 if let SochValue::Array(arr) = set_value {
642 arr.contains(doc_value)
643 } else {
644 false
645 }
646 }
647
648 fn fuse_scores(
650 &self,
651 candidates: Vec<CandidateDoc>,
652 query: &HybridQuery,
653 ) -> Result<Vec<CandidateDoc>, HybridQueryError> {
654 let vector_weight = query.vector.as_ref().map(|v| v.weight).unwrap_or(0.0);
655 let lexical_weight = query.lexical.as_ref().map(|l| l.weight).unwrap_or(0.0);
656
657 let mut fused = candidates;
658
659 match query.fusion.method {
660 FusionMethod::Rrf => {
661 for doc in &mut fused {
664 let mut score = 0.0;
665
666 if let Some(rank) = doc.vector_rank {
667 score += vector_weight / (query.fusion.rrf_k + rank as f32);
668 }
669
670 if let Some(rank) = doc.lexical_rank {
671 score += lexical_weight / (query.fusion.rrf_k + rank as f32);
672 }
673
674 doc.fused_score = score;
675 }
676 }
677
678 FusionMethod::WeightedSum => {
679 for doc in &mut fused {
681 let mut score = 0.0;
682
683 if let Some(s) = doc.vector_score {
684 score += vector_weight * s;
685 }
686
687 if let Some(s) = doc.lexical_score {
688 score += lexical_weight * s;
689 }
690
691 doc.fused_score = score;
692 }
693 }
694
695 FusionMethod::Max => {
696 for doc in &mut fused {
698 let v_score = doc.vector_score.map(|s| vector_weight * s).unwrap_or(0.0);
699 let l_score = doc.lexical_score.map(|s| lexical_weight * s).unwrap_or(0.0);
700 doc.fused_score = v_score.max(l_score);
701 }
702 }
703
704 FusionMethod::Rsf => {
705 for doc in &mut fused {
707 let mut score = 0.0;
708 let mut count = 0;
709
710 if let Some(s) = doc.vector_score {
711 score += s;
712 count += 1;
713 }
714
715 if let Some(s) = doc.lexical_score {
716 score += s;
717 count += 1;
718 }
719
720 doc.fused_score = if count > 0 { score / count as f32 } else { 0.0 };
721 }
722 }
723 }
724
725 Ok(fused)
726 }
727
728 fn rerank(
730 &self,
731 candidates: &[CandidateDoc],
732 query: &str,
733 config: &RerankConfig,
734 ) -> Result<Vec<CandidateDoc>, HybridQueryError> {
735 let to_rerank: Vec<_> = candidates.iter().take(config.top_n).cloned().collect();
737
738 let mut reranked = to_rerank;
741 let query_terms: HashSet<&str> = query.split_whitespace().collect();
742
743 for doc in &mut reranked {
744 let content_terms: HashSet<&str> = doc.content.split_whitespace().collect();
745 let overlap = query_terms.intersection(&content_terms).count();
746
747 doc.fused_score += (overlap as f32) * 0.01;
749 }
750
751 reranked.extend(candidates.iter().skip(config.top_n).cloned());
753
754 Ok(reranked)
755 }
756}
757
758#[derive(Debug, Clone)]
760struct CandidateDoc {
761 id: String,
762 content: String,
763 metadata: HashMap<String, SochValue>,
764 vector_rank: Option<usize>,
765 vector_score: Option<f32>,
766 lexical_rank: Option<usize>,
767 lexical_score: Option<f32>,
768 fused_score: f32,
769}
770
771pub struct LexicalIndex {
777 collections: std::sync::RwLock<HashMap<String, InvertedIndex>>,
779}
780
781struct InvertedIndex {
783 postings: HashMap<String, Vec<(String, u32)>>,
785
786 doc_lengths: HashMap<String, u32>,
788
789 documents: HashMap<String, String>,
791
792 avg_doc_len: f32,
794
795 k1: f32,
797 b: f32,
798}
799
800#[derive(Debug, Clone)]
802pub struct LexicalSearchResult {
803 pub id: String,
804 pub score: f32,
805 pub content: String,
806}
807
808impl LexicalIndex {
809 pub fn new() -> Self {
811 Self {
812 collections: std::sync::RwLock::new(HashMap::new()),
813 }
814 }
815
816 pub fn create_collection(&self, name: &str) {
818 let mut collections = self.collections.write().unwrap();
819 collections.insert(name.to_string(), InvertedIndex {
820 postings: HashMap::new(),
821 doc_lengths: HashMap::new(),
822 documents: HashMap::new(),
823 avg_doc_len: 0.0,
824 k1: 1.2,
825 b: 0.75,
826 });
827 }
828
829 pub fn index_document(&self, collection: &str, id: &str, content: &str) -> Result<(), HybridQueryError> {
831 let mut collections = self.collections.write().unwrap();
832 let index = collections.get_mut(collection)
833 .ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
834
835 let tokens: Vec<String> = content
837 .split_whitespace()
838 .map(|t| t.to_lowercase())
839 .collect();
840
841 let doc_len = tokens.len() as u32;
842
843 index.doc_lengths.insert(id.to_string(), doc_len);
845 index.documents.insert(id.to_string(), content.to_string());
846
847 let total_len: u32 = index.doc_lengths.values().sum();
849 index.avg_doc_len = total_len as f32 / index.doc_lengths.len() as f32;
850
851 let mut term_freqs: HashMap<String, u32> = HashMap::new();
853 for token in &tokens {
854 *term_freqs.entry(token.clone()).or_insert(0) += 1;
855 }
856
857 for (term, freq) in term_freqs {
859 index.postings
860 .entry(term)
861 .or_insert_with(Vec::new)
862 .push((id.to_string(), freq));
863 }
864
865 Ok(())
866 }
867
868 pub fn search(
870 &self,
871 collection: &str,
872 query: &str,
873 _fields: &[String],
874 limit: usize,
875 ) -> Result<Vec<LexicalSearchResult>, HybridQueryError> {
876 let collections = self.collections.read().unwrap();
877 let index = collections.get(collection)
878 .ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
879
880 let query_terms: Vec<String> = query
882 .split_whitespace()
883 .map(|t| t.to_lowercase())
884 .collect();
885
886 let n = index.doc_lengths.len() as f32;
887 let mut scores: HashMap<String, f32> = HashMap::new();
888
889 for term in &query_terms {
891 if let Some(postings) = index.postings.get(term) {
892 let df = postings.len() as f32;
893 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
894
895 for (doc_id, tf) in postings {
896 let doc_len = *index.doc_lengths.get(doc_id).unwrap_or(&1) as f32;
897 let tf = *tf as f32;
898
899 let score = idf * (tf * (index.k1 + 1.0)) /
901 (tf + index.k1 * (1.0 - index.b + index.b * doc_len / index.avg_doc_len));
902
903 *scores.entry(doc_id.clone()).or_insert(0.0) += score;
904 }
905 }
906 }
907
908 let mut results: Vec<_> = scores.into_iter().collect();
910 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
911
912 let results: Vec<LexicalSearchResult> = results
914 .into_iter()
915 .take(limit)
916 .map(|(id, score)| {
917 let content = index.documents.get(&id).cloned().unwrap_or_default();
918 LexicalSearchResult { id, score, content }
919 })
920 .collect();
921
922 Ok(results)
923 }
924}
925
926impl Default for LexicalIndex {
927 fn default() -> Self {
928 Self::new()
929 }
930}
931
932#[derive(Debug, Clone)]
938pub struct HybridSearchResult {
939 pub id: String,
941 pub score: f32,
943 pub content: String,
945 pub metadata: HashMap<String, SochValue>,
947 pub vector_score: Option<f32>,
949 pub lexical_score: Option<f32>,
951}
952
953#[derive(Debug, Clone)]
955pub struct HybridQueryResult {
956 pub results: Vec<HybridSearchResult>,
958 pub query: HybridQuery,
960 pub stats: HybridQueryStats,
962}
963
964#[derive(Debug, Clone, Default)]
966pub struct HybridQueryStats {
967 pub vector_candidates: usize,
969 pub lexical_candidates: usize,
971 pub filtered_candidates: usize,
973 pub fusion_time_us: u64,
975 pub rerank_time_us: u64,
977}
978
979#[derive(Debug, Clone)]
981pub enum HybridQueryError {
982 CollectionNotFound(String),
984 VectorSearchError(String),
986 LexicalSearchError(String),
988 FilterError(String),
990 RerankError(String),
992}
993
994impl std::fmt::Display for HybridQueryError {
995 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
996 match self {
997 Self::CollectionNotFound(name) => write!(f, "Collection not found: {}", name),
998 Self::VectorSearchError(msg) => write!(f, "Vector search error: {}", msg),
999 Self::LexicalSearchError(msg) => write!(f, "Lexical search error: {}", msg),
1000 Self::FilterError(msg) => write!(f, "Filter error: {}", msg),
1001 Self::RerankError(msg) => write!(f, "Rerank error: {}", msg),
1002 }
1003 }
1004}
1005
1006impl std::error::Error for HybridQueryError {}
1007
1008#[cfg(test)]
1013mod tests {
1014 use super::*;
1015
1016 #[test]
1017 fn test_hybrid_query_builder() {
1018 let query = HybridQuery::new("documents")
1019 .with_vector(vec![0.1, 0.2, 0.3], 0.7)
1020 .with_lexical("search query", 0.3)
1021 .filter_eq("category", SochValue::Text("tech".to_string()))
1022 .with_fusion(FusionMethod::Rrf)
1023 .with_rerank("cross-encoder", 20)
1024 .limit(10);
1025
1026 assert_eq!(query.collection, "documents");
1027 assert!(query.vector.is_some());
1028 assert!(query.lexical.is_some());
1029 assert_eq!(query.filters.len(), 1);
1030 assert_eq!(query.limit, 10);
1031 }
1032
1033 #[test]
1034 fn test_lexical_index_bm25() {
1035 let index = LexicalIndex::new();
1036 index.create_collection("test");
1037
1038 index.index_document("test", "doc1", "the quick brown fox").unwrap();
1039 index.index_document("test", "doc2", "the lazy dog sleeps").unwrap();
1040 index.index_document("test", "doc3", "quick fox jumps over the lazy dog").unwrap();
1041
1042 let results = index.search("test", "quick fox", &[], 10).unwrap();
1043
1044 assert!(!results.is_empty());
1045 let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
1047 assert!(ids.contains(&"doc1") || ids.contains(&"doc3"));
1048 assert!(!ids.contains(&"doc2"));
1050 }
1051
1052 #[test]
1053 fn test_rrf_fusion() {
1054 let k = 60.0;
1056
1057 let vector_weight = 0.7;
1059 let lexical_weight = 0.3;
1060
1061 let score = vector_weight / (k + 0.0) + lexical_weight / (k + 5.0);
1062
1063 assert!(score > 0.01 && score < 0.02);
1065 }
1066
1067 #[test]
1068 fn test_filter_matching() {
1069 let filters = vec![
1070 MetadataFilter {
1071 field: "status".to_string(),
1072 op: FilterOp::Eq,
1073 value: SochValue::Text("active".to_string()),
1074 },
1075 MetadataFilter {
1076 field: "count".to_string(),
1077 op: FilterOp::Gte,
1078 value: SochValue::Int(10),
1079 },
1080 ];
1081
1082 let mut metadata = HashMap::new();
1083 metadata.insert("status".to_string(), SochValue::Text("active".to_string()));
1084 metadata.insert("count".to_string(), SochValue::Int(15));
1085
1086 let doc = CandidateDoc {
1088 id: "test".to_string(),
1089 content: "test content".to_string(),
1090 metadata,
1091 vector_rank: None,
1092 vector_score: None,
1093 lexical_rank: None,
1094 lexical_score: None,
1095 fused_score: 0.0,
1096 };
1097
1098 assert!(doc.metadata.get("status") == Some(&SochValue::Text("active".to_string())));
1100 if let Some(SochValue::Int(count)) = doc.metadata.get("count") {
1101 assert!(*count >= 10);
1102 }
1103 }
1104}