hermes_core/query/vector/
combiner.rs1#[derive(Debug, Clone, Copy, PartialEq)]
5pub enum MultiValueCombiner {
6 Sum,
8 Max,
10 Avg,
12 LogSumExp {
16 temperature: f32,
18 },
19 WeightedTopK {
22 k: usize,
24 decay: f32,
26 },
27}
28
29impl Default for MultiValueCombiner {
30 fn default() -> Self {
31 MultiValueCombiner::LogSumExp { temperature: 1.5 }
34 }
35}
36
37impl MultiValueCombiner {
38 pub fn log_sum_exp() -> Self {
40 Self::LogSumExp { temperature: 1.5 }
41 }
42
43 pub fn log_sum_exp_with_temperature(temperature: f32) -> Self {
45 Self::LogSumExp { temperature }
46 }
47
48 pub fn weighted_top_k() -> Self {
50 Self::WeightedTopK { k: 5, decay: 0.7 }
51 }
52
53 pub fn weighted_top_k_with_params(k: usize, decay: f32) -> Self {
55 Self::WeightedTopK { k, decay }
56 }
57
58 pub fn combine(&self, scores: &[(u32, f32)]) -> f32 {
60 if scores.is_empty() {
61 return 0.0;
62 }
63
64 match self {
65 MultiValueCombiner::Sum => scores.iter().map(|(_, s)| s).sum(),
66 MultiValueCombiner::Max => scores
67 .iter()
68 .map(|(_, s)| *s)
69 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
70 .unwrap_or(0.0),
71 MultiValueCombiner::Avg => {
72 let sum: f32 = scores.iter().map(|(_, s)| s).sum();
73 sum / scores.len() as f32
74 }
75 MultiValueCombiner::LogSumExp { temperature } => {
76 let t = *temperature;
79 let max_score = scores
80 .iter()
81 .map(|(_, s)| *s)
82 .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
83 .unwrap_or(0.0);
84
85 let sum_exp: f32 = scores
86 .iter()
87 .map(|(_, s)| (t * (s - max_score)).exp())
88 .sum();
89
90 max_score + sum_exp.ln() / t
91 }
92 MultiValueCombiner::WeightedTopK { k, decay } => {
93 let mut sorted: Vec<f32> = scores.iter().map(|(_, s)| *s).collect();
95 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
96 sorted.truncate(*k);
97
98 let mut weight = 1.0f32;
100 let mut weighted_sum = 0.0f32;
101 let mut weight_total = 0.0f32;
102
103 for score in sorted {
104 weighted_sum += weight * score;
105 weight_total += weight;
106 weight *= decay;
107 }
108
109 if weight_total > 0.0 {
110 weighted_sum / weight_total
111 } else {
112 0.0
113 }
114 }
115 }
116 }
117}
118
119#[cfg(test)]
120mod tests {
121 use super::*;
122
123 #[test]
124 fn test_combiner_sum() {
125 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
126 let combiner = MultiValueCombiner::Sum;
127 assert!((combiner.combine(&scores) - 6.0).abs() < 1e-6);
128 }
129
130 #[test]
131 fn test_combiner_max() {
132 let scores = vec![(0, 1.0), (1, 3.0), (2, 2.0)];
133 let combiner = MultiValueCombiner::Max;
134 assert!((combiner.combine(&scores) - 3.0).abs() < 1e-6);
135 }
136
137 #[test]
138 fn test_combiner_avg() {
139 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
140 let combiner = MultiValueCombiner::Avg;
141 assert!((combiner.combine(&scores) - 2.0).abs() < 1e-6);
142 }
143
144 #[test]
145 fn test_combiner_log_sum_exp() {
146 let scores = vec![(0, 1.0), (1, 2.0), (2, 3.0)];
147 let combiner = MultiValueCombiner::log_sum_exp();
148 let result = combiner.combine(&scores);
149 assert!(result >= 3.0);
151 assert!(result <= 3.0 + (3.0_f32).ln() / 1.5);
152 }
153
154 #[test]
155 fn test_combiner_log_sum_exp_approaches_max_with_high_temp() {
156 let scores = vec![(0, 1.0), (1, 5.0), (2, 2.0)];
157 let combiner = MultiValueCombiner::log_sum_exp_with_temperature(10.0);
159 let result = combiner.combine(&scores);
160 assert!((result - 5.0).abs() < 0.5);
162 }
163
164 #[test]
165 fn test_combiner_weighted_top_k() {
166 let scores = vec![(0, 5.0), (1, 3.0), (2, 1.0), (3, 0.5)];
167 let combiner = MultiValueCombiner::weighted_top_k_with_params(3, 0.5);
168 let result = combiner.combine(&scores);
169 assert!((result - 3.857).abs() < 0.01);
174 }
175
176 #[test]
177 fn test_combiner_weighted_top_k_less_than_k() {
178 let scores = vec![(0, 2.0), (1, 1.0)];
179 let combiner = MultiValueCombiner::weighted_top_k_with_params(5, 0.7);
180 let result = combiner.combine(&scores);
181 assert!((result - 1.588).abs() < 0.01);
186 }
187
188 #[test]
189 fn test_combiner_empty_scores() {
190 let scores: Vec<(u32, f32)> = vec![];
191 assert_eq!(MultiValueCombiner::Sum.combine(&scores), 0.0);
192 assert_eq!(MultiValueCombiner::Max.combine(&scores), 0.0);
193 assert_eq!(MultiValueCombiner::Avg.combine(&scores), 0.0);
194 assert_eq!(MultiValueCombiner::log_sum_exp().combine(&scores), 0.0);
195 assert_eq!(MultiValueCombiner::weighted_top_k().combine(&scores), 0.0);
196 }
197
198 #[test]
199 fn test_combiner_single_score() {
200 let scores = vec![(0, 5.0)];
201 assert!((MultiValueCombiner::Sum.combine(&scores) - 5.0).abs() < 1e-6);
203 assert!((MultiValueCombiner::Max.combine(&scores) - 5.0).abs() < 1e-6);
204 assert!((MultiValueCombiner::Avg.combine(&scores) - 5.0).abs() < 1e-6);
205 assert!((MultiValueCombiner::log_sum_exp().combine(&scores) - 5.0).abs() < 1e-6);
206 assert!((MultiValueCombiner::weighted_top_k().combine(&scores) - 5.0).abs() < 1e-6);
207 }
208
209 #[test]
210 fn test_default_combiner_is_log_sum_exp() {
211 let combiner = MultiValueCombiner::default();
212 match combiner {
213 MultiValueCombiner::LogSumExp { temperature } => {
214 assert!((temperature - 1.5).abs() < 1e-6);
215 }
216 _ => panic!("Default combiner should be LogSumExp"),
217 }
218 }
219}