Skip to main content

oxirs_vec/reranking/
config.rs

1//! Configuration for re-ranking
2
3use serde::{Deserialize, Serialize};
4
5/// Re-ranking mode
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum RerankingMode {
8    /// Re-rank all candidates
9    Full,
10    /// Re-rank top-k candidates only
11    TopK,
12    /// Adaptive re-ranking based on score distribution
13    Adaptive,
14    /// No re-ranking (passthrough)
15    Disabled,
16}
17
18/// Strategy for fusing retrieval and re-ranking scores
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum FusionStrategy {
21    /// Use only re-ranking scores
22    RerankingOnly,
23    /// Use only retrieval scores
24    RetrievalOnly,
25    /// Linear combination: α * retrieval + (1-α) * reranking
26    Linear,
27    /// Reciprocal rank fusion
28    ReciprocalRank,
29    /// Learned score fusion (trained weights)
30    Learned,
31    /// Harmonic mean
32    Harmonic,
33    /// Geometric mean
34    Geometric,
35}
36
37/// Configuration for re-ranking
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct RerankingConfig {
40    /// Re-ranking mode
41    pub mode: RerankingMode,
42
43    /// Maximum number of candidates to re-rank
44    pub max_candidates: usize,
45
46    /// Number of final results to return
47    pub top_k: usize,
48
49    /// Fusion strategy for combining scores
50    pub fusion_strategy: FusionStrategy,
51
52    /// Weight for retrieval score in linear fusion (0.0 to 1.0)
53    pub retrieval_weight: f32,
54
55    /// Batch size for cross-encoder inference
56    pub batch_size: usize,
57
58    /// Timeout for re-ranking (milliseconds)
59    pub timeout_ms: Option<u64>,
60
61    /// Enable result caching
62    pub enable_caching: bool,
63
64    /// Cache size (number of entries)
65    pub cache_size: usize,
66
67    /// Enable diversity-aware re-ranking
68    pub enable_diversity: bool,
69
70    /// Diversity weight (0.0 to 1.0)
71    pub diversity_weight: f32,
72
73    /// Model name or path
74    pub model_name: String,
75
76    /// Model backend (local, api, etc.)
77    pub model_backend: String,
78
79    /// Enable parallel batch processing
80    pub enable_parallel: bool,
81
82    /// Number of worker threads
83    pub num_workers: usize,
84}
85
86impl RerankingConfig {
87    /// Create default configuration
88    pub fn default_config() -> Self {
89        Self {
90            mode: RerankingMode::TopK,
91            max_candidates: 100,
92            top_k: 10,
93            fusion_strategy: FusionStrategy::Linear,
94            retrieval_weight: 0.3,
95            batch_size: 32,
96            timeout_ms: Some(5000),
97            enable_caching: true,
98            cache_size: 1000,
99            enable_diversity: false,
100            diversity_weight: 0.2,
101            model_name: "cross-encoder/ms-marco-MiniLM-L-12-v2".to_string(),
102            model_backend: "local".to_string(),
103            enable_parallel: true,
104            num_workers: 4,
105        }
106    }
107
108    /// Create configuration optimized for accuracy
109    pub fn accuracy_optimized() -> Self {
110        Self {
111            mode: RerankingMode::Full,
112            max_candidates: 200,
113            top_k: 10,
114            fusion_strategy: FusionStrategy::RerankingOnly,
115            retrieval_weight: 0.0,
116            batch_size: 16,
117            timeout_ms: Some(10000),
118            enable_caching: true,
119            cache_size: 2000,
120            enable_diversity: true,
121            diversity_weight: 0.3,
122            model_name: "cross-encoder/ms-marco-TinyBERT-L-6-v2".to_string(),
123            model_backend: "local".to_string(),
124            enable_parallel: true,
125            num_workers: 8,
126        }
127    }
128
129    /// Create configuration optimized for speed
130    pub fn speed_optimized() -> Self {
131        Self {
132            mode: RerankingMode::TopK,
133            max_candidates: 50,
134            top_k: 10,
135            fusion_strategy: FusionStrategy::Linear,
136            retrieval_weight: 0.5,
137            batch_size: 64,
138            timeout_ms: Some(2000),
139            enable_caching: true,
140            cache_size: 500,
141            enable_diversity: false,
142            diversity_weight: 0.0,
143            model_name: "cross-encoder/ms-marco-MiniLM-L-2-v2".to_string(),
144            model_backend: "local".to_string(),
145            enable_parallel: true,
146            num_workers: 2,
147        }
148    }
149
150    /// Create configuration for API-based models
151    pub fn api_based(api_backend: impl Into<String>) -> Self {
152        Self {
153            mode: RerankingMode::TopK,
154            max_candidates: 100,
155            top_k: 10,
156            fusion_strategy: FusionStrategy::Linear,
157            retrieval_weight: 0.3,
158            batch_size: 16,
159            timeout_ms: Some(30000), // Longer timeout for API
160            enable_caching: true,
161            cache_size: 5000, // Larger cache for API
162            enable_diversity: false,
163            diversity_weight: 0.2,
164            model_name: "rerank-v2".to_string(),
165            model_backend: api_backend.into(),
166            enable_parallel: false, // API handles parallelism
167            num_workers: 1,
168        }
169    }
170
171    /// Validate configuration
172    pub fn validate(&self) -> Result<(), String> {
173        if self.max_candidates == 0 {
174            return Err("max_candidates must be greater than 0".to_string());
175        }
176
177        if self.top_k == 0 {
178            return Err("top_k must be greater than 0".to_string());
179        }
180
181        if self.top_k > self.max_candidates {
182            return Err("top_k cannot exceed max_candidates".to_string());
183        }
184
185        if self.retrieval_weight < 0.0 || self.retrieval_weight > 1.0 {
186            return Err("retrieval_weight must be between 0.0 and 1.0".to_string());
187        }
188
189        if self.diversity_weight < 0.0 || self.diversity_weight > 1.0 {
190            return Err("diversity_weight must be between 0.0 and 1.0".to_string());
191        }
192
193        if self.batch_size == 0 {
194            return Err("batch_size must be greater than 0".to_string());
195        }
196
197        if self.cache_size == 0 && self.enable_caching {
198            return Err("cache_size must be greater than 0 when caching is enabled".to_string());
199        }
200
201        if self.num_workers == 0 && self.enable_parallel {
202            return Err(
203                "num_workers must be greater than 0 when parallel processing is enabled"
204                    .to_string(),
205            );
206        }
207
208        if self.model_name.is_empty() {
209            return Err("model_name cannot be empty".to_string());
210        }
211
212        Ok(())
213    }
214
215    /// Get fusion weight for reranking score
216    pub fn reranking_weight(&self) -> f32 {
217        1.0 - self.retrieval_weight
218    }
219}
220
221impl Default for RerankingConfig {
222    fn default() -> Self {
223        Self::default_config()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
230    use super::*;
231
232    #[test]
233    fn test_default_config() {
234        let config = RerankingConfig::default_config();
235        assert_eq!(config.mode, RerankingMode::TopK);
236        assert_eq!(config.max_candidates, 100);
237        assert_eq!(config.top_k, 10);
238        assert!(config.validate().is_ok());
239    }
240
241    #[test]
242    fn test_accuracy_optimized() {
243        let config = RerankingConfig::accuracy_optimized();
244        assert_eq!(config.mode, RerankingMode::Full);
245        assert_eq!(config.fusion_strategy, FusionStrategy::RerankingOnly);
246        assert!(config.enable_diversity);
247        assert!(config.validate().is_ok());
248    }
249
250    #[test]
251    fn test_speed_optimized() {
252        let config = RerankingConfig::speed_optimized();
253        assert_eq!(config.max_candidates, 50);
254        assert!(config.batch_size > 32); // Larger batches for speed
255        assert!(!config.enable_diversity); // No diversity for speed
256        assert!(config.validate().is_ok());
257    }
258
259    #[test]
260    fn test_api_based() -> Result<()> {
261        let config = RerankingConfig::api_based("cohere");
262        assert_eq!(config.model_backend, "cohere");
263        assert!(config.timeout_ms.expect("test value") > 10000); // Longer timeout
264        assert!(config.cache_size > 1000); // Larger cache
265        assert!(config.validate().is_ok());
266        Ok(())
267    }
268
269    #[test]
270    fn test_validation() {
271        let mut config = RerankingConfig::default_config();
272        assert!(config.validate().is_ok());
273
274        config.max_candidates = 0;
275        assert!(config.validate().is_err());
276
277        config = RerankingConfig::default_config();
278        config.top_k = 0;
279        assert!(config.validate().is_err());
280
281        config = RerankingConfig::default_config();
282        config.top_k = 200;
283        config.max_candidates = 100;
284        assert!(config.validate().is_err());
285
286        config = RerankingConfig::default_config();
287        config.retrieval_weight = 1.5;
288        assert!(config.validate().is_err());
289
290        config = RerankingConfig::default_config();
291        config.model_name = "".to_string();
292        assert!(config.validate().is_err());
293    }
294
295    #[test]
296    fn test_reranking_weight() {
297        let mut config = RerankingConfig::default_config();
298        config.retrieval_weight = 0.3;
299        assert!((config.reranking_weight() - 0.7).abs() < 0.001);
300
301        config.retrieval_weight = 0.0;
302        assert_eq!(config.reranking_weight(), 1.0);
303
304        config.retrieval_weight = 1.0;
305        assert_eq!(config.reranking_weight(), 0.0);
306    }
307
308    #[test]
309    fn test_fusion_strategies() {
310        let strategies = vec![
311            FusionStrategy::RerankingOnly,
312            FusionStrategy::RetrievalOnly,
313            FusionStrategy::Linear,
314            FusionStrategy::ReciprocalRank,
315            FusionStrategy::Learned,
316            FusionStrategy::Harmonic,
317            FusionStrategy::Geometric,
318        ];
319
320        for strategy in strategies {
321            let mut config = RerankingConfig::default_config();
322            config.fusion_strategy = strategy;
323            assert!(config.validate().is_ok());
324        }
325    }
326
327    #[test]
328    fn test_reranking_modes() {
329        let modes = vec![
330            RerankingMode::Full,
331            RerankingMode::TopK,
332            RerankingMode::Adaptive,
333            RerankingMode::Disabled,
334        ];
335
336        for mode in modes {
337            let mut config = RerankingConfig::default_config();
338            config.mode = mode;
339            assert!(config.validate().is_ok());
340        }
341    }
342}