use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RerankingMode {
Full,
TopK,
Adaptive,
Disabled,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum FusionStrategy {
RerankingOnly,
RetrievalOnly,
Linear,
ReciprocalRank,
Learned,
Harmonic,
Geometric,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankingConfig {
pub mode: RerankingMode,
pub max_candidates: usize,
pub top_k: usize,
pub fusion_strategy: FusionStrategy,
pub retrieval_weight: f32,
pub batch_size: usize,
pub timeout_ms: Option<u64>,
pub enable_caching: bool,
pub cache_size: usize,
pub enable_diversity: bool,
pub diversity_weight: f32,
pub model_name: String,
pub model_backend: String,
pub enable_parallel: bool,
pub num_workers: usize,
}
impl RerankingConfig {
pub fn default_config() -> Self {
Self {
mode: RerankingMode::TopK,
max_candidates: 100,
top_k: 10,
fusion_strategy: FusionStrategy::Linear,
retrieval_weight: 0.3,
batch_size: 32,
timeout_ms: Some(5000),
enable_caching: true,
cache_size: 1000,
enable_diversity: false,
diversity_weight: 0.2,
model_name: "cross-encoder/ms-marco-MiniLM-L-12-v2".to_string(),
model_backend: "local".to_string(),
enable_parallel: true,
num_workers: 4,
}
}
pub fn accuracy_optimized() -> Self {
Self {
mode: RerankingMode::Full,
max_candidates: 200,
top_k: 10,
fusion_strategy: FusionStrategy::RerankingOnly,
retrieval_weight: 0.0,
batch_size: 16,
timeout_ms: Some(10000),
enable_caching: true,
cache_size: 2000,
enable_diversity: true,
diversity_weight: 0.3,
model_name: "cross-encoder/ms-marco-TinyBERT-L-6-v2".to_string(),
model_backend: "local".to_string(),
enable_parallel: true,
num_workers: 8,
}
}
pub fn speed_optimized() -> Self {
Self {
mode: RerankingMode::TopK,
max_candidates: 50,
top_k: 10,
fusion_strategy: FusionStrategy::Linear,
retrieval_weight: 0.5,
batch_size: 64,
timeout_ms: Some(2000),
enable_caching: true,
cache_size: 500,
enable_diversity: false,
diversity_weight: 0.0,
model_name: "cross-encoder/ms-marco-MiniLM-L-2-v2".to_string(),
model_backend: "local".to_string(),
enable_parallel: true,
num_workers: 2,
}
}
pub fn api_based(api_backend: impl Into<String>) -> Self {
Self {
mode: RerankingMode::TopK,
max_candidates: 100,
top_k: 10,
fusion_strategy: FusionStrategy::Linear,
retrieval_weight: 0.3,
batch_size: 16,
timeout_ms: Some(30000), enable_caching: true,
cache_size: 5000, enable_diversity: false,
diversity_weight: 0.2,
model_name: "rerank-v2".to_string(),
model_backend: api_backend.into(),
enable_parallel: false, num_workers: 1,
}
}
pub fn validate(&self) -> Result<(), String> {
if self.max_candidates == 0 {
return Err("max_candidates must be greater than 0".to_string());
}
if self.top_k == 0 {
return Err("top_k must be greater than 0".to_string());
}
if self.top_k > self.max_candidates {
return Err("top_k cannot exceed max_candidates".to_string());
}
if self.retrieval_weight < 0.0 || self.retrieval_weight > 1.0 {
return Err("retrieval_weight must be between 0.0 and 1.0".to_string());
}
if self.diversity_weight < 0.0 || self.diversity_weight > 1.0 {
return Err("diversity_weight must be between 0.0 and 1.0".to_string());
}
if self.batch_size == 0 {
return Err("batch_size must be greater than 0".to_string());
}
if self.cache_size == 0 && self.enable_caching {
return Err("cache_size must be greater than 0 when caching is enabled".to_string());
}
if self.num_workers == 0 && self.enable_parallel {
return Err(
"num_workers must be greater than 0 when parallel processing is enabled"
.to_string(),
);
}
if self.model_name.is_empty() {
return Err("model_name cannot be empty".to_string());
}
Ok(())
}
pub fn reranking_weight(&self) -> f32 {
1.0 - self.retrieval_weight
}
}
impl Default for RerankingConfig {
fn default() -> Self {
Self::default_config()
}
}
#[cfg(test)]
mod tests {
type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
use super::*;
#[test]
fn test_default_config() {
let config = RerankingConfig::default_config();
assert_eq!(config.mode, RerankingMode::TopK);
assert_eq!(config.max_candidates, 100);
assert_eq!(config.top_k, 10);
assert!(config.validate().is_ok());
}
#[test]
fn test_accuracy_optimized() {
let config = RerankingConfig::accuracy_optimized();
assert_eq!(config.mode, RerankingMode::Full);
assert_eq!(config.fusion_strategy, FusionStrategy::RerankingOnly);
assert!(config.enable_diversity);
assert!(config.validate().is_ok());
}
#[test]
fn test_speed_optimized() {
let config = RerankingConfig::speed_optimized();
assert_eq!(config.max_candidates, 50);
assert!(config.batch_size > 32); assert!(!config.enable_diversity); assert!(config.validate().is_ok());
}
#[test]
fn test_api_based() -> Result<()> {
let config = RerankingConfig::api_based("cohere");
assert_eq!(config.model_backend, "cohere");
assert!(config.timeout_ms.expect("test value") > 10000); assert!(config.cache_size > 1000); assert!(config.validate().is_ok());
Ok(())
}
#[test]
fn test_validation() {
let mut config = RerankingConfig::default_config();
assert!(config.validate().is_ok());
config.max_candidates = 0;
assert!(config.validate().is_err());
config = RerankingConfig::default_config();
config.top_k = 0;
assert!(config.validate().is_err());
config = RerankingConfig::default_config();
config.top_k = 200;
config.max_candidates = 100;
assert!(config.validate().is_err());
config = RerankingConfig::default_config();
config.retrieval_weight = 1.5;
assert!(config.validate().is_err());
config = RerankingConfig::default_config();
config.model_name = "".to_string();
assert!(config.validate().is_err());
}
#[test]
fn test_reranking_weight() {
let mut config = RerankingConfig::default_config();
config.retrieval_weight = 0.3;
assert!((config.reranking_weight() - 0.7).abs() < 0.001);
config.retrieval_weight = 0.0;
assert_eq!(config.reranking_weight(), 1.0);
config.retrieval_weight = 1.0;
assert_eq!(config.reranking_weight(), 0.0);
}
#[test]
fn test_fusion_strategies() {
let strategies = vec![
FusionStrategy::RerankingOnly,
FusionStrategy::RetrievalOnly,
FusionStrategy::Linear,
FusionStrategy::ReciprocalRank,
FusionStrategy::Learned,
FusionStrategy::Harmonic,
FusionStrategy::Geometric,
];
for strategy in strategies {
let mut config = RerankingConfig::default_config();
config.fusion_strategy = strategy;
assert!(config.validate().is_ok());
}
}
#[test]
fn test_reranking_modes() {
let modes = vec![
RerankingMode::Full,
RerankingMode::TopK,
RerankingMode::Adaptive,
RerankingMode::Disabled,
];
for mode in modes {
let mut config = RerankingConfig::default_config();
config.mode = mode;
assert!(config.validate().is_ok());
}
}
}