1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum RerankingMode {
8 Full,
10 TopK,
12 Adaptive,
14 Disabled,
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum FusionStrategy {
21 RerankingOnly,
23 RetrievalOnly,
25 Linear,
27 ReciprocalRank,
29 Learned,
31 Harmonic,
33 Geometric,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct RerankingConfig {
40 pub mode: RerankingMode,
42
43 pub max_candidates: usize,
45
46 pub top_k: usize,
48
49 pub fusion_strategy: FusionStrategy,
51
52 pub retrieval_weight: f32,
54
55 pub batch_size: usize,
57
58 pub timeout_ms: Option<u64>,
60
61 pub enable_caching: bool,
63
64 pub cache_size: usize,
66
67 pub enable_diversity: bool,
69
70 pub diversity_weight: f32,
72
73 pub model_name: String,
75
76 pub model_backend: String,
78
79 pub enable_parallel: bool,
81
82 pub num_workers: usize,
84}
85
86impl RerankingConfig {
87 pub fn default_config() -> Self {
89 Self {
90 mode: RerankingMode::TopK,
91 max_candidates: 100,
92 top_k: 10,
93 fusion_strategy: FusionStrategy::Linear,
94 retrieval_weight: 0.3,
95 batch_size: 32,
96 timeout_ms: Some(5000),
97 enable_caching: true,
98 cache_size: 1000,
99 enable_diversity: false,
100 diversity_weight: 0.2,
101 model_name: "cross-encoder/ms-marco-MiniLM-L-12-v2".to_string(),
102 model_backend: "local".to_string(),
103 enable_parallel: true,
104 num_workers: 4,
105 }
106 }
107
108 pub fn accuracy_optimized() -> Self {
110 Self {
111 mode: RerankingMode::Full,
112 max_candidates: 200,
113 top_k: 10,
114 fusion_strategy: FusionStrategy::RerankingOnly,
115 retrieval_weight: 0.0,
116 batch_size: 16,
117 timeout_ms: Some(10000),
118 enable_caching: true,
119 cache_size: 2000,
120 enable_diversity: true,
121 diversity_weight: 0.3,
122 model_name: "cross-encoder/ms-marco-TinyBERT-L-6-v2".to_string(),
123 model_backend: "local".to_string(),
124 enable_parallel: true,
125 num_workers: 8,
126 }
127 }
128
129 pub fn speed_optimized() -> Self {
131 Self {
132 mode: RerankingMode::TopK,
133 max_candidates: 50,
134 top_k: 10,
135 fusion_strategy: FusionStrategy::Linear,
136 retrieval_weight: 0.5,
137 batch_size: 64,
138 timeout_ms: Some(2000),
139 enable_caching: true,
140 cache_size: 500,
141 enable_diversity: false,
142 diversity_weight: 0.0,
143 model_name: "cross-encoder/ms-marco-MiniLM-L-2-v2".to_string(),
144 model_backend: "local".to_string(),
145 enable_parallel: true,
146 num_workers: 2,
147 }
148 }
149
150 pub fn api_based(api_backend: impl Into<String>) -> Self {
152 Self {
153 mode: RerankingMode::TopK,
154 max_candidates: 100,
155 top_k: 10,
156 fusion_strategy: FusionStrategy::Linear,
157 retrieval_weight: 0.3,
158 batch_size: 16,
159 timeout_ms: Some(30000), enable_caching: true,
161 cache_size: 5000, enable_diversity: false,
163 diversity_weight: 0.2,
164 model_name: "rerank-v2".to_string(),
165 model_backend: api_backend.into(),
166 enable_parallel: false, num_workers: 1,
168 }
169 }
170
171 pub fn validate(&self) -> Result<(), String> {
173 if self.max_candidates == 0 {
174 return Err("max_candidates must be greater than 0".to_string());
175 }
176
177 if self.top_k == 0 {
178 return Err("top_k must be greater than 0".to_string());
179 }
180
181 if self.top_k > self.max_candidates {
182 return Err("top_k cannot exceed max_candidates".to_string());
183 }
184
185 if self.retrieval_weight < 0.0 || self.retrieval_weight > 1.0 {
186 return Err("retrieval_weight must be between 0.0 and 1.0".to_string());
187 }
188
189 if self.diversity_weight < 0.0 || self.diversity_weight > 1.0 {
190 return Err("diversity_weight must be between 0.0 and 1.0".to_string());
191 }
192
193 if self.batch_size == 0 {
194 return Err("batch_size must be greater than 0".to_string());
195 }
196
197 if self.cache_size == 0 && self.enable_caching {
198 return Err("cache_size must be greater than 0 when caching is enabled".to_string());
199 }
200
201 if self.num_workers == 0 && self.enable_parallel {
202 return Err(
203 "num_workers must be greater than 0 when parallel processing is enabled"
204 .to_string(),
205 );
206 }
207
208 if self.model_name.is_empty() {
209 return Err("model_name cannot be empty".to_string());
210 }
211
212 Ok(())
213 }
214
215 pub fn reranking_weight(&self) -> f32 {
217 1.0 - self.retrieval_weight
218 }
219}
220
221impl Default for RerankingConfig {
222 fn default() -> Self {
223 Self::default_config()
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230
231 #[test]
232 fn test_default_config() {
233 let config = RerankingConfig::default_config();
234 assert_eq!(config.mode, RerankingMode::TopK);
235 assert_eq!(config.max_candidates, 100);
236 assert_eq!(config.top_k, 10);
237 assert!(config.validate().is_ok());
238 }
239
240 #[test]
241 fn test_accuracy_optimized() {
242 let config = RerankingConfig::accuracy_optimized();
243 assert_eq!(config.mode, RerankingMode::Full);
244 assert_eq!(config.fusion_strategy, FusionStrategy::RerankingOnly);
245 assert!(config.enable_diversity);
246 assert!(config.validate().is_ok());
247 }
248
249 #[test]
250 fn test_speed_optimized() {
251 let config = RerankingConfig::speed_optimized();
252 assert_eq!(config.max_candidates, 50);
253 assert!(config.batch_size > 32); assert!(!config.enable_diversity); assert!(config.validate().is_ok());
256 }
257
258 #[test]
259 fn test_api_based() {
260 let config = RerankingConfig::api_based("cohere");
261 assert_eq!(config.model_backend, "cohere");
262 assert!(config.timeout_ms.unwrap() > 10000); assert!(config.cache_size > 1000); assert!(config.validate().is_ok());
265 }
266
267 #[test]
268 fn test_validation() {
269 let mut config = RerankingConfig::default_config();
270 assert!(config.validate().is_ok());
271
272 config.max_candidates = 0;
273 assert!(config.validate().is_err());
274
275 config = RerankingConfig::default_config();
276 config.top_k = 0;
277 assert!(config.validate().is_err());
278
279 config = RerankingConfig::default_config();
280 config.top_k = 200;
281 config.max_candidates = 100;
282 assert!(config.validate().is_err());
283
284 config = RerankingConfig::default_config();
285 config.retrieval_weight = 1.5;
286 assert!(config.validate().is_err());
287
288 config = RerankingConfig::default_config();
289 config.model_name = "".to_string();
290 assert!(config.validate().is_err());
291 }
292
293 #[test]
294 fn test_reranking_weight() {
295 let mut config = RerankingConfig::default_config();
296 config.retrieval_weight = 0.3;
297 assert!((config.reranking_weight() - 0.7).abs() < 0.001);
298
299 config.retrieval_weight = 0.0;
300 assert_eq!(config.reranking_weight(), 1.0);
301
302 config.retrieval_weight = 1.0;
303 assert_eq!(config.reranking_weight(), 0.0);
304 }
305
306 #[test]
307 fn test_fusion_strategies() {
308 let strategies = vec![
309 FusionStrategy::RerankingOnly,
310 FusionStrategy::RetrievalOnly,
311 FusionStrategy::Linear,
312 FusionStrategy::ReciprocalRank,
313 FusionStrategy::Learned,
314 FusionStrategy::Harmonic,
315 FusionStrategy::Geometric,
316 ];
317
318 for strategy in strategies {
319 let mut config = RerankingConfig::default_config();
320 config.fusion_strategy = strategy;
321 assert!(config.validate().is_ok());
322 }
323 }
324
325 #[test]
326 fn test_reranking_modes() {
327 let modes = vec![
328 RerankingMode::Full,
329 RerankingMode::TopK,
330 RerankingMode::Adaptive,
331 RerankingMode::Disabled,
332 ];
333
334 for mode in modes {
335 let mut config = RerankingConfig::default_config();
336 config.mode = mode;
337 assert!(config.validate().is_ok());
338 }
339 }
340}