Skip to main content

hermes_core/query/
vector.rs

1//! Vector query types for dense and sparse vector search
2
3use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10/// Strategy for combining scores when a document has multiple values for the same field
11#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
12pub enum MultiValueCombiner {
13    /// Sum all scores (default for sparse vectors - accumulates dot product contributions)
14    #[default]
15    Sum,
16    /// Take the maximum score
17    Max,
18    /// Take the average score
19    Avg,
20}
21
22/// Dense vector query for similarity search
23#[derive(Debug, Clone)]
24pub struct DenseVectorQuery {
25    /// Field containing the dense vectors
26    pub field: Field,
27    /// Query vector
28    pub vector: Vec<f32>,
29    /// Number of clusters to probe (for IVF indexes)
30    pub nprobe: usize,
31    /// Re-ranking factor (multiplied by k for candidate selection)
32    pub rerank_factor: usize,
33    /// How to combine scores for multi-valued documents
34    pub combiner: MultiValueCombiner,
35}
36
37impl DenseVectorQuery {
38    /// Create a new dense vector query
39    pub fn new(field: Field, vector: Vec<f32>) -> Self {
40        Self {
41            field,
42            vector,
43            nprobe: 32,
44            rerank_factor: 3,
45            combiner: MultiValueCombiner::Max,
46        }
47    }
48
49    /// Set the number of clusters to probe (for IVF indexes)
50    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
51        self.nprobe = nprobe;
52        self
53    }
54
55    /// Set the re-ranking factor
56    pub fn with_rerank_factor(mut self, factor: usize) -> Self {
57        self.rerank_factor = factor;
58        self
59    }
60
61    /// Set the multi-value score combiner
62    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
63        self.combiner = combiner;
64        self
65    }
66}
67
68impl Query for DenseVectorQuery {
69    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
70        let field = self.field;
71        let vector = self.vector.clone();
72        let rerank_factor = self.rerank_factor;
73        let combiner = self.combiner;
74        Box::pin(async move {
75            let results =
76                reader.search_dense_vector(field, &vector, limit, rerank_factor, combiner)?;
77
78            Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
79        })
80    }
81
82    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
83        Box::pin(async move { Ok(u32::MAX) })
84    }
85}
86
87/// Scorer for dense vector search results with ordinal tracking
88struct DenseVectorScorer {
89    results: Vec<VectorSearchResult>,
90    position: usize,
91    field_id: u32,
92}
93
94impl DenseVectorScorer {
95    fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
96        Self {
97            results,
98            position: 0,
99            field_id,
100        }
101    }
102}
103
104impl Scorer for DenseVectorScorer {
105    fn doc(&self) -> DocId {
106        if self.position < self.results.len() {
107            self.results[self.position].doc_id
108        } else {
109            TERMINATED
110        }
111    }
112
113    fn score(&self) -> Score {
114        if self.position < self.results.len() {
115            self.results[self.position].score
116        } else {
117            0.0
118        }
119    }
120
121    fn advance(&mut self) -> DocId {
122        self.position += 1;
123        self.doc()
124    }
125
126    fn seek(&mut self, target: DocId) -> DocId {
127        while self.doc() < target && self.doc() != TERMINATED {
128            self.advance();
129        }
130        self.doc()
131    }
132
133    fn size_hint(&self) -> u32 {
134        (self.results.len() - self.position) as u32
135    }
136
137    fn matched_positions(&self) -> Option<MatchedPositions> {
138        if self.position >= self.results.len() {
139            return None;
140        }
141        let result = &self.results[self.position];
142        let scored_positions: Vec<ScoredPosition> = result
143            .ordinals
144            .iter()
145            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
146            .collect();
147        Some(vec![(self.field_id, scored_positions)])
148    }
149}
150
151/// Sparse vector query for similarity search
152#[derive(Debug, Clone)]
153pub struct SparseVectorQuery {
154    /// Field containing the sparse vectors
155    pub field: Field,
156    /// Query vector as (dimension_id, weight) pairs
157    pub vector: Vec<(u32, f32)>,
158    /// How to combine scores for multi-valued documents
159    pub combiner: MultiValueCombiner,
160}
161
162impl SparseVectorQuery {
163    /// Create a new sparse vector query
164    pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
165        Self {
166            field,
167            vector,
168            combiner: MultiValueCombiner::Sum,
169        }
170    }
171
172    /// Set the multi-value score combiner
173    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
174        self.combiner = combiner;
175        self
176    }
177
178    /// Create from separate indices and weights vectors
179    pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
180        let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
181        Self::new(field, vector)
182    }
183
184    /// Create from raw text using a HuggingFace tokenizer (single segment)
185    ///
186    /// This method tokenizes the text and creates a sparse vector query.
187    /// For multi-segment indexes, use `from_text_with_stats` instead.
188    ///
189    /// # Arguments
190    /// * `field` - The sparse vector field to search
191    /// * `text` - Raw text to tokenize
192    /// * `tokenizer_name` - HuggingFace tokenizer path (e.g., "bert-base-uncased")
193    /// * `weighting` - Weighting strategy for tokens
194    /// * `sparse_index` - Optional sparse index for IDF lookup (required for IDF weighting)
195    #[cfg(feature = "native")]
196    pub fn from_text(
197        field: Field,
198        text: &str,
199        tokenizer_name: &str,
200        weighting: crate::structures::QueryWeighting,
201        sparse_index: Option<&crate::segment::SparseIndex>,
202    ) -> crate::Result<Self> {
203        use crate::structures::QueryWeighting;
204        use crate::tokenizer::tokenizer_cache;
205
206        let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
207        let token_ids = tokenizer.tokenize_unique(text)?;
208
209        let weights: Vec<f32> = match weighting {
210            QueryWeighting::One => vec![1.0f32; token_ids.len()],
211            QueryWeighting::Idf => {
212                if let Some(index) = sparse_index {
213                    index.idf_weights(&token_ids)
214                } else {
215                    vec![1.0f32; token_ids.len()]
216                }
217            }
218        };
219
220        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
221        Ok(Self::new(field, vector))
222    }
223
224    /// Create from raw text using global statistics (multi-segment)
225    ///
226    /// This is the recommended method for multi-segment indexes as it uses
227    /// aggregated IDF values across all segments for consistent ranking.
228    ///
229    /// # Arguments
230    /// * `field` - The sparse vector field to search
231    /// * `text` - Raw text to tokenize
232    /// * `tokenizer` - Pre-loaded HuggingFace tokenizer
233    /// * `weighting` - Weighting strategy for tokens
234    /// * `global_stats` - Global statistics for IDF computation
235    #[cfg(feature = "native")]
236    pub fn from_text_with_stats(
237        field: Field,
238        text: &str,
239        tokenizer: &crate::tokenizer::HfTokenizer,
240        weighting: crate::structures::QueryWeighting,
241        global_stats: Option<&super::GlobalStats>,
242    ) -> crate::Result<Self> {
243        use crate::structures::QueryWeighting;
244
245        let token_ids = tokenizer.tokenize_unique(text)?;
246
247        let weights: Vec<f32> = match weighting {
248            QueryWeighting::One => vec![1.0f32; token_ids.len()],
249            QueryWeighting::Idf => {
250                if let Some(stats) = global_stats {
251                    // Clamp to zero: negative weights don't make sense for IDF
252                    stats
253                        .sparse_idf_weights(field, &token_ids)
254                        .into_iter()
255                        .map(|w| w.max(0.0))
256                        .collect()
257                } else {
258                    vec![1.0f32; token_ids.len()]
259                }
260            }
261        };
262
263        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
264        Ok(Self::new(field, vector))
265    }
266
267    /// Create from raw text, loading tokenizer from index directory
268    ///
269    /// This method supports the `index://` prefix for tokenizer paths,
270    /// loading tokenizer.json from the index directory.
271    ///
272    /// # Arguments
273    /// * `field` - The sparse vector field to search
274    /// * `text` - Raw text to tokenize
275    /// * `tokenizer_bytes` - Tokenizer JSON bytes (pre-loaded from directory)
276    /// * `weighting` - Weighting strategy for tokens
277    /// * `global_stats` - Global statistics for IDF computation
278    #[cfg(feature = "native")]
279    pub fn from_text_with_tokenizer_bytes(
280        field: Field,
281        text: &str,
282        tokenizer_bytes: &[u8],
283        weighting: crate::structures::QueryWeighting,
284        global_stats: Option<&super::GlobalStats>,
285    ) -> crate::Result<Self> {
286        use crate::structures::QueryWeighting;
287        use crate::tokenizer::HfTokenizer;
288
289        let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
290        let token_ids = tokenizer.tokenize_unique(text)?;
291
292        let weights: Vec<f32> = match weighting {
293            QueryWeighting::One => vec![1.0f32; token_ids.len()],
294            QueryWeighting::Idf => {
295                if let Some(stats) = global_stats {
296                    // Clamp to zero: negative weights don't make sense for IDF
297                    stats
298                        .sparse_idf_weights(field, &token_ids)
299                        .into_iter()
300                        .map(|w| w.max(0.0))
301                        .collect()
302                } else {
303                    vec![1.0f32; token_ids.len()]
304                }
305            }
306        };
307
308        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
309        Ok(Self::new(field, vector))
310    }
311}
312
313impl Query for SparseVectorQuery {
314    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
315        let field = self.field;
316        let vector = self.vector.clone();
317        let combiner = self.combiner;
318        Box::pin(async move {
319            let results = reader
320                .search_sparse_vector(field, &vector, limit, combiner)
321                .await?;
322
323            Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
324        })
325    }
326
327    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
328        Box::pin(async move { Ok(u32::MAX) })
329    }
330}
331
332/// Scorer for sparse vector search results with ordinal tracking
333struct SparseVectorScorer {
334    results: Vec<VectorSearchResult>,
335    position: usize,
336    field_id: u32,
337}
338
339impl SparseVectorScorer {
340    fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
341        Self {
342            results,
343            position: 0,
344            field_id,
345        }
346    }
347}
348
349impl Scorer for SparseVectorScorer {
350    fn doc(&self) -> DocId {
351        if self.position < self.results.len() {
352            self.results[self.position].doc_id
353        } else {
354            TERMINATED
355        }
356    }
357
358    fn score(&self) -> Score {
359        if self.position < self.results.len() {
360            self.results[self.position].score
361        } else {
362            0.0
363        }
364    }
365
366    fn advance(&mut self) -> DocId {
367        self.position += 1;
368        self.doc()
369    }
370
371    fn seek(&mut self, target: DocId) -> DocId {
372        while self.doc() < target && self.doc() != TERMINATED {
373            self.advance();
374        }
375        self.doc()
376    }
377
378    fn size_hint(&self) -> u32 {
379        (self.results.len() - self.position) as u32
380    }
381
382    fn matched_positions(&self) -> Option<MatchedPositions> {
383        if self.position >= self.results.len() {
384            return None;
385        }
386        let result = &self.results[self.position];
387        let scored_positions: Vec<ScoredPosition> = result
388            .ordinals
389            .iter()
390            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
391            .collect();
392        Some(vec![(self.field_id, scored_positions)])
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use crate::dsl::Field;
400
401    #[test]
402    fn test_dense_vector_query_builder() {
403        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
404            .with_nprobe(64)
405            .with_rerank_factor(5);
406
407        assert_eq!(query.field, Field(0));
408        assert_eq!(query.vector.len(), 3);
409        assert_eq!(query.nprobe, 64);
410        assert_eq!(query.rerank_factor, 5);
411    }
412
413    #[test]
414    fn test_sparse_vector_query_new() {
415        let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
416        let query = SparseVectorQuery::new(Field(0), sparse.clone());
417
418        assert_eq!(query.field, Field(0));
419        assert_eq!(query.vector, sparse);
420    }
421
422    #[test]
423    fn test_sparse_vector_query_from_indices_weights() {
424        let query =
425            SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
426
427        assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
428    }
429}