Skip to main content

hermes_core/query/
vector.rs

1//! Vector query types for dense and sparse vector search
2
3use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10/// Strategy for combining scores when a document has multiple values for the same field
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13    /// Sum all scores (accumulates dot product contributions)
14    Sum,
15    /// Take the maximum score
16    Max,
17    /// Take the average score
18    Avg,
19    /// Log-Sum-Exp: smooth maximum approximation (default)
20    /// `score = (1/t) * log(Σ exp(t * sᵢ))`
21    /// Higher temperature → closer to max; lower → closer to mean
22    LogSumExp {
23        /// Temperature parameter (default: 1.5)
24        temperature: f32,
25    },
26    /// Weighted Top-K: weight top scores with exponential decay
27    /// `score = Σ wᵢ * sorted_scores[i]` where `wᵢ = decay^i`
28    WeightedTopK {
29        /// Number of top scores to consider (default: 5)
30        k: usize,
31        /// Decay factor per rank (default: 0.7)
32        decay: f32,
33    },
34}
35
36impl Default for MultiValueCombiner {
37    fn default() -> Self {
38        // LogSumExp with temperature 1.5 provides good balance between
39        // max (best relevance) and sum (saturation from multiple matches)
40        MultiValueCombiner::LogSumExp { temperature: 1.5 }
41    }
42}
43
44impl MultiValueCombiner {
45    /// Create LogSumExp combiner with default temperature (1.5)
46    pub fn log_sum_exp() -> Self {
47        Self::LogSumExp { temperature: 1.5 }
48    }
49
50    /// Create LogSumExp combiner with custom temperature
51    pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52        Self::LogSumExp { temperature }
53    }
54
55    /// Create WeightedTopK combiner with defaults (k=5, decay=0.7)
56    pub fn weighted_top_k() -> Self {
57        Self::WeightedTopK { k: 5, decay: 0.7 }
58    }
59
60    /// Create WeightedTopK combiner with custom parameters
61    pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62        Self::WeightedTopK { k, decay }
63    }
64
65    /// Combine multiple scores into a single score
66    pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67        if scores.is_empty() {
68            return 0.0;
69        }
70
71        match self {
72            MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73            MultiValueCombiner::Max => scores
74                .iter()
75                .map(|(_, s)| *s)
76                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77                .unwrap_or(0.0),
78            MultiValueCombiner::Avg => {
79                let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80                sum / scores.len() as f32
81            }
82            MultiValueCombiner::LogSumExp { temperature } => {
83                // Numerically stable log-sum-exp:
84                // LSE(x) = max(x) + log(Σ exp(xᵢ - max(x)))
85                let t = *temperature;
86                let max_score = scores
87                    .iter()
88                    .map(|(_, s)| *s)
89                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90                    .unwrap_or(0.0);
91
92                let sum_exp: f32 = scores
93                    .iter()
94                    .map(|(_, s)| (t * (s - max_score)).exp())
95                    .sum();
96
97                max_score + sum_exp.ln() / t
98            }
99            MultiValueCombiner::WeightedTopK { k, decay } => {
100                // Sort scores descending and take top k
101                let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102                sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103                sorted.truncate(*k);
104
105                // Apply exponential decay weights
106                let mut weight = 1.0f32;
107                let mut weighted_sum = 0.0f32;
108                let mut weight_total = 0.0f32;
109
110                for score in sorted {
111                    weighted_sum += weight * score;
112                    weight_total += weight;
113                    weight *= decay;
114                }
115
116                if weight_total > 0.0 {
117                    weighted_sum / weight_total
118                } else {
119                    0.0
120                }
121            }
122        }
123    }
124}
125
126/// Dense vector query for similarity search
127#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129    /// Field containing the dense vectors
130    pub field: Field,
131    /// Query vector
132    pub vector: Vec<f32>,
133    /// Number of clusters to probe (for IVF indexes)
134    pub nprobe: usize,
135    /// Re-ranking factor (multiplied by k for candidate selection, e.g. 3.0)
136    pub rerank_factor: f32,
137    /// How to combine scores for multi-valued documents
138    pub combiner: MultiValueCombiner,
139}
140
141impl DenseVectorQuery {
142    /// Create a new dense vector query
143    pub fn new(field: Field, vector: Vec<f32>) -> Self {
144        Self {
145            field,
146            vector,
147            nprobe: 32,
148            rerank_factor: 3.0,
149            combiner: MultiValueCombiner::Max,
150        }
151    }
152
153    /// Set the number of clusters to probe (for IVF indexes)
154    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
155        self.nprobe = nprobe;
156        self
157    }
158
159    /// Set the re-ranking factor (e.g. 3.0 = fetch 3x candidates for reranking)
160    pub fn with_rerank_factor(mut self, factor: f32) -> Self {
161        self.rerank_factor = factor;
162        self
163    }
164
165    /// Set the multi-value score combiner
166    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
167        self.combiner = combiner;
168        self
169    }
170}
171
172impl Query for DenseVectorQuery {
173    fn scorer<'a>(
174        &self,
175        reader: &'a SegmentReader,
176        limit: usize,
177        _predicate: Option<super::DocPredicate<'a>>,
178    ) -> ScorerFuture<'a> {
179        let field = self.field;
180        let vector = self.vector.clone();
181        let nprobe = self.nprobe;
182        let rerank_factor = self.rerank_factor;
183        let combiner = self.combiner;
184        Box::pin(async move {
185            let results = reader
186                .search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
187                .await?;
188
189            Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
190        })
191    }
192
193    #[cfg(feature = "sync")]
194    fn scorer_sync<'a>(
195        &self,
196        reader: &'a SegmentReader,
197        limit: usize,
198        _predicate: Option<super::DocPredicate<'a>>,
199    ) -> crate::Result<Box<dyn Scorer + 'a>> {
200        let results = reader.search_dense_vector_sync(
201            self.field,
202            &self.vector,
203            limit,
204            self.nprobe,
205            self.rerank_factor,
206            self.combiner,
207        )?;
208        Ok(Box::new(DenseVectorScorer::new(results, self.field.0)) as Box<dyn Scorer>)
209    }
210
211    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
212        Box::pin(async move { Ok(u32::MAX) })
213    }
214}
215
216/// Scorer for dense vector search results with ordinal tracking
217struct DenseVectorScorer {
218    results: Vec<VectorSearchResult>,
219    position: usize,
220    field_id: u32,
221}
222
223impl DenseVectorScorer {
224    fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
225        Self {
226            results,
227            position: 0,
228            field_id,
229        }
230    }
231}
232
233impl Scorer for DenseVectorScorer {
234    fn doc(&self) -> DocId {
235        if self.position < self.results.len() {
236            self.results[self.position].doc_id
237        } else {
238            TERMINATED
239        }
240    }
241
242    fn score(&self) -> Score {
243        if self.position < self.results.len() {
244            self.results[self.position].score
245        } else {
246            0.0
247        }
248    }
249
250    fn advance(&mut self) -> DocId {
251        self.position += 1;
252        self.doc()
253    }
254
255    fn seek(&mut self, target: DocId) -> DocId {
256        while self.doc() < target && self.doc() != TERMINATED {
257            self.advance();
258        }
259        self.doc()
260    }
261
262    fn size_hint(&self) -> u32 {
263        (self.results.len() - self.position) as u32
264    }
265
266    fn matched_positions(&self) -> Option<MatchedPositions> {
267        if self.position >= self.results.len() {
268            return None;
269        }
270        let result = &self.results[self.position];
271        let scored_positions: Vec<ScoredPosition> = result
272            .ordinals
273            .iter()
274            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
275            .collect();
276        Some(vec![(self.field_id, scored_positions)])
277    }
278}
279
280/// Sparse vector query for similarity search
281#[derive(Debug, Clone)]
282pub struct SparseVectorQuery {
283    /// Field containing the sparse vectors
284    pub field: Field,
285    /// Query vector as (dimension_id, weight) pairs
286    pub vector: Vec<(u32, f32)>,
287    /// How to combine scores for multi-valued documents
288    pub combiner: MultiValueCombiner,
289    /// Approximate search factor (1.0 = exact, lower values = faster but approximate)
290    /// Controls MaxScore pruning aggressiveness in block-max scoring
291    pub heap_factor: f32,
292    /// Minimum abs(weight) for query dimensions (0.0 = no filtering)
293    /// Dimensions below this threshold are dropped before search.
294    pub weight_threshold: f32,
295    /// Maximum number of query dimensions to process (None = all)
296    /// Keeps only the top-k dimensions by abs(weight).
297    pub max_query_dims: Option<usize>,
298    /// Fraction of query dimensions to keep (0.0-1.0), same semantics as
299    /// indexing-time `pruning`: sort by abs(weight) descending,
300    /// keep top fraction. None or 1.0 = no pruning.
301    pub pruning: Option<f32>,
302}
303
304impl SparseVectorQuery {
305    /// Create a new sparse vector query
306    ///
307    /// Default combiner is `LogSumExp { temperature: 0.7 }` which provides
308    /// saturation for documents with many sparse vectors (e.g., 100+ ordinals).
309    /// This prevents over-weighting from multiple matches while still allowing
310    /// additional matches to contribute to the score.
311    pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
312        Self {
313            field,
314            vector,
315            combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
316            heap_factor: 1.0,
317            weight_threshold: 0.0,
318            max_query_dims: None,
319            pruning: None,
320        }
321    }
322
323    /// Set the multi-value score combiner
324    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
325        self.combiner = combiner;
326        self
327    }
328
329    /// Set the heap factor for approximate search
330    ///
331    /// Controls the trade-off between speed and recall:
332    /// - 1.0 = exact search (default)
333    /// - 0.8-0.9 = ~20-40% faster with minimal recall loss
334    /// - Lower values = more aggressive pruning, faster but lower recall
335    pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
336        self.heap_factor = heap_factor.clamp(0.0, 1.0);
337        self
338    }
339
340    /// Set minimum weight threshold for query dimensions
341    /// Dimensions with abs(weight) below this are dropped before search.
342    pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
343        self.weight_threshold = threshold;
344        self
345    }
346
347    /// Set maximum number of query dimensions (top-k by weight)
348    pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
349        self.max_query_dims = Some(max_dims);
350        self
351    }
352
353    /// Set pruning fraction (0.0-1.0): keep top fraction of query dims by weight.
354    /// Same semantics as indexing-time `pruning`.
355    pub fn with_pruning(mut self, fraction: f32) -> Self {
356        self.pruning = Some(fraction.clamp(0.0, 1.0));
357        self
358    }
359
360    /// Apply weight_threshold, pruning, and max_query_dims, returning the pruned vector.
361    fn pruned_vector(&self) -> Vec<(u32, f32)> {
362        let original_len = self.vector.len();
363
364        // Step 1: weight_threshold — drop dimensions below minimum weight
365        let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
366            self.vector
367                .iter()
368                .copied()
369                .filter(|(_, w)| w.abs() >= self.weight_threshold)
370                .collect()
371        } else {
372            self.vector.clone()
373        };
374        let after_threshold = v.len();
375
376        // Step 2: pruning — keep top fraction by abs(weight), same as indexing
377        let mut sorted_by_weight = false;
378        if let Some(fraction) = self.pruning
379            && fraction < 1.0
380            && v.len() > 1
381        {
382            v.sort_unstable_by(|a, b| {
383                b.1.abs()
384                    .partial_cmp(&a.1.abs())
385                    .unwrap_or(std::cmp::Ordering::Equal)
386            });
387            sorted_by_weight = true;
388            let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
389            v.truncate(keep);
390        }
391        let after_pruning = v.len();
392
393        // Step 3: max_query_dims — absolute cap on dimensions
394        if let Some(max_dims) = self.max_query_dims
395            && v.len() > max_dims
396        {
397            if !sorted_by_weight {
398                v.sort_unstable_by(|a, b| {
399                    b.1.abs()
400                        .partial_cmp(&a.1.abs())
401                        .unwrap_or(std::cmp::Ordering::Equal)
402                });
403            }
404            v.truncate(max_dims);
405        }
406
407        if v.len() < original_len {
408            log::debug!(
409                "[sparse query] field={}: pruned {}->{} dims \
410                 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{})",
411                self.field.0,
412                original_len,
413                v.len(),
414                original_len,
415                after_threshold,
416                after_threshold,
417                after_pruning,
418                after_pruning,
419                v.len(),
420            );
421            if log::log_enabled!(log::Level::Trace) {
422                for (dim, w) in &v {
423                    log::trace!("  dim={}, weight={:.4}", dim, w);
424                }
425            }
426        }
427
428        v
429    }
430
431    /// Create from separate indices and weights vectors
432    pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
433        let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
434        Self::new(field, vector)
435    }
436
437    /// Create from raw text using a HuggingFace tokenizer (single segment)
438    ///
439    /// This method tokenizes the text and creates a sparse vector query.
440    /// For multi-segment indexes, use `from_text_with_stats` instead.
441    ///
442    /// # Arguments
443    /// * `field` - The sparse vector field to search
444    /// * `text` - Raw text to tokenize
445    /// * `tokenizer_name` - HuggingFace tokenizer path (e.g., "bert-base-uncased")
446    /// * `weighting` - Weighting strategy for tokens
447    /// * `sparse_index` - Optional sparse index for IDF lookup (required for IDF weighting)
448    #[cfg(feature = "native")]
449    pub fn from_text(
450        field: Field,
451        text: &str,
452        tokenizer_name: &str,
453        weighting: crate::structures::QueryWeighting,
454        sparse_index: Option<&crate::segment::SparseIndex>,
455    ) -> crate::Result<Self> {
456        use crate::structures::QueryWeighting;
457        use crate::tokenizer::tokenizer_cache;
458
459        let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
460        let token_ids = tokenizer.tokenize_unique(text)?;
461
462        let weights: Vec<f32> = match weighting {
463            QueryWeighting::One => vec![1.0f32; token_ids.len()],
464            QueryWeighting::Idf => {
465                if let Some(index) = sparse_index {
466                    index.idf_weights(&token_ids)
467                } else {
468                    vec![1.0f32; token_ids.len()]
469                }
470            }
471            QueryWeighting::IdfFile => {
472                use crate::tokenizer::idf_weights_cache;
473                if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
474                    token_ids.iter().map(|&id| idf.get(id)).collect()
475                } else {
476                    vec![1.0f32; token_ids.len()]
477                }
478            }
479        };
480
481        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
482        Ok(Self::new(field, vector))
483    }
484
485    /// Create from raw text using global statistics (multi-segment)
486    ///
487    /// This is the recommended method for multi-segment indexes as it uses
488    /// aggregated IDF values across all segments for consistent ranking.
489    ///
490    /// # Arguments
491    /// * `field` - The sparse vector field to search
492    /// * `text` - Raw text to tokenize
493    /// * `tokenizer` - Pre-loaded HuggingFace tokenizer
494    /// * `weighting` - Weighting strategy for tokens
495    /// * `global_stats` - Global statistics for IDF computation
496    #[cfg(feature = "native")]
497    pub fn from_text_with_stats(
498        field: Field,
499        text: &str,
500        tokenizer: &crate::tokenizer::HfTokenizer,
501        weighting: crate::structures::QueryWeighting,
502        global_stats: Option<&super::GlobalStats>,
503    ) -> crate::Result<Self> {
504        use crate::structures::QueryWeighting;
505
506        let token_ids = tokenizer.tokenize_unique(text)?;
507
508        let weights: Vec<f32> = match weighting {
509            QueryWeighting::One => vec![1.0f32; token_ids.len()],
510            QueryWeighting::Idf => {
511                if let Some(stats) = global_stats {
512                    // Clamp to zero: negative weights don't make sense for IDF
513                    stats
514                        .sparse_idf_weights(field, &token_ids)
515                        .into_iter()
516                        .map(|w| w.max(0.0))
517                        .collect()
518                } else {
519                    vec![1.0f32; token_ids.len()]
520                }
521            }
522            QueryWeighting::IdfFile => {
523                // IdfFile requires a tokenizer name for HF model lookup;
524                // this code path doesn't have one, so fall back to 1.0
525                vec![1.0f32; token_ids.len()]
526            }
527        };
528
529        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
530        Ok(Self::new(field, vector))
531    }
532
533    /// Create from raw text, loading tokenizer from index directory
534    ///
535    /// This method supports the `index://` prefix for tokenizer paths,
536    /// loading tokenizer.json from the index directory.
537    ///
538    /// # Arguments
539    /// * `field` - The sparse vector field to search
540    /// * `text` - Raw text to tokenize
541    /// * `tokenizer_bytes` - Tokenizer JSON bytes (pre-loaded from directory)
542    /// * `weighting` - Weighting strategy for tokens
543    /// * `global_stats` - Global statistics for IDF computation
544    #[cfg(feature = "native")]
545    pub fn from_text_with_tokenizer_bytes(
546        field: Field,
547        text: &str,
548        tokenizer_bytes: &[u8],
549        weighting: crate::structures::QueryWeighting,
550        global_stats: Option<&super::GlobalStats>,
551    ) -> crate::Result<Self> {
552        use crate::structures::QueryWeighting;
553        use crate::tokenizer::HfTokenizer;
554
555        let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
556        let token_ids = tokenizer.tokenize_unique(text)?;
557
558        let weights: Vec<f32> = match weighting {
559            QueryWeighting::One => vec![1.0f32; token_ids.len()],
560            QueryWeighting::Idf => {
561                if let Some(stats) = global_stats {
562                    // Clamp to zero: negative weights don't make sense for IDF
563                    stats
564                        .sparse_idf_weights(field, &token_ids)
565                        .into_iter()
566                        .map(|w| w.max(0.0))
567                        .collect()
568                } else {
569                    vec![1.0f32; token_ids.len()]
570                }
571            }
572            QueryWeighting::IdfFile => {
573                // IdfFile requires a tokenizer name for HF model lookup;
574                // this code path doesn't have one, so fall back to 1.0
575                vec![1.0f32; token_ids.len()]
576            }
577        };
578
579        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
580        Ok(Self::new(field, vector))
581    }
582}
583
584impl Query for SparseVectorQuery {
585    fn scorer<'a>(
586        &self,
587        reader: &'a SegmentReader,
588        limit: usize,
589        _predicate: Option<super::DocPredicate<'a>>,
590    ) -> ScorerFuture<'a> {
591        let field = self.field;
592        let vector = self.pruned_vector();
593        let combiner = self.combiner;
594        let heap_factor = self.heap_factor;
595        Box::pin(async move {
596            let results = reader
597                .search_sparse_vector(field, &vector, limit, combiner, heap_factor)
598                .await?;
599
600            Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
601        })
602    }
603
604    #[cfg(feature = "sync")]
605    fn scorer_sync<'a>(
606        &self,
607        reader: &'a SegmentReader,
608        limit: usize,
609        _predicate: Option<super::DocPredicate<'a>>,
610    ) -> crate::Result<Box<dyn Scorer + 'a>> {
611        let vector = self.pruned_vector();
612        let results = reader.search_sparse_vector_sync(
613            self.field,
614            &vector,
615            limit,
616            self.combiner,
617            self.heap_factor,
618        )?;
619        Ok(Box::new(SparseVectorScorer::new(results, self.field.0)) as Box<dyn Scorer>)
620    }
621
622    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
623        Box::pin(async move { Ok(u32::MAX) })
624    }
625}
626
627/// Scorer for sparse vector search results with ordinal tracking
628struct SparseVectorScorer {
629    results: Vec<VectorSearchResult>,
630    position: usize,
631    field_id: u32,
632}
633
634impl SparseVectorScorer {
635    fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
636        Self {
637            results,
638            position: 0,
639            field_id,
640        }
641    }
642}
643
644impl Scorer for SparseVectorScorer {
645    fn doc(&self) -> DocId {
646        if self.position < self.results.len() {
647            self.results[self.position].doc_id
648        } else {
649            TERMINATED
650        }
651    }
652
653    fn score(&self) -> Score {
654        if self.position < self.results.len() {
655            self.results[self.position].score
656        } else {
657            0.0
658        }
659    }
660
661    fn advance(&mut self) -> DocId {
662        self.position += 1;
663        self.doc()
664    }
665
666    fn seek(&mut self, target: DocId) -> DocId {
667        while self.doc() < target && self.doc() != TERMINATED {
668            self.advance();
669        }
670        self.doc()
671    }
672
673    fn size_hint(&self) -> u32 {
674        (self.results.len() - self.position) as u32
675    }
676
677    fn matched_positions(&self) -> Option<MatchedPositions> {
678        if self.position >= self.results.len() {
679            return None;
680        }
681        let result = &self.results[self.position];
682        let scored_positions: Vec<ScoredPosition> = result
683            .ordinals
684            .iter()
685            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
686            .collect();
687        Some(vec![(self.field_id, scored_positions)])
688    }
689}
690
691#[cfg(test)]
692mod tests {
693    use super::*;
694    use crate::dsl::Field;
695
696    #[test]
697    fn test_dense_vector_query_builder() {
698        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
699            .with_nprobe(64)
700            .with_rerank_factor(5.0);
701
702        assert_eq!(query.field, Field(0));
703        assert_eq!(query.vector.len(), 3);
704        assert_eq!(query.nprobe, 64);
705        assert_eq!(query.rerank_factor, 5.0);
706    }
707
708    #[test]
709    fn test_sparse_vector_query_new() {
710        let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
711        let query = SparseVectorQuery::new(Field(0), sparse.clone());
712
713        assert_eq!(query.field, Field(0));
714        assert_eq!(query.vector, sparse);
715    }
716
717    #[test]
718    fn test_sparse_vector_query_from_indices_weights() {
719        let query =
720            SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
721
722        assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
723    }
724
725    #[test]
726    fn test_combiner_sum() {
727        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
728        let combiner = MultiValueCombiner::Sum;
729        assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
730    }
731
732    #[test]
733    fn test_combiner_max() {
734        let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
735        let combiner = MultiValueCombiner::Max;
736        assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
737    }
738
739    #[test]
740    fn test_combiner_avg() {
741        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
742        let combiner = MultiValueCombiner::Avg;
743        assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
744    }
745
746    #[test]
747    fn test_combiner_log_sum_exp() {
748        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
749        let combiner = MultiValueCombiner::log_sum_exp();
750        let result = combiner.combine(&scores);
751        // LogSumExp should be between max (3.0) and max + log(n)/t
752        assert!(result >= 3.0);
753        assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
754    }
755
756    #[test]
757    fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
758        let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
759        // High temperature should approach max
760        let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
761        let result = combiner.combine(&scores);
762        // Should be very close to max (5.0)
763        assert!((result - 5.0).abs() < 0.5);
764    }
765
766    #[test]
767    fn test_combiner_weighted_top_k() {
768        let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
769        let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
770        let result = combiner.combine(&scores);
771        // Top 3: 5.0, 3.0, 1.0 with weights 1.0, 0.5, 0.25
772        // weighted_sum = 5*1 + 3*0.5 + 1*0.25 = 6.75
773        // weight_total = 1.75
774        // result = 6.75 / 1.75 ≈ 3.857
775        assert!((result - 3.857).abs() < 0.01);
776    }
777
778    #[test]
779    fn test_combiner_weighted_top_k_less_than_k() {
780        let scores = vec![(0, 2.0), (1, 1.0)];
781        let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
782        let result = combiner.combine(&scores);
783        // Only 2 scores, weights 1.0 and 0.7
784        // weighted_sum = 2*1 + 1*0.7 = 2.7
785        // weight_total = 1.7
786        // result = 2.7 / 1.7 ≈ 1.588
787        assert!((result - 1.588).abs() < 0.01);
788    }
789
790    #[test]
791    fn test_combiner_empty_scores() {
792        let scores: Vec<(u32, f32)> = vec![];
793        assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
794        assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
795        assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
796        assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
797        assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
798    }
799
800    #[test]
801    fn test_combiner_single_score() {
802        let scores = vec![(0, 5.0)];
803        // All combiners should return 5.0 for a single score
804        assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
805        assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
806        assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
807        assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
808        assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
809    }
810
811    #[test]
812    fn test_default_combiner_is_log_sum_exp() {
813        let combiner = MultiValueCombiner::default();
814        match combiner {
815            MultiValueCombiner::LogSumExp { temperature } => {
816                assert!((temperature - 1.5).abs() < 1e-6);
817            }
818            _ => panic!("Default combiner should be LogSumExp"),
819        }
820    }
821}