1use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13 Sum,
15 Max,
17 Avg,
19 LogSumExp {
23 temperature: f32,
25 },
26 WeightedTopK {
29 k: usize,
31 decay: f32,
33 },
34}
35
36impl Default for MultiValueCombiner {
37 fn default() -> Self {
38 MultiValueCombiner::LogSumExp { temperature: 1.5 }
41 }
42}
43
44impl MultiValueCombiner {
45 pub fn log_sum_exp() -> Self {
47 Self::LogSumExp { temperature: 1.5 }
48 }
49
50 pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52 Self::LogSumExp { temperature }
53 }
54
55 pub fn weighted_top_k() -> Self {
57 Self::WeightedTopK { k: 5, decay: 0.7 }
58 }
59
60 pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62 Self::WeightedTopK { k, decay }
63 }
64
65 pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67 if scores.is_empty() {
68 return 0.0;
69 }
70
71 match self {
72 MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73 MultiValueCombiner::Max => scores
74 .iter()
75 .map(|(_, s)| *s)
76 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77 .unwrap_or(0.0),
78 MultiValueCombiner::Avg => {
79 let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80 sum / scores.len() as f32
81 }
82 MultiValueCombiner::LogSumExp { temperature } => {
83 let t = *temperature;
86 let max_score = scores
87 .iter()
88 .map(|(_, s)| *s)
89 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90 .unwrap_or(0.0);
91
92 let sum_exp: f32 = scores
93 .iter()
94 .map(|(_, s)| (t * (s - max_score)).exp())
95 .sum();
96
97 max_score + sum_exp.ln() / t
98 }
99 MultiValueCombiner::WeightedTopK { k, decay } => {
100 let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103 sorted.truncate(*k);
104
105 let mut weight = 1.0f32;
107 let mut weighted_sum = 0.0f32;
108 let mut weight_total = 0.0f32;
109
110 for score in sorted {
111 weighted_sum += weight * score;
112 weight_total += weight;
113 weight *= decay;
114 }
115
116 if weight_total > 0.0 {
117 weighted_sum / weight_total
118 } else {
119 0.0
120 }
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129 pub field: Field,
131 pub vector: Vec<f32>,
133 pub nprobe: usize,
135 pub rerank_factor: usize,
137 pub combiner: MultiValueCombiner,
139}
140
141impl DenseVectorQuery {
142 pub fn new(field: Field, vector: Vec<f32>) -> Self {
144 Self {
145 field,
146 vector,
147 nprobe: 32,
148 rerank_factor: 3,
149 combiner: MultiValueCombiner::Max,
150 }
151 }
152
153 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
155 self.nprobe = nprobe;
156 self
157 }
158
159 pub fn with_rerank_factor(mut self, factor: usize) -> Self {
161 self.rerank_factor = factor;
162 self
163 }
164
165 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
167 self.combiner = combiner;
168 self
169 }
170}
171
172impl Query for DenseVectorQuery {
173 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
174 let field = self.field;
175 let vector = self.vector.clone();
176 let rerank_factor = self.rerank_factor;
177 let combiner = self.combiner;
178 Box::pin(async move {
179 let results =
180 reader.search_dense_vector(field, &vector, limit, rerank_factor, combiner)?;
181
182 Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
183 })
184 }
185
186 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
187 Box::pin(async move { Ok(u32::MAX) })
188 }
189}
190
191struct DenseVectorScorer {
193 results: Vec<VectorSearchResult>,
194 position: usize,
195 field_id: u32,
196}
197
198impl DenseVectorScorer {
199 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
200 Self {
201 results,
202 position: 0,
203 field_id,
204 }
205 }
206}
207
208impl Scorer for DenseVectorScorer {
209 fn doc(&self) -> DocId {
210 if self.position < self.results.len() {
211 self.results[self.position].doc_id
212 } else {
213 TERMINATED
214 }
215 }
216
217 fn score(&self) -> Score {
218 if self.position < self.results.len() {
219 self.results[self.position].score
220 } else {
221 0.0
222 }
223 }
224
225 fn advance(&mut self) -> DocId {
226 self.position += 1;
227 self.doc()
228 }
229
230 fn seek(&mut self, target: DocId) -> DocId {
231 while self.doc() < target && self.doc() != TERMINATED {
232 self.advance();
233 }
234 self.doc()
235 }
236
237 fn size_hint(&self) -> u32 {
238 (self.results.len() - self.position) as u32
239 }
240
241 fn matched_positions(&self) -> Option<MatchedPositions> {
242 if self.position >= self.results.len() {
243 return None;
244 }
245 let result = &self.results[self.position];
246 let scored_positions: Vec<ScoredPosition> = result
247 .ordinals
248 .iter()
249 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
250 .collect();
251 Some(vec![(self.field_id, scored_positions)])
252 }
253}
254
255#[derive(Debug, Clone)]
257pub struct SparseVectorQuery {
258 pub field: Field,
260 pub vector: Vec<(u32, f32)>,
262 pub combiner: MultiValueCombiner,
264 pub heap_factor: f32,
267}
268
269impl SparseVectorQuery {
270 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
272 Self {
273 field,
274 vector,
275 combiner: MultiValueCombiner::Sum,
276 heap_factor: 1.0,
277 }
278 }
279
280 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
282 self.combiner = combiner;
283 self
284 }
285
286 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
293 self.heap_factor = heap_factor.clamp(0.0, 1.0);
294 self
295 }
296
297 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
299 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
300 Self::new(field, vector)
301 }
302
303 #[cfg(feature = "native")]
315 pub fn from_text(
316 field: Field,
317 text: &str,
318 tokenizer_name: &str,
319 weighting: crate::structures::QueryWeighting,
320 sparse_index: Option<&crate::segment::SparseIndex>,
321 ) -> crate::Result<Self> {
322 use crate::structures::QueryWeighting;
323 use crate::tokenizer::tokenizer_cache;
324
325 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
326 let token_ids = tokenizer.tokenize_unique(text)?;
327
328 let weights: Vec<f32> = match weighting {
329 QueryWeighting::One => vec![1.0f32; token_ids.len()],
330 QueryWeighting::Idf => {
331 if let Some(index) = sparse_index {
332 index.idf_weights(&token_ids)
333 } else {
334 vec![1.0f32; token_ids.len()]
335 }
336 }
337 };
338
339 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
340 Ok(Self::new(field, vector))
341 }
342
343 #[cfg(feature = "native")]
355 pub fn from_text_with_stats(
356 field: Field,
357 text: &str,
358 tokenizer: &crate::tokenizer::HfTokenizer,
359 weighting: crate::structures::QueryWeighting,
360 global_stats: Option<&super::GlobalStats>,
361 ) -> crate::Result<Self> {
362 use crate::structures::QueryWeighting;
363
364 let token_ids = tokenizer.tokenize_unique(text)?;
365
366 let weights: Vec<f32> = match weighting {
367 QueryWeighting::One => vec![1.0f32; token_ids.len()],
368 QueryWeighting::Idf => {
369 if let Some(stats) = global_stats {
370 stats
372 .sparse_idf_weights(field, &token_ids)
373 .into_iter()
374 .map(|w| w.max(0.0))
375 .collect()
376 } else {
377 vec![1.0f32; token_ids.len()]
378 }
379 }
380 };
381
382 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
383 Ok(Self::new(field, vector))
384 }
385
386 #[cfg(feature = "native")]
398 pub fn from_text_with_tokenizer_bytes(
399 field: Field,
400 text: &str,
401 tokenizer_bytes: &[u8],
402 weighting: crate::structures::QueryWeighting,
403 global_stats: Option<&super::GlobalStats>,
404 ) -> crate::Result<Self> {
405 use crate::structures::QueryWeighting;
406 use crate::tokenizer::HfTokenizer;
407
408 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
409 let token_ids = tokenizer.tokenize_unique(text)?;
410
411 let weights: Vec<f32> = match weighting {
412 QueryWeighting::One => vec![1.0f32; token_ids.len()],
413 QueryWeighting::Idf => {
414 if let Some(stats) = global_stats {
415 stats
417 .sparse_idf_weights(field, &token_ids)
418 .into_iter()
419 .map(|w| w.max(0.0))
420 .collect()
421 } else {
422 vec![1.0f32; token_ids.len()]
423 }
424 }
425 };
426
427 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
428 Ok(Self::new(field, vector))
429 }
430}
431
432impl Query for SparseVectorQuery {
433 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
434 let field = self.field;
435 let vector = self.vector.clone();
436 let combiner = self.combiner;
437 let heap_factor = self.heap_factor;
438 Box::pin(async move {
439 let results = reader
440 .search_sparse_vector(field, &vector, limit, combiner, heap_factor)
441 .await?;
442
443 Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
444 })
445 }
446
447 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
448 Box::pin(async move { Ok(u32::MAX) })
449 }
450}
451
452struct SparseVectorScorer {
454 results: Vec<VectorSearchResult>,
455 position: usize,
456 field_id: u32,
457}
458
459impl SparseVectorScorer {
460 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
461 Self {
462 results,
463 position: 0,
464 field_id,
465 }
466 }
467}
468
469impl Scorer for SparseVectorScorer {
470 fn doc(&self) -> DocId {
471 if self.position < self.results.len() {
472 self.results[self.position].doc_id
473 } else {
474 TERMINATED
475 }
476 }
477
478 fn score(&self) -> Score {
479 if self.position < self.results.len() {
480 self.results[self.position].score
481 } else {
482 0.0
483 }
484 }
485
486 fn advance(&mut self) -> DocId {
487 self.position += 1;
488 self.doc()
489 }
490
491 fn seek(&mut self, target: DocId) -> DocId {
492 while self.doc() < target && self.doc() != TERMINATED {
493 self.advance();
494 }
495 self.doc()
496 }
497
498 fn size_hint(&self) -> u32 {
499 (self.results.len() - self.position) as u32
500 }
501
502 fn matched_positions(&self) -> Option<MatchedPositions> {
503 if self.position >= self.results.len() {
504 return None;
505 }
506 let result = &self.results[self.position];
507 let scored_positions: Vec<ScoredPosition> = result
508 .ordinals
509 .iter()
510 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
511 .collect();
512 Some(vec![(self.field_id, scored_positions)])
513 }
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519 use crate::dsl::Field;
520
521 #[test]
522 fn test_dense_vector_query_builder() {
523 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
524 .with_nprobe(64)
525 .with_rerank_factor(5);
526
527 assert_eq!(query.field, Field(0));
528 assert_eq!(query.vector.len(), 3);
529 assert_eq!(query.nprobe, 64);
530 assert_eq!(query.rerank_factor, 5);
531 }
532
533 #[test]
534 fn test_sparse_vector_query_new() {
535 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
536 let query = SparseVectorQuery::new(Field(0), sparse.clone());
537
538 assert_eq!(query.field, Field(0));
539 assert_eq!(query.vector, sparse);
540 }
541
542 #[test]
543 fn test_sparse_vector_query_from_indices_weights() {
544 let query =
545 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
546
547 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
548 }
549
550 #[test]
551 fn test_combiner_sum() {
552 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
553 let combiner = MultiValueCombiner::Sum;
554 assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
555 }
556
557 #[test]
558 fn test_combiner_max() {
559 let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
560 let combiner = MultiValueCombiner::Max;
561 assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
562 }
563
564 #[test]
565 fn test_combiner_avg() {
566 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
567 let combiner = MultiValueCombiner::Avg;
568 assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
569 }
570
571 #[test]
572 fn test_combiner_log_sum_exp() {
573 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
574 let combiner = MultiValueCombiner::log_sum_exp();
575 let result = combiner.combine(&scores);
576 assert!(result >= 3.0);
578 assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
579 }
580
581 #[test]
582 fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
583 let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
584 let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
586 let result = combiner.combine(&scores);
587 assert!((result - 5.0).abs() < 0.5);
589 }
590
591 #[test]
592 fn test_combiner_weighted_top_k() {
593 let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
594 let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
595 let result = combiner.combine(&scores);
596 assert!((result - 3.857).abs() < 0.01);
601 }
602
603 #[test]
604 fn test_combiner_weighted_top_k_less_than_k() {
605 let scores = vec![(0, 2.0), (1, 1.0)];
606 let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
607 let result = combiner.combine(&scores);
608 assert!((result - 1.588).abs() < 0.01);
613 }
614
615 #[test]
616 fn test_combiner_empty_scores() {
617 let scores: Vec<(u32, f32)> = vec![];
618 assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
619 assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
620 assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
621 assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
622 assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
623 }
624
625 #[test]
626 fn test_combiner_single_score() {
627 let scores = vec![(0, 5.0)];
628 assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
630 assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
631 assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
632 assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
633 assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
634 }
635
636 #[test]
637 fn test_default_combiner_is_log_sum_exp() {
638 let combiner = MultiValueCombiner::default();
639 match combiner {
640 MultiValueCombiner::LogSumExp { temperature } => {
641 assert!((temperature - 1.5).abs() < 1e-6);
642 }
643 _ => panic!("Default combiner should be LogSumExp"),
644 }
645 }
646}