1use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13 Sum,
15 Max,
17 Avg,
19 LogSumExp {
23 temperature: f32,
25 },
26 WeightedTopK {
29 k: usize,
31 decay: f32,
33 },
34}
35
36impl Default for MultiValueCombiner {
37 fn default() -> Self {
38 MultiValueCombiner::LogSumExp { temperature: 1.5 }
41 }
42}
43
44impl MultiValueCombiner {
45 pub fn log_sum_exp() -> Self {
47 Self::LogSumExp { temperature: 1.5 }
48 }
49
50 pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52 Self::LogSumExp { temperature }
53 }
54
55 pub fn weighted_top_k() -> Self {
57 Self::WeightedTopK { k: 5, decay: 0.7 }
58 }
59
60 pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62 Self::WeightedTopK { k, decay }
63 }
64
65 pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67 if scores.is_empty() {
68 return 0.0;
69 }
70
71 match self {
72 MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73 MultiValueCombiner::Max => scores
74 .iter()
75 .map(|(_, s)| *s)
76 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77 .unwrap_or(0.0),
78 MultiValueCombiner::Avg => {
79 let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80 sum / scores.len() as f32
81 }
82 MultiValueCombiner::LogSumExp { temperature } => {
83 let t = *temperature;
86 let max_score = scores
87 .iter()
88 .map(|(_, s)| *s)
89 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90 .unwrap_or(0.0);
91
92 let sum_exp: f32 = scores
93 .iter()
94 .map(|(_, s)| (t * (s - max_score)).exp())
95 .sum();
96
97 max_score + sum_exp.ln() / t
98 }
99 MultiValueCombiner::WeightedTopK { k, decay } => {
100 let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103 sorted.truncate(*k);
104
105 let mut weight = 1.0f32;
107 let mut weighted_sum = 0.0f32;
108 let mut weight_total = 0.0f32;
109
110 for score in sorted {
111 weighted_sum += weight * score;
112 weight_total += weight;
113 weight *= decay;
114 }
115
116 if weight_total > 0.0 {
117 weighted_sum / weight_total
118 } else {
119 0.0
120 }
121 }
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129 pub field: Field,
131 pub vector: Vec<f32>,
133 pub nprobe: usize,
135 pub rerank_factor: usize,
137 pub combiner: MultiValueCombiner,
139}
140
141impl DenseVectorQuery {
142 pub fn new(field: Field, vector: Vec<f32>) -> Self {
144 Self {
145 field,
146 vector,
147 nprobe: 32,
148 rerank_factor: 3,
149 combiner: MultiValueCombiner::Max,
150 }
151 }
152
153 pub fn with_nprobe(mut self, nprobe: usize) -> Self {
155 self.nprobe = nprobe;
156 self
157 }
158
159 pub fn with_rerank_factor(mut self, factor: usize) -> Self {
161 self.rerank_factor = factor;
162 self
163 }
164
165 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
167 self.combiner = combiner;
168 self
169 }
170}
171
172impl Query for DenseVectorQuery {
173 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
174 let field = self.field;
175 let vector = self.vector.clone();
176 let nprobe = self.nprobe;
177 let rerank_factor = self.rerank_factor;
178 let combiner = self.combiner;
179 Box::pin(async move {
180 let results = reader.search_dense_vector(
181 field,
182 &vector,
183 limit,
184 nprobe,
185 rerank_factor,
186 combiner,
187 )?;
188
189 Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
190 })
191 }
192
193 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
194 Box::pin(async move { Ok(u32::MAX) })
195 }
196}
197
198struct DenseVectorScorer {
200 results: Vec<VectorSearchResult>,
201 position: usize,
202 field_id: u32,
203}
204
205impl DenseVectorScorer {
206 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
207 Self {
208 results,
209 position: 0,
210 field_id,
211 }
212 }
213}
214
215impl Scorer for DenseVectorScorer {
216 fn doc(&self) -> DocId {
217 if self.position < self.results.len() {
218 self.results[self.position].doc_id
219 } else {
220 TERMINATED
221 }
222 }
223
224 fn score(&self) -> Score {
225 if self.position < self.results.len() {
226 self.results[self.position].score
227 } else {
228 0.0
229 }
230 }
231
232 fn advance(&mut self) -> DocId {
233 self.position += 1;
234 self.doc()
235 }
236
237 fn seek(&mut self, target: DocId) -> DocId {
238 while self.doc() < target && self.doc() != TERMINATED {
239 self.advance();
240 }
241 self.doc()
242 }
243
244 fn size_hint(&self) -> u32 {
245 (self.results.len() - self.position) as u32
246 }
247
248 fn matched_positions(&self) -> Option<MatchedPositions> {
249 if self.position >= self.results.len() {
250 return None;
251 }
252 let result = &self.results[self.position];
253 let scored_positions: Vec<ScoredPosition> = result
254 .ordinals
255 .iter()
256 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
257 .collect();
258 Some(vec![(self.field_id, scored_positions)])
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct SparseVectorQuery {
265 pub field: Field,
267 pub vector: Vec<(u32, f32)>,
269 pub combiner: MultiValueCombiner,
271 pub heap_factor: f32,
274}
275
276impl SparseVectorQuery {
277 pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
284 Self {
285 field,
286 vector,
287 combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
288 heap_factor: 1.0,
289 }
290 }
291
292 pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
294 self.combiner = combiner;
295 self
296 }
297
298 pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
305 self.heap_factor = heap_factor.clamp(0.0, 1.0);
306 self
307 }
308
309 pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
311 let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
312 Self::new(field, vector)
313 }
314
315 #[cfg(feature = "native")]
327 pub fn from_text(
328 field: Field,
329 text: &str,
330 tokenizer_name: &str,
331 weighting: crate::structures::QueryWeighting,
332 sparse_index: Option<&crate::segment::SparseIndex>,
333 ) -> crate::Result<Self> {
334 use crate::structures::QueryWeighting;
335 use crate::tokenizer::tokenizer_cache;
336
337 let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
338 let token_ids = tokenizer.tokenize_unique(text)?;
339
340 let weights: Vec<f32> = match weighting {
341 QueryWeighting::One => vec![1.0f32; token_ids.len()],
342 QueryWeighting::Idf => {
343 if let Some(index) = sparse_index {
344 index.idf_weights(&token_ids)
345 } else {
346 vec![1.0f32; token_ids.len()]
347 }
348 }
349 QueryWeighting::IdfFile => {
350 use crate::tokenizer::idf_weights_cache;
351 if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
352 token_ids.iter().map(|&id| idf.get(id)).collect()
353 } else {
354 vec![1.0f32; token_ids.len()]
355 }
356 }
357 };
358
359 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
360 Ok(Self::new(field, vector))
361 }
362
363 #[cfg(feature = "native")]
375 pub fn from_text_with_stats(
376 field: Field,
377 text: &str,
378 tokenizer: &crate::tokenizer::HfTokenizer,
379 weighting: crate::structures::QueryWeighting,
380 global_stats: Option<&super::GlobalStats>,
381 ) -> crate::Result<Self> {
382 use crate::structures::QueryWeighting;
383
384 let token_ids = tokenizer.tokenize_unique(text)?;
385
386 let weights: Vec<f32> = match weighting {
387 QueryWeighting::One => vec![1.0f32; token_ids.len()],
388 QueryWeighting::Idf => {
389 if let Some(stats) = global_stats {
390 stats
392 .sparse_idf_weights(field, &token_ids)
393 .into_iter()
394 .map(|w| w.max(0.0))
395 .collect()
396 } else {
397 vec![1.0f32; token_ids.len()]
398 }
399 }
400 QueryWeighting::IdfFile => {
401 vec![1.0f32; token_ids.len()]
404 }
405 };
406
407 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
408 Ok(Self::new(field, vector))
409 }
410
411 #[cfg(feature = "native")]
423 pub fn from_text_with_tokenizer_bytes(
424 field: Field,
425 text: &str,
426 tokenizer_bytes: &[u8],
427 weighting: crate::structures::QueryWeighting,
428 global_stats: Option<&super::GlobalStats>,
429 ) -> crate::Result<Self> {
430 use crate::structures::QueryWeighting;
431 use crate::tokenizer::HfTokenizer;
432
433 let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
434 let token_ids = tokenizer.tokenize_unique(text)?;
435
436 let weights: Vec<f32> = match weighting {
437 QueryWeighting::One => vec![1.0f32; token_ids.len()],
438 QueryWeighting::Idf => {
439 if let Some(stats) = global_stats {
440 stats
442 .sparse_idf_weights(field, &token_ids)
443 .into_iter()
444 .map(|w| w.max(0.0))
445 .collect()
446 } else {
447 vec![1.0f32; token_ids.len()]
448 }
449 }
450 QueryWeighting::IdfFile => {
451 vec![1.0f32; token_ids.len()]
454 }
455 };
456
457 let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
458 Ok(Self::new(field, vector))
459 }
460}
461
462impl Query for SparseVectorQuery {
463 fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
464 let field = self.field;
465 let vector = self.vector.clone();
466 let combiner = self.combiner;
467 let heap_factor = self.heap_factor;
468 Box::pin(async move {
469 let results = reader
470 .search_sparse_vector(field, &vector, limit, combiner, heap_factor)
471 .await?;
472
473 Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
474 })
475 }
476
477 fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
478 Box::pin(async move { Ok(u32::MAX) })
479 }
480}
481
482struct SparseVectorScorer {
484 results: Vec<VectorSearchResult>,
485 position: usize,
486 field_id: u32,
487}
488
489impl SparseVectorScorer {
490 fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
491 Self {
492 results,
493 position: 0,
494 field_id,
495 }
496 }
497}
498
499impl Scorer for SparseVectorScorer {
500 fn doc(&self) -> DocId {
501 if self.position < self.results.len() {
502 self.results[self.position].doc_id
503 } else {
504 TERMINATED
505 }
506 }
507
508 fn score(&self) -> Score {
509 if self.position < self.results.len() {
510 self.results[self.position].score
511 } else {
512 0.0
513 }
514 }
515
516 fn advance(&mut self) -> DocId {
517 self.position += 1;
518 self.doc()
519 }
520
521 fn seek(&mut self, target: DocId) -> DocId {
522 while self.doc() < target && self.doc() != TERMINATED {
523 self.advance();
524 }
525 self.doc()
526 }
527
528 fn size_hint(&self) -> u32 {
529 (self.results.len() - self.position) as u32
530 }
531
532 fn matched_positions(&self) -> Option<MatchedPositions> {
533 if self.position >= self.results.len() {
534 return None;
535 }
536 let result = &self.results[self.position];
537 let scored_positions: Vec<ScoredPosition> = result
538 .ordinals
539 .iter()
540 .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
541 .collect();
542 Some(vec![(self.field_id, scored_positions)])
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 use super::*;
549 use crate::dsl::Field;
550
551 #[test]
552 fn test_dense_vector_query_builder() {
553 let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
554 .with_nprobe(64)
555 .with_rerank_factor(5);
556
557 assert_eq!(query.field, Field(0));
558 assert_eq!(query.vector.len(), 3);
559 assert_eq!(query.nprobe, 64);
560 assert_eq!(query.rerank_factor, 5);
561 }
562
563 #[test]
564 fn test_sparse_vector_query_new() {
565 let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
566 let query = SparseVectorQuery::new(Field(0), sparse.clone());
567
568 assert_eq!(query.field, Field(0));
569 assert_eq!(query.vector, sparse);
570 }
571
572 #[test]
573 fn test_sparse_vector_query_from_indices_weights() {
574 let query =
575 SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
576
577 assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
578 }
579
580 #[test]
581 fn test_combiner_sum() {
582 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
583 let combiner = MultiValueCombiner::Sum;
584 assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
585 }
586
587 #[test]
588 fn test_combiner_max() {
589 let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
590 let combiner = MultiValueCombiner::Max;
591 assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
592 }
593
594 #[test]
595 fn test_combiner_avg() {
596 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
597 let combiner = MultiValueCombiner::Avg;
598 assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
599 }
600
601 #[test]
602 fn test_combiner_log_sum_exp() {
603 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
604 let combiner = MultiValueCombiner::log_sum_exp();
605 let result = combiner.combine(&scores);
606 assert!(result >= 3.0);
608 assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
609 }
610
611 #[test]
612 fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
613 let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
614 let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
616 let result = combiner.combine(&scores);
617 assert!((result - 5.0).abs() < 0.5);
619 }
620
621 #[test]
622 fn test_combiner_weighted_top_k() {
623 let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
624 let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
625 let result = combiner.combine(&scores);
626 assert!((result - 3.857).abs() < 0.01);
631 }
632
633 #[test]
634 fn test_combiner_weighted_top_k_less_than_k() {
635 let scores = vec![(0, 2.0), (1, 1.0)];
636 let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
637 let result = combiner.combine(&scores);
638 assert!((result - 1.588).abs() < 0.01);
643 }
644
645 #[test]
646 fn test_combiner_empty_scores() {
647 let scores: Vec<(u32, f32)> = vec![];
648 assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
649 assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
650 assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
651 assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
652 assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
653 }
654
655 #[test]
656 fn test_combiner_single_score() {
657 let scores = vec![(0, 5.0)];
658 assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
660 assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
661 assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
662 assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
663 assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
664 }
665
666 #[test]
667 fn test_default_combiner_is_log_sum_exp() {
668 let combiner = MultiValueCombiner::default();
669 match combiner {
670 MultiValueCombiner::LogSumExp { temperature } => {
671 assert!((temperature - 1.5).abs() < 1e-6);
672 }
673 _ => panic!("Default combiner should be LogSumExp"),
674 }
675 }
676}