Skip to main content

hermes_core/query/vector/
combiner.rs

1//! Multi-value score combination strategies for vector search
2
3/// Strategy for combining scores when a document has multiple values for the same field
4#[derive(Debug, Clone, Copy, PartialEq)]
5pub enum MultiValueCombiner {
6    /// Sum all scores (accumulates dot product contributions)
7    Sum,
8    /// Take the maximum score
9    Max,
10    /// Take the average score
11    Avg,
12    /// Log-Sum-Exp: smooth maximum approximation (default)
13    /// `score = (1/t) * log(Σ exp(t * sᵢ))`
14    /// Higher temperature → closer to max; lower → closer to mean
15    LogSumExp {
16        /// Temperature parameter (default: 1.5)
17        temperature: f32,
18    },
19    /// Weighted Top-K: weight top scores with exponential decay
20    /// `score = Σ wᵢ * sorted_scores[i]` where `wᵢ = decay^i`
21    WeightedTopK {
22        /// Number of top scores to consider (default: 5)
23        k: usize,
24        /// Decay factor per rank (default: 0.7)
25        decay: f32,
26    },
27}
28
29impl Default for MultiValueCombiner {
30    fn default() -> Self {
31        // LogSumExp with temperature 1.5 provides good balance between
32        // max (best relevance) and sum (saturation from multiple matches)
33        MultiValueCombiner::LogSumExp { temperature: 1.5 }
34    }
35}
36
37impl MultiValueCombiner {
38    /// Create LogSumExp combiner with default temperature (1.5)
39    pub fn log_sum_exp() -> Self {
40        Self::LogSumExp { temperature: 1.5 }
41    }
42
43    /// Create LogSumExp combiner with custom temperature
44    pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
45        Self::LogSumExp { temperature }
46    }
47
48    /// Create WeightedTopK combiner with defaults (k=5, decay=0.7)
49    pub fn weighted_top_k() -> Self {
50        Self::WeightedTopK { k: 5, decay: 0.7 }
51    }
52
53    /// Create WeightedTopK combiner with custom parameters
54    pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
55        Self::WeightedTopK { k, decay }
56    }
57
58    /// Combine multiple scores into a single score
59    pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
60        if scores.is_empty() {
61            return 0.0;
62        }
63
64        match self {
65            MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
66            MultiValueCombiner::Max => scores
67                .iter()
68                .map(|(_, s)| *s)
69                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
70                .unwrap_or(0.0),
71            MultiValueCombiner::Avg => {
72                let sum: f32 = scores.iter().map(|(_, s)| s).sum();
73                sum / scores.len() as f32
74            }
75            MultiValueCombiner::LogSumExp { temperature } => {
76                // Numerically stable log-sum-exp:
77                // LSE(x) = max(x) + log(Σ exp(xᵢ - max(x)))
78                let t = *temperature;
79                let max_score = scores
80                    .iter()
81                    .map(|(_, s)| *s)
82                    .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
83                    .unwrap_or(0.0);
84
85                let sum_exp: f32 = scores
86                    .iter()
87                    .map(|(_, s)| (t * (s - max_score)).exp())
88                    .sum();
89
90                max_score + sum_exp.ln() / t
91            }
92            MultiValueCombiner::WeightedTopK { k, decay } => {
93                // Sort scores descending and take top k
94                let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
95                sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
96                sorted.truncate(*k);
97
98                // Apply exponential decay weights
99                let mut weight = 1.0f32;
100                let mut weighted_sum = 0.0f32;
101                let mut weight_total = 0.0f32;
102
103                for score in sorted {
104                    weighted_sum += weight * score;
105                    weight_total += weight;
106                    weight *= decay;
107                }
108
109                if weight_total > 0.0 {
110                    weighted_sum / weight_total
111                } else {
112                    0.0
113                }
114            }
115        }
116    }
117}
118
119#[cfg(test)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn test_combiner_sum() {
125        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
126        let combiner = MultiValueCombiner::Sum;
127        assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
128    }
129
130    #[test]
131    fn test_combiner_max() {
132        let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
133        let combiner = MultiValueCombiner::Max;
134        assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
135    }
136
137    #[test]
138    fn test_combiner_avg() {
139        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
140        let combiner = MultiValueCombiner::Avg;
141        assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
142    }
143
144    #[test]
145    fn test_combiner_log_sum_exp() {
146        let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
147        let combiner = MultiValueCombiner::log_sum_exp();
148        let result = combiner.combine(&scores);
149        // LogSumExp should be between max (3.0) and max + log(n)/t
150        assert!(result >= 3.0);
151        assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
152    }
153
154    #[test]
155    fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
156        let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
157        // High temperature should approach max
158        let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
159        let result = combiner.combine(&scores);
160        // Should be very close to max (5.0)
161        assert!((result - 5.0).abs() < 0.5);
162    }
163
164    #[test]
165    fn test_combiner_weighted_top_k() {
166        let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
167        let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
168        let result = combiner.combine(&scores);
169        // Top 3: 5.0, 3.0, 1.0 with weights 1.0, 0.5, 0.25
170        // weighted_sum = 5*1 + 3*0.5 + 1*0.25 = 6.75
171        // weight_total = 1.75
172        // result = 6.75 / 1.75 ≈ 3.857
173        assert!((result - 3.857).abs() < 0.01);
174    }
175
176    #[test]
177    fn test_combiner_weighted_top_k_less_than_k() {
178        let scores = vec![(0, 2.0), (1, 1.0)];
179        let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
180        let result = combiner.combine(&scores);
181        // Only 2 scores, weights 1.0 and 0.7
182        // weighted_sum = 2*1 + 1*0.7 = 2.7
183        // weight_total = 1.7
184        // result = 2.7 / 1.7 ≈ 1.588
185        assert!((result - 1.588).abs() < 0.01);
186    }
187
188    #[test]
189    fn test_combiner_empty_scores() {
190        let scores: Vec<(u32, f32)> = vec![];
191        assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
192        assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
193        assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
194        assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
195        assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
196    }
197
198    #[test]
199    fn test_combiner_single_score() {
200        let scores = vec![(0, 5.0)];
201        // All combiners should return 5.0 for a single score
202        assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
203        assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
204        assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
205        assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
206        assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
207    }
208
209    #[test]
210    fn test_default_combiner_is_log_sum_exp() {
211        let combiner = MultiValueCombiner::default();
212        match combiner {
213            MultiValueCombiner::LogSumExp { temperature } => {
214                assert!((temperature - 1.5).abs() < 1e-6);
215            }
216            _ => panic!("Default combiner should be LogSumExp"),
217        }
218    }
219}