1use crate::dsl::Field;
4use crate::segment::SegmentReader;
5use crate::{DocId, Score, TERMINATED};
6
7use super::traits::{CountFuture, Query, Scorer, ScorerFuture};
8
9#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
11pub enum MultiValueCombiner {
12 #[default]
14 Sum,
15 Max,
17 Avg,
19}
20
21#[derive(Debug, Clone)]
23pub struct DenseVectorQuery {
24 pub field: Field,
26 pub vector: Vec<f32>,
28 pub nprobe: usize,
30 pub rerank_factor: usize,
32 pub combiner: MultiValueCombiner,
34}
35
36impl DenseVectorQuery {
37 pub fn new(field: Field, vector: Vec<f32>) -> Self {
39 Self {
40 field,
41 vector,
42 nprobe: 32,
43 rerank_factor: 3,
44 combiner: MultiValueCombiner::Max,
45 }
46 }
47
48 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
50 self.nprobe = nprobe;
51 self
52 }
53
54 pub fn with_rerank_factor(mut self, factor: usize) -> Self {
56 self.rerank_factor = factor;
57 self
58 }
59
60 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
62 self.combiner = combiner;
63 self
64 }
65}
66
67impl Query for DenseVectorQuery {
68 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
69 let field = self.field;
70 let vector = self.vector.clone();
71 let rerank_factor = self.rerank_factor;
72 let combiner = self.combiner;
73 Box::pin(async move {
74 let results =
75 reader.search_dense_vector(field, &vector, limit, rerank_factor, combiner)?;
76
77 Ok(Box::new(DenseVectorScorer::new(results)) as Box<dyn Scorer>)
78 })
79 }
80
81 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
82 Box::pin(async move { Ok(u32::MAX) })
83 }
84}
85
86struct DenseVectorScorer {
88 results: Vec<(u32, f32)>,
89 position: usize,
90}
91
92impl DenseVectorScorer {
93 fn new(results: Vec<(u32, f32)>) -> Self {
94 Self {
95 results,
96 position: 0,
97 }
98 }
99}
100
101impl Scorer for DenseVectorScorer {
102 fn doc(&self) -> DocId {
103 if self.position < self.results.len() {
104 self.results[self.position].0
105 } else {
106 TERMINATED
107 }
108 }
109
110 fn score(&self) -> Score {
111 if self.position < self.results.len() {
112 self.results[self.position].1
113 } else {
114 0.0
115 }
116 }
117
118 fn advance(&mut self) -> DocId {
119 self.position += 1;
120 self.doc()
121 }
122
123 fn seek(&mut self, target: DocId) -> DocId {
124 while self.doc() < target && self.doc() != TERMINATED {
125 self.advance();
126 }
127 self.doc()
128 }
129
130 fn size_hint(&self) -> u32 {
131 (self.results.len() - self.position) as u32
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct SparseVectorQuery {
138 pub field: Field,
140 pub vector: Vec<(u32, f32)>,
142 pub combiner: MultiValueCombiner,
144}
145
146impl SparseVectorQuery {
147 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
149 Self {
150 field,
151 vector,
152 combiner: MultiValueCombiner::Sum,
153 }
154 }
155
156 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
158 self.combiner = combiner;
159 self
160 }
161
162 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
164 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
165 Self::new(field, vector)
166 }
167
168 #[cfg(feature = "native")]
180 pub fn from_text(
181 field: Field,
182 text: &str,
183 tokenizer_name: &str,
184 weighting: crate::structures::QueryWeighting,
185 sparse_index: Option<&crate::segment::SparseIndex>,
186 ) -> crate::Result<Self> {
187 use crate::structures::QueryWeighting;
188 use crate::tokenizer::tokenizer_cache;
189
190 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
191 let token_ids = tokenizer.tokenize_unique(text)?;
192
193 let weights: Vec<f32> = match weighting {
194 QueryWeighting::One => vec![1.0f32; token_ids.len()],
195 QueryWeighting::Idf => {
196 if let Some(index) = sparse_index {
197 index.idf_weights(&token_ids)
198 } else {
199 vec![1.0f32; token_ids.len()]
200 }
201 }
202 };
203
204 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
205 Ok(Self::new(field, vector))
206 }
207
208 #[cfg(feature = "native")]
220 pub fn from_text_with_stats(
221 field: Field,
222 text: &str,
223 tokenizer: &crate::tokenizer::HfTokenizer,
224 weighting: crate::structures::QueryWeighting,
225 global_stats: Option<&super::GlobalStats>,
226 ) -> crate::Result<Self> {
227 use crate::structures::QueryWeighting;
228
229 let token_ids = tokenizer.tokenize_unique(text)?;
230
231 let weights: Vec<f32> = match weighting {
232 QueryWeighting::One => vec![1.0f32; token_ids.len()],
233 QueryWeighting::Idf => {
234 if let Some(stats) = global_stats {
235 stats.sparse_idf_weights(field, &token_ids)
236 } else {
237 vec![1.0f32; token_ids.len()]
238 }
239 }
240 };
241
242 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
243 Ok(Self::new(field, vector))
244 }
245
246 #[cfg(feature = "native")]
258 pub fn from_text_with_tokenizer_bytes(
259 field: Field,
260 text: &str,
261 tokenizer_bytes: &[u8],
262 weighting: crate::structures::QueryWeighting,
263 global_stats: Option<&super::GlobalStats>,
264 ) -> crate::Result<Self> {
265 use crate::structures::QueryWeighting;
266 use crate::tokenizer::HfTokenizer;
267
268 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
269 let token_ids = tokenizer.tokenize_unique(text)?;
270
271 let weights: Vec<f32> = match weighting {
272 QueryWeighting::One => vec![1.0f32; token_ids.len()],
273 QueryWeighting::Idf => {
274 if let Some(stats) = global_stats {
275 stats.sparse_idf_weights(field, &token_ids)
276 } else {
277 vec![1.0f32; token_ids.len()]
278 }
279 }
280 };
281
282 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
283 Ok(Self::new(field, vector))
284 }
285}
286
287impl Query for SparseVectorQuery {
288 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
289 let field = self.field;
290 let vector = self.vector.clone();
291 let combiner = self.combiner;
292 Box::pin(async move {
293 let results = reader
294 .search_sparse_vector(field, &vector, limit, combiner)
295 .await?;
296
297 Ok(Box::new(SparseVectorScorer::new(results)) as Box<dyn Scorer>)
298 })
299 }
300
301 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
302 Box::pin(async move { Ok(u32::MAX) })
303 }
304}
305
306struct SparseVectorScorer {
308 results: Vec<(u32, f32)>,
309 position: usize,
310}
311
312impl SparseVectorScorer {
313 fn new(results: Vec<(u32, f32)>) -> Self {
314 Self {
315 results,
316 position: 0,
317 }
318 }
319}
320
321impl Scorer for SparseVectorScorer {
322 fn doc(&self) -> DocId {
323 if self.position < self.results.len() {
324 self.results[self.position].0
325 } else {
326 TERMINATED
327 }
328 }
329
330 fn score(&self) -> Score {
331 if self.position < self.results.len() {
332 self.results[self.position].1
333 } else {
334 0.0
335 }
336 }
337
338 fn advance(&mut self) -> DocId {
339 self.position += 1;
340 self.doc()
341 }
342
343 fn seek(&mut self, target: DocId) -> DocId {
344 while self.doc() < target && self.doc() != TERMINATED {
345 self.advance();
346 }
347 self.doc()
348 }
349
350 fn size_hint(&self) -> u32 {
351 (self.results.len() - self.position) as u32
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::dsl::Field;
359
360 #[test]
361 fn test_dense_vector_query_builder() {
362 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
363 .with_nprobe(64)
364 .with_rerank_factor(5);
365
366 assert_eq!(query.field, Field(0));
367 assert_eq!(query.vector.len(), 3);
368 assert_eq!(query.nprobe, 64);
369 assert_eq!(query.rerank_factor, 5);
370 }
371
372 #[test]
373 fn test_sparse_vector_query_new() {
374 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
375 let query = SparseVectorQuery::new(Field(0), sparse.clone());
376
377 assert_eq!(query.field, Field(0));
378 assert_eq!(query.vector, sparse);
379 }
380
381 #[test]
382 fn test_sparse_vector_query_from_indices_weights() {
383 let query =
384 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
385
386 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
387 }
388}