1use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13 Sum,
15 Max,
17 Avg,
19 LogSumExp {
23 temperature: f32,
25 },
26 WeightedTopK {
29 k: usize,
31 decay: f32,
33 },
34}
35
36impl Default for MultiValueCombiner {
37 fn default() -> Self {
38 MultiValueCombiner::LogSumExp { temperature: 1.5 }
41 }
42}
43
44impl MultiValueCombiner {
45 pub fn log_sum_exp() -> Self {
47 Self::LogSumExp { temperature: 1.5 }
48 }
49
50 pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52 Self::LogSumExp { temperature }
53 }
54
55 pub fn weighted_top_k() -> Self {
57 Self::WeightedTopK { k: 5, decay: 0.7 }
58 }
59
60 pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62 Self::WeightedTopK { k, decay }
63 }
64
65 pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67 if scores.is_empty() {
68 return 0.0;
69 }
70
71 match self {
72 MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73 MultiValueCombiner::Max => scores
74 .iter()
75 .map(|(_, s)| *s)
76 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77 .unwrap_or(0.0),
78 MultiValueCombiner::Avg => {
79 let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80 sum / scores.len() as f32
81 }
82 MultiValueCombiner::LogSumExp { temperature } => {
83 let t = *temperature;
86 let max_score = scores
87 .iter()
88 .map(|(_, s)| *s)
89 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90 .unwrap_or(0.0);
91
92 let sum_exp: f32 = scores
93 .iter()
94 .map(|(_, s)| (t * (s - max_score)).exp())
95 .sum();
96
97 max_score + sum_exp.ln() / t
98 }
99 MultiValueCombiner::WeightedTopK { k, decay } => {
100 let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103 sorted.truncate(*k);
104
105 let mut weight = 1.0f32;
107 let mut weighted_sum = 0.0f32;
108 let mut weight_total = 0.0f32;
109
110 for score in sorted {
111 weighted_sum += weight * score;
112 weight_total += weight;
113 weight *= decay;
114 }
115
116 if weight_total > 0.0 {
117 weighted_sum / weight_total
118 } else {
119 0.0
120 }
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129 pub field: Field,
131 pub vector: Vec<f32>,
133 pub nprobe: usize,
135 pub rerank_factor: f32,
137 pub combiner: MultiValueCombiner,
139}
140
141impl std::fmt::Display for DenseVectorQuery {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 write!(
144 f,
145 "Dense({}, dim={}, nprobe={}, rerank={})",
146 self.field.0,
147 self.vector.len(),
148 self.nprobe,
149 self.rerank_factor
150 )
151 }
152}
153
154impl DenseVectorQuery {
155 pub fn new(field: Field, vector: Vec<f32>) -> Self {
157 Self {
158 field,
159 vector,
160 nprobe: 32,
161 rerank_factor: 3.0,
162 combiner: MultiValueCombiner::Max,
163 }
164 }
165
166 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
168 self.nprobe = nprobe;
169 self
170 }
171
172 pub fn with_rerank_factor(mut self, factor: f32) -> Self {
174 self.rerank_factor = factor;
175 self
176 }
177
178 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
180 self.combiner = combiner;
181 self
182 }
183}
184
185impl Query for DenseVectorQuery {
186 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
187 let field = self.field;
188 let vector = self.vector.clone();
189 let nprobe = self.nprobe;
190 let rerank_factor = self.rerank_factor;
191 let combiner = self.combiner;
192 Box::pin(async move {
193 let results = reader
194 .search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
195 .await?;
196
197 Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
198 })
199 }
200
201 #[cfg(feature = "sync")]
202 fn scorer_sync<'a>(
203 &self,
204 reader: &'a SegmentReader,
205 limit: usize,
206 ) -> crate::Result<Box<dyn Scorer + 'a>> {
207 let results = reader.search_dense_vector_sync(
208 self.field,
209 &self.vector,
210 limit,
211 self.nprobe,
212 self.rerank_factor,
213 self.combiner,
214 )?;
215 Ok(Box::new(DenseVectorScorer::new(results, self.field.0)) as Box<dyn Scorer>)
216 }
217
218 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
219 Box::pin(async move { Ok(u32::MAX) })
220 }
221}
222
223struct DenseVectorScorer {
225 results: Vec<VectorSearchResult>,
226 position: usize,
227 field_id: u32,
228}
229
230impl DenseVectorScorer {
231 fn new(mut results: Vec<VectorSearchResult>, field_id: u32) -> Self {
232 results.sort_unstable_by_key(|r| r.doc_id);
234 Self {
235 results,
236 position: 0,
237 field_id,
238 }
239 }
240}
241
242impl super::docset::DocSet for DenseVectorScorer {
243 fn doc(&self) -> DocId {
244 if self.position < self.results.len() {
245 self.results[self.position].doc_id
246 } else {
247 TERMINATED
248 }
249 }
250
251 fn advance(&mut self) -> DocId {
252 self.position += 1;
253 self.doc()
254 }
255
256 fn seek(&mut self, target: DocId) -> DocId {
257 let remaining = &self.results[self.position..];
259 let offset = remaining.partition_point(|r| r.doc_id < target);
260 self.position += offset;
261 self.doc()
262 }
263
264 fn size_hint(&self) -> u32 {
265 (self.results.len() - self.position) as u32
266 }
267}
268
269impl Scorer for DenseVectorScorer {
270 fn score(&self) -> Score {
271 if self.position < self.results.len() {
272 self.results[self.position].score
273 } else {
274 0.0
275 }
276 }
277
278 fn matched_positions(&self) -> Option<MatchedPositions> {
279 if self.position >= self.results.len() {
280 return None;
281 }
282 let result = &self.results[self.position];
283 let scored_positions: Vec<ScoredPosition> = result
284 .ordinals
285 .iter()
286 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
287 .collect();
288 Some(vec![(self.field_id, scored_positions)])
289 }
290}
291
292#[derive(Debug, Clone)]
294pub struct SparseVectorQuery {
295 pub field: Field,
297 pub vector: Vec<(u32, f32)>,
299 pub combiner: MultiValueCombiner,
301 pub heap_factor: f32,
304 pub weight_threshold: f32,
307 pub max_query_dims: Option<usize>,
310 pub pruning: Option<f32>,
314 pub over_fetch_factor: f32,
316 pruned: Option<Vec<(u32, f32)>>,
318}
319
320impl std::fmt::Display for SparseVectorQuery {
321 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
322 let dims = self.pruned_dims();
323 write!(f, "Sparse({}, dims={}", self.field.0, dims.len())?;
324 if self.heap_factor < 1.0 {
325 write!(f, ", heap={}", self.heap_factor)?;
326 }
327 if self.vector.len() != dims.len() {
328 write!(f, ", orig={}", self.vector.len())?;
329 }
330 write!(f, ")")
331 }
332}
333
334impl SparseVectorQuery {
335 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
342 Self {
343 field,
344 vector,
345 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
346 heap_factor: 1.0,
347 weight_threshold: 0.0,
348 max_query_dims: None,
349 pruning: None,
350 over_fetch_factor: 2.0,
351 pruned: None,
352 }
353 }
354
355 fn pruned_dims(&self) -> &[(u32, f32)] {
357 self.pruned.as_deref().unwrap_or(&self.vector)
358 }
359
360 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
362 self.combiner = combiner;
363 self
364 }
365
366 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
371 self.over_fetch_factor = factor.max(1.0);
372 self
373 }
374
375 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
382 self.heap_factor = heap_factor.clamp(0.0, 1.0);
383 self
384 }
385
386 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
389 self.weight_threshold = threshold;
390 self.pruned = Some(self.compute_pruned_vector());
391 self
392 }
393
394 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
396 self.max_query_dims = Some(max_dims);
397 self.pruned = Some(self.compute_pruned_vector());
398 self
399 }
400
401 pub fn with_pruning(mut self, fraction: f32) -> Self {
404 self.pruning = Some(fraction.clamp(0.0, 1.0));
405 self.pruned = Some(self.compute_pruned_vector());
406 self
407 }
408
409 fn compute_pruned_vector(&self) -> Vec<(u32, f32)> {
411 let original_len = self.vector.len();
412
413 let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
415 self.vector
416 .iter()
417 .copied()
418 .filter(|(_, w)| w.abs() >= self.weight_threshold)
419 .collect()
420 } else {
421 self.vector.clone()
422 };
423 let after_threshold = v.len();
424
425 let mut sorted_by_weight = false;
427 if let Some(fraction) = self.pruning
428 && fraction < 1.0
429 && v.len() > 1
430 {
431 v.sort_unstable_by(|a, b| {
432 b.1.abs()
433 .partial_cmp(&a.1.abs())
434 .unwrap_or(std::cmp::Ordering::Equal)
435 });
436 sorted_by_weight = true;
437 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
438 v.truncate(keep);
439 }
440 let after_pruning = v.len();
441
442 if let Some(max_dims) = self.max_query_dims
444 && v.len() > max_dims
445 {
446 if !sorted_by_weight {
447 v.sort_unstable_by(|a, b| {
448 b.1.abs()
449 .partial_cmp(&a.1.abs())
450 .unwrap_or(std::cmp::Ordering::Equal)
451 });
452 }
453 v.truncate(max_dims);
454 }
455
456 if v.len() < original_len {
457 let src: Vec<_> = self
458 .vector
459 .iter()
460 .map(|(d, w)| format!("({},{:.4})", d, w))
461 .collect();
462 let pruned_fmt: Vec<_> = v.iter().map(|(d, w)| format!("({},{:.4})", d, w)).collect();
463 log::debug!(
464 "[sparse query] field={}: pruned {}->{} dims \
465 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{}), \
466 source=[{}], pruned=[{}]",
467 self.field.0,
468 original_len,
469 v.len(),
470 original_len,
471 after_threshold,
472 after_threshold,
473 after_pruning,
474 after_pruning,
475 v.len(),
476 src.join(", "),
477 pruned_fmt.join(", "),
478 );
479 }
480
481 v
482 }
483
484 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
486 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
487 Self::new(field, vector)
488 }
489
490 #[cfg(feature = "native")]
502 pub fn from_text(
503 field: Field,
504 text: &str,
505 tokenizer_name: &str,
506 weighting: crate::structures::QueryWeighting,
507 sparse_index: Option<&crate::segment::SparseIndex>,
508 ) -> crate::Result<Self> {
509 use crate::structures::QueryWeighting;
510 use crate::tokenizer::tokenizer_cache;
511
512 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
513 let token_ids = tokenizer.tokenize_unique(text)?;
514
515 let weights: Vec<f32> = match weighting {
516 QueryWeighting::One => vec![1.0f32; token_ids.len()],
517 QueryWeighting::Idf => {
518 if let Some(index) = sparse_index {
519 index.idf_weights(&token_ids)
520 } else {
521 vec![1.0f32; token_ids.len()]
522 }
523 }
524 QueryWeighting::IdfFile => {
525 use crate::tokenizer::idf_weights_cache;
526 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
527 token_ids.iter().map(|&id| idf.get(id)).collect()
528 } else {
529 vec![1.0f32; token_ids.len()]
530 }
531 }
532 };
533
534 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
535 Ok(Self::new(field, vector))
536 }
537
538 #[cfg(feature = "native")]
550 pub fn from_text_with_stats(
551 field: Field,
552 text: &str,
553 tokenizer: &crate::tokenizer::HfTokenizer,
554 weighting: crate::structures::QueryWeighting,
555 global_stats: Option<&super::GlobalStats>,
556 ) -> crate::Result<Self> {
557 use crate::structures::QueryWeighting;
558
559 let token_ids = tokenizer.tokenize_unique(text)?;
560
561 let weights: Vec<f32> = match weighting {
562 QueryWeighting::One => vec![1.0f32; token_ids.len()],
563 QueryWeighting::Idf => {
564 if let Some(stats) = global_stats {
565 stats
567 .sparse_idf_weights(field, &token_ids)
568 .into_iter()
569 .map(|w| w.max(0.0))
570 .collect()
571 } else {
572 vec![1.0f32; token_ids.len()]
573 }
574 }
575 QueryWeighting::IdfFile => {
576 vec![1.0f32; token_ids.len()]
579 }
580 };
581
582 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
583 Ok(Self::new(field, vector))
584 }
585
586 #[cfg(feature = "native")]
598 pub fn from_text_with_tokenizer_bytes(
599 field: Field,
600 text: &str,
601 tokenizer_bytes: &[u8],
602 weighting: crate::structures::QueryWeighting,
603 global_stats: Option<&super::GlobalStats>,
604 ) -> crate::Result<Self> {
605 use crate::structures::QueryWeighting;
606 use crate::tokenizer::HfTokenizer;
607
608 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
609 let token_ids = tokenizer.tokenize_unique(text)?;
610
611 let weights: Vec<f32> = match weighting {
612 QueryWeighting::One => vec![1.0f32; token_ids.len()],
613 QueryWeighting::Idf => {
614 if let Some(stats) = global_stats {
615 stats
617 .sparse_idf_weights(field, &token_ids)
618 .into_iter()
619 .map(|w| w.max(0.0))
620 .collect()
621 } else {
622 vec![1.0f32; token_ids.len()]
623 }
624 }
625 QueryWeighting::IdfFile => {
626 vec![1.0f32; token_ids.len()]
629 }
630 };
631
632 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
633 Ok(Self::new(field, vector))
634 }
635}
636
637impl SparseVectorQuery {
638 fn build_inner_query(&self, reader: &SegmentReader) -> Option<Box<dyn Query>> {
644 let si = reader.sparse_index(self.field)?;
645 let matched: Vec<(u32, f32)> = self
646 .pruned_dims()
647 .iter()
648 .filter(|(d, _)| si.has_dimension(*d))
649 .copied()
650 .collect();
651 if matched.is_empty() {
652 return None;
653 }
654
655 let make_term = |(dim_id, weight)| {
656 SparseTermQuery::new(self.field, dim_id, weight)
657 .with_heap_factor(self.heap_factor)
658 .with_combiner(self.combiner)
659 .with_over_fetch_factor(self.over_fetch_factor)
660 };
661
662 if matched.len() == 1 {
663 return Some(Box::new(make_term(matched[0])));
664 }
665
666 let mut bool_q = super::BooleanQuery::new();
667 for dims in matched {
668 bool_q = bool_q.should(make_term(dims));
669 }
670 Some(Box::new(bool_q))
671 }
672}
673
674impl Query for SparseVectorQuery {
675 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
676 match self.build_inner_query(reader) {
677 None => Box::pin(async { Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer>) }),
678 Some(q) => q.scorer(reader, limit),
679 }
680 }
681
682 #[cfg(feature = "sync")]
683 fn scorer_sync<'a>(
684 &self,
685 reader: &'a SegmentReader,
686 limit: usize,
687 ) -> crate::Result<Box<dyn Scorer + 'a>> {
688 match self.build_inner_query(reader) {
689 None => Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
690 Some(q) => q.scorer_sync(reader, limit),
691 }
692 }
693
694 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
695 Box::pin(async move { Ok(u32::MAX) })
696 }
697}
698
699#[derive(Debug, Clone)]
707pub struct SparseTermQuery {
708 pub field: Field,
709 pub dim_id: u32,
710 pub weight: f32,
711 pub heap_factor: f32,
713 pub combiner: MultiValueCombiner,
715 pub over_fetch_factor: f32,
717}
718
719impl std::fmt::Display for SparseTermQuery {
720 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
721 write!(
722 f,
723 "SparseTerm({}, dim={}, w={:.3})",
724 self.field.0, self.dim_id, self.weight
725 )
726 }
727}
728
729impl SparseTermQuery {
730 pub fn new(field: Field, dim_id: u32, weight: f32) -> Self {
731 Self {
732 field,
733 dim_id,
734 weight,
735 heap_factor: 1.0,
736 combiner: MultiValueCombiner::default(),
737 over_fetch_factor: 2.0,
738 }
739 }
740
741 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
742 self.heap_factor = heap_factor;
743 self
744 }
745
746 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
747 self.combiner = combiner;
748 self
749 }
750
751 pub fn with_over_fetch_factor(mut self, factor: f32) -> Self {
752 self.over_fetch_factor = factor.max(1.0);
753 self
754 }
755
756 fn make_scorer<'a>(
759 &self,
760 reader: &'a SegmentReader,
761 ) -> crate::Result<Option<SparseTermScorer<'a>>> {
762 let si = match reader.sparse_index(self.field) {
763 Some(si) => si,
764 None => return Ok(None),
765 };
766 let (skip_start, skip_count, global_max, block_data_offset) =
767 match si.get_skip_range_full(self.dim_id) {
768 Some(v) => v,
769 None => return Ok(None),
770 };
771 let cursor = super::TermCursor::sparse(
772 si,
773 self.weight,
774 skip_start,
775 skip_count,
776 global_max,
777 block_data_offset,
778 );
779 Ok(Some(SparseTermScorer {
780 cursor,
781 field_id: self.field.0,
782 }))
783 }
784}
785
786impl Query for SparseTermQuery {
787 fn scorer<'a>(&self, reader: &'a SegmentReader, _limit: usize) -> ScorerFuture<'a> {
788 let query = self.clone();
789 Box::pin(async move {
790 let mut scorer = match query.make_scorer(reader)? {
791 Some(s) => s,
792 None => return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
793 };
794 scorer.cursor.ensure_block_loaded().await.ok();
795 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
796 })
797 }
798
799 #[cfg(feature = "sync")]
800 fn scorer_sync<'a>(
801 &self,
802 reader: &'a SegmentReader,
803 _limit: usize,
804 ) -> crate::Result<Box<dyn Scorer + 'a>> {
805 let mut scorer = match self.make_scorer(reader)? {
806 Some(s) => s,
807 None => return Ok(Box::new(super::EmptyScorer) as Box<dyn Scorer + 'a>),
808 };
809 scorer.cursor.ensure_block_loaded_sync().ok();
810 Ok(Box::new(scorer) as Box<dyn Scorer + 'a>)
811 }
812
813 fn count_estimate<'a>(&self, reader: &'a SegmentReader) -> CountFuture<'a> {
814 let field = self.field;
815 let dim_id = self.dim_id;
816 Box::pin(async move {
817 let si = match reader.sparse_index(field) {
818 Some(si) => si,
819 None => return Ok(0),
820 };
821 match si.get_skip_range_full(dim_id) {
822 Some((_, skip_count, _, _)) => Ok((skip_count * 256) as u32),
823 None => Ok(0),
824 }
825 })
826 }
827
828 fn as_sparse_term_query_info(&self) -> Option<super::SparseTermQueryInfo> {
829 Some(super::SparseTermQueryInfo {
830 field: self.field,
831 dim_id: self.dim_id,
832 weight: self.weight,
833 heap_factor: self.heap_factor,
834 combiner: self.combiner,
835 over_fetch_factor: self.over_fetch_factor,
836 })
837 }
838}
839
840struct SparseTermScorer<'a> {
845 cursor: super::TermCursor<'a>,
846 field_id: u32,
847}
848
849impl super::docset::DocSet for SparseTermScorer<'_> {
850 fn doc(&self) -> DocId {
851 let d = self.cursor.doc();
852 if d == u32::MAX { TERMINATED } else { d }
853 }
854
855 fn advance(&mut self) -> DocId {
856 match self.cursor.advance_sync() {
857 Ok(d) if d == u32::MAX => TERMINATED,
858 Ok(d) => d,
859 Err(_) => TERMINATED,
860 }
861 }
862
863 fn seek(&mut self, target: DocId) -> DocId {
864 match self.cursor.seek_sync(target) {
865 Ok(d) if d == u32::MAX => TERMINATED,
866 Ok(d) => d,
867 Err(_) => TERMINATED,
868 }
869 }
870
871 fn size_hint(&self) -> u32 {
872 0
873 }
874}
875
876impl Scorer for SparseTermScorer<'_> {
877 fn score(&self) -> Score {
878 self.cursor.score()
879 }
880
881 fn matched_positions(&self) -> Option<MatchedPositions> {
882 let ordinal = self.cursor.ordinal();
883 let score = self.cursor.score();
884 if score == 0.0 {
885 return None;
886 }
887 Some(vec![(
888 self.field_id,
889 vec![ScoredPosition::new(ordinal as u32, score)],
890 )])
891 }
892}
893
894#[cfg(test)]
895mod tests {
896 use super::*;
897 use crate::dsl::Field;
898
899 #[test]
900 fn test_dense_vector_query_builder() {
901 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
902 .with_nprobe(64)
903 .with_rerank_factor(5.0);
904
905 assert_eq!(query.field, Field(0));
906 assert_eq!(query.vector.len(), 3);
907 assert_eq!(query.nprobe, 64);
908 assert_eq!(query.rerank_factor, 5.0);
909 }
910
911 #[test]
912 fn test_sparse_vector_query_new() {
913 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
914 let query = SparseVectorQuery::new(Field(0), sparse.clone());
915
916 assert_eq!(query.field, Field(0));
917 assert_eq!(query.vector, sparse);
918 }
919
920 #[test]
921 fn test_sparse_vector_query_from_indices_weights() {
922 let query =
923 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
924
925 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
926 }
927
928 #[test]
929 fn test_combiner_sum() {
930 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
931 let combiner = MultiValueCombiner::Sum;
932 assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
933 }
934
935 #[test]
936 fn test_combiner_max() {
937 let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
938 let combiner = MultiValueCombiner::Max;
939 assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
940 }
941
942 #[test]
943 fn test_combiner_avg() {
944 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
945 let combiner = MultiValueCombiner::Avg;
946 assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
947 }
948
949 #[test]
950 fn test_combiner_log_sum_exp() {
951 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
952 let combiner = MultiValueCombiner::log_sum_exp();
953 let result = combiner.combine(&scores);
954 assert!(result >= 3.0);
956 assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
957 }
958
959 #[test]
960 fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
961 let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
962 let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
964 let result = combiner.combine(&scores);
965 assert!((result - 5.0).abs() < 0.5);
967 }
968
969 #[test]
970 fn test_combiner_weighted_top_k() {
971 let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
972 let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
973 let result = combiner.combine(&scores);
974 assert!((result - 3.857).abs() < 0.01);
979 }
980
981 #[test]
982 fn test_combiner_weighted_top_k_less_than_k() {
983 let scores = vec![(0, 2.0), (1, 1.0)];
984 let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
985 let result = combiner.combine(&scores);
986 assert!((result - 1.588).abs() < 0.01);
991 }
992
993 #[test]
994 fn test_combiner_empty_scores() {
995 let scores: Vec<(u32, f32)> = vec![];
996 assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
997 assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
998 assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
999 assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
1000 assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
1001 }
1002
1003 #[test]
1004 fn test_combiner_single_score() {
1005 let scores = vec![(0, 5.0)];
1006 assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
1008 assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
1009 assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
1010 assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
1011 assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
1012 }
1013
1014 #[test]
1015 fn test_default_combiner_is_log_sum_exp() {
1016 let combiner = MultiValueCombiner::default();
1017 match combiner {
1018 MultiValueCombiner::LogSumExp { temperature } => {
1019 assert!((temperature - 1.5).abs() < 1e-6);
1020 }
1021 _ => panic!("Default combiner should be LogSumExp"),
1022 }
1023 }
1024}