hermes_core/query/
vector.rs

1//! Vector query types for dense and sparse vector search
2
3use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::traits::{CountFuture, Query, Scorer, ScorerFuture};
8
9/// Dense vector query for similarity search
10#[derive(Debug, Clone)]
11pub struct DenseVectorQuery {
12    /// Field containing the dense vectors
13    pub field: Field,
14    /// Query vector
15    pub vector: Vec<f32>,
16    /// Number of clusters to probe (for IVF indexes)
17    pub nprobe: usize,
18    /// Re-ranking factor (multiplied by k for candidate selection)
19    pub rerank_factor: usize,
20}
21
22impl DenseVectorQuery {
23    /// Create a new dense vector query
24    pub fn new(field: Field, vector: Vec<f32>) -> Self {
25        Self {
26            field,
27            vector,
28            nprobe: 32,
29            rerank_factor: 3,
30        }
31    }
32
33    /// Set the number of clusters to probe (for IVF indexes)
34    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
35        self.nprobe = nprobe;
36        self
37    }
38
39    /// Set the re-ranking factor
40    pub fn with_rerank_factor(mut self, factor: usize) -> Self {
41        self.rerank_factor = factor;
42        self
43    }
44}
45
46impl Query for DenseVectorQuery {
47    fn scorer<'a>(&'a self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
48        Box::pin(async move {
49            let results =
50                reader.search_dense_vector(self.field, &self.vector, limit, self.rerank_factor)?;
51
52            Ok(Box::new(DenseVectorScorer::new(results)) as Box<dyn Scorer>)
53        })
54    }
55
56    fn count_estimate<'a>(&'a self, _reader: &'a SegmentReader) -> CountFuture<'a> {
57        Box::pin(async move { Ok(u32::MAX) })
58    }
59}
60
61/// Scorer for dense vector search results
62struct DenseVectorScorer {
63    results: Vec<(u32, f32)>,
64    position: usize,
65}
66
67impl DenseVectorScorer {
68    fn new(results: Vec<(u32, f32)>) -> Self {
69        Self {
70            results,
71            position: 0,
72        }
73    }
74}
75
76impl Scorer for DenseVectorScorer {
77    fn doc(&self) -> DocId {
78        if self.position < self.results.len() {
79            self.results[self.position].0
80        } else {
81            TERMINATED
82        }
83    }
84
85    fn score(&self) -> Score {
86        if self.position < self.results.len() {
87            // Convert distance to score (smaller distance = higher score)
88            let distance = self.results[self.position].1;
89            1.0 / (1.0 + distance)
90        } else {
91            0.0
92        }
93    }
94
95    fn advance(&mut self) -> DocId {
96        self.position += 1;
97        self.doc()
98    }
99
100    fn seek(&mut self, target: DocId) -> DocId {
101        while self.doc() < target && self.doc() != TERMINATED {
102            self.advance();
103        }
104        self.doc()
105    }
106
107    fn size_hint(&self) -> u32 {
108        (self.results.len() - self.position) as u32
109    }
110}
111
112/// Sparse vector query for similarity search
113#[derive(Debug, Clone)]
114pub struct SparseVectorQuery {
115    /// Field containing the sparse vectors
116    pub field: Field,
117    /// Query vector as (dimension_id, weight) pairs
118    pub vector: Vec<(u32, f32)>,
119}
120
121impl SparseVectorQuery {
122    /// Create a new sparse vector query
123    pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
124        Self { field, vector }
125    }
126
127    /// Create from separate indices and weights vectors
128    pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
129        let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
130        Self::new(field, vector)
131    }
132
133    /// Create from raw text using a HuggingFace tokenizer (single segment)
134    ///
135    /// This method tokenizes the text and creates a sparse vector query.
136    /// For multi-segment indexes, use `from_text_with_stats` instead.
137    ///
138    /// # Arguments
139    /// * `field` - The sparse vector field to search
140    /// * `text` - Raw text to tokenize
141    /// * `tokenizer_name` - HuggingFace tokenizer path (e.g., "bert-base-uncased")
142    /// * `weighting` - Weighting strategy for tokens
143    /// * `sparse_index` - Optional sparse index for IDF lookup (required for IDF weighting)
144    #[cfg(feature = "native")]
145    pub fn from_text(
146        field: Field,
147        text: &str,
148        tokenizer_name: &str,
149        weighting: crate::structures::QueryWeighting,
150        sparse_index: Option<&crate::segment::SparseIndex>,
151    ) -> crate::Result<Self> {
152        use crate::structures::QueryWeighting;
153        use crate::tokenizer::tokenizer_cache;
154
155        let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
156        let token_ids = tokenizer.tokenize_unique(text)?;
157
158        let weights: Vec<f32> = match weighting {
159            QueryWeighting::One => vec![1.0f32; token_ids.len()],
160            QueryWeighting::Idf => {
161                if let Some(index) = sparse_index {
162                    index.idf_weights(&token_ids)
163                } else {
164                    vec![1.0f32; token_ids.len()]
165                }
166            }
167        };
168
169        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
170        Ok(Self::new(field, vector))
171    }
172
173    /// Create from raw text using global statistics (multi-segment)
174    ///
175    /// This is the recommended method for multi-segment indexes as it uses
176    /// aggregated IDF values across all segments for consistent ranking.
177    ///
178    /// # Arguments
179    /// * `field` - The sparse vector field to search
180    /// * `text` - Raw text to tokenize
181    /// * `tokenizer` - Pre-loaded HuggingFace tokenizer
182    /// * `weighting` - Weighting strategy for tokens
183    /// * `global_stats` - Global statistics for IDF computation
184    #[cfg(feature = "native")]
185    pub fn from_text_with_stats(
186        field: Field,
187        text: &str,
188        tokenizer: &crate::tokenizer::HfTokenizer,
189        weighting: crate::structures::QueryWeighting,
190        global_stats: Option<&super::GlobalStats>,
191    ) -> crate::Result<Self> {
192        use crate::structures::QueryWeighting;
193
194        let token_ids = tokenizer.tokenize_unique(text)?;
195
196        let weights: Vec<f32> = match weighting {
197            QueryWeighting::One => vec![1.0f32; token_ids.len()],
198            QueryWeighting::Idf => {
199                if let Some(stats) = global_stats {
200                    stats.sparse_idf_weights(field, &token_ids)
201                } else {
202                    vec![1.0f32; token_ids.len()]
203                }
204            }
205        };
206
207        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
208        Ok(Self::new(field, vector))
209    }
210
211    /// Create from raw text, loading tokenizer from index directory
212    ///
213    /// This method supports the `index://` prefix for tokenizer paths,
214    /// loading tokenizer.json from the index directory.
215    ///
216    /// # Arguments
217    /// * `field` - The sparse vector field to search
218    /// * `text` - Raw text to tokenize
219    /// * `tokenizer_bytes` - Tokenizer JSON bytes (pre-loaded from directory)
220    /// * `weighting` - Weighting strategy for tokens
221    /// * `global_stats` - Global statistics for IDF computation
222    #[cfg(feature = "native")]
223    pub fn from_text_with_tokenizer_bytes(
224        field: Field,
225        text: &str,
226        tokenizer_bytes: &[u8],
227        weighting: crate::structures::QueryWeighting,
228        global_stats: Option<&super::GlobalStats>,
229    ) -> crate::Result<Self> {
230        use crate::structures::QueryWeighting;
231        use crate::tokenizer::HfTokenizer;
232
233        let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
234        let token_ids = tokenizer.tokenize_unique(text)?;
235
236        let weights: Vec<f32> = match weighting {
237            QueryWeighting::One => vec![1.0f32; token_ids.len()],
238            QueryWeighting::Idf => {
239                if let Some(stats) = global_stats {
240                    stats.sparse_idf_weights(field, &token_ids)
241                } else {
242                    vec![1.0f32; token_ids.len()]
243                }
244            }
245        };
246
247        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
248        Ok(Self::new(field, vector))
249    }
250}
251
252impl Query for SparseVectorQuery {
253    fn scorer<'a>(&'a self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
254        Box::pin(async move {
255            let results = reader
256                .search_sparse_vector(self.field, &self.vector, limit)
257                .await?;
258
259            Ok(Box::new(SparseVectorScorer::new(results)) as Box<dyn Scorer>)
260        })
261    }
262
263    fn count_estimate<'a>(&'a self, _reader: &'a SegmentReader) -> CountFuture<'a> {
264        Box::pin(async move { Ok(u32::MAX) })
265    }
266}
267
268/// Scorer for sparse vector search results
269struct SparseVectorScorer {
270    results: Vec<(u32, f32)>,
271    position: usize,
272}
273
274impl SparseVectorScorer {
275    fn new(results: Vec<(u32, f32)>) -> Self {
276        Self {
277            results,
278            position: 0,
279        }
280    }
281}
282
283impl Scorer for SparseVectorScorer {
284    fn doc(&self) -> DocId {
285        if self.position < self.results.len() {
286            self.results[self.position].0
287        } else {
288            TERMINATED
289        }
290    }
291
292    fn score(&self) -> Score {
293        if self.position < self.results.len() {
294            self.results[self.position].1
295        } else {
296            0.0
297        }
298    }
299
300    fn advance(&mut self) -> DocId {
301        self.position += 1;
302        self.doc()
303    }
304
305    fn seek(&mut self, target: DocId) -> DocId {
306        while self.doc() < target && self.doc() != TERMINATED {
307            self.advance();
308        }
309        self.doc()
310    }
311
312    fn size_hint(&self) -> u32 {
313        (self.results.len() - self.position) as u32
314    }
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320    use crate::dsl::Field;
321
322    #[test]
323    fn test_dense_vector_query_builder() {
324        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
325            .with_nprobe(64)
326            .with_rerank_factor(5);
327
328        assert_eq!(query.field, Field(0));
329        assert_eq!(query.vector.len(), 3);
330        assert_eq!(query.nprobe, 64);
331        assert_eq!(query.rerank_factor, 5);
332    }
333
334    #[test]
335    fn test_sparse_vector_query_new() {
336        let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
337        let query = SparseVectorQuery::new(Field(0), sparse.clone());
338
339        assert_eq!(query.field, Field(0));
340        assert_eq!(query.vector, sparse);
341    }
342
343    #[test]
344    fn test_sparse_vector_query_from_indices_weights() {
345        let query =
346            SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
347
348        assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
349    }
350}