1use std::collections::{HashMap, HashSet};
93use std::cmp::Ordering;
94use std::sync::Arc;
95
96use crate::context_query::VectorIndex;
97use crate::soch_ql::SochValue;
98
99#[derive(Debug, Clone)]
105pub struct HybridQuery {
106 pub collection: String,
108
109 pub vector: Option<VectorQueryComponent>,
111
112 pub lexical: Option<LexicalQueryComponent>,
114
115 pub filters: Vec<MetadataFilter>,
117
118 pub fusion: FusionConfig,
120
121 pub rerank: Option<RerankConfig>,
123
124 pub limit: usize,
126
127 pub min_score: Option<f32>,
129}
130
131impl HybridQuery {
132 pub fn new(collection: &str) -> Self {
134 Self {
135 collection: collection.to_string(),
136 vector: None,
137 lexical: None,
138 filters: Vec::new(),
139 fusion: FusionConfig::default(),
140 rerank: None,
141 limit: 10,
142 min_score: None,
143 }
144 }
145
146 pub fn with_vector(mut self, embedding: Vec<f32>, weight: f32) -> Self {
148 self.vector = Some(VectorQueryComponent {
149 embedding,
150 weight,
151 ef_search: 100,
152 });
153 self
154 }
155
156 pub fn with_vector_text(mut self, text: String, weight: f32) -> Self {
158 self.vector = Some(VectorQueryComponent {
159 embedding: Vec::new(), weight,
161 ef_search: 100,
162 });
163 self.lexical = self.lexical.or(Some(LexicalQueryComponent {
165 query: text,
166 weight: 0.0, fields: vec!["content".to_string()],
168 }));
169 self
170 }
171
172 pub fn with_lexical(mut self, query: &str, weight: f32) -> Self {
174 self.lexical = Some(LexicalQueryComponent {
175 query: query.to_string(),
176 weight,
177 fields: vec!["content".to_string()],
178 });
179 self
180 }
181
182 pub fn with_lexical_fields(mut self, query: &str, weight: f32, fields: Vec<String>) -> Self {
184 self.lexical = Some(LexicalQueryComponent {
185 query: query.to_string(),
186 weight,
187 fields,
188 });
189 self
190 }
191
192 pub fn filter(mut self, field: &str, op: FilterOp, value: SochValue) -> Self {
194 self.filters.push(MetadataFilter {
195 field: field.to_string(),
196 op,
197 value,
198 });
199 self
200 }
201
202 pub fn filter_eq(self, field: &str, value: impl Into<SochValue>) -> Self {
204 self.filter(field, FilterOp::Eq, value.into())
205 }
206
207 pub fn filter_range(mut self, field: &str, min: Option<SochValue>, max: Option<SochValue>) -> Self {
209 if let Some(min_val) = min {
210 self.filters.push(MetadataFilter {
211 field: field.to_string(),
212 op: FilterOp::Gte,
213 value: min_val,
214 });
215 }
216 if let Some(max_val) = max {
217 self.filters.push(MetadataFilter {
218 field: field.to_string(),
219 op: FilterOp::Lte,
220 value: max_val,
221 });
222 }
223 self
224 }
225
226 pub fn with_fusion(mut self, method: FusionMethod) -> Self {
228 self.fusion.method = method;
229 self
230 }
231
232 pub fn with_rrf_k(mut self, k: f32) -> Self {
234 self.fusion.rrf_k = k;
235 self
236 }
237
238 pub fn with_rerank(mut self, model: &str, top_n: usize) -> Self {
240 self.rerank = Some(RerankConfig {
241 model: model.to_string(),
242 top_n,
243 batch_size: 32,
244 });
245 self
246 }
247
248 pub fn limit(mut self, limit: usize) -> Self {
250 self.limit = limit;
251 self
252 }
253
254 pub fn min_score(mut self, score: f32) -> Self {
256 self.min_score = Some(score);
257 self
258 }
259}
260
261#[derive(Debug, Clone)]
263pub struct VectorQueryComponent {
264 pub embedding: Vec<f32>,
266 pub weight: f32,
268 pub ef_search: usize,
270}
271
272#[derive(Debug, Clone)]
274pub struct LexicalQueryComponent {
275 pub query: String,
277 pub weight: f32,
279 pub fields: Vec<String>,
281}
282
283#[derive(Debug, Clone)]
285pub struct MetadataFilter {
286 pub field: String,
288 pub op: FilterOp,
290 pub value: SochValue,
292}
293
294#[derive(Debug, Clone, Copy, PartialEq, Eq)]
296pub enum FilterOp {
297 Eq,
299 Ne,
301 Gt,
303 Gte,
305 Lt,
307 Lte,
309 Contains,
311 In,
313}
314
315#[derive(Debug, Clone)]
317pub struct FusionConfig {
318 pub method: FusionMethod,
320 pub rrf_k: f32,
322 pub normalize: bool,
324}
325
326impl Default for FusionConfig {
327 fn default() -> Self {
328 Self {
329 method: FusionMethod::Rrf,
330 rrf_k: 60.0,
331 normalize: true,
332 }
333 }
334}
335
336#[derive(Debug, Clone, Copy, PartialEq, Eq)]
338pub enum FusionMethod {
339 Rrf,
341 WeightedSum,
343 Max,
345 Rsf,
347}
348
349#[derive(Debug, Clone)]
351pub struct RerankConfig {
352 pub model: String,
354 pub top_n: usize,
356 pub batch_size: usize,
358}
359
360#[derive(Debug, Clone)]
366pub struct HybridExecutionPlan {
367 pub query: HybridQuery,
369
370 pub steps: Vec<ExecutionStep>,
372
373 pub estimated_cost: f64,
375}
376
377#[derive(Debug, Clone)]
379pub enum ExecutionStep {
380 VectorSearch {
382 collection: String,
383 ef_search: usize,
384 weight: f32,
385 },
386
387 LexicalSearch {
389 collection: String,
390 query: String,
391 fields: Vec<String>,
392 weight: f32,
393 },
394
395 PreFilter {
400 filters: Vec<MetadataFilter>,
401 },
402
403 Fusion {
410 method: FusionMethod,
411 rrf_k: f32,
412 },
413
414 Rerank {
416 model: String,
417 top_n: usize,
418 },
419
420 Limit {
422 count: usize,
423 min_score: Option<f32>,
424 },
425
426 Redact {
432 fields: Vec<String>,
434 method: RedactionMethod,
436 },
437}
438
439#[derive(Debug, Clone)]
441pub enum RedactionMethod {
442 Replace(String),
444 Mask,
446 Remove,
448 Hash,
450}
451
452pub struct HybridQueryExecutor<V: VectorIndex> {
458 vector_index: Arc<V>,
460
461 lexical_index: Arc<LexicalIndex>,
463}
464
465impl<V: VectorIndex> HybridQueryExecutor<V> {
466 pub fn new(vector_index: Arc<V>, lexical_index: Arc<LexicalIndex>) -> Self {
468 Self {
469 vector_index,
470 lexical_index,
471 }
472 }
473
474 pub fn execute(&self, query: &HybridQuery) -> Result<HybridQueryResult, HybridQueryError> {
476 let mut candidates: HashMap<String, CandidateDoc> = HashMap::new();
477
478 let overfetch = (query.limit * 3).max(100);
480
481 if let Some(vector) = &query.vector {
483 if !vector.embedding.is_empty() {
484 let results = self.vector_index
485 .search_by_embedding(&query.collection, &vector.embedding, overfetch, None)
486 .map_err(HybridQueryError::VectorSearchError)?;
487
488 for (rank, result) in results.iter().enumerate() {
489 let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
490 CandidateDoc {
491 id: result.id.clone(),
492 content: result.content.clone(),
493 metadata: result.metadata.clone(),
494 vector_rank: None,
495 vector_score: None,
496 lexical_rank: None,
497 lexical_score: None,
498 fused_score: 0.0,
499 }
500 });
501 entry.vector_rank = Some(rank);
502 entry.vector_score = Some(result.score);
503 }
504 }
505 }
506
507 if let Some(lexical) = &query.lexical {
509 if lexical.weight > 0.0 {
510 let results = self.lexical_index.search(
511 &query.collection,
512 &lexical.query,
513 &lexical.fields,
514 overfetch,
515 )?;
516
517 for (rank, result) in results.iter().enumerate() {
518 let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
519 CandidateDoc {
520 id: result.id.clone(),
521 content: result.content.clone(),
522 metadata: HashMap::new(),
523 vector_rank: None,
524 vector_score: None,
525 lexical_rank: None,
526 lexical_score: None,
527 fused_score: 0.0,
528 }
529 });
530 entry.lexical_rank = Some(rank);
531 entry.lexical_score = Some(result.score);
532 }
533 }
534 }
535
536 let filtered: Vec<CandidateDoc> = candidates
538 .into_values()
539 .filter(|doc| self.matches_filters(doc, &query.filters))
540 .collect();
541
542 let mut fused = self.fuse_scores(filtered, query)?;
544
545 fused.sort_by(|a, b| b.fused_score.partial_cmp(&a.fused_score).unwrap_or(Ordering::Equal));
547
548 if let Some(rerank) = &query.rerank {
550 fused = self.rerank(&fused, &query.lexical.as_ref().map(|l| l.query.clone()).unwrap_or_default(), rerank)?;
551 }
552
553 if let Some(min) = query.min_score {
555 fused.retain(|doc| doc.fused_score >= min);
556 }
557
558 fused.truncate(query.limit);
560
561 let results: Vec<HybridSearchResult> = fused
563 .into_iter()
564 .map(|doc| HybridSearchResult {
565 id: doc.id,
566 score: doc.fused_score,
567 content: doc.content,
568 metadata: doc.metadata,
569 vector_score: doc.vector_score,
570 lexical_score: doc.lexical_score,
571 })
572 .collect();
573
574 Ok(HybridQueryResult {
575 results,
576 query: query.clone(),
577 stats: HybridQueryStats {
578 vector_candidates: 0, lexical_candidates: 0,
580 filtered_candidates: 0,
581 fusion_time_us: 0,
582 rerank_time_us: 0,
583 },
584 })
585 }
586
587 fn matches_filters(&self, doc: &CandidateDoc, filters: &[MetadataFilter]) -> bool {
589 for filter in filters {
590 if let Some(value) = doc.metadata.get(&filter.field) {
591 if !self.match_filter(value, &filter.op, &filter.value) {
592 return false;
593 }
594 } else {
595 return false;
597 }
598 }
599 true
600 }
601
602 fn match_filter(&self, doc_value: &SochValue, op: &FilterOp, filter_value: &SochValue) -> bool {
604 match op {
605 FilterOp::Eq => doc_value == filter_value,
606 FilterOp::Ne => doc_value != filter_value,
607 FilterOp::Gt => self.compare_values(doc_value, filter_value) == Some(Ordering::Greater),
608 FilterOp::Gte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Greater | Ordering::Equal)),
609 FilterOp::Lt => self.compare_values(doc_value, filter_value) == Some(Ordering::Less),
610 FilterOp::Lte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Less | Ordering::Equal)),
611 FilterOp::Contains => self.value_contains(doc_value, filter_value),
612 FilterOp::In => self.value_in_set(doc_value, filter_value),
613 }
614 }
615
616 fn compare_values(&self, a: &SochValue, b: &SochValue) -> Option<Ordering> {
618 match (a, b) {
619 (SochValue::Int(a), SochValue::Int(b)) => Some(a.cmp(b)),
620 (SochValue::UInt(a), SochValue::UInt(b)) => Some(a.cmp(b)),
621 (SochValue::Float(a), SochValue::Float(b)) => a.partial_cmp(b),
622 (SochValue::Text(a), SochValue::Text(b)) => Some(a.cmp(b)),
623 _ => None,
624 }
625 }
626
627 fn value_contains(&self, doc_value: &SochValue, search_value: &SochValue) -> bool {
629 match (doc_value, search_value) {
630 (SochValue::Text(text), SochValue::Text(search)) => text.contains(search.as_str()),
631 (SochValue::Array(arr), _) => arr.contains(search_value),
632 _ => false,
633 }
634 }
635
636 fn value_in_set(&self, doc_value: &SochValue, set_value: &SochValue) -> bool {
638 if let SochValue::Array(arr) = set_value {
639 arr.contains(doc_value)
640 } else {
641 false
642 }
643 }
644
645 fn fuse_scores(
647 &self,
648 candidates: Vec<CandidateDoc>,
649 query: &HybridQuery,
650 ) -> Result<Vec<CandidateDoc>, HybridQueryError> {
651 let vector_weight = query.vector.as_ref().map(|v| v.weight).unwrap_or(0.0);
652 let lexical_weight = query.lexical.as_ref().map(|l| l.weight).unwrap_or(0.0);
653
654 let mut fused = candidates;
655
656 match query.fusion.method {
657 FusionMethod::Rrf => {
658 for doc in &mut fused {
661 let mut score = 0.0;
662
663 if let Some(rank) = doc.vector_rank {
664 score += vector_weight / (query.fusion.rrf_k + rank as f32);
665 }
666
667 if let Some(rank) = doc.lexical_rank {
668 score += lexical_weight / (query.fusion.rrf_k + rank as f32);
669 }
670
671 doc.fused_score = score;
672 }
673 }
674
675 FusionMethod::WeightedSum => {
676 for doc in &mut fused {
678 let mut score = 0.0;
679
680 if let Some(s) = doc.vector_score {
681 score += vector_weight * s;
682 }
683
684 if let Some(s) = doc.lexical_score {
685 score += lexical_weight * s;
686 }
687
688 doc.fused_score = score;
689 }
690 }
691
692 FusionMethod::Max => {
693 for doc in &mut fused {
695 let v_score = doc.vector_score.map(|s| vector_weight * s).unwrap_or(0.0);
696 let l_score = doc.lexical_score.map(|s| lexical_weight * s).unwrap_or(0.0);
697 doc.fused_score = v_score.max(l_score);
698 }
699 }
700
701 FusionMethod::Rsf => {
702 for doc in &mut fused {
704 let mut score = 0.0;
705 let mut count = 0;
706
707 if let Some(s) = doc.vector_score {
708 score += s;
709 count += 1;
710 }
711
712 if let Some(s) = doc.lexical_score {
713 score += s;
714 count += 1;
715 }
716
717 doc.fused_score = if count > 0 { score / count as f32 } else { 0.0 };
718 }
719 }
720 }
721
722 Ok(fused)
723 }
724
725 fn rerank(
727 &self,
728 candidates: &[CandidateDoc],
729 query: &str,
730 config: &RerankConfig,
731 ) -> Result<Vec<CandidateDoc>, HybridQueryError> {
732 let to_rerank: Vec<_> = candidates.iter().take(config.top_n).cloned().collect();
734
735 let mut reranked = to_rerank;
738 let query_terms: HashSet<&str> = query.split_whitespace().collect();
739
740 for doc in &mut reranked {
741 let content_terms: HashSet<&str> = doc.content.split_whitespace().collect();
742 let overlap = query_terms.intersection(&content_terms).count();
743
744 doc.fused_score += (overlap as f32) * 0.01;
746 }
747
748 reranked.extend(candidates.iter().skip(config.top_n).cloned());
750
751 Ok(reranked)
752 }
753}
754
755#[derive(Debug, Clone)]
757struct CandidateDoc {
758 id: String,
759 content: String,
760 metadata: HashMap<String, SochValue>,
761 vector_rank: Option<usize>,
762 vector_score: Option<f32>,
763 lexical_rank: Option<usize>,
764 lexical_score: Option<f32>,
765 fused_score: f32,
766}
767
768pub struct LexicalIndex {
774 collections: std::sync::RwLock<HashMap<String, InvertedIndex>>,
776}
777
778struct InvertedIndex {
780 postings: HashMap<String, Vec<(String, u32)>>,
782
783 doc_lengths: HashMap<String, u32>,
785
786 documents: HashMap<String, String>,
788
789 avg_doc_len: f32,
791
792 k1: f32,
794 b: f32,
795}
796
797#[derive(Debug, Clone)]
799pub struct LexicalSearchResult {
800 pub id: String,
801 pub score: f32,
802 pub content: String,
803}
804
805impl LexicalIndex {
806 pub fn new() -> Self {
808 Self {
809 collections: std::sync::RwLock::new(HashMap::new()),
810 }
811 }
812
813 pub fn create_collection(&self, name: &str) {
815 let mut collections = self.collections.write().unwrap();
816 collections.insert(name.to_string(), InvertedIndex {
817 postings: HashMap::new(),
818 doc_lengths: HashMap::new(),
819 documents: HashMap::new(),
820 avg_doc_len: 0.0,
821 k1: 1.2,
822 b: 0.75,
823 });
824 }
825
826 pub fn index_document(&self, collection: &str, id: &str, content: &str) -> Result<(), HybridQueryError> {
828 let mut collections = self.collections.write().unwrap();
829 let index = collections.get_mut(collection)
830 .ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
831
832 let tokens: Vec<String> = content
834 .split_whitespace()
835 .map(|t| t.to_lowercase())
836 .collect();
837
838 let doc_len = tokens.len() as u32;
839
840 index.doc_lengths.insert(id.to_string(), doc_len);
842 index.documents.insert(id.to_string(), content.to_string());
843
844 let total_len: u32 = index.doc_lengths.values().sum();
846 index.avg_doc_len = total_len as f32 / index.doc_lengths.len() as f32;
847
848 let mut term_freqs: HashMap<String, u32> = HashMap::new();
850 for token in &tokens {
851 *term_freqs.entry(token.clone()).or_insert(0) += 1;
852 }
853
854 for (term, freq) in term_freqs {
856 index.postings
857 .entry(term)
858 .or_insert_with(Vec::new)
859 .push((id.to_string(), freq));
860 }
861
862 Ok(())
863 }
864
865 pub fn search(
867 &self,
868 collection: &str,
869 query: &str,
870 _fields: &[String],
871 limit: usize,
872 ) -> Result<Vec<LexicalSearchResult>, HybridQueryError> {
873 let collections = self.collections.read().unwrap();
874 let index = collections.get(collection)
875 .ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
876
877 let query_terms: Vec<String> = query
879 .split_whitespace()
880 .map(|t| t.to_lowercase())
881 .collect();
882
883 let n = index.doc_lengths.len() as f32;
884 let mut scores: HashMap<String, f32> = HashMap::new();
885
886 for term in &query_terms {
888 if let Some(postings) = index.postings.get(term) {
889 let df = postings.len() as f32;
890 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
891
892 for (doc_id, tf) in postings {
893 let doc_len = *index.doc_lengths.get(doc_id).unwrap_or(&1) as f32;
894 let tf = *tf as f32;
895
896 let score = idf * (tf * (index.k1 + 1.0)) /
898 (tf + index.k1 * (1.0 - index.b + index.b * doc_len / index.avg_doc_len));
899
900 *scores.entry(doc_id.clone()).or_insert(0.0) += score;
901 }
902 }
903 }
904
905 let mut results: Vec<_> = scores.into_iter().collect();
907 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
908
909 let results: Vec<LexicalSearchResult> = results
911 .into_iter()
912 .take(limit)
913 .map(|(id, score)| {
914 let content = index.documents.get(&id).cloned().unwrap_or_default();
915 LexicalSearchResult { id, score, content }
916 })
917 .collect();
918
919 Ok(results)
920 }
921}
922
923impl Default for LexicalIndex {
924 fn default() -> Self {
925 Self::new()
926 }
927}
928
929#[derive(Debug, Clone)]
935pub struct HybridSearchResult {
936 pub id: String,
938 pub score: f32,
940 pub content: String,
942 pub metadata: HashMap<String, SochValue>,
944 pub vector_score: Option<f32>,
946 pub lexical_score: Option<f32>,
948}
949
950#[derive(Debug, Clone)]
952pub struct HybridQueryResult {
953 pub results: Vec<HybridSearchResult>,
955 pub query: HybridQuery,
957 pub stats: HybridQueryStats,
959}
960
961#[derive(Debug, Clone, Default)]
963pub struct HybridQueryStats {
964 pub vector_candidates: usize,
966 pub lexical_candidates: usize,
968 pub filtered_candidates: usize,
970 pub fusion_time_us: u64,
972 pub rerank_time_us: u64,
974}
975
976#[derive(Debug, Clone)]
978pub enum HybridQueryError {
979 CollectionNotFound(String),
981 VectorSearchError(String),
983 LexicalSearchError(String),
985 FilterError(String),
987 RerankError(String),
989}
990
991impl std::fmt::Display for HybridQueryError {
992 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
993 match self {
994 Self::CollectionNotFound(name) => write!(f, "Collection not found: {}", name),
995 Self::VectorSearchError(msg) => write!(f, "Vector search error: {}", msg),
996 Self::LexicalSearchError(msg) => write!(f, "Lexical search error: {}", msg),
997 Self::FilterError(msg) => write!(f, "Filter error: {}", msg),
998 Self::RerankError(msg) => write!(f, "Rerank error: {}", msg),
999 }
1000 }
1001}
1002
1003impl std::error::Error for HybridQueryError {}
1004
1005#[cfg(test)]
1010mod tests {
1011 use super::*;
1012
1013 #[test]
1014 fn test_hybrid_query_builder() {
1015 let query = HybridQuery::new("documents")
1016 .with_vector(vec![0.1, 0.2, 0.3], 0.7)
1017 .with_lexical("search query", 0.3)
1018 .filter_eq("category", SochValue::Text("tech".to_string()))
1019 .with_fusion(FusionMethod::Rrf)
1020 .with_rerank("cross-encoder", 20)
1021 .limit(10);
1022
1023 assert_eq!(query.collection, "documents");
1024 assert!(query.vector.is_some());
1025 assert!(query.lexical.is_some());
1026 assert_eq!(query.filters.len(), 1);
1027 assert_eq!(query.limit, 10);
1028 }
1029
1030 #[test]
1031 fn test_lexical_index_bm25() {
1032 let index = LexicalIndex::new();
1033 index.create_collection("test");
1034
1035 index.index_document("test", "doc1", "the quick brown fox").unwrap();
1036 index.index_document("test", "doc2", "the lazy dog sleeps").unwrap();
1037 index.index_document("test", "doc3", "quick fox jumps over the lazy dog").unwrap();
1038
1039 let results = index.search("test", "quick fox", &[], 10).unwrap();
1040
1041 assert!(!results.is_empty());
1042 let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
1044 assert!(ids.contains(&"doc1") || ids.contains(&"doc3"));
1045 assert!(!ids.contains(&"doc2"));
1047 }
1048
1049 #[test]
1050 fn test_rrf_fusion() {
1051 let k = 60.0;
1053
1054 let vector_weight = 0.7;
1056 let lexical_weight = 0.3;
1057
1058 let score = vector_weight / (k + 0.0) + lexical_weight / (k + 5.0);
1059
1060 assert!(score > 0.01 && score < 0.02);
1062 }
1063
1064 #[test]
1065 fn test_filter_matching() {
1066 let filters = vec![
1067 MetadataFilter {
1068 field: "status".to_string(),
1069 op: FilterOp::Eq,
1070 value: SochValue::Text("active".to_string()),
1071 },
1072 MetadataFilter {
1073 field: "count".to_string(),
1074 op: FilterOp::Gte,
1075 value: SochValue::Int(10),
1076 },
1077 ];
1078
1079 let mut metadata = HashMap::new();
1080 metadata.insert("status".to_string(), SochValue::Text("active".to_string()));
1081 metadata.insert("count".to_string(), SochValue::Int(15));
1082
1083 let doc = CandidateDoc {
1085 id: "test".to_string(),
1086 content: "test content".to_string(),
1087 metadata,
1088 vector_rank: None,
1089 vector_score: None,
1090 lexical_rank: None,
1091 lexical_score: None,
1092 fused_score: 0.0,
1093 };
1094
1095 assert!(doc.metadata.get("status") == Some(&SochValue::Text("active".to_string())));
1097 if let Some(SochValue::Int(count)) = doc.metadata.get("count") {
1098 assert!(*count >= 10);
1099 }
1100 }
1101}