oxirs_vec/reranking/
config.rs

1//! Configuration for re-ranking
2
3use serde::{Deserialize, Serialize};
4
5/// Re-ranking mode
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum RerankingMode {
8    /// Re-rank all candidates
9    Full,
10    /// Re-rank top-k candidates only
11    TopK,
12    /// Adaptive re-ranking based on score distribution
13    Adaptive,
14    /// No re-ranking (passthrough)
15    Disabled,
16}
17
18/// Strategy for fusing retrieval and re-ranking scores
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum FusionStrategy {
21    /// Use only re-ranking scores
22    RerankingOnly,
23    /// Use only retrieval scores
24    RetrievalOnly,
25    /// Linear combination: α * retrieval + (1-α) * reranking
26    Linear,
27    /// Reciprocal rank fusion
28    ReciprocalRank,
29    /// Learned score fusion (trained weights)
30    Learned,
31    /// Harmonic mean
32    Harmonic,
33    /// Geometric mean
34    Geometric,
35}
36
37/// Configuration for re-ranking
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct RerankingConfig {
40    /// Re-ranking mode
41    pub mode: RerankingMode,
42
43    /// Maximum number of candidates to re-rank
44    pub max_candidates: usize,
45
46    /// Number of final results to return
47    pub top_k: usize,
48
49    /// Fusion strategy for combining scores
50    pub fusion_strategy: FusionStrategy,
51
52    /// Weight for retrieval score in linear fusion (0.0 to 1.0)
53    pub retrieval_weight: f32,
54
55    /// Batch size for cross-encoder inference
56    pub batch_size: usize,
57
58    /// Timeout for re-ranking (milliseconds)
59    pub timeout_ms: Option<u64>,
60
61    /// Enable result caching
62    pub enable_caching: bool,
63
64    /// Cache size (number of entries)
65    pub cache_size: usize,
66
67    /// Enable diversity-aware re-ranking
68    pub enable_diversity: bool,
69
70    /// Diversity weight (0.0 to 1.0)
71    pub diversity_weight: f32,
72
73    /// Model name or path
74    pub model_name: String,
75
76    /// Model backend (local, api, etc.)
77    pub model_backend: String,
78
79    /// Enable parallel batch processing
80    pub enable_parallel: bool,
81
82    /// Number of worker threads
83    pub num_workers: usize,
84}
85
86impl RerankingConfig {
87    /// Create default configuration
88    pub fn default_config() -> Self {
89        Self {
90            mode: RerankingMode::TopK,
91            max_candidates: 100,
92            top_k: 10,
93            fusion_strategy: FusionStrategy::Linear,
94            retrieval_weight: 0.3,
95            batch_size: 32,
96            timeout_ms: Some(5000),
97            enable_caching: true,
98            cache_size: 1000,
99            enable_diversity: false,
100            diversity_weight: 0.2,
101            model_name: "cross-encoder/ms-marco-MiniLM-L-12-v2".to_string(),
102            model_backend: "local".to_string(),
103            enable_parallel: true,
104            num_workers: 4,
105        }
106    }
107
108    /// Create configuration optimized for accuracy
109    pub fn accuracy_optimized() -> Self {
110        Self {
111            mode: RerankingMode::Full,
112            max_candidates: 200,
113            top_k: 10,
114            fusion_strategy: FusionStrategy::RerankingOnly,
115            retrieval_weight: 0.0,
116            batch_size: 16,
117            timeout_ms: Some(10000),
118            enable_caching: true,
119            cache_size: 2000,
120            enable_diversity: true,
121            diversity_weight: 0.3,
122            model_name: "cross-encoder/ms-marco-TinyBERT-L-6-v2".to_string(),
123            model_backend: "local".to_string(),
124            enable_parallel: true,
125            num_workers: 8,
126        }
127    }
128
129    /// Create configuration optimized for speed
130    pub fn speed_optimized() -> Self {
131        Self {
132            mode: RerankingMode::TopK,
133            max_candidates: 50,
134            top_k: 10,
135            fusion_strategy: FusionStrategy::Linear,
136            retrieval_weight: 0.5,
137            batch_size: 64,
138            timeout_ms: Some(2000),
139            enable_caching: true,
140            cache_size: 500,
141            enable_diversity: false,
142            diversity_weight: 0.0,
143            model_name: "cross-encoder/ms-marco-MiniLM-L-2-v2".to_string(),
144            model_backend: "local".to_string(),
145            enable_parallel: true,
146            num_workers: 2,
147        }
148    }
149
150    /// Create configuration for API-based models
151    pub fn api_based(api_backend: impl Into<String>) -> Self {
152        Self {
153            mode: RerankingMode::TopK,
154            max_candidates: 100,
155            top_k: 10,
156            fusion_strategy: FusionStrategy::Linear,
157            retrieval_weight: 0.3,
158            batch_size: 16,
159            timeout_ms: Some(30000), // Longer timeout for API
160            enable_caching: true,
161            cache_size: 5000, // Larger cache for API
162            enable_diversity: false,
163            diversity_weight: 0.2,
164            model_name: "rerank-v2".to_string(),
165            model_backend: api_backend.into(),
166            enable_parallel: false, // API handles parallelism
167            num_workers: 1,
168        }
169    }
170
171    /// Validate configuration
172    pub fn validate(&self) -> Result<(), String> {
173        if self.max_candidates == 0 {
174            return Err("max_candidates must be greater than 0".to_string());
175        }
176
177        if self.top_k == 0 {
178            return Err("top_k must be greater than 0".to_string());
179        }
180
181        if self.top_k > self.max_candidates {
182            return Err("top_k cannot exceed max_candidates".to_string());
183        }
184
185        if self.retrieval_weight < 0.0 || self.retrieval_weight > 1.0 {
186            return Err("retrieval_weight must be between 0.0 and 1.0".to_string());
187        }
188
189        if self.diversity_weight < 0.0 || self.diversity_weight > 1.0 {
190            return Err("diversity_weight must be between 0.0 and 1.0".to_string());
191        }
192
193        if self.batch_size == 0 {
194            return Err("batch_size must be greater than 0".to_string());
195        }
196
197        if self.cache_size == 0 && self.enable_caching {
198            return Err("cache_size must be greater than 0 when caching is enabled".to_string());
199        }
200
201        if self.num_workers == 0 && self.enable_parallel {
202            return Err(
203                "num_workers must be greater than 0 when parallel processing is enabled"
204                    .to_string(),
205            );
206        }
207
208        if self.model_name.is_empty() {
209            return Err("model_name cannot be empty".to_string());
210        }
211
212        Ok(())
213    }
214
215    /// Get fusion weight for reranking score
216    pub fn reranking_weight(&self) -> f32 {
217        1.0 - self.retrieval_weight
218    }
219}
220
221impl Default for RerankingConfig {
222    fn default() -> Self {
223        Self::default_config()
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230
231    #[test]
232    fn test_default_config() {
233        let config = RerankingConfig::default_config();
234        assert_eq!(config.mode, RerankingMode::TopK);
235        assert_eq!(config.max_candidates, 100);
236        assert_eq!(config.top_k, 10);
237        assert!(config.validate().is_ok());
238    }
239
240    #[test]
241    fn test_accuracy_optimized() {
242        let config = RerankingConfig::accuracy_optimized();
243        assert_eq!(config.mode, RerankingMode::Full);
244        assert_eq!(config.fusion_strategy, FusionStrategy::RerankingOnly);
245        assert!(config.enable_diversity);
246        assert!(config.validate().is_ok());
247    }
248
249    #[test]
250    fn test_speed_optimized() {
251        let config = RerankingConfig::speed_optimized();
252        assert_eq!(config.max_candidates, 50);
253        assert!(config.batch_size > 32); // Larger batches for speed
254        assert!(!config.enable_diversity); // No diversity for speed
255        assert!(config.validate().is_ok());
256    }
257
258    #[test]
259    fn test_api_based() {
260        let config = RerankingConfig::api_based("cohere");
261        assert_eq!(config.model_backend, "cohere");
262        assert!(config.timeout_ms.unwrap() > 10000); // Longer timeout
263        assert!(config.cache_size > 1000); // Larger cache
264        assert!(config.validate().is_ok());
265    }
266
267    #[test]
268    fn test_validation() {
269        let mut config = RerankingConfig::default_config();
270        assert!(config.validate().is_ok());
271
272        config.max_candidates = 0;
273        assert!(config.validate().is_err());
274
275        config = RerankingConfig::default_config();
276        config.top_k = 0;
277        assert!(config.validate().is_err());
278
279        config = RerankingConfig::default_config();
280        config.top_k = 200;
281        config.max_candidates = 100;
282        assert!(config.validate().is_err());
283
284        config = RerankingConfig::default_config();
285        config.retrieval_weight = 1.5;
286        assert!(config.validate().is_err());
287
288        config = RerankingConfig::default_config();
289        config.model_name = "".to_string();
290        assert!(config.validate().is_err());
291    }
292
293    #[test]
294    fn test_reranking_weight() {
295        let mut config = RerankingConfig::default_config();
296        config.retrieval_weight = 0.3;
297        assert!((config.reranking_weight() - 0.7).abs() < 0.001);
298
299        config.retrieval_weight = 0.0;
300        assert_eq!(config.reranking_weight(), 1.0);
301
302        config.retrieval_weight = 1.0;
303        assert_eq!(config.reranking_weight(), 0.0);
304    }
305
306    #[test]
307    fn test_fusion_strategies() {
308        let strategies = vec![
309            FusionStrategy::RerankingOnly,
310            FusionStrategy::RetrievalOnly,
311            FusionStrategy::Linear,
312            FusionStrategy::ReciprocalRank,
313            FusionStrategy::Learned,
314            FusionStrategy::Harmonic,
315            FusionStrategy::Geometric,
316        ];
317
318        for strategy in strategies {
319            let mut config = RerankingConfig::default_config();
320            config.fusion_strategy = strategy;
321            assert!(config.validate().is_ok());
322        }
323    }
324
325    #[test]
326    fn test_reranking_modes() {
327        let modes = vec![
328            RerankingMode::Full,
329            RerankingMode::TopK,
330            RerankingMode::Adaptive,
331            RerankingMode::Disabled,
332        ];
333
334        for mode in modes {
335            let mut config = RerankingConfig::default_config();
336            config.mode = mode;
337            assert!(config.validate().is_ok());
338        }
339    }
340}