Skip to main content

hermes_core/query/
vector.rs

1//! Vector query types for dense and sparse vector search
2
3use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::traits::{CountFuture, Query, Scorer, ScorerFuture};
8
9/// Strategy for combining scores when a document has multiple values for the same field
10#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
11pub enum MultiValueCombiner {
12    /// Sum all scores (default for sparse vectors - accumulates dot product contributions)
13    #[default]
14    Sum,
15    /// Take the maximum score
16    Max,
17    /// Take the average score
18    Avg,
19}
20
21/// Dense vector query for similarity search
22#[derive(Debug, Clone)]
23pub struct DenseVectorQuery {
24    /// Field containing the dense vectors
25    pub field: Field,
26    /// Query vector
27    pub vector: Vec<f32>,
28    /// Number of clusters to probe (for IVF indexes)
29    pub nprobe: usize,
30    /// Re-ranking factor (multiplied by k for candidate selection)
31    pub rerank_factor: usize,
32    /// How to combine scores for multi-valued documents
33    pub combiner: MultiValueCombiner,
34}
35
36impl DenseVectorQuery {
37    /// Create a new dense vector query
38    pub fn new(field: Field, vector: Vec<f32>) -> Self {
39        Self {
40            field,
41            vector,
42            nprobe: 32,
43            rerank_factor: 3,
44            combiner: MultiValueCombiner::Max,
45        }
46    }
47
48    /// Set the number of clusters to probe (for IVF indexes)
49    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
50        self.nprobe = nprobe;
51        self
52    }
53
54    /// Set the re-ranking factor
55    pub fn with_rerank_factor(mut self, factor: usize) -> Self {
56        self.rerank_factor = factor;
57        self
58    }
59
60    /// Set the multi-value score combiner
61    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
62        self.combiner = combiner;
63        self
64    }
65}
66
67impl Query for DenseVectorQuery {
68    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
69        let field = self.field;
70        let vector = self.vector.clone();
71        let rerank_factor = self.rerank_factor;
72        let combiner = self.combiner;
73        Box::pin(async move {
74            let results =
75                reader.search_dense_vector(field, &vector, limit, rerank_factor, combiner)?;
76
77            Ok(Box::new(DenseVectorScorer::new(results)) as Box<dyn Scorer>)
78        })
79    }
80
81    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
82        Box::pin(async move { Ok(u32::MAX) })
83    }
84}
85
86/// Scorer for dense vector search results
87struct DenseVectorScorer {
88    results: Vec<(u32, f32)>,
89    position: usize,
90}
91
92impl DenseVectorScorer {
93    fn new(results: Vec<(u32, f32)>) -> Self {
94        Self {
95            results,
96            position: 0,
97        }
98    }
99}
100
101impl Scorer for DenseVectorScorer {
102    fn doc(&self) -> DocId {
103        if self.position < self.results.len() {
104            self.results[self.position].0
105        } else {
106            TERMINATED
107        }
108    }
109
110    fn score(&self) -> Score {
111        if self.position < self.results.len() {
112            self.results[self.position].1
113        } else {
114            0.0
115        }
116    }
117
118    fn advance(&mut self) -> DocId {
119        self.position += 1;
120        self.doc()
121    }
122
123    fn seek(&mut self, target: DocId) -> DocId {
124        while self.doc() < target && self.doc() != TERMINATED {
125            self.advance();
126        }
127        self.doc()
128    }
129
130    fn size_hint(&self) -> u32 {
131        (self.results.len() - self.position) as u32
132    }
133}
134
135/// Sparse vector query for similarity search
136#[derive(Debug, Clone)]
137pub struct SparseVectorQuery {
138    /// Field containing the sparse vectors
139    pub field: Field,
140    /// Query vector as (dimension_id, weight) pairs
141    pub vector: Vec<(u32, f32)>,
142    /// How to combine scores for multi-valued documents
143    pub combiner: MultiValueCombiner,
144}
145
146impl SparseVectorQuery {
147    /// Create a new sparse vector query
148    pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
149        Self {
150            field,
151            vector,
152            combiner: MultiValueCombiner::Sum,
153        }
154    }
155
156    /// Set the multi-value score combiner
157    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
158        self.combiner = combiner;
159        self
160    }
161
162    /// Create from separate indices and weights vectors
163    pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
164        let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
165        Self::new(field, vector)
166    }
167
168    /// Create from raw text using a HuggingFace tokenizer (single segment)
169    ///
170    /// This method tokenizes the text and creates a sparse vector query.
171    /// For multi-segment indexes, use `from_text_with_stats` instead.
172    ///
173    /// # Arguments
174    /// * `field` - The sparse vector field to search
175    /// * `text` - Raw text to tokenize
176    /// * `tokenizer_name` - HuggingFace tokenizer path (e.g., "bert-base-uncased")
177    /// * `weighting` - Weighting strategy for tokens
178    /// * `sparse_index` - Optional sparse index for IDF lookup (required for IDF weighting)
179    #[cfg(feature = "native")]
180    pub fn from_text(
181        field: Field,
182        text: &str,
183        tokenizer_name: &str,
184        weighting: crate::structures::QueryWeighting,
185        sparse_index: Option<&crate::segment::SparseIndex>,
186    ) -> crate::Result<Self> {
187        use crate::structures::QueryWeighting;
188        use crate::tokenizer::tokenizer_cache;
189
190        let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
191        let token_ids = tokenizer.tokenize_unique(text)?;
192
193        let weights: Vec<f32> = match weighting {
194            QueryWeighting::One => vec![1.0f32; token_ids.len()],
195            QueryWeighting::Idf => {
196                if let Some(index) = sparse_index {
197                    index.idf_weights(&token_ids)
198                } else {
199                    vec![1.0f32; token_ids.len()]
200                }
201            }
202        };
203
204        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
205        Ok(Self::new(field, vector))
206    }
207
208    /// Create from raw text using global statistics (multi-segment)
209    ///
210    /// This is the recommended method for multi-segment indexes as it uses
211    /// aggregated IDF values across all segments for consistent ranking.
212    ///
213    /// # Arguments
214    /// * `field` - The sparse vector field to search
215    /// * `text` - Raw text to tokenize
216    /// * `tokenizer` - Pre-loaded HuggingFace tokenizer
217    /// * `weighting` - Weighting strategy for tokens
218    /// * `global_stats` - Global statistics for IDF computation
219    #[cfg(feature = "native")]
220    pub fn from_text_with_stats(
221        field: Field,
222        text: &str,
223        tokenizer: &crate::tokenizer::HfTokenizer,
224        weighting: crate::structures::QueryWeighting,
225        global_stats: Option<&super::GlobalStats>,
226    ) -> crate::Result<Self> {
227        use crate::structures::QueryWeighting;
228
229        let token_ids = tokenizer.tokenize_unique(text)?;
230
231        let weights: Vec<f32> = match weighting {
232            QueryWeighting::One => vec![1.0f32; token_ids.len()],
233            QueryWeighting::Idf => {
234                if let Some(stats) = global_stats {
235                    stats.sparse_idf_weights(field, &token_ids)
236                } else {
237                    vec![1.0f32; token_ids.len()]
238                }
239            }
240        };
241
242        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
243        Ok(Self::new(field, vector))
244    }
245
246    /// Create from raw text, loading tokenizer from index directory
247    ///
248    /// This method supports the `index://` prefix for tokenizer paths,
249    /// loading tokenizer.json from the index directory.
250    ///
251    /// # Arguments
252    /// * `field` - The sparse vector field to search
253    /// * `text` - Raw text to tokenize
254    /// * `tokenizer_bytes` - Tokenizer JSON bytes (pre-loaded from directory)
255    /// * `weighting` - Weighting strategy for tokens
256    /// * `global_stats` - Global statistics for IDF computation
257    #[cfg(feature = "native")]
258    pub fn from_text_with_tokenizer_bytes(
259        field: Field,
260        text: &str,
261        tokenizer_bytes: &[u8],
262        weighting: crate::structures::QueryWeighting,
263        global_stats: Option<&super::GlobalStats>,
264    ) -> crate::Result<Self> {
265        use crate::structures::QueryWeighting;
266        use crate::tokenizer::HfTokenizer;
267
268        let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
269        let token_ids = tokenizer.tokenize_unique(text)?;
270
271        let weights: Vec<f32> = match weighting {
272            QueryWeighting::One => vec![1.0f32; token_ids.len()],
273            QueryWeighting::Idf => {
274                if let Some(stats) = global_stats {
275                    stats.sparse_idf_weights(field, &token_ids)
276                } else {
277                    vec![1.0f32; token_ids.len()]
278                }
279            }
280        };
281
282        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
283        Ok(Self::new(field, vector))
284    }
285}
286
287impl Query for SparseVectorQuery {
288    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
289        let field = self.field;
290        let vector = self.vector.clone();
291        let combiner = self.combiner;
292        Box::pin(async move {
293            let results = reader
294                .search_sparse_vector(field, &vector, limit, combiner)
295                .await?;
296
297            Ok(Box::new(SparseVectorScorer::new(results)) as Box<dyn Scorer>)
298        })
299    }
300
301    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
302        Box::pin(async move { Ok(u32::MAX) })
303    }
304}
305
306/// Scorer for sparse vector search results
307struct SparseVectorScorer {
308    results: Vec<(u32, f32)>,
309    position: usize,
310}
311
312impl SparseVectorScorer {
313    fn new(results: Vec<(u32, f32)>) -> Self {
314        Self {
315            results,
316            position: 0,
317        }
318    }
319}
320
321impl Scorer for SparseVectorScorer {
322    fn doc(&self) -> DocId {
323        if self.position < self.results.len() {
324            self.results[self.position].0
325        } else {
326            TERMINATED
327        }
328    }
329
330    fn score(&self) -> Score {
331        if self.position < self.results.len() {
332            self.results[self.position].1
333        } else {
334            0.0
335        }
336    }
337
338    fn advance(&mut self) -> DocId {
339        self.position += 1;
340        self.doc()
341    }
342
343    fn seek(&mut self, target: DocId) -> DocId {
344        while self.doc() < target && self.doc() != TERMINATED {
345            self.advance();
346        }
347        self.doc()
348    }
349
350    fn size_hint(&self) -> u32 {
351        (self.results.len() - self.position) as u32
352    }
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358    use crate::dsl::Field;
359
360    #[test]
361    fn test_dense_vector_query_builder() {
362        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
363            .with_nprobe(64)
364            .with_rerank_factor(5);
365
366        assert_eq!(query.field, Field(0));
367        assert_eq!(query.vector.len(), 3);
368        assert_eq!(query.nprobe, 64);
369        assert_eq!(query.rerank_factor, 5);
370    }
371
372    #[test]
373    fn test_sparse_vector_query_new() {
374        let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
375        let query = SparseVectorQuery::new(Field(0), sparse.clone());
376
377        assert_eq!(query.field, Field(0));
378        assert_eq!(query.vector, sparse);
379    }
380
381    #[test]
382    fn test_sparse_vector_query_from_indices_weights() {
383        let query =
384            SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
385
386        assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
387    }
388}