1use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13 Sum,
15 Max,
17 Avg,
19 LogSumExp {
23 temperature: f32,
25 },
26 WeightedTopK {
29 k: usize,
31 decay: f32,
33 },
34}
35
36impl Default for MultiValueCombiner {
37 fn default() -> Self {
38 MultiValueCombiner::LogSumExp { temperature: 1.5 }
41 }
42}
43
44impl MultiValueCombiner {
45 pub fn log_sum_exp() -> Self {
47 Self::LogSumExp { temperature: 1.5 }
48 }
49
50 pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52 Self::LogSumExp { temperature }
53 }
54
55 pub fn weighted_top_k() -> Self {
57 Self::WeightedTopK { k: 5, decay: 0.7 }
58 }
59
60 pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62 Self::WeightedTopK { k, decay }
63 }
64
65 pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67 if scores.is_empty() {
68 return 0.0;
69 }
70
71 match self {
72 MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73 MultiValueCombiner::Max => scores
74 .iter()
75 .map(|(_, s)| *s)
76 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77 .unwrap_or(0.0),
78 MultiValueCombiner::Avg => {
79 let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80 sum / scores.len() as f32
81 }
82 MultiValueCombiner::LogSumExp { temperature } => {
83 let t = *temperature;
86 let max_score = scores
87 .iter()
88 .map(|(_, s)| *s)
89 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90 .unwrap_or(0.0);
91
92 let sum_exp: f32 = scores
93 .iter()
94 .map(|(_, s)| (t * (s - max_score)).exp())
95 .sum();
96
97 max_score + sum_exp.ln() / t
98 }
99 MultiValueCombiner::WeightedTopK { k, decay } => {
100 let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103 sorted.truncate(*k);
104
105 let mut weight = 1.0f32;
107 let mut weighted_sum = 0.0f32;
108 let mut weight_total = 0.0f32;
109
110 for score in sorted {
111 weighted_sum += weight * score;
112 weight_total += weight;
113 weight *= decay;
114 }
115
116 if weight_total > 0.0 {
117 weighted_sum / weight_total
118 } else {
119 0.0
120 }
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129 pub field: Field,
131 pub vector: Vec<f32>,
133 pub nprobe: usize,
135 pub rerank_factor: f32,
137 pub combiner: MultiValueCombiner,
139}
140
141impl DenseVectorQuery {
142 pub fn new(field: Field, vector: Vec<f32>) -> Self {
144 Self {
145 field,
146 vector,
147 nprobe: 32,
148 rerank_factor: 3.0,
149 combiner: MultiValueCombiner::Max,
150 }
151 }
152
153 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
155 self.nprobe = nprobe;
156 self
157 }
158
159 pub fn with_rerank_factor(mut self, factor: f32) -> Self {
161 self.rerank_factor = factor;
162 self
163 }
164
165 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
167 self.combiner = combiner;
168 self
169 }
170}
171
172impl Query for DenseVectorQuery {
173 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
174 let field = self.field;
175 let vector = self.vector.clone();
176 let nprobe = self.nprobe;
177 let rerank_factor = self.rerank_factor;
178 let combiner = self.combiner;
179 Box::pin(async move {
180 let results = reader
181 .search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
182 .await?;
183
184 Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
185 })
186 }
187
188 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
189 Box::pin(async move { Ok(u32::MAX) })
190 }
191}
192
193struct DenseVectorScorer {
195 results: Vec<VectorSearchResult>,
196 position: usize,
197 field_id: u32,
198}
199
200impl DenseVectorScorer {
201 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
202 Self {
203 results,
204 position: 0,
205 field_id,
206 }
207 }
208}
209
210impl Scorer for DenseVectorScorer {
211 fn doc(&self) -> DocId {
212 if self.position < self.results.len() {
213 self.results[self.position].doc_id
214 } else {
215 TERMINATED
216 }
217 }
218
219 fn score(&self) -> Score {
220 if self.position < self.results.len() {
221 self.results[self.position].score
222 } else {
223 0.0
224 }
225 }
226
227 fn advance(&mut self) -> DocId {
228 self.position += 1;
229 self.doc()
230 }
231
232 fn seek(&mut self, target: DocId) -> DocId {
233 while self.doc() < target && self.doc() != TERMINATED {
234 self.advance();
235 }
236 self.doc()
237 }
238
239 fn size_hint(&self) -> u32 {
240 (self.results.len() - self.position) as u32
241 }
242
243 fn matched_positions(&self) -> Option<MatchedPositions> {
244 if self.position >= self.results.len() {
245 return None;
246 }
247 let result = &self.results[self.position];
248 let scored_positions: Vec<ScoredPosition> = result
249 .ordinals
250 .iter()
251 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
252 .collect();
253 Some(vec![(self.field_id, scored_positions)])
254 }
255}
256
257#[derive(Debug, Clone)]
259pub struct SparseVectorQuery {
260 pub field: Field,
262 pub vector: Vec<(u32, f32)>,
264 pub combiner: MultiValueCombiner,
266 pub heap_factor: f32,
269 pub weight_threshold: f32,
272 pub max_query_dims: Option<usize>,
275 pub pruning: Option<f32>,
279}
280
281impl SparseVectorQuery {
282 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
289 Self {
290 field,
291 vector,
292 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
293 heap_factor: 1.0,
294 weight_threshold: 0.0,
295 max_query_dims: None,
296 pruning: None,
297 }
298 }
299
300 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
302 self.combiner = combiner;
303 self
304 }
305
306 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
313 self.heap_factor = heap_factor.clamp(0.0, 1.0);
314 self
315 }
316
317 pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
320 self.weight_threshold = threshold;
321 self
322 }
323
324 pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
326 self.max_query_dims = Some(max_dims);
327 self
328 }
329
330 pub fn with_pruning(mut self, fraction: f32) -> Self {
333 self.pruning = Some(fraction.clamp(0.0, 1.0));
334 self
335 }
336
337 fn pruned_vector(&self) -> Vec<(u32, f32)> {
339 let original_len = self.vector.len();
340
341 let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
343 self.vector
344 .iter()
345 .copied()
346 .filter(|(_, w)| w.abs() >= self.weight_threshold)
347 .collect()
348 } else {
349 self.vector.clone()
350 };
351 let after_threshold = v.len();
352
353 let mut sorted_by_weight = false;
355 if let Some(fraction) = self.pruning
356 && fraction < 1.0
357 && v.len() > 1
358 {
359 v.sort_unstable_by(|a, b| {
360 b.1.abs()
361 .partial_cmp(&a.1.abs())
362 .unwrap_or(std::cmp::Ordering::Equal)
363 });
364 sorted_by_weight = true;
365 let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
366 v.truncate(keep);
367 }
368 let after_pruning = v.len();
369
370 if let Some(max_dims) = self.max_query_dims
372 && v.len() > max_dims
373 {
374 if !sorted_by_weight {
375 v.sort_unstable_by(|a, b| {
376 b.1.abs()
377 .partial_cmp(&a.1.abs())
378 .unwrap_or(std::cmp::Ordering::Equal)
379 });
380 }
381 v.truncate(max_dims);
382 }
383
384 if v.len() < original_len {
385 log::debug!(
386 "[sparse query] field={}: pruned {}->{} dims \
387 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{})",
388 self.field.0,
389 original_len,
390 v.len(),
391 original_len,
392 after_threshold,
393 after_threshold,
394 after_pruning,
395 after_pruning,
396 v.len(),
397 );
398 if log::log_enabled!(log::Level::Trace) {
399 for (dim, w) in &v {
400 log::trace!(" dim={}, weight={:.4}", dim, w);
401 }
402 }
403 }
404
405 v
406 }
407
408 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
410 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
411 Self::new(field, vector)
412 }
413
414 #[cfg(feature = "native")]
426 pub fn from_text(
427 field: Field,
428 text: &str,
429 tokenizer_name: &str,
430 weighting: crate::structures::QueryWeighting,
431 sparse_index: Option<&crate::segment::SparseIndex>,
432 ) -> crate::Result<Self> {
433 use crate::structures::QueryWeighting;
434 use crate::tokenizer::tokenizer_cache;
435
436 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
437 let token_ids = tokenizer.tokenize_unique(text)?;
438
439 let weights: Vec<f32> = match weighting {
440 QueryWeighting::One => vec![1.0f32; token_ids.len()],
441 QueryWeighting::Idf => {
442 if let Some(index) = sparse_index {
443 index.idf_weights(&token_ids)
444 } else {
445 vec![1.0f32; token_ids.len()]
446 }
447 }
448 QueryWeighting::IdfFile => {
449 use crate::tokenizer::idf_weights_cache;
450 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
451 token_ids.iter().map(|&id| idf.get(id)).collect()
452 } else {
453 vec![1.0f32; token_ids.len()]
454 }
455 }
456 };
457
458 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
459 Ok(Self::new(field, vector))
460 }
461
462 #[cfg(feature = "native")]
474 pub fn from_text_with_stats(
475 field: Field,
476 text: &str,
477 tokenizer: &crate::tokenizer::HfTokenizer,
478 weighting: crate::structures::QueryWeighting,
479 global_stats: Option<&super::GlobalStats>,
480 ) -> crate::Result<Self> {
481 use crate::structures::QueryWeighting;
482
483 let token_ids = tokenizer.tokenize_unique(text)?;
484
485 let weights: Vec<f32> = match weighting {
486 QueryWeighting::One => vec![1.0f32; token_ids.len()],
487 QueryWeighting::Idf => {
488 if let Some(stats) = global_stats {
489 stats
491 .sparse_idf_weights(field, &token_ids)
492 .into_iter()
493 .map(|w| w.max(0.0))
494 .collect()
495 } else {
496 vec![1.0f32; token_ids.len()]
497 }
498 }
499 QueryWeighting::IdfFile => {
500 vec![1.0f32; token_ids.len()]
503 }
504 };
505
506 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
507 Ok(Self::new(field, vector))
508 }
509
510 #[cfg(feature = "native")]
522 pub fn from_text_with_tokenizer_bytes(
523 field: Field,
524 text: &str,
525 tokenizer_bytes: &[u8],
526 weighting: crate::structures::QueryWeighting,
527 global_stats: Option<&super::GlobalStats>,
528 ) -> crate::Result<Self> {
529 use crate::structures::QueryWeighting;
530 use crate::tokenizer::HfTokenizer;
531
532 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
533 let token_ids = tokenizer.tokenize_unique(text)?;
534
535 let weights: Vec<f32> = match weighting {
536 QueryWeighting::One => vec![1.0f32; token_ids.len()],
537 QueryWeighting::Idf => {
538 if let Some(stats) = global_stats {
539 stats
541 .sparse_idf_weights(field, &token_ids)
542 .into_iter()
543 .map(|w| w.max(0.0))
544 .collect()
545 } else {
546 vec![1.0f32; token_ids.len()]
547 }
548 }
549 QueryWeighting::IdfFile => {
550 vec![1.0f32; token_ids.len()]
553 }
554 };
555
556 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
557 Ok(Self::new(field, vector))
558 }
559}
560
561impl Query for SparseVectorQuery {
562 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
563 let field = self.field;
564 let vector = self.pruned_vector();
565 let combiner = self.combiner;
566 let heap_factor = self.heap_factor;
567 Box::pin(async move {
568 let results = reader
569 .search_sparse_vector(field, &vector, limit, combiner, heap_factor)
570 .await?;
571
572 Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
573 })
574 }
575
576 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
577 Box::pin(async move { Ok(u32::MAX) })
578 }
579}
580
581struct SparseVectorScorer {
583 results: Vec<VectorSearchResult>,
584 position: usize,
585 field_id: u32,
586}
587
588impl SparseVectorScorer {
589 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
590 Self {
591 results,
592 position: 0,
593 field_id,
594 }
595 }
596}
597
598impl Scorer for SparseVectorScorer {
599 fn doc(&self) -> DocId {
600 if self.position < self.results.len() {
601 self.results[self.position].doc_id
602 } else {
603 TERMINATED
604 }
605 }
606
607 fn score(&self) -> Score {
608 if self.position < self.results.len() {
609 self.results[self.position].score
610 } else {
611 0.0
612 }
613 }
614
615 fn advance(&mut self) -> DocId {
616 self.position += 1;
617 self.doc()
618 }
619
620 fn seek(&mut self, target: DocId) -> DocId {
621 while self.doc() < target && self.doc() != TERMINATED {
622 self.advance();
623 }
624 self.doc()
625 }
626
627 fn size_hint(&self) -> u32 {
628 (self.results.len() - self.position) as u32
629 }
630
631 fn matched_positions(&self) -> Option<MatchedPositions> {
632 if self.position >= self.results.len() {
633 return None;
634 }
635 let result = &self.results[self.position];
636 let scored_positions: Vec<ScoredPosition> = result
637 .ordinals
638 .iter()
639 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
640 .collect();
641 Some(vec![(self.field_id, scored_positions)])
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use crate::dsl::Field;
649
650 #[test]
651 fn test_dense_vector_query_builder() {
652 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
653 .with_nprobe(64)
654 .with_rerank_factor(5.0);
655
656 assert_eq!(query.field, Field(0));
657 assert_eq!(query.vector.len(), 3);
658 assert_eq!(query.nprobe, 64);
659 assert_eq!(query.rerank_factor, 5.0);
660 }
661
662 #[test]
663 fn test_sparse_vector_query_new() {
664 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
665 let query = SparseVectorQuery::new(Field(0), sparse.clone());
666
667 assert_eq!(query.field, Field(0));
668 assert_eq!(query.vector, sparse);
669 }
670
671 #[test]
672 fn test_sparse_vector_query_from_indices_weights() {
673 let query =
674 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
675
676 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
677 }
678
679 #[test]
680 fn test_combiner_sum() {
681 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
682 let combiner = MultiValueCombiner::Sum;
683 assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
684 }
685
686 #[test]
687 fn test_combiner_max() {
688 let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
689 let combiner = MultiValueCombiner::Max;
690 assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
691 }
692
693 #[test]
694 fn test_combiner_avg() {
695 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
696 let combiner = MultiValueCombiner::Avg;
697 assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
698 }
699
700 #[test]
701 fn test_combiner_log_sum_exp() {
702 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
703 let combiner = MultiValueCombiner::log_sum_exp();
704 let result = combiner.combine(&scores);
705 assert!(result >= 3.0);
707 assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
708 }
709
710 #[test]
711 fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
712 let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
713 let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
715 let result = combiner.combine(&scores);
716 assert!((result - 5.0).abs() < 0.5);
718 }
719
720 #[test]
721 fn test_combiner_weighted_top_k() {
722 let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
723 let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
724 let result = combiner.combine(&scores);
725 assert!((result - 3.857).abs() < 0.01);
730 }
731
732 #[test]
733 fn test_combiner_weighted_top_k_less_than_k() {
734 let scores = vec![(0, 2.0), (1, 1.0)];
735 let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
736 let result = combiner.combine(&scores);
737 assert!((result - 1.588).abs() < 0.01);
742 }
743
744 #[test]
745 fn test_combiner_empty_scores() {
746 let scores: Vec<(u32, f32)> = vec![];
747 assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
748 assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
749 assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
750 assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
751 assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
752 }
753
754 #[test]
755 fn test_combiner_single_score() {
756 let scores = vec![(0, 5.0)];
757 assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
759 assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
760 assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
761 assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
762 assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
763 }
764
765 #[test]
766 fn test_default_combiner_is_log_sum_exp() {
767 let combiner = MultiValueCombiner::default();
768 match combiner {
769 MultiValueCombiner::LogSumExp { temperature } => {
770 assert!((temperature - 1.5).abs() < 1e-6);
771 }
772 _ => panic!("Default combiner should be LogSumExp"),
773 }
774 }
775}