Skip to main content

hermes_core/query/
vector.rs

1//! Vector query types for dense and sparse vector search
2
3use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10/// Strategy for combining scores when a document has multiple values for the same field
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13    /// Sum all scores (accumulates dot product contributions)
14    Sum,
15    /// Take the maximum score
16    Max,
17    /// Take the average score
18    Avg,
19    /// Log-Sum-Exp: smooth maximum approximation (default)
20    /// `score = (1/t) * log(Σ exp(t * sᵢ))`
21    /// Higher temperature → closer to max; lower → closer to mean
22    LogSumExp {
23        /// Temperature parameter (default: 1.5)
24        temperature: f32,
25    },
26    /// Weighted Top-K: weight top scores with exponential decay
27    /// `score = Σ wᵢ * sorted_scores[i]` where `wᵢ = decay^i`
28    WeightedTopK {
29        /// Number of top scores to consider (default: 5)
30        k: usize,
31        /// Decay factor per rank (default: 0.7)
32        decay: f32,
33    },
34}
35
36impl Default for MultiValueCombiner {
37    fn default() -> Self {
38        // LogSumExp with temperature 1.5 provides good balance between
39        // max (best relevance) and sum (saturation from multiple matches)
40        MultiValueCombiner::LogSumExp { temperature: 1.5 }
41    }
42}
43
44impl MultiValueCombiner {
45    /// Create LogSumExp combiner with default temperature (1.5)
46    pub fn log_sum_exp() -> Self {
47        Self::LogSumExp { temperature: 1.5 }
48    }
49
50    /// Create LogSumExp combiner with custom temperature
51    pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52        Self::LogSumExp { temperature }
53    }
54
55    /// Create WeightedTopK combiner with defaults (k=5, decay=0.7)
56    pub fn weighted_top_k() -> Self {
57        Self::WeightedTopK { k: 5, decay: 0.7 }
58    }
59
60    /// Create WeightedTopK combiner with custom parameters
61    pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62        Self::WeightedTopK { k, decay }
63    }
64
65    /// Combine multiple scores into a single score
66    pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67        if scores.is_empty() {
68            return 0.0;
69        }
70
71        match self {
72            MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73            MultiValueCombiner::Max => scores
74                .iter()
75                .map(|(_, s)| *s)
76                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77                .unwrap_or(0.0),
78            MultiValueCombiner::Avg => {
79                let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80                sum / scores.len() as f32
81            }
82            MultiValueCombiner::LogSumExp { temperature } => {
83                // Numerically stable log-sum-exp:
84                // LSE(x) = max(x) + log(Σ exp(xᵢ - max(x)))
85                let t = *temperature;
86                let max_score = scores
87                    .iter()
88                    .map(|(_, s)| *s)
89                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90                    .unwrap_or(0.0);
91
92                let sum_exp: f32 = scores
93                    .iter()
94                    .map(|(_, s)| (t * (s - max_score)).exp())
95                    .sum();
96
97                max_score + sum_exp.ln() / t
98            }
99            MultiValueCombiner::WeightedTopK { k, decay } => {
100                // Sort scores descending and take top k
101                let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102                sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103                sorted.truncate(*k);
104
105                // Apply exponential decay weights
106                let mut weight = 1.0f32;
107                let mut weighted_sum = 0.0f32;
108                let mut weight_total = 0.0f32;
109
110                for score in sorted {
111                    weighted_sum += weight * score;
112                    weight_total += weight;
113                    weight *= decay;
114                }
115
116                if weight_total > 0.0 {
117                    weighted_sum / weight_total
118                } else {
119                    0.0
120                }
121            }
122        }
123    }
124}
125
126/// Dense vector query for similarity search
127#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129    /// Field containing the dense vectors
130    pub field: Field,
131    /// Query vector
132    pub vector: Vec<f32>,
133    /// Number of clusters to probe (for IVF indexes)
134    pub nprobe: usize,
135    /// Re-ranking factor (multiplied by k for candidate selection)
136    pub rerank_factor: usize,
137    /// How to combine scores for multi-valued documents
138    pub combiner: MultiValueCombiner,
139}
140
141impl DenseVectorQuery {
142    /// Create a new dense vector query
143    pub fn new(field: Field, vector: Vec<f32>) -> Self {
144        Self {
145            field,
146            vector,
147            nprobe: 32,
148            rerank_factor: 3,
149            combiner: MultiValueCombiner::Max,
150        }
151    }
152
153    /// Set the number of clusters to probe (for IVF indexes)
154    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
155        self.nprobe = nprobe;
156        self
157    }
158
159    /// Set the re-ranking factor
160    pub fn with_rerank_factor(mut self, factor: usize) -> Self {
161        self.rerank_factor = factor;
162        self
163    }
164
165    /// Set the multi-value score combiner
166    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
167        self.combiner = combiner;
168        self
169    }
170}
171
172impl Query for DenseVectorQuery {
173    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
174        let field = self.field;
175        let vector = self.vector.clone();
176        let rerank_factor = self.rerank_factor;
177        let combiner = self.combiner;
178        Box::pin(async move {
179            let results =
180                reader.search_dense_vector(field, &vector, limit, rerank_factor, combiner)?;
181
182            Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
183        })
184    }
185
186    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
187        Box::pin(async move { Ok(u32::MAX) })
188    }
189}
190
191/// Scorer for dense vector search results with ordinal tracking
192struct DenseVectorScorer {
193    results: Vec<VectorSearchResult>,
194    position: usize,
195    field_id: u32,
196}
197
198impl DenseVectorScorer {
199    fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
200        Self {
201            results,
202            position: 0,
203            field_id,
204        }
205    }
206}
207
208impl Scorer for DenseVectorScorer {
209    fn doc(&self) -> DocId {
210        if self.position < self.results.len() {
211            self.results[self.position].doc_id
212        } else {
213            TERMINATED
214        }
215    }
216
217    fn score(&self) -> Score {
218        if self.position < self.results.len() {
219            self.results[self.position].score
220        } else {
221            0.0
222        }
223    }
224
225    fn advance(&mut self) -> DocId {
226        self.position += 1;
227        self.doc()
228    }
229
230    fn seek(&mut self, target: DocId) -> DocId {
231        while self.doc() < target && self.doc() != TERMINATED {
232            self.advance();
233        }
234        self.doc()
235    }
236
237    fn size_hint(&self) -> u32 {
238        (self.results.len() - self.position) as u32
239    }
240
241    fn matched_positions(&self) -> Option<MatchedPositions> {
242        if self.position >= self.results.len() {
243            return None;
244        }
245        let result = &self.results[self.position];
246        let scored_positions: Vec<ScoredPosition> = result
247            .ordinals
248            .iter()
249            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
250            .collect();
251        Some(vec![(self.field_id, scored_positions)])
252    }
253}
254
255/// Sparse vector query for similarity search
256#[derive(Debug, Clone)]
257pub struct SparseVectorQuery {
258    /// Field containing the sparse vectors
259    pub field: Field,
260    /// Query vector as (dimension_id, weight) pairs
261    pub vector: Vec<(u32, f32)>,
262    /// How to combine scores for multi-valued documents
263    pub combiner: MultiValueCombiner,
264    /// Approximate search factor (1.0 = exact, lower values = faster but approximate)
265    /// Controls WAND pruning aggressiveness in block-max scoring
266    pub heap_factor: f32,
267}
268
269impl SparseVectorQuery {
270    /// Create a new sparse vector query
271    ///
272    /// Default combiner is `LogSumExp { temperature: 0.7 }` which provides
273    /// saturation for documents with many sparse vectors (e.g., 100+ ordinals).
274    /// This prevents over-weighting from multiple matches while still allowing
275    /// additional matches to contribute to the score.
276    pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
277        Self {
278            field,
279            vector,
280            combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
281            heap_factor: 1.0,
282        }
283    }
284
285    /// Set the multi-value score combiner
286    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
287        self.combiner = combiner;
288        self
289    }
290
291    /// Set the heap factor for approximate search
292    ///
293    /// Controls the trade-off between speed and recall:
294    /// - 1.0 = exact search (default)
295    /// - 0.8-0.9 = ~20-40% faster with minimal recall loss
296    /// - Lower values = more aggressive pruning, faster but lower recall
297    pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
298        self.heap_factor = heap_factor.clamp(0.0, 1.0);
299        self
300    }
301
302    /// Create from separate indices and weights vectors
303    pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
304        let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
305        Self::new(field, vector)
306    }
307
308    /// Create from raw text using a HuggingFace tokenizer (single segment)
309    ///
310    /// This method tokenizes the text and creates a sparse vector query.
311    /// For multi-segment indexes, use `from_text_with_stats` instead.
312    ///
313    /// # Arguments
314    /// * `field` - The sparse vector field to search
315    /// * `text` - Raw text to tokenize
316    /// * `tokenizer_name` - HuggingFace tokenizer path (e.g., "bert-base-uncased")
317    /// * `weighting` - Weighting strategy for tokens
318    /// * `sparse_index` - Optional sparse index for IDF lookup (required for IDF weighting)
319    #[cfg(feature = "native")]
320    pub fn from_text(
321        field: Field,
322        text: &str,
323        tokenizer_name: &str,
324        weighting: crate::structures::QueryWeighting,
325        sparse_index: Option<&crate::segment::SparseIndex>,
326    ) -> crate::Result<Self> {
327        use crate::structures::QueryWeighting;
328        use crate::tokenizer::tokenizer_cache;
329
330        let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
331        let token_ids = tokenizer.tokenize_unique(text)?;
332
333        let weights: Vec<f32> = match weighting {
334            QueryWeighting::One => vec![1.0f32; token_ids.len()],
335            QueryWeighting::Idf => {
336                if let Some(index) = sparse_index {
337                    index.idf_weights(&token_ids)
338                } else {
339                    vec![1.0f32; token_ids.len()]
340                }
341            }
342            QueryWeighting::IdfFile => {
343                use crate::tokenizer::idf_weights_cache;
344                if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
345                    token_ids.iter().map(|&id| idf.get(id)).collect()
346                } else {
347                    vec![1.0f32; token_ids.len()]
348                }
349            }
350        };
351
352        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
353        Ok(Self::new(field, vector))
354    }
355
356    /// Create from raw text using global statistics (multi-segment)
357    ///
358    /// This is the recommended method for multi-segment indexes as it uses
359    /// aggregated IDF values across all segments for consistent ranking.
360    ///
361    /// # Arguments
362    /// * `field` - The sparse vector field to search
363    /// * `text` - Raw text to tokenize
364    /// * `tokenizer` - Pre-loaded HuggingFace tokenizer
365    /// * `weighting` - Weighting strategy for tokens
366    /// * `global_stats` - Global statistics for IDF computation
367    #[cfg(feature = "native")]
368    pub fn from_text_with_stats(
369        field: Field,
370        text: &str,
371        tokenizer: &crate::tokenizer::HfTokenizer,
372        weighting: crate::structures::QueryWeighting,
373        global_stats: Option<&super::GlobalStats>,
374    ) -> crate::Result<Self> {
375        use crate::structures::QueryWeighting;
376
377        let token_ids = tokenizer.tokenize_unique(text)?;
378
379        let weights: Vec<f32> = match weighting {
380            QueryWeighting::One => vec![1.0f32; token_ids.len()],
381            QueryWeighting::Idf => {
382                if let Some(stats) = global_stats {
383                    // Clamp to zero: negative weights don't make sense for IDF
384                    stats
385                        .sparse_idf_weights(field, &token_ids)
386                        .into_iter()
387                        .map(|w| w.max(0.0))
388                        .collect()
389                } else {
390                    vec![1.0f32; token_ids.len()]
391                }
392            }
393            QueryWeighting::IdfFile => {
394                // IdfFile requires a tokenizer name for HF model lookup;
395                // this code path doesn't have one, so fall back to 1.0
396                vec![1.0f32; token_ids.len()]
397            }
398        };
399
400        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
401        Ok(Self::new(field, vector))
402    }
403
404    /// Create from raw text, loading tokenizer from index directory
405    ///
406    /// This method supports the `index://` prefix for tokenizer paths,
407    /// loading tokenizer.json from the index directory.
408    ///
409    /// # Arguments
410    /// * `field` - The sparse vector field to search
411    /// * `text` - Raw text to tokenize
412    /// * `tokenizer_bytes` - Tokenizer JSON bytes (pre-loaded from directory)
413    /// * `weighting` - Weighting strategy for tokens
414    /// * `global_stats` - Global statistics for IDF computation
415    #[cfg(feature = "native")]
416    pub fn from_text_with_tokenizer_bytes(
417        field: Field,
418        text: &str,
419        tokenizer_bytes: &[u8],
420        weighting: crate::structures::QueryWeighting,
421        global_stats: Option<&super::GlobalStats>,
422    ) -> crate::Result<Self> {
423        use crate::structures::QueryWeighting;
424        use crate::tokenizer::HfTokenizer;
425
426        let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
427        let token_ids = tokenizer.tokenize_unique(text)?;
428
429        let weights: Vec<f32> = match weighting {
430            QueryWeighting::One => vec![1.0f32; token_ids.len()],
431            QueryWeighting::Idf => {
432                if let Some(stats) = global_stats {
433                    // Clamp to zero: negative weights don't make sense for IDF
434                    stats
435                        .sparse_idf_weights(field, &token_ids)
436                        .into_iter()
437                        .map(|w| w.max(0.0))
438                        .collect()
439                } else {
440                    vec![1.0f32; token_ids.len()]
441                }
442            }
443            QueryWeighting::IdfFile => {
444                // IdfFile requires a tokenizer name for HF model lookup;
445                // this code path doesn't have one, so fall back to 1.0
446                vec![1.0f32; token_ids.len()]
447            }
448        };
449
450        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
451        Ok(Self::new(field, vector))
452    }
453}
454
455impl Query for SparseVectorQuery {
456    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
457        let field = self.field;
458        let vector = self.vector.clone();
459        let combiner = self.combiner;
460        let heap_factor = self.heap_factor;
461        Box::pin(async move {
462            let results = reader
463                .search_sparse_vector(field, &vector, limit, combiner, heap_factor)
464                .await?;
465
466            Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
467        })
468    }
469
470    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
471        Box::pin(async move { Ok(u32::MAX) })
472    }
473}
474
475/// Scorer for sparse vector search results with ordinal tracking
476struct SparseVectorScorer {
477    results: Vec<VectorSearchResult>,
478    position: usize,
479    field_id: u32,
480}
481
482impl SparseVectorScorer {
483    fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
484        Self {
485            results,
486            position: 0,
487            field_id,
488        }
489    }
490}
491
492impl Scorer for SparseVectorScorer {
493    fn doc(&self) -> DocId {
494        if self.position < self.results.len() {
495            self.results[self.position].doc_id
496        } else {
497            TERMINATED
498        }
499    }
500
501    fn score(&self) -> Score {
502        if self.position < self.results.len() {
503            self.results[self.position].score
504        } else {
505            0.0
506        }
507    }
508
509    fn advance(&mut self) -> DocId {
510        self.position += 1;
511        self.doc()
512    }
513
514    fn seek(&mut self, target: DocId) -> DocId {
515        while self.doc() < target && self.doc() != TERMINATED {
516            self.advance();
517        }
518        self.doc()
519    }
520
521    fn size_hint(&self) -> u32 {
522        (self.results.len() - self.position) as u32
523    }
524
525    fn matched_positions(&self) -> Option<MatchedPositions> {
526        if self.position >= self.results.len() {
527            return None;
528        }
529        let result = &self.results[self.position];
530        let scored_positions: Vec<ScoredPosition> = result
531            .ordinals
532            .iter()
533            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
534            .collect();
535        Some(vec![(self.field_id, scored_positions)])
536    }
537}
538
539#[cfg(test)]
540mod tests {
541    use super::*;
542    use crate::dsl::Field;
543
544    #[test]
545    fn test_dense_vector_query_builder() {
546        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
547            .with_nprobe(64)
548            .with_rerank_factor(5);
549
550        assert_eq!(query.field, Field(0));
551        assert_eq!(query.vector.len(), 3);
552        assert_eq!(query.nprobe, 64);
553        assert_eq!(query.rerank_factor, 5);
554    }
555
556    #[test]
557    fn test_sparse_vector_query_new() {
558        let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
559        let query = SparseVectorQuery::new(Field(0), sparse.clone());
560
561        assert_eq!(query.field, Field(0));
562        assert_eq!(query.vector, sparse);
563    }
564
565    #[test]
566    fn test_sparse_vector_query_from_indices_weights() {
567        let query =
568            SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
569
570        assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
571    }
572
573    #[test]
574    fn test_combiner_sum() {
575        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
576        let combiner = MultiValueCombiner::Sum;
577        assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
578    }
579
580    #[test]
581    fn test_combiner_max() {
582        let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
583        let combiner = MultiValueCombiner::Max;
584        assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
585    }
586
587    #[test]
588    fn test_combiner_avg() {
589        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
590        let combiner = MultiValueCombiner::Avg;
591        assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
592    }
593
594    #[test]
595    fn test_combiner_log_sum_exp() {
596        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
597        let combiner = MultiValueCombiner::log_sum_exp();
598        let result = combiner.combine(&scores);
599        // LogSumExp should be between max (3.0) and max + log(n)/t
600        assert!(result >= 3.0);
601        assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
602    }
603
604    #[test]
605    fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
606        let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
607        // High temperature should approach max
608        let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
609        let result = combiner.combine(&scores);
610        // Should be very close to max (5.0)
611        assert!((result - 5.0).abs() < 0.5);
612    }
613
614    #[test]
615    fn test_combiner_weighted_top_k() {
616        let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
617        let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
618        let result = combiner.combine(&scores);
619        // Top 3: 5.0, 3.0, 1.0 with weights 1.0, 0.5, 0.25
620        // weighted_sum = 5*1 + 3*0.5 + 1*0.25 = 6.75
621        // weight_total = 1.75
622        // result = 6.75 / 1.75 ≈ 3.857
623        assert!((result - 3.857).abs() < 0.01);
624    }
625
626    #[test]
627    fn test_combiner_weighted_top_k_less_than_k() {
628        let scores = vec![(0, 2.0), (1, 1.0)];
629        let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
630        let result = combiner.combine(&scores);
631        // Only 2 scores, weights 1.0 and 0.7
632        // weighted_sum = 2*1 + 1*0.7 = 2.7
633        // weight_total = 1.7
634        // result = 2.7 / 1.7 ≈ 1.588
635        assert!((result - 1.588).abs() < 0.01);
636    }
637
638    #[test]
639    fn test_combiner_empty_scores() {
640        let scores: Vec<(u32, f32)> = vec![];
641        assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
642        assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
643        assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
644        assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
645        assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
646    }
647
648    #[test]
649    fn test_combiner_single_score() {
650        let scores = vec![(0, 5.0)];
651        // All combiners should return 5.0 for a single score
652        assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
653        assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
654        assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
655        assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
656        assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
657    }
658
659    #[test]
660    fn test_default_combiner_is_log_sum_exp() {
661        let combiner = MultiValueCombiner::default();
662        match combiner {
663            MultiValueCombiner::LogSumExp { temperature } => {
664                assert!((temperature - 1.5).abs() < 1e-6);
665            }
666            _ => panic!("Default combiner should be LogSumExp"),
667        }
668    }
669}