1use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13 Sum,
15 Max,
17 Avg,
19 LogSumExp {
23 temperature: f32,
25 },
26 WeightedTopK {
29 k: usize,
31 decay: f32,
33 },
34}
35
36impl Default for MultiValueCombiner {
37 fn default() -> Self {
38 MultiValueCombiner::LogSumExp { temperature: 1.5 }
41 }
42}
43
44impl MultiValueCombiner {
45 pub fn log_sum_exp() -> Self {
47 Self::LogSumExp { temperature: 1.5 }
48 }
49
50 pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52 Self::LogSumExp { temperature }
53 }
54
55 pub fn weighted_top_k() -> Self {
57 Self::WeightedTopK { k: 5, decay: 0.7 }
58 }
59
60 pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62 Self::WeightedTopK { k, decay }
63 }
64
65 pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67 if scores.is_empty() {
68 return 0.0;
69 }
70
71 match self {
72 MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73 MultiValueCombiner::Max => scores
74 .iter()
75 .map(|(_, s)| *s)
76 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77 .unwrap_or(0.0),
78 MultiValueCombiner::Avg => {
79 let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80 sum / scores.len() as f32
81 }
82 MultiValueCombiner::LogSumExp { temperature } => {
83 let t = *temperature;
86 let max_score = scores
87 .iter()
88 .map(|(_, s)| *s)
89 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90 .unwrap_or(0.0);
91
92 let sum_exp: f32 = scores
93 .iter()
94 .map(|(_, s)| (t * (s - max_score)).exp())
95 .sum();
96
97 max_score + sum_exp.ln() / t
98 }
99 MultiValueCombiner::WeightedTopK { k, decay } => {
100 let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103 sorted.truncate(*k);
104
105 let mut weight = 1.0f32;
107 let mut weighted_sum = 0.0f32;
108 let mut weight_total = 0.0f32;
109
110 for score in sorted {
111 weighted_sum += weight * score;
112 weight_total += weight;
113 weight *= decay;
114 }
115
116 if weight_total > 0.0 {
117 weighted_sum / weight_total
118 } else {
119 0.0
120 }
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129 pub field: Field,
131 pub vector: Vec<f32>,
133 pub nprobe: usize,
135 pub rerank_factor: f32,
137 pub combiner: MultiValueCombiner,
139}
140
141impl DenseVectorQuery {
142 pub fn new(field: Field, vector: Vec<f32>) -> Self {
144 Self {
145 field,
146 vector,
147 nprobe: 32,
148 rerank_factor: 3.0,
149 combiner: MultiValueCombiner::Max,
150 }
151 }
152
153 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
155 self.nprobe = nprobe;
156 self
157 }
158
159 pub fn with_rerank_factor(mut self, factor: f32) -> Self {
161 self.rerank_factor = factor;
162 self
163 }
164
165 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
167 self.combiner = combiner;
168 self
169 }
170}
171
172impl Query for DenseVectorQuery {
173 fn scorer<'a>(
174 &self,
175 reader: &'a SegmentReader,
176 limit: usize,
177 _predicate: Option<super::DocPredicate<'a>>,
178 ) -> ScorerFuture<'a> {
179 let field = self.field;
180 let vector = self.vector.clone();
181 let nprobe = self.nprobe;
182 let rerank_factor = self.rerank_factor;
183 let combiner = self.combiner;
184 Box::pin(async move {
185 let results = reader
186 .search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
187 .await?;
188
189 Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
190 })
191 }
192
193 #[cfg(feature = "sync")]
194 fn scorer_sync<'a>(
195 &self,
196 reader: &'a SegmentReader,
197 limit: usize,
198 _predicate: Option<super::DocPredicate<'a>>,
199 ) -> crate::Result<Box<dyn Scorer + 'a>> {
200 let results = reader.search_dense_vector_sync(
201 self.field,
202 &self.vector,
203 limit,
204 self.nprobe,
205 self.rerank_factor,
206 self.combiner,
207 )?;
208 Ok(Box::new(DenseVectorScorer::new(results, self.field.0)) as Box<dyn Scorer>)
209 }
210
211 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
212 Box::pin(async move { Ok(u32::MAX) })
213 }
214}
215
216struct DenseVectorScorer {
218 results: Vec<VectorSearchResult>,
219 position: usize,
220 field_id: u32,
221}
222
223impl DenseVectorScorer {
224 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
225 Self {
226 results,
227 position: 0,
228 field_id,
229 }
230 }
231}
232
233impl Scorer for DenseVectorScorer {
234 fn doc(&self) -> DocId {
235 if self.position < self.results.len() {
236 self.results[self.position].doc_id
237 } else {
238 TERMINATED
239 }
240 }
241
242 fn score(&self) -> Score {
243 if self.position < self.results.len() {
244 self.results[self.position].score
245 } else {
246 0.0
247 }
248 }
249
250 fn advance(&mut self) -> DocId {
251 self.position += 1;
252 self.doc()
253 }
254
255 fn seek(&mut self, target: DocId) -> DocId {
256 while self.doc() < target && self.doc() != TERMINATED {
257 self.advance();
258 }
259 self.doc()
260 }
261
262 fn size_hint(&self) -> u32 {
263 (self.results.len() - self.position) as u32
264 }
265
266 fn matched_positions(&self) -> Option<MatchedPositions> {
267 if self.position >= self.results.len() {
268 return None;
269 }
270 let result = &self.results[self.position];
271 let scored_positions: Vec<ScoredPosition> = result
272 .ordinals
273 .iter()
274 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
275 .collect();
276 Some(vec![(self.field_id, scored_positions)])
277 }
278}
279
280#[derive(Debug, Clone)]
282pub struct SparseVectorQuery {
283 pub field: Field,
285 pub vector: Vec<(u32, f32)>,
287 pub combiner: MultiValueCombiner,
289 pub heap_factor: f32,
292 pub weight_threshold: f32,
295 pub max_query_dims: Option<usize>,
298 pub pruning: Option<f32>,
302}
303
304impl SparseVectorQuery {
305 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
312 Self {
313 field,
314 vector,
315 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
316 heap_factor: 1.0,
317 weight_threshold: 0.0,
318 max_query_dims: None,
319 pruning: None,
320 }
321 }
322
323 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
325 self.combiner = combiner;
326 self
327 }
328
329 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
336 self.heap_factor = heap_factor.clamp(0.0, 1.0);
337 self
338 }
339
340 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
343 self.weight_threshold = threshold;
344 self
345 }
346
347 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
349 self.max_query_dims = Some(max_dims);
350 self
351 }
352
353 pub fn with_pruning(mut self, fraction: f32) -> Self {
356 self.pruning = Some(fraction.clamp(0.0, 1.0));
357 self
358 }
359
360 fn pruned_vector(&self) -> Vec<(u32, f32)> {
362 let original_len = self.vector.len();
363
364 let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
366 self.vector
367 .iter()
368 .copied()
369 .filter(|(_, w)| w.abs() >= self.weight_threshold)
370 .collect()
371 } else {
372 self.vector.clone()
373 };
374 let after_threshold = v.len();
375
376 let mut sorted_by_weight = false;
378 if let Some(fraction) = self.pruning
379 && fraction < 1.0
380 && v.len() > 1
381 {
382 v.sort_unstable_by(|a, b| {
383 b.1.abs()
384 .partial_cmp(&a.1.abs())
385 .unwrap_or(std::cmp::Ordering::Equal)
386 });
387 sorted_by_weight = true;
388 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
389 v.truncate(keep);
390 }
391 let after_pruning = v.len();
392
393 if let Some(max_dims) = self.max_query_dims
395 && v.len() > max_dims
396 {
397 if !sorted_by_weight {
398 v.sort_unstable_by(|a, b| {
399 b.1.abs()
400 .partial_cmp(&a.1.abs())
401 .unwrap_or(std::cmp::Ordering::Equal)
402 });
403 }
404 v.truncate(max_dims);
405 }
406
407 if v.len() < original_len {
408 log::debug!(
409 "[sparse query] field={}: pruned {}->{} dims \
410 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{})",
411 self.field.0,
412 original_len,
413 v.len(),
414 original_len,
415 after_threshold,
416 after_threshold,
417 after_pruning,
418 after_pruning,
419 v.len(),
420 );
421 if log::log_enabled!(log::Level::Trace) {
422 for (dim, w) in &v {
423 log::trace!(" dim={}, weight={:.4}", dim, w);
424 }
425 }
426 }
427
428 v
429 }
430
431 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
433 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
434 Self::new(field, vector)
435 }
436
437 #[cfg(feature = "native")]
449 pub fn from_text(
450 field: Field,
451 text: &str,
452 tokenizer_name: &str,
453 weighting: crate::structures::QueryWeighting,
454 sparse_index: Option<&crate::segment::SparseIndex>,
455 ) -> crate::Result<Self> {
456 use crate::structures::QueryWeighting;
457 use crate::tokenizer::tokenizer_cache;
458
459 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
460 let token_ids = tokenizer.tokenize_unique(text)?;
461
462 let weights: Vec<f32> = match weighting {
463 QueryWeighting::One => vec![1.0f32; token_ids.len()],
464 QueryWeighting::Idf => {
465 if let Some(index) = sparse_index {
466 index.idf_weights(&token_ids)
467 } else {
468 vec![1.0f32; token_ids.len()]
469 }
470 }
471 QueryWeighting::IdfFile => {
472 use crate::tokenizer::idf_weights_cache;
473 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
474 token_ids.iter().map(|&id| idf.get(id)).collect()
475 } else {
476 vec![1.0f32; token_ids.len()]
477 }
478 }
479 };
480
481 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
482 Ok(Self::new(field, vector))
483 }
484
485 #[cfg(feature = "native")]
497 pub fn from_text_with_stats(
498 field: Field,
499 text: &str,
500 tokenizer: &crate::tokenizer::HfTokenizer,
501 weighting: crate::structures::QueryWeighting,
502 global_stats: Option<&super::GlobalStats>,
503 ) -> crate::Result<Self> {
504 use crate::structures::QueryWeighting;
505
506 let token_ids = tokenizer.tokenize_unique(text)?;
507
508 let weights: Vec<f32> = match weighting {
509 QueryWeighting::One => vec![1.0f32; token_ids.len()],
510 QueryWeighting::Idf => {
511 if let Some(stats) = global_stats {
512 stats
514 .sparse_idf_weights(field, &token_ids)
515 .into_iter()
516 .map(|w| w.max(0.0))
517 .collect()
518 } else {
519 vec![1.0f32; token_ids.len()]
520 }
521 }
522 QueryWeighting::IdfFile => {
523 vec![1.0f32; token_ids.len()]
526 }
527 };
528
529 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
530 Ok(Self::new(field, vector))
531 }
532
533 #[cfg(feature = "native")]
545 pub fn from_text_with_tokenizer_bytes(
546 field: Field,
547 text: &str,
548 tokenizer_bytes: &[u8],
549 weighting: crate::structures::QueryWeighting,
550 global_stats: Option<&super::GlobalStats>,
551 ) -> crate::Result<Self> {
552 use crate::structures::QueryWeighting;
553 use crate::tokenizer::HfTokenizer;
554
555 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
556 let token_ids = tokenizer.tokenize_unique(text)?;
557
558 let weights: Vec<f32> = match weighting {
559 QueryWeighting::One => vec![1.0f32; token_ids.len()],
560 QueryWeighting::Idf => {
561 if let Some(stats) = global_stats {
562 stats
564 .sparse_idf_weights(field, &token_ids)
565 .into_iter()
566 .map(|w| w.max(0.0))
567 .collect()
568 } else {
569 vec![1.0f32; token_ids.len()]
570 }
571 }
572 QueryWeighting::IdfFile => {
573 vec![1.0f32; token_ids.len()]
576 }
577 };
578
579 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
580 Ok(Self::new(field, vector))
581 }
582}
583
584impl Query for SparseVectorQuery {
585 fn scorer<'a>(
586 &self,
587 reader: &'a SegmentReader,
588 limit: usize,
589 _predicate: Option<super::DocPredicate<'a>>,
590 ) -> ScorerFuture<'a> {
591 let field = self.field;
592 let vector = self.pruned_vector();
593 let combiner = self.combiner;
594 let heap_factor = self.heap_factor;
595 Box::pin(async move {
596 let results = reader
597 .search_sparse_vector(field, &vector, limit, combiner, heap_factor)
598 .await?;
599
600 Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
601 })
602 }
603
604 #[cfg(feature = "sync")]
605 fn scorer_sync<'a>(
606 &self,
607 reader: &'a SegmentReader,
608 limit: usize,
609 _predicate: Option<super::DocPredicate<'a>>,
610 ) -> crate::Result<Box<dyn Scorer + 'a>> {
611 let vector = self.pruned_vector();
612 let results = reader.search_sparse_vector_sync(
613 self.field,
614 &vector,
615 limit,
616 self.combiner,
617 self.heap_factor,
618 )?;
619 Ok(Box::new(SparseVectorScorer::new(results, self.field.0)) as Box<dyn Scorer>)
620 }
621
622 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
623 Box::pin(async move { Ok(u32::MAX) })
624 }
625}
626
627struct SparseVectorScorer {
629 results: Vec<VectorSearchResult>,
630 position: usize,
631 field_id: u32,
632}
633
634impl SparseVectorScorer {
635 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
636 Self {
637 results,
638 position: 0,
639 field_id,
640 }
641 }
642}
643
644impl Scorer for SparseVectorScorer {
645 fn doc(&self) -> DocId {
646 if self.position < self.results.len() {
647 self.results[self.position].doc_id
648 } else {
649 TERMINATED
650 }
651 }
652
653 fn score(&self) -> Score {
654 if self.position < self.results.len() {
655 self.results[self.position].score
656 } else {
657 0.0
658 }
659 }
660
661 fn advance(&mut self) -> DocId {
662 self.position += 1;
663 self.doc()
664 }
665
666 fn seek(&mut self, target: DocId) -> DocId {
667 while self.doc() < target && self.doc() != TERMINATED {
668 self.advance();
669 }
670 self.doc()
671 }
672
673 fn size_hint(&self) -> u32 {
674 (self.results.len() - self.position) as u32
675 }
676
677 fn matched_positions(&self) -> Option<MatchedPositions> {
678 if self.position >= self.results.len() {
679 return None;
680 }
681 let result = &self.results[self.position];
682 let scored_positions: Vec<ScoredPosition> = result
683 .ordinals
684 .iter()
685 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
686 .collect();
687 Some(vec![(self.field_id, scored_positions)])
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694 use crate::dsl::Field;
695
696 #[test]
697 fn test_dense_vector_query_builder() {
698 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
699 .with_nprobe(64)
700 .with_rerank_factor(5.0);
701
702 assert_eq!(query.field, Field(0));
703 assert_eq!(query.vector.len(), 3);
704 assert_eq!(query.nprobe, 64);
705 assert_eq!(query.rerank_factor, 5.0);
706 }
707
708 #[test]
709 fn test_sparse_vector_query_new() {
710 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
711 let query = SparseVectorQuery::new(Field(0), sparse.clone());
712
713 assert_eq!(query.field, Field(0));
714 assert_eq!(query.vector, sparse);
715 }
716
717 #[test]
718 fn test_sparse_vector_query_from_indices_weights() {
719 let query =
720 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
721
722 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
723 }
724
725 #[test]
726 fn test_combiner_sum() {
727 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
728 let combiner = MultiValueCombiner::Sum;
729 assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
730 }
731
732 #[test]
733 fn test_combiner_max() {
734 let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
735 let combiner = MultiValueCombiner::Max;
736 assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
737 }
738
739 #[test]
740 fn test_combiner_avg() {
741 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
742 let combiner = MultiValueCombiner::Avg;
743 assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
744 }
745
746 #[test]
747 fn test_combiner_log_sum_exp() {
748 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
749 let combiner = MultiValueCombiner::log_sum_exp();
750 let result = combiner.combine(&scores);
751 assert!(result >= 3.0);
753 assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
754 }
755
756 #[test]
757 fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
758 let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
759 let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
761 let result = combiner.combine(&scores);
762 assert!((result - 5.0).abs() < 0.5);
764 }
765
766 #[test]
767 fn test_combiner_weighted_top_k() {
768 let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
769 let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
770 let result = combiner.combine(&scores);
771 assert!((result - 3.857).abs() < 0.01);
776 }
777
778 #[test]
779 fn test_combiner_weighted_top_k_less_than_k() {
780 let scores = vec![(0, 2.0), (1, 1.0)];
781 let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
782 let result = combiner.combine(&scores);
783 assert!((result - 1.588).abs() < 0.01);
788 }
789
790 #[test]
791 fn test_combiner_empty_scores() {
792 let scores: Vec<(u32, f32)> = vec![];
793 assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
794 assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
795 assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
796 assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
797 assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
798 }
799
800 #[test]
801 fn test_combiner_single_score() {
802 let scores = vec![(0, 5.0)];
803 assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
805 assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
806 assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
807 assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
808 assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
809 }
810
811 #[test]
812 fn test_default_combiner_is_log_sum_exp() {
813 let combiner = MultiValueCombiner::default();
814 match combiner {
815 MultiValueCombiner::LogSumExp { temperature } => {
816 assert!((temperature - 1.5).abs() < 1e-6);
817 }
818 _ => panic!("Default combiner should be LogSumExp"),
819 }
820 }
821}