sochdb_query/
hybrid_retrieval.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Hybrid Retrieval Pipeline (Task 3)
16//!
17//! This module implements a unified hybrid query planner combining:
18//! - Vector similarity search (ANN)
19//! - Lexical search (BM25)
20//! - Metadata filtering (PRE-FILTER ONLY)
21//! - Score fusion (RRF)
22//! - Cross-encoder reranking
23//!
24//! ## CRITICAL INVARIANT: No Post-Filtering
25//!
26//! This module enforces a hard security invariant:
27//! 
28//! > **All filtering MUST occur during candidate generation, never after.**
29//!
30//! The `ExecutionStep::PostFilter` variant has been intentionally removed.
31//! This guarantees:
32//! 1. **Security by construction**: No leakage of filtered documents
33//! 2. **No wasted compute**: We never score disallowed documents
34//! 3. **Monotone property**: `result-set ⊆ allowed-set` (verifiable)
35//!
36//! ## Correct Pattern
37//!
38//! ```text
39//! FilterIR + AuthScope → AllowedSet (computed once)
40//!     ↓
41//! vector_search(query, AllowedSet) → filtered candidates
42//! bm25_search(query, AllowedSet)   → filtered candidates
43//!     ↓
44//! fusion(filtered_v, filtered_b)   → already correct!
45//!     ↓
46//! rerank, limit                    → final results
47//! ```
48//!
49//! ## Anti-Pattern (What We Prevent)
50//!
51//! ```text
52//! BAD: vector_search() → candidates → filter → too few/leaky
53//!      bm25_search()   → candidates → filter → inconsistent
54//!      fusion()        → filter at end → SECURITY RISK!
55//! ```
56//!
57//! ## Execution Plan
58//!
59//! ```text
60//! HybridQuery
61//!     │
62//!     ▼
63//! ┌─────────────────────────────────────────┐
64//! │              ExecutionPlan              │
65//! │  ┌─────────┐ ┌─────────┐ ┌──────────┐  │
66//! │  │ Vector  │ │  BM25   │ │  Filter  │  │
67//! │  │ Search  │ │ Search  │ │ (PRE-ONLY)│  │
68//! │  └────┬────┘ └────┬────┘ └────┬─────┘  │
69//! │       │           │           │        │
70//! │       └─────┬─────┘           │        │
71//! │             ▼                 │        │
72//! │       ┌─────────┐             │        │
73//! │       │  Fusion │◄────────────┘        │
74//! │       │  (RRF)  │                      │
75//! │       └────┬────┘                      │
76//! │            ▼                           │
77//! │       ┌─────────┐                      │
78//! │       │ Rerank  │                      │
79//! │       └────┬────┘                      │
80//! │            ▼                           │
81//! │       ┌─────────┐                      │
82//! │       │  Limit  │                      │
83//! │       └─────────┘                      │
84//! └─────────────────────────────────────────┘
85//! ```
86//!
87//! ## Scoring
88//!
89//! RRF fusion: `score(d) = Σ w_i / (k + rank_i(d))`
90//! where k is typically 60 (robust default)
91
92use std::collections::{HashMap, HashSet};
93use std::cmp::Ordering;
94use std::sync::Arc;
95
96use crate::context_query::VectorIndex;
97use crate::soch_ql::SochValue;
98
99// ============================================================================
100// Hybrid Query Builder
101// ============================================================================
102
103/// Builder for hybrid retrieval queries
104#[derive(Debug, Clone)]
105pub struct HybridQuery {
106    /// Collection to search
107    pub collection: String,
108    
109    /// Vector search component
110    pub vector: Option<VectorQueryComponent>,
111    
112    /// Lexical (BM25) search component
113    pub lexical: Option<LexicalQueryComponent>,
114    
115    /// Metadata filters
116    pub filters: Vec<MetadataFilter>,
117    
118    /// Fusion configuration
119    pub fusion: FusionConfig,
120    
121    /// Reranking configuration
122    pub rerank: Option<RerankConfig>,
123    
124    /// Result limit
125    pub limit: usize,
126    
127    /// Minimum score threshold
128    pub min_score: Option<f32>,
129}
130
131impl HybridQuery {
132    /// Create a new hybrid query builder
133    pub fn new(collection: &str) -> Self {
134        Self {
135            collection: collection.to_string(),
136            vector: None,
137            lexical: None,
138            filters: Vec::new(),
139            fusion: FusionConfig::default(),
140            rerank: None,
141            limit: 10,
142            min_score: None,
143        }
144    }
145    
146    /// Add vector search component
147    pub fn with_vector(mut self, embedding: Vec<f32>, weight: f32) -> Self {
148        self.vector = Some(VectorQueryComponent {
149            embedding,
150            weight,
151            ef_search: 100,
152        });
153        self
154    }
155    
156    /// Add vector search from text (requires embedding provider)
157    pub fn with_vector_text(mut self, text: String, weight: f32) -> Self {
158        self.vector = Some(VectorQueryComponent {
159            embedding: Vec::new(), // Will be resolved at execution time
160            weight,
161            ef_search: 100,
162        });
163        // Store text for later resolution
164        self.lexical = self.lexical.or(Some(LexicalQueryComponent {
165            query: text,
166            weight: 0.0, // Text stored but not used for lexical
167            fields: vec!["content".to_string()],
168        }));
169        self
170    }
171    
172    /// Add lexical (BM25) search component
173    pub fn with_lexical(mut self, query: &str, weight: f32) -> Self {
174        self.lexical = Some(LexicalQueryComponent {
175            query: query.to_string(),
176            weight,
177            fields: vec!["content".to_string()],
178        });
179        self
180    }
181    
182    /// Add lexical search with specific fields
183    pub fn with_lexical_fields(mut self, query: &str, weight: f32, fields: Vec<String>) -> Self {
184        self.lexical = Some(LexicalQueryComponent {
185            query: query.to_string(),
186            weight,
187            fields,
188        });
189        self
190    }
191    
192    /// Add metadata filter
193    pub fn filter(mut self, field: &str, op: FilterOp, value: SochValue) -> Self {
194        self.filters.push(MetadataFilter {
195            field: field.to_string(),
196            op,
197            value,
198        });
199        self
200    }
201    
202    /// Add equality filter
203    pub fn filter_eq(self, field: &str, value: impl Into<SochValue>) -> Self {
204        self.filter(field, FilterOp::Eq, value.into())
205    }
206    
207    /// Add range filter
208    pub fn filter_range(mut self, field: &str, min: Option<SochValue>, max: Option<SochValue>) -> Self {
209        if let Some(min_val) = min {
210            self.filters.push(MetadataFilter {
211                field: field.to_string(),
212                op: FilterOp::Gte,
213                value: min_val,
214            });
215        }
216        if let Some(max_val) = max {
217            self.filters.push(MetadataFilter {
218                field: field.to_string(),
219                op: FilterOp::Lte,
220                value: max_val,
221            });
222        }
223        self
224    }
225    
226    /// Set fusion method
227    pub fn with_fusion(mut self, method: FusionMethod) -> Self {
228        self.fusion.method = method;
229        self
230    }
231    
232    /// Set RRF k parameter
233    pub fn with_rrf_k(mut self, k: f32) -> Self {
234        self.fusion.rrf_k = k;
235        self
236    }
237    
238    /// Enable reranking
239    pub fn with_rerank(mut self, model: &str, top_n: usize) -> Self {
240        self.rerank = Some(RerankConfig {
241            model: model.to_string(),
242            top_n,
243            batch_size: 32,
244        });
245        self
246    }
247    
248    /// Set result limit
249    pub fn limit(mut self, limit: usize) -> Self {
250        self.limit = limit;
251        self
252    }
253    
254    /// Set minimum score threshold
255    pub fn min_score(mut self, score: f32) -> Self {
256        self.min_score = Some(score);
257        self
258    }
259}
260
261/// Vector search component
262#[derive(Debug, Clone)]
263pub struct VectorQueryComponent {
264    /// Query embedding
265    pub embedding: Vec<f32>,
266    /// Weight for fusion
267    pub weight: f32,
268    /// HNSW ef_search parameter
269    pub ef_search: usize,
270}
271
272/// Lexical search component
273#[derive(Debug, Clone)]
274pub struct LexicalQueryComponent {
275    /// Query text
276    pub query: String,
277    /// Weight for fusion
278    pub weight: f32,
279    /// Fields to search
280    pub fields: Vec<String>,
281}
282
283/// Metadata filter
284#[derive(Debug, Clone)]
285pub struct MetadataFilter {
286    /// Field name
287    pub field: String,
288    /// Comparison operator
289    pub op: FilterOp,
290    /// Value to compare
291    pub value: SochValue,
292}
293
294/// Filter comparison operators
295#[derive(Debug, Clone, Copy, PartialEq, Eq)]
296pub enum FilterOp {
297    /// Equal
298    Eq,
299    /// Not equal
300    Ne,
301    /// Greater than
302    Gt,
303    /// Greater than or equal
304    Gte,
305    /// Less than
306    Lt,
307    /// Less than or equal
308    Lte,
309    /// Contains (for arrays/strings)
310    Contains,
311    /// In set
312    In,
313}
314
315/// Fusion configuration
316#[derive(Debug, Clone)]
317pub struct FusionConfig {
318    /// Fusion method
319    pub method: FusionMethod,
320    /// RRF k parameter (default: 60)
321    pub rrf_k: f32,
322    /// Normalize scores before fusion
323    pub normalize: bool,
324}
325
326impl Default for FusionConfig {
327    fn default() -> Self {
328        Self {
329            method: FusionMethod::Rrf,
330            rrf_k: 60.0,
331            normalize: true,
332        }
333    }
334}
335
336/// Score fusion methods
337#[derive(Debug, Clone, Copy, PartialEq, Eq)]
338pub enum FusionMethod {
339    /// Reciprocal Rank Fusion
340    Rrf,
341    /// Weighted sum of normalized scores
342    WeightedSum,
343    /// Max score from any source
344    Max,
345    /// Relative score fusion
346    Rsf,
347}
348
349/// Reranking configuration
350#[derive(Debug, Clone)]
351pub struct RerankConfig {
352    /// Reranker model
353    pub model: String,
354    /// Number of top candidates to rerank
355    pub top_n: usize,
356    /// Batch size for reranking
357    pub batch_size: usize,
358}
359
360// ============================================================================
361// Execution Plan
362// ============================================================================
363
364/// Execution plan for hybrid query
365#[derive(Debug, Clone)]
366pub struct HybridExecutionPlan {
367    /// Query being executed
368    pub query: HybridQuery,
369    
370    /// Execution steps
371    pub steps: Vec<ExecutionStep>,
372    
373    /// Estimated cost
374    pub estimated_cost: f64,
375}
376
377/// Individual execution step
378#[derive(Debug, Clone)]
379pub enum ExecutionStep {
380    /// Vector similarity search
381    VectorSearch {
382        collection: String,
383        ef_search: usize,
384        weight: f32,
385    },
386    
387    /// Lexical (BM25) search
388    LexicalSearch {
389        collection: String,
390        query: String,
391        fields: Vec<String>,
392        weight: f32,
393    },
394    
395    /// Pre-filter (before retrieval) - REQUIRED for security
396    /// 
397    /// This is the ONLY allowed filter step. Filters are always applied
398    /// during candidate generation via AllowedSet, never after.
399    PreFilter {
400        filters: Vec<MetadataFilter>,
401    },
402    
403    // NOTE: PostFilter has been REMOVED by design.
404    // The "no post-filtering" invariant is a hard security requirement.
405    // All filtering must happen via PreFilter -> AllowedSet -> candidate generation.
406    // See unified_fusion.rs for the correct pattern.
407    
408    /// Score fusion
409    Fusion {
410        method: FusionMethod,
411        rrf_k: f32,
412    },
413    
414    /// Reranking (does NOT filter, only re-orders)
415    Rerank {
416        model: String,
417        top_n: usize,
418    },
419    
420    /// Limit results (applied AFTER all filtering is complete)
421    Limit {
422        count: usize,
423        min_score: Option<f32>,
424    },
425    
426    /// Redaction transform (post-retrieval modification, NOT filtering)
427    /// 
428    /// Unlike filtering (which removes candidates), redaction transforms
429    /// the content of already-allowed documents. This preserves the
430    /// invariant: result-set ⊆ allowed-set.
431    Redact {
432        /// Fields to redact
433        fields: Vec<String>,
434        /// Redaction method
435        method: RedactionMethod,
436    },
437}
438
439/// Redaction methods for post-retrieval content transformation
440#[derive(Debug, Clone)]
441pub enum RedactionMethod {
442    /// Replace with a fixed string
443    Replace(String),
444    /// Mask with asterisks
445    Mask,
446    /// Remove the field entirely
447    Remove,
448    /// Hash the value
449    Hash,
450}
451
452// ============================================================================
453// Hybrid Query Executor
454// ============================================================================
455
456/// Executor for hybrid queries
457pub struct HybridQueryExecutor<V: VectorIndex> {
458    /// Vector index
459    vector_index: Arc<V>,
460    
461    /// Lexical index (BM25)
462    lexical_index: Arc<LexicalIndex>,
463}
464
465impl<V: VectorIndex> HybridQueryExecutor<V> {
466    /// Create a new executor
467    pub fn new(vector_index: Arc<V>, lexical_index: Arc<LexicalIndex>) -> Self {
468        Self {
469            vector_index,
470            lexical_index,
471        }
472    }
473    
474    /// Execute a hybrid query
475    pub fn execute(&self, query: &HybridQuery) -> Result<HybridQueryResult, HybridQueryError> {
476        let mut candidates: HashMap<String, CandidateDoc> = HashMap::new();
477        
478        // Over-fetch factor for fusion
479        let overfetch = (query.limit * 3).max(100);
480        
481        // Execute vector search
482        if let Some(vector) = &query.vector {
483            if !vector.embedding.is_empty() {
484                let results = self.vector_index
485                    .search_by_embedding(&query.collection, &vector.embedding, overfetch, None)
486                    .map_err(HybridQueryError::VectorSearchError)?;
487                
488                for (rank, result) in results.iter().enumerate() {
489                    let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
490                        CandidateDoc {
491                            id: result.id.clone(),
492                            content: result.content.clone(),
493                            metadata: result.metadata.clone(),
494                            vector_rank: None,
495                            vector_score: None,
496                            lexical_rank: None,
497                            lexical_score: None,
498                            fused_score: 0.0,
499                        }
500                    });
501                    entry.vector_rank = Some(rank);
502                    entry.vector_score = Some(result.score);
503                }
504            }
505        }
506        
507        // Execute lexical search
508        if let Some(lexical) = &query.lexical {
509            if lexical.weight > 0.0 {
510                let results = self.lexical_index.search(
511                    &query.collection,
512                    &lexical.query,
513                    &lexical.fields,
514                    overfetch,
515                )?;
516                
517                for (rank, result) in results.iter().enumerate() {
518                    let entry = candidates.entry(result.id.clone()).or_insert_with(|| {
519                        CandidateDoc {
520                            id: result.id.clone(),
521                            content: result.content.clone(),
522                            metadata: HashMap::new(),
523                            vector_rank: None,
524                            vector_score: None,
525                            lexical_rank: None,
526                            lexical_score: None,
527                            fused_score: 0.0,
528                        }
529                    });
530                    entry.lexical_rank = Some(rank);
531                    entry.lexical_score = Some(result.score);
532                }
533            }
534        }
535        
536        // Apply filters
537        let filtered: Vec<CandidateDoc> = candidates
538            .into_values()
539            .filter(|doc| self.matches_filters(doc, &query.filters))
540            .collect();
541        
542        // Fuse scores
543        let mut fused = self.fuse_scores(filtered, query)?;
544        
545        // Sort by fused score (descending)
546        fused.sort_by(|a, b| b.fused_score.partial_cmp(&a.fused_score).unwrap_or(Ordering::Equal));
547        
548        // Apply reranking (stub - would call reranker model)
549        if let Some(rerank) = &query.rerank {
550            fused = self.rerank(&fused, &query.lexical.as_ref().map(|l| l.query.clone()).unwrap_or_default(), rerank)?;
551        }
552        
553        // Apply min_score filter
554        if let Some(min) = query.min_score {
555            fused.retain(|doc| doc.fused_score >= min);
556        }
557        
558        // Limit results
559        fused.truncate(query.limit);
560        
561        // Convert to results
562        let results: Vec<HybridSearchResult> = fused
563            .into_iter()
564            .map(|doc| HybridSearchResult {
565                id: doc.id,
566                score: doc.fused_score,
567                content: doc.content,
568                metadata: doc.metadata,
569                vector_score: doc.vector_score,
570                lexical_score: doc.lexical_score,
571            })
572            .collect();
573        
574        Ok(HybridQueryResult {
575            results,
576            query: query.clone(),
577            stats: HybridQueryStats {
578                vector_candidates: 0, // Would be populated in real impl
579                lexical_candidates: 0,
580                filtered_candidates: 0,
581                fusion_time_us: 0,
582                rerank_time_us: 0,
583            },
584        })
585    }
586    
587    /// Check if document matches all filters
588    fn matches_filters(&self, doc: &CandidateDoc, filters: &[MetadataFilter]) -> bool {
589        for filter in filters {
590            if let Some(value) = doc.metadata.get(&filter.field) {
591                if !self.match_filter(value, &filter.op, &filter.value) {
592                    return false;
593                }
594            } else {
595                // Field not present - filter fails
596                return false;
597            }
598        }
599        true
600    }
601    
602    /// Match a single filter
603    fn match_filter(&self, doc_value: &SochValue, op: &FilterOp, filter_value: &SochValue) -> bool {
604        match op {
605            FilterOp::Eq => doc_value == filter_value,
606            FilterOp::Ne => doc_value != filter_value,
607            FilterOp::Gt => self.compare_values(doc_value, filter_value) == Some(Ordering::Greater),
608            FilterOp::Gte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Greater | Ordering::Equal)),
609            FilterOp::Lt => self.compare_values(doc_value, filter_value) == Some(Ordering::Less),
610            FilterOp::Lte => matches!(self.compare_values(doc_value, filter_value), Some(Ordering::Less | Ordering::Equal)),
611            FilterOp::Contains => self.value_contains(doc_value, filter_value),
612            FilterOp::In => self.value_in_set(doc_value, filter_value),
613        }
614    }
615    
616    /// Compare two SochValues
617    fn compare_values(&self, a: &SochValue, b: &SochValue) -> Option<Ordering> {
618        match (a, b) {
619            (SochValue::Int(a), SochValue::Int(b)) => Some(a.cmp(b)),
620            (SochValue::UInt(a), SochValue::UInt(b)) => Some(a.cmp(b)),
621            (SochValue::Float(a), SochValue::Float(b)) => a.partial_cmp(b),
622            (SochValue::Text(a), SochValue::Text(b)) => Some(a.cmp(b)),
623            _ => None,
624        }
625    }
626    
627    /// Check if value contains another
628    fn value_contains(&self, doc_value: &SochValue, search_value: &SochValue) -> bool {
629        match (doc_value, search_value) {
630            (SochValue::Text(text), SochValue::Text(search)) => text.contains(search.as_str()),
631            (SochValue::Array(arr), _) => arr.contains(search_value),
632            _ => false,
633        }
634    }
635    
636    /// Check if value is in set
637    fn value_in_set(&self, doc_value: &SochValue, set_value: &SochValue) -> bool {
638        if let SochValue::Array(arr) = set_value {
639            arr.contains(doc_value)
640        } else {
641            false
642        }
643    }
644    
645    /// Fuse scores from multiple sources
646    fn fuse_scores(
647        &self,
648        candidates: Vec<CandidateDoc>,
649        query: &HybridQuery,
650    ) -> Result<Vec<CandidateDoc>, HybridQueryError> {
651        let vector_weight = query.vector.as_ref().map(|v| v.weight).unwrap_or(0.0);
652        let lexical_weight = query.lexical.as_ref().map(|l| l.weight).unwrap_or(0.0);
653        
654        let mut fused = candidates;
655        
656        match query.fusion.method {
657            FusionMethod::Rrf => {
658                // Reciprocal Rank Fusion
659                // score(d) = Σ w_i / (k + rank_i(d))
660                for doc in &mut fused {
661                    let mut score = 0.0;
662                    
663                    if let Some(rank) = doc.vector_rank {
664                        score += vector_weight / (query.fusion.rrf_k + rank as f32);
665                    }
666                    
667                    if let Some(rank) = doc.lexical_rank {
668                        score += lexical_weight / (query.fusion.rrf_k + rank as f32);
669                    }
670                    
671                    doc.fused_score = score;
672                }
673            }
674            
675            FusionMethod::WeightedSum => {
676                // Weighted sum of normalized scores
677                for doc in &mut fused {
678                    let mut score = 0.0;
679                    
680                    if let Some(s) = doc.vector_score {
681                        score += vector_weight * s;
682                    }
683                    
684                    if let Some(s) = doc.lexical_score {
685                        score += lexical_weight * s;
686                    }
687                    
688                    doc.fused_score = score;
689                }
690            }
691            
692            FusionMethod::Max => {
693                // Maximum score from any source
694                for doc in &mut fused {
695                    let v_score = doc.vector_score.map(|s| vector_weight * s).unwrap_or(0.0);
696                    let l_score = doc.lexical_score.map(|s| lexical_weight * s).unwrap_or(0.0);
697                    doc.fused_score = v_score.max(l_score);
698                }
699            }
700            
701            FusionMethod::Rsf => {
702                // Relative Score Fusion (simplified)
703                for doc in &mut fused {
704                    let mut score = 0.0;
705                    let mut count = 0;
706                    
707                    if let Some(s) = doc.vector_score {
708                        score += s;
709                        count += 1;
710                    }
711                    
712                    if let Some(s) = doc.lexical_score {
713                        score += s;
714                        count += 1;
715                    }
716                    
717                    doc.fused_score = if count > 0 { score / count as f32 } else { 0.0 };
718                }
719            }
720        }
721        
722        Ok(fused)
723    }
724    
725    /// Rerank candidates using cross-encoder (stub)
726    fn rerank(
727        &self,
728        candidates: &[CandidateDoc],
729        query: &str,
730        config: &RerankConfig,
731    ) -> Result<Vec<CandidateDoc>, HybridQueryError> {
732        // Take top_n candidates for reranking
733        let to_rerank: Vec<_> = candidates.iter().take(config.top_n).cloned().collect();
734        
735        // Stub: In production, would call cross-encoder model
736        // For now, just apply a small boost based on query term overlap
737        let mut reranked = to_rerank;
738        let query_terms: HashSet<&str> = query.split_whitespace().collect();
739        
740        for doc in &mut reranked {
741            let content_terms: HashSet<&str> = doc.content.split_whitespace().collect();
742            let overlap = query_terms.intersection(&content_terms).count();
743            
744            // Small boost for term overlap
745            doc.fused_score += (overlap as f32) * 0.01;
746        }
747        
748        // Add remaining candidates unchanged
749        reranked.extend(candidates.iter().skip(config.top_n).cloned());
750        
751        Ok(reranked)
752    }
753}
754
755/// Internal candidate document during processing
756#[derive(Debug, Clone)]
757struct CandidateDoc {
758    id: String,
759    content: String,
760    metadata: HashMap<String, SochValue>,
761    vector_rank: Option<usize>,
762    vector_score: Option<f32>,
763    lexical_rank: Option<usize>,
764    lexical_score: Option<f32>,
765    fused_score: f32,
766}
767
768// ============================================================================
769// Lexical Index (BM25)
770// ============================================================================
771
772/// Simple lexical (BM25) index
773pub struct LexicalIndex {
774    /// Collections: name -> inverted index
775    collections: std::sync::RwLock<HashMap<String, InvertedIndex>>,
776}
777
778/// Inverted index for a collection
779struct InvertedIndex {
780    /// Term -> posting list (doc_id, term_freq)
781    postings: HashMap<String, Vec<(String, u32)>>,
782    
783    /// Document lengths
784    doc_lengths: HashMap<String, u32>,
785    
786    /// Document contents
787    documents: HashMap<String, String>,
788    
789    /// Average document length
790    avg_doc_len: f32,
791    
792    /// BM25 parameters
793    k1: f32,
794    b: f32,
795}
796
797/// Lexical search result
798#[derive(Debug, Clone)]
799pub struct LexicalSearchResult {
800    pub id: String,
801    pub score: f32,
802    pub content: String,
803}
804
805impl LexicalIndex {
806    /// Create a new lexical index
807    pub fn new() -> Self {
808        Self {
809            collections: std::sync::RwLock::new(HashMap::new()),
810        }
811    }
812    
813    /// Create collection
814    pub fn create_collection(&self, name: &str) {
815        let mut collections = self.collections.write().unwrap();
816        collections.insert(name.to_string(), InvertedIndex {
817            postings: HashMap::new(),
818            doc_lengths: HashMap::new(),
819            documents: HashMap::new(),
820            avg_doc_len: 0.0,
821            k1: 1.2,
822            b: 0.75,
823        });
824    }
825    
826    /// Index a document
827    pub fn index_document(&self, collection: &str, id: &str, content: &str) -> Result<(), HybridQueryError> {
828        let mut collections = self.collections.write().unwrap();
829        let index = collections.get_mut(collection)
830            .ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
831        
832        // Tokenize
833        let tokens: Vec<String> = content
834            .split_whitespace()
835            .map(|t| t.to_lowercase())
836            .collect();
837        
838        let doc_len = tokens.len() as u32;
839        
840        // Update document length
841        index.doc_lengths.insert(id.to_string(), doc_len);
842        index.documents.insert(id.to_string(), content.to_string());
843        
844        // Update average doc length
845        let total_len: u32 = index.doc_lengths.values().sum();
846        index.avg_doc_len = total_len as f32 / index.doc_lengths.len() as f32;
847        
848        // Count term frequencies
849        let mut term_freqs: HashMap<String, u32> = HashMap::new();
850        for token in &tokens {
851            *term_freqs.entry(token.clone()).or_insert(0) += 1;
852        }
853        
854        // Update postings
855        for (term, freq) in term_freqs {
856            index.postings
857                .entry(term)
858                .or_insert_with(Vec::new)
859                .push((id.to_string(), freq));
860        }
861        
862        Ok(())
863    }
864    
865    /// Search using BM25
866    pub fn search(
867        &self,
868        collection: &str,
869        query: &str,
870        _fields: &[String],
871        limit: usize,
872    ) -> Result<Vec<LexicalSearchResult>, HybridQueryError> {
873        let collections = self.collections.read().unwrap();
874        let index = collections.get(collection)
875            .ok_or_else(|| HybridQueryError::CollectionNotFound(collection.to_string()))?;
876        
877        // Tokenize query
878        let query_terms: Vec<String> = query
879            .split_whitespace()
880            .map(|t| t.to_lowercase())
881            .collect();
882        
883        let n = index.doc_lengths.len() as f32;
884        let mut scores: HashMap<String, f32> = HashMap::new();
885        
886        // Calculate BM25 scores
887        for term in &query_terms {
888            if let Some(postings) = index.postings.get(term) {
889                let df = postings.len() as f32;
890                let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
891                
892                for (doc_id, tf) in postings {
893                    let doc_len = *index.doc_lengths.get(doc_id).unwrap_or(&1) as f32;
894                    let tf = *tf as f32;
895                    
896                    // BM25 formula
897                    let score = idf * (tf * (index.k1 + 1.0)) / 
898                        (tf + index.k1 * (1.0 - index.b + index.b * doc_len / index.avg_doc_len));
899                    
900                    *scores.entry(doc_id.clone()).or_insert(0.0) += score;
901                }
902            }
903        }
904        
905        // Sort by score
906        let mut results: Vec<_> = scores.into_iter().collect();
907        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
908        
909        // Convert to results
910        let results: Vec<LexicalSearchResult> = results
911            .into_iter()
912            .take(limit)
913            .map(|(id, score)| {
914                let content = index.documents.get(&id).cloned().unwrap_or_default();
915                LexicalSearchResult { id, score, content }
916            })
917            .collect();
918        
919        Ok(results)
920    }
921}
922
923impl Default for LexicalIndex {
924    fn default() -> Self {
925        Self::new()
926    }
927}
928
929// ============================================================================
930// Results
931// ============================================================================
932
933/// Hybrid search result
934#[derive(Debug, Clone)]
935pub struct HybridSearchResult {
936    /// Document ID
937    pub id: String,
938    /// Fused score
939    pub score: f32,
940    /// Document content
941    pub content: String,
942    /// Document metadata
943    pub metadata: HashMap<String, SochValue>,
944    /// Score from vector search (if any)
945    pub vector_score: Option<f32>,
946    /// Score from lexical search (if any)
947    pub lexical_score: Option<f32>,
948}
949
950/// Result of hybrid query execution
951#[derive(Debug, Clone)]
952pub struct HybridQueryResult {
953    /// Search results
954    pub results: Vec<HybridSearchResult>,
955    /// Original query
956    pub query: HybridQuery,
957    /// Execution statistics
958    pub stats: HybridQueryStats,
959}
960
961/// Execution statistics
962#[derive(Debug, Clone, Default)]
963pub struct HybridQueryStats {
964    /// Candidates from vector search
965    pub vector_candidates: usize,
966    /// Candidates from lexical search
967    pub lexical_candidates: usize,
968    /// Candidates after filtering
969    pub filtered_candidates: usize,
970    /// Fusion time in microseconds
971    pub fusion_time_us: u64,
972    /// Rerank time in microseconds
973    pub rerank_time_us: u64,
974}
975
976/// Hybrid query error
977#[derive(Debug, Clone)]
978pub enum HybridQueryError {
979    /// Collection not found
980    CollectionNotFound(String),
981    /// Vector search error
982    VectorSearchError(String),
983    /// Lexical search error
984    LexicalSearchError(String),
985    /// Filter error
986    FilterError(String),
987    /// Rerank error
988    RerankError(String),
989}
990
991impl std::fmt::Display for HybridQueryError {
992    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
993        match self {
994            Self::CollectionNotFound(name) => write!(f, "Collection not found: {}", name),
995            Self::VectorSearchError(msg) => write!(f, "Vector search error: {}", msg),
996            Self::LexicalSearchError(msg) => write!(f, "Lexical search error: {}", msg),
997            Self::FilterError(msg) => write!(f, "Filter error: {}", msg),
998            Self::RerankError(msg) => write!(f, "Rerank error: {}", msg),
999        }
1000    }
1001}
1002
1003impl std::error::Error for HybridQueryError {}
1004
1005// ============================================================================
1006// Tests
1007// ============================================================================
1008
1009#[cfg(test)]
1010mod tests {
1011    use super::*;
1012    
1013    #[test]
1014    fn test_hybrid_query_builder() {
1015        let query = HybridQuery::new("documents")
1016            .with_vector(vec![0.1, 0.2, 0.3], 0.7)
1017            .with_lexical("search query", 0.3)
1018            .filter_eq("category", SochValue::Text("tech".to_string()))
1019            .with_fusion(FusionMethod::Rrf)
1020            .with_rerank("cross-encoder", 20)
1021            .limit(10);
1022        
1023        assert_eq!(query.collection, "documents");
1024        assert!(query.vector.is_some());
1025        assert!(query.lexical.is_some());
1026        assert_eq!(query.filters.len(), 1);
1027        assert_eq!(query.limit, 10);
1028    }
1029    
1030    #[test]
1031    fn test_lexical_index_bm25() {
1032        let index = LexicalIndex::new();
1033        index.create_collection("test");
1034        
1035        index.index_document("test", "doc1", "the quick brown fox").unwrap();
1036        index.index_document("test", "doc2", "the lazy dog sleeps").unwrap();
1037        index.index_document("test", "doc3", "quick fox jumps over the lazy dog").unwrap();
1038        
1039        let results = index.search("test", "quick fox", &[], 10).unwrap();
1040        
1041        assert!(!results.is_empty());
1042        // doc1 and doc3 should both appear in results (they both have "quick" and/or "fox")
1043        let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
1044        assert!(ids.contains(&"doc1") || ids.contains(&"doc3"));
1045        // doc2 should not appear (no "quick" or "fox")
1046        assert!(!ids.contains(&"doc2"));
1047    }
1048    
1049    #[test]
1050    fn test_rrf_fusion() {
1051        // RRF formula: score = Σ w / (k + rank)
1052        let k = 60.0;
1053        
1054        // Doc appears at rank 0 in vector, rank 5 in lexical
1055        let vector_weight = 0.7;
1056        let lexical_weight = 0.3;
1057        
1058        let score = vector_weight / (k + 0.0) + lexical_weight / (k + 5.0);
1059        
1060        // Should be approximately 0.0116 + 0.0046 = 0.0162
1061        assert!(score > 0.01 && score < 0.02);
1062    }
1063    
1064    #[test]
1065    fn test_filter_matching() {
1066        let filters = vec![
1067            MetadataFilter {
1068                field: "status".to_string(),
1069                op: FilterOp::Eq,
1070                value: SochValue::Text("active".to_string()),
1071            },
1072            MetadataFilter {
1073                field: "count".to_string(),
1074                op: FilterOp::Gte,
1075                value: SochValue::Int(10),
1076            },
1077        ];
1078        
1079        let mut metadata = HashMap::new();
1080        metadata.insert("status".to_string(), SochValue::Text("active".to_string()));
1081        metadata.insert("count".to_string(), SochValue::Int(15));
1082        
1083        // Create a mock candidate
1084        let doc = CandidateDoc {
1085            id: "test".to_string(),
1086            content: "test content".to_string(),
1087            metadata,
1088            vector_rank: None,
1089            vector_score: None,
1090            lexical_rank: None,
1091            lexical_score: None,
1092            fused_score: 0.0,
1093        };
1094        
1095        // Would pass filters
1096        assert!(doc.metadata.get("status") == Some(&SochValue::Text("active".to_string())));
1097        if let Some(SochValue::Int(count)) = doc.metadata.get("count") {
1098            assert!(*count >= 10);
1099        }
1100    }
1101}