Skip to main content

hermes_core/query/
vector.rs

1//! Vector query types for dense and sparse vector search
2
3use crate::dsl::Field;
4use crate::segment::{SegmentReader, VectorSearchResult};
5use crate::{DocId, Score, TERMINATED};
6
7use super::ScoredPosition;
8use super::traits::{CountFuture, MatchedPositions, Query, Scorer, ScorerFuture};
9
10/// Strategy for combining scores when a document has multiple values for the same field
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub enum MultiValueCombiner {
13    /// Sum all scores (accumulates dot product contributions)
14    Sum,
15    /// Take the maximum score
16    Max,
17    /// Take the average score
18    Avg,
19    /// Log-Sum-Exp: smooth maximum approximation (default)
20    /// `score = (1/t) * log(Σ exp(t * sᵢ))`
21    /// Higher temperature → closer to max; lower → closer to mean
22    LogSumExp {
23        /// Temperature parameter (default: 1.5)
24        temperature: f32,
25    },
26    /// Weighted Top-K: weight top scores with exponential decay
27    /// `score = Σ wᵢ * sorted_scores[i]` where `wᵢ = decay^i`
28    WeightedTopK {
29        /// Number of top scores to consider (default: 5)
30        k: usize,
31        /// Decay factor per rank (default: 0.7)
32        decay: f32,
33    },
34}
35
36impl Default for MultiValueCombiner {
37    fn default() -> Self {
38        // LogSumExp with temperature 1.5 provides good balance between
39        // max (best relevance) and sum (saturation from multiple matches)
40        MultiValueCombiner::LogSumExp { temperature: 1.5 }
41    }
42}
43
44impl MultiValueCombiner {
45    /// Create LogSumExp combiner with default temperature (1.5)
46    pub fn log_sum_exp() -> Self {
47        Self::LogSumExp { temperature: 1.5 }
48    }
49
50    /// Create LogSumExp combiner with custom temperature
51    pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
52        Self::LogSumExp { temperature }
53    }
54
55    /// Create WeightedTopK combiner with defaults (k=5, decay=0.7)
56    pub fn weighted_top_k() -> Self {
57        Self::WeightedTopK { k: 5, decay: 0.7 }
58    }
59
60    /// Create WeightedTopK combiner with custom parameters
61    pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
62        Self::WeightedTopK { k, decay }
63    }
64
65    /// Combine multiple scores into a single score
66    pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
67        if scores.is_empty() {
68            return 0.0;
69        }
70
71        match self {
72            MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
73            MultiValueCombiner::Max => scores
74                .iter()
75                .map(|(_, s)| *s)
76                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
77                .unwrap_or(0.0),
78            MultiValueCombiner::Avg => {
79                let sum: f32 = scores.iter().map(|(_, s)| s).sum();
80                sum / scores.len() as f32
81            }
82            MultiValueCombiner::LogSumExp { temperature } => {
83                // Numerically stable log-sum-exp:
84                // LSE(x) = max(x) + log(Σ exp(xᵢ - max(x)))
85                let t = *temperature;
86                let max_score = scores
87                    .iter()
88                    .map(|(_, s)| *s)
89                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
90                    .unwrap_or(0.0);
91
92                let sum_exp: f32 = scores
93                    .iter()
94                    .map(|(_, s)| (t * (s - max_score)).exp())
95                    .sum();
96
97                max_score + sum_exp.ln() / t
98            }
99            MultiValueCombiner::WeightedTopK { k, decay } => {
100                // Sort scores descending and take top k
101                let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
102                sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
103                sorted.truncate(*k);
104
105                // Apply exponential decay weights
106                let mut weight = 1.0f32;
107                let mut weighted_sum = 0.0f32;
108                let mut weight_total = 0.0f32;
109
110                for score in sorted {
111                    weighted_sum += weight * score;
112                    weight_total += weight;
113                    weight *= decay;
114                }
115
116                if weight_total > 0.0 {
117                    weighted_sum / weight_total
118                } else {
119                    0.0
120                }
121            }
122        }
123    }
124}
125
126/// Dense vector query for similarity search
127#[derive(Debug, Clone)]
128pub struct DenseVectorQuery {
129    /// Field containing the dense vectors
130    pub field: Field,
131    /// Query vector
132    pub vector: Vec<f32>,
133    /// Number of clusters to probe (for IVF indexes)
134    pub nprobe: usize,
135    /// Re-ranking factor (multiplied by k for candidate selection)
136    pub rerank_factor: usize,
137    /// How to combine scores for multi-valued documents
138    pub combiner: MultiValueCombiner,
139}
140
141impl DenseVectorQuery {
142    /// Create a new dense vector query
143    pub fn new(field: Field, vector: Vec<f32>) -> Self {
144        Self {
145            field,
146            vector,
147            nprobe: 32,
148            rerank_factor: 3,
149            combiner: MultiValueCombiner::Max,
150        }
151    }
152
153    /// Set the number of clusters to probe (for IVF indexes)
154    pub fn with_nprobe(mut self, nprobe: usize) -> Self {
155        self.nprobe = nprobe;
156        self
157    }
158
159    /// Set the re-ranking factor
160    pub fn with_rerank_factor(mut self, factor: usize) -> Self {
161        self.rerank_factor = factor;
162        self
163    }
164
165    /// Set the multi-value score combiner
166    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
167        self.combiner = combiner;
168        self
169    }
170}
171
172impl Query for DenseVectorQuery {
173    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
174        let field = self.field;
175        let vector = self.vector.clone();
176        let nprobe = self.nprobe;
177        let rerank_factor = self.rerank_factor;
178        let combiner = self.combiner;
179        Box::pin(async move {
180            let results = reader
181                .search_dense_vector(field, &vector, limit, nprobe, rerank_factor, combiner)
182                .await?;
183
184            Ok(Box::new(DenseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
185        })
186    }
187
188    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
189        Box::pin(async move { Ok(u32::MAX) })
190    }
191}
192
193/// Scorer for dense vector search results with ordinal tracking
194struct DenseVectorScorer {
195    results: Vec<VectorSearchResult>,
196    position: usize,
197    field_id: u32,
198}
199
200impl DenseVectorScorer {
201    fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
202        Self {
203            results,
204            position: 0,
205            field_id,
206        }
207    }
208}
209
210impl Scorer for DenseVectorScorer {
211    fn doc(&self) -> DocId {
212        if self.position < self.results.len() {
213            self.results[self.position].doc_id
214        } else {
215            TERMINATED
216        }
217    }
218
219    fn score(&self) -> Score {
220        if self.position < self.results.len() {
221            self.results[self.position].score
222        } else {
223            0.0
224        }
225    }
226
227    fn advance(&mut self) -> DocId {
228        self.position += 1;
229        self.doc()
230    }
231
232    fn seek(&mut self, target: DocId) -> DocId {
233        while self.doc() < target && self.doc() != TERMINATED {
234            self.advance();
235        }
236        self.doc()
237    }
238
239    fn size_hint(&self) -> u32 {
240        (self.results.len() - self.position) as u32
241    }
242
243    fn matched_positions(&self) -> Option<MatchedPositions> {
244        if self.position >= self.results.len() {
245            return None;
246        }
247        let result = &self.results[self.position];
248        let scored_positions: Vec<ScoredPosition> = result
249            .ordinals
250            .iter()
251            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
252            .collect();
253        Some(vec![(self.field_id, scored_positions)])
254    }
255}
256
257/// Sparse vector query for similarity search
258#[derive(Debug, Clone)]
259pub struct SparseVectorQuery {
260    /// Field containing the sparse vectors
261    pub field: Field,
262    /// Query vector as (dimension_id, weight) pairs
263    pub vector: Vec<(u32, f32)>,
264    /// How to combine scores for multi-valued documents
265    pub combiner: MultiValueCombiner,
266    /// Approximate search factor (1.0 = exact, lower values = faster but approximate)
267    /// Controls MaxScore pruning aggressiveness in block-max scoring
268    pub heap_factor: f32,
269    /// Minimum abs(weight) for query dimensions (0.0 = no filtering)
270    /// Dimensions below this threshold are dropped before search.
271    pub weight_threshold: f32,
272    /// Maximum number of query dimensions to process (None = all)
273    /// Keeps only the top-k dimensions by abs(weight).
274    pub max_query_dims: Option<usize>,
275    /// Fraction of query dimensions to keep (0.0-1.0), same semantics as
276    /// indexing-time `pruning`: sort by abs(weight) descending,
277    /// keep top fraction. None or 1.0 = no pruning.
278    pub pruning: Option<f32>,
279}
280
281impl SparseVectorQuery {
282    /// Create a new sparse vector query
283    ///
284    /// Default combiner is `LogSumExp { temperature: 0.7 }` which provides
285    /// saturation for documents with many sparse vectors (e.g., 100+ ordinals).
286    /// This prevents over-weighting from multiple matches while still allowing
287    /// additional matches to contribute to the score.
288    pub fn new(field: Field, vector: Vec<(u32, f32)>) -> Self {
289        Self {
290            field,
291            vector,
292            combiner: MultiValueCombiner::LogSumExp { temperature: 0.7 },
293            heap_factor: 1.0,
294            weight_threshold: 0.0,
295            max_query_dims: None,
296            pruning: None,
297        }
298    }
299
300    /// Set the multi-value score combiner
301    pub fn with_combiner(mut self, combiner: MultiValueCombiner) -> Self {
302        self.combiner = combiner;
303        self
304    }
305
306    /// Set the heap factor for approximate search
307    ///
308    /// Controls the trade-off between speed and recall:
309    /// - 1.0 = exact search (default)
310    /// - 0.8-0.9 = ~20-40% faster with minimal recall loss
311    /// - Lower values = more aggressive pruning, faster but lower recall
312    pub fn with_heap_factor(mut self, heap_factor: f32) -> Self {
313        self.heap_factor = heap_factor.clamp(0.0, 1.0);
314        self
315    }
316
317    /// Set minimum weight threshold for query dimensions
318    /// Dimensions with abs(weight) below this are dropped before search.
319    pub fn with_weight_threshold(mut self, threshold: f32) -> Self {
320        self.weight_threshold = threshold;
321        self
322    }
323
324    /// Set maximum number of query dimensions (top-k by weight)
325    pub fn with_max_query_dims(mut self, max_dims: usize) -> Self {
326        self.max_query_dims = Some(max_dims);
327        self
328    }
329
330    /// Set pruning fraction (0.0-1.0): keep top fraction of query dims by weight.
331    /// Same semantics as indexing-time `pruning`.
332    pub fn with_pruning(mut self, fraction: f32) -> Self {
333        self.pruning = Some(fraction.clamp(0.0, 1.0));
334        self
335    }
336
337    /// Apply weight_threshold, pruning, and max_query_dims, returning the pruned vector.
338    fn pruned_vector(&self) -> Vec<(u32, f32)> {
339        let original_len = self.vector.len();
340
341        // Step 1: weight_threshold — drop dimensions below minimum weight
342        let mut v: Vec<(u32, f32)> = if self.weight_threshold > 0.0 {
343            self.vector
344                .iter()
345                .copied()
346                .filter(|(_, w)| w.abs() >= self.weight_threshold)
347                .collect()
348        } else {
349            self.vector.clone()
350        };
351        let after_threshold = v.len();
352
353        // Step 2: pruning — keep top fraction by abs(weight), same as indexing
354        if let Some(fraction) = self.pruning
355            && fraction < 1.0
356            && v.len() > 1
357        {
358            v.sort_unstable_by(|a, b| {
359                b.1.abs()
360                    .partial_cmp(&a.1.abs())
361                    .unwrap_or(std::cmp::Ordering::Equal)
362            });
363            let keep = ((v.len() as f64 * fraction as f64).ceil() as usize).max(1);
364            v.truncate(keep);
365        }
366        let after_pruning = v.len();
367
368        // Step 3: max_query_dims — absolute cap on dimensions
369        if let Some(max_dims) = self.max_query_dims
370            && v.len() > max_dims
371        {
372            // Sort by descending abs(weight), keep top-k
373            v.sort_unstable_by(|a, b| {
374                b.1.abs()
375                    .partial_cmp(&a.1.abs())
376                    .unwrap_or(std::cmp::Ordering::Equal)
377            });
378            v.truncate(max_dims);
379        }
380
381        if v.len() < original_len {
382            log::debug!(
383                "[sparse query] field={}: pruned {}->{} dims \
384                 (threshold: {}->{}, pruning: {}->{}, max_dims: {}->{})",
385                self.field.0,
386                original_len,
387                v.len(),
388                original_len,
389                after_threshold,
390                after_threshold,
391                after_pruning,
392                after_pruning,
393                v.len(),
394            );
395            if log::log_enabled!(log::Level::Trace) {
396                for (dim, w) in &v {
397                    log::trace!("  dim={}, weight={:.4}", dim, w);
398                }
399            }
400        }
401
402        v
403    }
404
405    /// Create from separate indices and weights vectors
406    pub fn from_indices_weights(field: Field, indices: Vec<u32>, weights: Vec<f32>) -> Self {
407        let vector: Vec<(u32, f32)> = indices.into_iter().zip(weights).collect();
408        Self::new(field, vector)
409    }
410
411    /// Create from raw text using a HuggingFace tokenizer (single segment)
412    ///
413    /// This method tokenizes the text and creates a sparse vector query.
414    /// For multi-segment indexes, use `from_text_with_stats` instead.
415    ///
416    /// # Arguments
417    /// * `field` - The sparse vector field to search
418    /// * `text` - Raw text to tokenize
419    /// * `tokenizer_name` - HuggingFace tokenizer path (e.g., "bert-base-uncased")
420    /// * `weighting` - Weighting strategy for tokens
421    /// * `sparse_index` - Optional sparse index for IDF lookup (required for IDF weighting)
422    #[cfg(feature = "native")]
423    pub fn from_text(
424        field: Field,
425        text: &str,
426        tokenizer_name: &str,
427        weighting: crate::structures::QueryWeighting,
428        sparse_index: Option<&crate::segment::SparseIndex>,
429    ) -> crate::Result<Self> {
430        use crate::structures::QueryWeighting;
431        use crate::tokenizer::tokenizer_cache;
432
433        let tokenizer = tokenizer_cache().get_or_load(tokenizer_name)?;
434        let token_ids = tokenizer.tokenize_unique(text)?;
435
436        let weights: Vec<f32> = match weighting {
437            QueryWeighting::One => vec![1.0f32; token_ids.len()],
438            QueryWeighting::Idf => {
439                if let Some(index) = sparse_index {
440                    index.idf_weights(&token_ids)
441                } else {
442                    vec![1.0f32; token_ids.len()]
443                }
444            }
445            QueryWeighting::IdfFile => {
446                use crate::tokenizer::idf_weights_cache;
447                if let Some(idf) = idf_weights_cache().get_or_load(tokenizer_name) {
448                    token_ids.iter().map(|&id| idf.get(id)).collect()
449                } else {
450                    vec![1.0f32; token_ids.len()]
451                }
452            }
453        };
454
455        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
456        Ok(Self::new(field, vector))
457    }
458
459    /// Create from raw text using global statistics (multi-segment)
460    ///
461    /// This is the recommended method for multi-segment indexes as it uses
462    /// aggregated IDF values across all segments for consistent ranking.
463    ///
464    /// # Arguments
465    /// * `field` - The sparse vector field to search
466    /// * `text` - Raw text to tokenize
467    /// * `tokenizer` - Pre-loaded HuggingFace tokenizer
468    /// * `weighting` - Weighting strategy for tokens
469    /// * `global_stats` - Global statistics for IDF computation
470    #[cfg(feature = "native")]
471    pub fn from_text_with_stats(
472        field: Field,
473        text: &str,
474        tokenizer: &crate::tokenizer::HfTokenizer,
475        weighting: crate::structures::QueryWeighting,
476        global_stats: Option<&super::GlobalStats>,
477    ) -> crate::Result<Self> {
478        use crate::structures::QueryWeighting;
479
480        let token_ids = tokenizer.tokenize_unique(text)?;
481
482        let weights: Vec<f32> = match weighting {
483            QueryWeighting::One => vec![1.0f32; token_ids.len()],
484            QueryWeighting::Idf => {
485                if let Some(stats) = global_stats {
486                    // Clamp to zero: negative weights don't make sense for IDF
487                    stats
488                        .sparse_idf_weights(field, &token_ids)
489                        .into_iter()
490                        .map(|w| w.max(0.0))
491                        .collect()
492                } else {
493                    vec![1.0f32; token_ids.len()]
494                }
495            }
496            QueryWeighting::IdfFile => {
497                // IdfFile requires a tokenizer name for HF model lookup;
498                // this code path doesn't have one, so fall back to 1.0
499                vec![1.0f32; token_ids.len()]
500            }
501        };
502
503        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
504        Ok(Self::new(field, vector))
505    }
506
507    /// Create from raw text, loading tokenizer from index directory
508    ///
509    /// This method supports the `index://` prefix for tokenizer paths,
510    /// loading tokenizer.json from the index directory.
511    ///
512    /// # Arguments
513    /// * `field` - The sparse vector field to search
514    /// * `text` - Raw text to tokenize
515    /// * `tokenizer_bytes` - Tokenizer JSON bytes (pre-loaded from directory)
516    /// * `weighting` - Weighting strategy for tokens
517    /// * `global_stats` - Global statistics for IDF computation
518    #[cfg(feature = "native")]
519    pub fn from_text_with_tokenizer_bytes(
520        field: Field,
521        text: &str,
522        tokenizer_bytes: &[u8],
523        weighting: crate::structures::QueryWeighting,
524        global_stats: Option<&super::GlobalStats>,
525    ) -> crate::Result<Self> {
526        use crate::structures::QueryWeighting;
527        use crate::tokenizer::HfTokenizer;
528
529        let tokenizer = HfTokenizer::from_bytes(tokenizer_bytes)?;
530        let token_ids = tokenizer.tokenize_unique(text)?;
531
532        let weights: Vec<f32> = match weighting {
533            QueryWeighting::One => vec![1.0f32; token_ids.len()],
534            QueryWeighting::Idf => {
535                if let Some(stats) = global_stats {
536                    // Clamp to zero: negative weights don't make sense for IDF
537                    stats
538                        .sparse_idf_weights(field, &token_ids)
539                        .into_iter()
540                        .map(|w| w.max(0.0))
541                        .collect()
542                } else {
543                    vec![1.0f32; token_ids.len()]
544                }
545            }
546            QueryWeighting::IdfFile => {
547                // IdfFile requires a tokenizer name for HF model lookup;
548                // this code path doesn't have one, so fall back to 1.0
549                vec![1.0f32; token_ids.len()]
550            }
551        };
552
553        let vector: Vec<(u32, f32)> = token_ids.into_iter().zip(weights).collect();
554        Ok(Self::new(field, vector))
555    }
556}
557
558impl Query for SparseVectorQuery {
559    fn scorer<'a>(&self, reader: &'a SegmentReader, limit: usize) -> ScorerFuture<'a> {
560        let field = self.field;
561        let vector = self.pruned_vector();
562        let combiner = self.combiner;
563        let heap_factor = self.heap_factor;
564        Box::pin(async move {
565            let results = reader
566                .search_sparse_vector(field, &vector, limit, combiner, heap_factor)
567                .await?;
568
569            Ok(Box::new(SparseVectorScorer::new(results, field.0)) as Box<dyn Scorer>)
570        })
571    }
572
573    fn count_estimate<'a>(&self, _reader: &'a SegmentReader) -> CountFuture<'a> {
574        Box::pin(async move { Ok(u32::MAX) })
575    }
576}
577
578/// Scorer for sparse vector search results with ordinal tracking
579struct SparseVectorScorer {
580    results: Vec<VectorSearchResult>,
581    position: usize,
582    field_id: u32,
583}
584
585impl SparseVectorScorer {
586    fn new(results: Vec<VectorSearchResult>, field_id: u32) -> Self {
587        Self {
588            results,
589            position: 0,
590            field_id,
591        }
592    }
593}
594
595impl Scorer for SparseVectorScorer {
596    fn doc(&self) -> DocId {
597        if self.position < self.results.len() {
598            self.results[self.position].doc_id
599        } else {
600            TERMINATED
601        }
602    }
603
604    fn score(&self) -> Score {
605        if self.position < self.results.len() {
606            self.results[self.position].score
607        } else {
608            0.0
609        }
610    }
611
612    fn advance(&mut self) -> DocId {
613        self.position += 1;
614        self.doc()
615    }
616
617    fn seek(&mut self, target: DocId) -> DocId {
618        while self.doc() < target && self.doc() != TERMINATED {
619            self.advance();
620        }
621        self.doc()
622    }
623
624    fn size_hint(&self) -> u32 {
625        (self.results.len() - self.position) as u32
626    }
627
628    fn matched_positions(&self) -> Option<MatchedPositions> {
629        if self.position >= self.results.len() {
630            return None;
631        }
632        let result = &self.results[self.position];
633        let scored_positions: Vec<ScoredPosition> = result
634            .ordinals
635            .iter()
636            .map(|(ordinal, score)| ScoredPosition::new(*ordinal, *score))
637            .collect();
638        Some(vec![(self.field_id, scored_positions)])
639    }
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645    use crate::dsl::Field;
646
647    #[test]
648    fn test_dense_vector_query_builder() {
649        let query = DenseVectorQuery::new(Field(0), vec![1.0, 2.0, 3.0])
650            .with_nprobe(64)
651            .with_rerank_factor(5);
652
653        assert_eq!(query.field, Field(0));
654        assert_eq!(query.vector.len(), 3);
655        assert_eq!(query.nprobe, 64);
656        assert_eq!(query.rerank_factor, 5);
657    }
658
659    #[test]
660    fn test_sparse_vector_query_new() {
661        let sparse = vec![(1, 0.5), (5, 0.3), (10, 0.2)];
662        let query = SparseVectorQuery::new(Field(0), sparse.clone());
663
664        assert_eq!(query.field, Field(0));
665        assert_eq!(query.vector, sparse);
666    }
667
668    #[test]
669    fn test_sparse_vector_query_from_indices_weights() {
670        let query =
671            SparseVectorQuery::from_indices_weights(Field(0), vec![1, 5, 10], vec![0.5, 0.3, 0.2]);
672
673        assert_eq!(query.vector, vec![(1, 0.5), (5, 0.3), (10, 0.2)]);
674    }
675
676    #[test]
677    fn test_combiner_sum() {
678        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
679        let combiner = MultiValueCombiner::Sum;
680        assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
681    }
682
683    #[test]
684    fn test_combiner_max() {
685        let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
686        let combiner = MultiValueCombiner::Max;
687        assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
688    }
689
690    #[test]
691    fn test_combiner_avg() {
692        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
693        let combiner = MultiValueCombiner::Avg;
694        assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
695    }
696
697    #[test]
698    fn test_combiner_log_sum_exp() {
699        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
700        let combiner = MultiValueCombiner::log_sum_exp();
701        let result = combiner.combine(&scores);
702        // LogSumExp should be between max (3.0) and max + log(n)/t
703        assert!(result >= 3.0);
704        assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
705    }
706
707    #[test]
708    fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
709        let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
710        // High temperature should approach max
711        let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
712        let result = combiner.combine(&scores);
713        // Should be very close to max (5.0)
714        assert!((result - 5.0).abs() < 0.5);
715    }
716
717    #[test]
718    fn test_combiner_weighted_top_k() {
719        let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
720        let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
721        let result = combiner.combine(&scores);
722        // Top 3: 5.0, 3.0, 1.0 with weights 1.0, 0.5, 0.25
723        // weighted_sum = 5*1 + 3*0.5 + 1*0.25 = 6.75
724        // weight_total = 1.75
725        // result = 6.75 / 1.75 ≈ 3.857
726        assert!((result - 3.857).abs() < 0.01);
727    }
728
729    #[test]
730    fn test_combiner_weighted_top_k_less_than_k() {
731        let scores = vec![(0, 2.0), (1, 1.0)];
732        let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
733        let result = combiner.combine(&scores);
734        // Only 2 scores, weights 1.0 and 0.7
735        // weighted_sum = 2*1 + 1*0.7 = 2.7
736        // weight_total = 1.7
737        // result = 2.7 / 1.7 ≈ 1.588
738        assert!((result - 1.588).abs() < 0.01);
739    }
740
741    #[test]
742    fn test_combiner_empty_scores() {
743        let scores: Vec<(u32, f32)> = vec![];
744        assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
745        assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
746        assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
747        assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
748        assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
749    }
750
751    #[test]
752    fn test_combiner_single_score() {
753        let scores = vec![(0, 5.0)];
754        // All combiners should return 5.0 for a single score
755        assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
756        assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
757        assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
758        assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
759        assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
760    }
761
762    #[test]
763    fn test_default_combiner_is_log_sum_exp() {
764        let combiner = MultiValueCombiner::default();
765        match combiner {
766            MultiValueCombiner::LogSumExp { temperature } => {
767                assert!((temperature - 1.5).abs() < 1e-6);
768            }
769            _ => panic!("Default combiner should be LogSumExp"),
770        }
771    }
772}