1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::traits::{CountFuture, Query, Scorer, ScorerFuture};
8
9#[derive(Debug, Clone)]
11pub struct DenseVectorQuery {
12 pub field: Field,
14 pub vector: Vec<f32>,
16 pub nprobe: usize,
18 pub rerank_factor: usize,
20}
21
22impl DenseVectorQuery {
23 pub fn new(field: Field, vector: Vec<f32>) -> Self {
25 Self {
26 field,
27 vector,
28 nprobe: 32,
29 rerank_factor: 3,
30 }
31 }
32
33 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
35 self.nprobe = nprobe;
36 self
37 }
38
39 pub fn with_rerank_factor(mut self, factor: usize) -> Self {
41 self.rerank_factor = factor;
42 self
43 }
44}
45
46impl Query for DenseVectorQuery {
47 fn scorer<'a>(&'a self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
48 Box::pin(async move {
49 let results =
50 reader.search_dense_vector(self.field, &self.vector, limit, self.rerank_factor)?;
51
52 Ok(Box::new(DenseVectorScorer::new(results)) as Box<dyn Scorer>)
53 })
54 }
55
56 fn count_estimate<'a>(&'a self, _reader: &'a SegmentReader) -> CountFuture<'a> {
57 Box::pin(async move { Ok(u32::MAX) })
58 }
59}
60
61struct DenseVectorScorer {
63 results: Vec<(u32, f32)>,
64 position: usize,
65}
66
67impl DenseVectorScorer {
68 fn new(results: Vec<(u32, f32)>) -> Self {
69 Self {
70 results,
71 position: 0,
72 }
73 }
74}
75
76impl Scorer for DenseVectorScorer {
77 fn doc(&self) -> DocId {
78 if self.position < self.results.len() {
79 self.results[self.position].0
80 } else {
81 TERMINATED
82 }
83 }
84
85 fn score(&self) -> Score {
86 if self.position < self.results.len() {
87 let distance = self.results[self.position].1;
89 1.0 / (1.0 + distance)
90 } else {
91 0.0
92 }
93 }
94
95 fn advance(&mut self) -> DocId {
96 self.position += 1;
97 self.doc()
98 }
99
100 fn seek(&mut self, target: DocId) -> DocId {
101 while self.doc() < target && self.doc() != TERMINATED {
102 self.advance();
103 }
104 self.doc()
105 }
106
107 fn size_hint(&self) -> u32 {
108 (self.results.len() - self.position) as u32
109 }
110}
111
112#[derive(Debug, Clone)]
114pub struct SparseVectorQuery {
115 pub field: Field,
117 pub vector: Vec<(u32, f32)>,
119}
120
121impl SparseVectorQuery {
122 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
124 Self { field, vector }
125 }
126
127 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
129 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
130 Self::new(field, vector)
131 }
132
133 #[cfg(feature = "native")]
145 pub fn from_text(
146 field: Field,
147 text: &str,
148 tokenizer_name: &str,
149 weighting: crate::structures::QueryWeighting,
150 sparse_index: Option<&crate::segment::SparseIndex>,
151 ) -> crate::Result<Self> {
152 use crate::structures::QueryWeighting;
153 use crate::tokenizer::tokenizer_cache;
154
155 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
156 let token_ids = tokenizer.tokenize_unique(text)?;
157
158 let weights: Vec<f32> = match weighting {
159 QueryWeighting::One => vec![1.0f32; token_ids.len()],
160 QueryWeighting::Idf => {
161 if let Some(index) = sparse_index {
162 index.idf_weights(&token_ids)
163 } else {
164 vec![1.0f32; token_ids.len()]
165 }
166 }
167 };
168
169 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
170 Ok(Self::new(field, vector))
171 }
172
173 #[cfg(feature = "native")]
185 pub fn from_text_with_stats(
186 field: Field,
187 text: &str,
188 tokenizer: &crate::tokenizer::HfTokenizer,
189 weighting: crate::structures::QueryWeighting,
190 global_stats: Option<&super::GlobalStats>,
191 ) -> crate::Result<Self> {
192 use crate::structures::QueryWeighting;
193
194 let token_ids = tokenizer.tokenize_unique(text)?;
195
196 let weights: Vec<f32> = match weighting {
197 QueryWeighting::One => vec![1.0f32; token_ids.len()],
198 QueryWeighting::Idf => {
199 if let Some(stats) = global_stats {
200 stats.sparse_idf_weights(field, &token_ids)
201 } else {
202 vec![1.0f32; token_ids.len()]
203 }
204 }
205 };
206
207 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
208 Ok(Self::new(field, vector))
209 }
210
211 #[cfg(feature = "native")]
223 pub fn from_text_with_tokenizer_bytes(
224 field: Field,
225 text: &str,
226 tokenizer_bytes: &[u8],
227 weighting: crate::structures::QueryWeighting,
228 global_stats: Option<&super::GlobalStats>,
229 ) -> crate::Result<Self> {
230 use crate::structures::QueryWeighting;
231 use crate::tokenizer::HfTokenizer;
232
233 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
234 let token_ids = tokenizer.tokenize_unique(text)?;
235
236 let weights: Vec<f32> = match weighting {
237 QueryWeighting::One => vec![1.0f32; token_ids.len()],
238 QueryWeighting::Idf => {
239 if let Some(stats) = global_stats {
240 stats.sparse_idf_weights(field, &token_ids)
241 } else {
242 vec![1.0f32; token_ids.len()]
243 }
244 }
245 };
246
247 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
248 Ok(Self::new(field, vector))
249 }
250}
251
252impl Query for SparseVectorQuery {
253 fn scorer<'a>(&'a self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
254 Box::pin(async move {
255 let results = reader
256 .search_sparse_vector(self.field, &self.vector, limit)
257 .await?;
258
259 Ok(Box::new(SparseVectorScorer::new(results)) as Box<dyn Scorer>)
260 })
261 }
262
263 fn count_estimate<'a>(&'a self, _reader: &'a SegmentReader) -> CountFuture<'a> {
264 Box::pin(async move { Ok(u32::MAX) })
265 }
266}
267
268struct SparseVectorScorer {
270 results: Vec<(u32, f32)>,
271 position: usize,
272}
273
274impl SparseVectorScorer {
275 fn new(results: Vec<(u32, f32)>) -> Self {
276 Self {
277 results,
278 position: 0,
279 }
280 }
281}
282
283impl Scorer for SparseVectorScorer {
284 fn doc(&self) -> DocId {
285 if self.position < self.results.len() {
286 self.results[self.position].0
287 } else {
288 TERMINATED
289 }
290 }
291
292 fn score(&self) -> Score {
293 if self.position < self.results.len() {
294 self.results[self.position].1
295 } else {
296 0.0
297 }
298 }
299
300 fn advance(&mut self) -> DocId {
301 self.position += 1;
302 self.doc()
303 }
304
305 fn seek(&mut self, target: DocId) -> DocId {
306 while self.doc() < target && self.doc() != TERMINATED {
307 self.advance();
308 }
309 self.doc()
310 }
311
312 fn size_hint(&self) -> u32 {
313 (self.results.len() - self.position) as u32
314 }
315}
316
317#[cfg(test)]
318mod tests {
319 use super::*;
320 use crate::dsl::Field;
321
322 #[test]
323 fn test_dense_vector_query_builder() {
324 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
325 .with_nprobe(64)
326 .with_rerank_factor(5);
327
328 assert_eq!(query.field, Field(0));
329 assert_eq!(query.vector.len(), 3);
330 assert_eq!(query.nprobe, 64);
331 assert_eq!(query.rerank_factor, 5);
332 }
333
334 #[test]
335 fn test_sparse_vector_query_new() {
336 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
337 let query = SparseVectorQuery::new(Field(0), sparse.clone());
338
339 assert_eq!(query.field, Field(0));
340 assert_eq!(query.vector, sparse);
341 }
342
343 #[test]
344 fn test_sparse_vector_query_from_indices_weights() {
345 let query =
346 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
347
348 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
349 }
350}