1use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
12pub enum MultiValueCombiner {
13 #[default]
15 Sum,
16 Max,
18 Avg,
20}
21
22#[derive(Debug, Clone)]
24pub struct DenseVectorQuery {
25 pub field: Field,
27 pub vector: Vec<f32>,
29 pub nprobe: usize,
31 pub rerank_factor: usize,
33 pub combiner: MultiValueCombiner,
35}
36
37impl DenseVectorQuery {
38 pub fn new(field: Field, vector: Vec<f32>) -> Self {
40 Self {
41 field,
42 vector,
43 nprobe: 32,
44 rerank_factor: 3,
45 combiner: MultiValueCombiner::Max,
46 }
47 }
48
49 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
51 self.nprobe = nprobe;
52 self
53 }
54
55 pub fn with_rerank_factor(mut self, factor: usize) -> Self {
57 self.rerank_factor = factor;
58 self
59 }
60
61 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
63 self.combiner = combiner;
64 self
65 }
66}
67
68impl Query for DenseVectorQuery {
69 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
70 let field = self.field;
71 let vector = self.vector.clone();
72 let rerank_factor = self.rerank_factor;
73 let combiner = self.combiner;
74 Box::pin(async move {
75 let results =
76 reader.search_dense_vector(field, &vector, limit, rerank_factor, combiner)?;
77
78 Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
79 })
80 }
81
82 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
83 Box::pin(async move { Ok(u32::MAX) })
84 }
85}
86
87struct DenseVectorScorer {
89 results: Vec<VectorSearchResult>,
90 position: usize,
91 field_id: u32,
92}
93
94impl DenseVectorScorer {
95 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
96 Self {
97 results,
98 position: 0,
99 field_id,
100 }
101 }
102}
103
104impl Scorer for DenseVectorScorer {
105 fn doc(&self) -> DocId {
106 if self.position < self.results.len() {
107 self.results[self.position].doc_id
108 } else {
109 TERMINATED
110 }
111 }
112
113 fn score(&self) -> Score {
114 if self.position < self.results.len() {
115 self.results[self.position].score
116 } else {
117 0.0
118 }
119 }
120
121 fn advance(&mut self) -> DocId {
122 self.position += 1;
123 self.doc()
124 }
125
126 fn seek(&mut self, target: DocId) -> DocId {
127 while self.doc() < target && self.doc() != TERMINATED {
128 self.advance();
129 }
130 self.doc()
131 }
132
133 fn size_hint(&self) -> u32 {
134 (self.results.len() - self.position) as u32
135 }
136
137 fn matched_positions(&self) -> Option<MatchedPositions> {
138 if self.position >= self.results.len() {
139 return None;
140 }
141 let result = &self.results[self.position];
142 let scored_positions: Vec<ScoredPosition> = result
143 .ordinals
144 .iter()
145 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
146 .collect();
147 Some(vec![(self.field_id, scored_positions)])
148 }
149}
150
151#[derive(Debug, Clone)]
153pub struct SparseVectorQuery {
154 pub field: Field,
156 pub vector: Vec<(u32, f32)>,
158 pub combiner: MultiValueCombiner,
160}
161
162impl SparseVectorQuery {
163 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
165 Self {
166 field,
167 vector,
168 combiner: MultiValueCombiner::Sum,
169 }
170 }
171
172 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
174 self.combiner = combiner;
175 self
176 }
177
178 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
180 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
181 Self::new(field, vector)
182 }
183
184 #[cfg(feature = "native")]
196 pub fn from_text(
197 field: Field,
198 text: &str,
199 tokenizer_name: &str,
200 weighting: crate::structures::QueryWeighting,
201 sparse_index: Option<&crate::segment::SparseIndex>,
202 ) -> crate::Result<Self> {
203 use crate::structures::QueryWeighting;
204 use crate::tokenizer::tokenizer_cache;
205
206 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
207 let token_ids = tokenizer.tokenize_unique(text)?;
208
209 let weights: Vec<f32> = match weighting {
210 QueryWeighting::One => vec![1.0f32; token_ids.len()],
211 QueryWeighting::Idf => {
212 if let Some(index) = sparse_index {
213 index.idf_weights(&token_ids)
214 } else {
215 vec![1.0f32; token_ids.len()]
216 }
217 }
218 };
219
220 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
221 Ok(Self::new(field, vector))
222 }
223
224 #[cfg(feature = "native")]
236 pub fn from_text_with_stats(
237 field: Field,
238 text: &str,
239 tokenizer: &crate::tokenizer::HfTokenizer,
240 weighting: crate::structures::QueryWeighting,
241 global_stats: Option<&super::GlobalStats>,
242 ) -> crate::Result<Self> {
243 use crate::structures::QueryWeighting;
244
245 let token_ids = tokenizer.tokenize_unique(text)?;
246
247 let weights: Vec<f32> = match weighting {
248 QueryWeighting::One => vec![1.0f32; token_ids.len()],
249 QueryWeighting::Idf => {
250 if let Some(stats) = global_stats {
251 stats
253 .sparse_idf_weights(field, &token_ids)
254 .into_iter()
255 .map(|w| w.max(0.0))
256 .collect()
257 } else {
258 vec![1.0f32; token_ids.len()]
259 }
260 }
261 };
262
263 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
264 Ok(Self::new(field, vector))
265 }
266
267 #[cfg(feature = "native")]
279 pub fn from_text_with_tokenizer_bytes(
280 field: Field,
281 text: &str,
282 tokenizer_bytes: &[u8],
283 weighting: crate::structures::QueryWeighting,
284 global_stats: Option<&super::GlobalStats>,
285 ) -> crate::Result<Self> {
286 use crate::structures::QueryWeighting;
287 use crate::tokenizer::HfTokenizer;
288
289 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
290 let token_ids = tokenizer.tokenize_unique(text)?;
291
292 let weights: Vec<f32> = match weighting {
293 QueryWeighting::One => vec![1.0f32; token_ids.len()],
294 QueryWeighting::Idf => {
295 if let Some(stats) = global_stats {
296 stats
298 .sparse_idf_weights(field, &token_ids)
299 .into_iter()
300 .map(|w| w.max(0.0))
301 .collect()
302 } else {
303 vec![1.0f32; token_ids.len()]
304 }
305 }
306 };
307
308 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
309 Ok(Self::new(field, vector))
310 }
311}
312
313impl Query for SparseVectorQuery {
314 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
315 let field = self.field;
316 let vector = self.vector.clone();
317 let combiner = self.combiner;
318 Box::pin(async move {
319 let results = reader
320 .search_sparse_vector(field, &vector, limit, combiner)
321 .await?;
322
323 Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
324 })
325 }
326
327 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
328 Box::pin(async move { Ok(u32::MAX) })
329 }
330}
331
332struct SparseVectorScorer {
334 results: Vec<VectorSearchResult>,
335 position: usize,
336 field_id: u32,
337}
338
339impl SparseVectorScorer {
340 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
341 Self {
342 results,
343 position: 0,
344 field_id,
345 }
346 }
347}
348
349impl Scorer for SparseVectorScorer {
350 fn doc(&self) -> DocId {
351 if self.position < self.results.len() {
352 self.results[self.position].doc_id
353 } else {
354 TERMINATED
355 }
356 }
357
358 fn score(&self) -> Score {
359 if self.position < self.results.len() {
360 self.results[self.position].score
361 } else {
362 0.0
363 }
364 }
365
366 fn advance(&mut self) -> DocId {
367 self.position += 1;
368 self.doc()
369 }
370
371 fn seek(&mut self, target: DocId) -> DocId {
372 while self.doc() < target && self.doc() != TERMINATED {
373 self.advance();
374 }
375 self.doc()
376 }
377
378 fn size_hint(&self) -> u32 {
379 (self.results.len() - self.position) as u32
380 }
381
382 fn matched_positions(&self) -> Option<MatchedPositions> {
383 if self.position >= self.results.len() {
384 return None;
385 }
386 let result = &self.results[self.position];
387 let scored_positions: Vec<ScoredPosition> = result
388 .ordinals
389 .iter()
390 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
391 .collect();
392 Some(vec![(self.field_id, scored_positions)])
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use crate::dsl::Field;
400
401 #[test]
402 fn test_dense_vector_query_builder() {
403 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
404 .with_nprobe(64)
405 .with_rerank_factor(5);
406
407 assert_eq!(query.field, Field(0));
408 assert_eq!(query.vector.len(), 3);
409 assert_eq!(query.nprobe, 64);
410 assert_eq!(query.rerank_factor, 5);
411 }
412
413 #[test]
414 fn test_sparse_vector_query_new() {
415 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
416 let query = SparseVectorQuery::new(Field(0), sparse.clone());
417
418 assert_eq!(query.field, Field(0));
419 assert_eq!(query.vector, sparse);
420 }
421
422 #[test]
423 fn test_sparse_vector_query_from_indices_weights() {
424 let query =
425 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
426
427 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
428 }
429}