1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum RerankingMode {
8 Full,
10 TopK,
12 Adaptive,
14 Disabled,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum FusionStrategy {
21 RerankingOnly,
23 RetrievalOnly,
25 Linear,
27 ReciprocalRank,
29 Learned,
31 Harmonic,
33 Geometric,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct RerankingConfig {
40 pub mode: RerankingMode,
42
43 pub max_candidates: usize,
45
46 pub top_k: usize,
48
49 pub fusion_strategy: FusionStrategy,
51
52 pub retrieval_weight: f32,
54
55 pub batch_size: usize,
57
58 pub timeout_ms: Option<u64>,
60
61 pub enable_caching: bool,
63
64 pub cache_size: usize,
66
67 pub enable_diversity: bool,
69
70 pub diversity_weight: f32,
72
73 pub model_name: String,
75
76 pub model_backend: String,
78
79 pub enable_parallel: bool,
81
82 pub num_workers: usize,
84}
85
86impl RerankingConfig {
87 pub fn default_config() -> Self {
89 Self {
90 mode: RerankingMode::TopK,
91 max_candidates: 100,
92 top_k: 10,
93 fusion_strategy: FusionStrategy::Linear,
94 retrieval_weight: 0.3,
95 batch_size: 32,
96 timeout_ms: Some(5000),
97 enable_caching: true,
98 cache_size: 1000,
99 enable_diversity: false,
100 diversity_weight: 0.2,
101 model_name: "cross-encoder/ms-marco-MiniLM-L-12-v2".to_string(),
102 model_backend: "local".to_string(),
103 enable_parallel: true,
104 num_workers: 4,
105 }
106 }
107
108 pub fn accuracy_optimized() -> Self {
110 Self {
111 mode: RerankingMode::Full,
112 max_candidates: 200,
113 top_k: 10,
114 fusion_strategy: FusionStrategy::RerankingOnly,
115 retrieval_weight: 0.0,
116 batch_size: 16,
117 timeout_ms: Some(10000),
118 enable_caching: true,
119 cache_size: 2000,
120 enable_diversity: true,
121 diversity_weight: 0.3,
122 model_name: "cross-encoder/ms-marco-TinyBERT-L-6-v2".to_string(),
123 model_backend: "local".to_string(),
124 enable_parallel: true,
125 num_workers: 8,
126 }
127 }
128
129 pub fn speed_optimized() -> Self {
131 Self {
132 mode: RerankingMode::TopK,
133 max_candidates: 50,
134 top_k: 10,
135 fusion_strategy: FusionStrategy::Linear,
136 retrieval_weight: 0.5,
137 batch_size: 64,
138 timeout_ms: Some(2000),
139 enable_caching: true,
140 cache_size: 500,
141 enable_diversity: false,
142 diversity_weight: 0.0,
143 model_name: "cross-encoder/ms-marco-MiniLM-L-2-v2".to_string(),
144 model_backend: "local".to_string(),
145 enable_parallel: true,
146 num_workers: 2,
147 }
148 }
149
150 pub fn api_based(api_backend: impl Into<String>) -> Self {
152 Self {
153 mode: RerankingMode::TopK,
154 max_candidates: 100,
155 top_k: 10,
156 fusion_strategy: FusionStrategy::Linear,
157 retrieval_weight: 0.3,
158 batch_size: 16,
159 timeout_ms: Some(30000), enable_caching: true,
161 cache_size: 5000, enable_diversity: false,
163 diversity_weight: 0.2,
164 model_name: "rerank-v2".to_string(),
165 model_backend: api_backend.into(),
166 enable_parallel: false, num_workers: 1,
168 }
169 }
170
171 pub fn validate(&self) -> Result<(), String> {
173 if self.max_candidates == 0 {
174 return Err("max_candidates must be greater than 0".to_string());
175 }
176
177 if self.top_k == 0 {
178 return Err("top_k must be greater than 0".to_string());
179 }
180
181 if self.top_k > self.max_candidates {
182 return Err("top_k cannot exceed max_candidates".to_string());
183 }
184
185 if self.retrieval_weight < 0.0 || self.retrieval_weight > 1.0 {
186 return Err("retrieval_weight must be between 0.0 and 1.0".to_string());
187 }
188
189 if self.diversity_weight < 0.0 || self.diversity_weight > 1.0 {
190 return Err("diversity_weight must be between 0.0 and 1.0".to_string());
191 }
192
193 if self.batch_size == 0 {
194 return Err("batch_size must be greater than 0".to_string());
195 }
196
197 if self.cache_size == 0 && self.enable_caching {
198 return Err("cache_size must be greater than 0 when caching is enabled".to_string());
199 }
200
201 if self.num_workers == 0 && self.enable_parallel {
202 return Err(
203 "num_workers must be greater than 0 when parallel processing is enabled"
204 .to_string(),
205 );
206 }
207
208 if self.model_name.is_empty() {
209 return Err("model_name cannot be empty".to_string());
210 }
211
212 Ok(())
213 }
214
215 pub fn reranking_weight(&self) -> f32 {
217 1.0 - self.retrieval_weight
218 }
219}
220
221impl Default for RerankingConfig {
222 fn default() -> Self {
223 Self::default_config()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
230 use super::*;
231
232 #[test]
233 fn test_default_config() {
234 let config = RerankingConfig::default_config();
235 assert_eq!(config.mode, RerankingMode::TopK);
236 assert_eq!(config.max_candidates, 100);
237 assert_eq!(config.top_k, 10);
238 assert!(config.validate().is_ok());
239 }
240
241 #[test]
242 fn test_accuracy_optimized() {
243 let config = RerankingConfig::accuracy_optimized();
244 assert_eq!(config.mode, RerankingMode::Full);
245 assert_eq!(config.fusion_strategy, FusionStrategy::RerankingOnly);
246 assert!(config.enable_diversity);
247 assert!(config.validate().is_ok());
248 }
249
250 #[test]
251 fn test_speed_optimized() {
252 let config = RerankingConfig::speed_optimized();
253 assert_eq!(config.max_candidates, 50);
254 assert!(config.batch_size > 32); assert!(!config.enable_diversity); assert!(config.validate().is_ok());
257 }
258
259 #[test]
260 fn test_api_based() -> Result<()> {
261 let config = RerankingConfig::api_based("cohere");
262 assert_eq!(config.model_backend, "cohere");
263 assert!(config.timeout_ms.expect("test value") > 10000); assert!(config.cache_size > 1000); assert!(config.validate().is_ok());
266 Ok(())
267 }
268
269 #[test]
270 fn test_validation() {
271 let mut config = RerankingConfig::default_config();
272 assert!(config.validate().is_ok());
273
274 config.max_candidates = 0;
275 assert!(config.validate().is_err());
276
277 config = RerankingConfig::default_config();
278 config.top_k = 0;
279 assert!(config.validate().is_err());
280
281 config = RerankingConfig::default_config();
282 config.top_k = 200;
283 config.max_candidates = 100;
284 assert!(config.validate().is_err());
285
286 config = RerankingConfig::default_config();
287 config.retrieval_weight = 1.5;
288 assert!(config.validate().is_err());
289
290 config = RerankingConfig::default_config();
291 config.model_name = "".to_string();
292 assert!(config.validate().is_err());
293 }
294
295 #[test]
296 fn test_reranking_weight() {
297 let mut config = RerankingConfig::default_config();
298 config.retrieval_weight = 0.3;
299 assert!((config.reranking_weight() - 0.7).abs() < 0.001);
300
301 config.retrieval_weight = 0.0;
302 assert_eq!(config.reranking_weight(), 1.0);
303
304 config.retrieval_weight = 1.0;
305 assert_eq!(config.reranking_weight(), 0.0);
306 }
307
308 #[test]
309 fn test_fusion_strategies() {
310 let strategies = vec![
311 FusionStrategy::RerankingOnly,
312 FusionStrategy::RetrievalOnly,
313 FusionStrategy::Linear,
314 FusionStrategy::ReciprocalRank,
315 FusionStrategy::Learned,
316 FusionStrategy::Harmonic,
317 FusionStrategy::Geometric,
318 ];
319
320 for strategy in strategies {
321 let mut config = RerankingConfig::default_config();
322 config.fusion_strategy = strategy;
323 assert!(config.validate().is_ok());
324 }
325 }
326
327 #[test]
328 fn test_reranking_modes() {
329 let modes = vec![
330 RerankingMode::Full,
331 RerankingMode::TopK,
332 RerankingMode::Adaptive,
333 RerankingMode::Disabled,
334 ];
335
336 for mode in modes {
337 let mut config = RerankingConfig::default_config();
338 config.mode = mode;
339 assert!(config.validate().is_ok());
340 }
341 }
342}