use serde::{Deserialize, Serialize};
use thiserror::Error;
pub type RerankingResult<T> = std::result::Result<T, RerankingError>;
#[derive(Debug, Error, Clone, Serialize, Deserialize)]
pub enum RerankingError {
#[error("Model not loaded: {model_name}")]
ModelNotLoaded { model_name: String },
#[error("Model error: {message}")]
ModelError { message: String },
#[error("Invalid configuration: {message}")]
InvalidConfiguration { message: String },
#[error("Batch size exceeded: {size} > {max}")]
BatchSizeExceeded { size: usize, max: usize },
#[error("Cache error: {message}")]
CacheError { message: String },
#[error("Score fusion error: {message}")]
FusionError { message: String },
#[error("API error: {message}")]
ApiError { message: String },
#[error("Timeout: operation took longer than {timeout_ms}ms")]
Timeout { timeout_ms: u64 },
#[error("Backend error: {message}")]
BackendError { message: String },
#[error("Internal error: {message}")]
InternalError { message: String },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ScoredCandidate {
pub id: String,
pub retrieval_score: f32,
pub reranking_score: Option<f32>,
pub final_score: f32,
pub content: Option<String>,
pub metadata: std::collections::HashMap<String, String>,
pub original_rank: usize,
}
impl ScoredCandidate {
pub fn new(id: impl Into<String>, retrieval_score: f32, original_rank: usize) -> Self {
Self {
id: id.into(),
retrieval_score,
reranking_score: None,
final_score: retrieval_score,
content: None,
metadata: std::collections::HashMap::new(),
original_rank,
}
}
pub fn with_reranking_score(mut self, score: f32) -> Self {
self.reranking_score = Some(score);
self.final_score = score;
self
}
pub fn with_content(mut self, content: impl Into<String>) -> Self {
self.content = Some(content.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
pub fn effective_score(&self) -> f32 {
self.reranking_score.unwrap_or(self.retrieval_score)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scored_candidate_creation() {
let candidate = ScoredCandidate::new("doc1", 0.85, 0);
assert_eq!(candidate.id, "doc1");
assert_eq!(candidate.retrieval_score, 0.85);
assert_eq!(candidate.final_score, 0.85);
assert_eq!(candidate.original_rank, 0);
assert!(candidate.reranking_score.is_none());
}
#[test]
fn test_candidate_with_reranking() {
let candidate = ScoredCandidate::new("doc1", 0.85, 0).with_reranking_score(0.92);
assert_eq!(candidate.retrieval_score, 0.85);
assert_eq!(candidate.reranking_score, Some(0.92));
assert_eq!(candidate.final_score, 0.92);
assert_eq!(candidate.effective_score(), 0.92);
}
#[test]
fn test_candidate_with_content() {
let candidate = ScoredCandidate::new("doc1", 0.85, 0).with_content("Test document");
assert_eq!(candidate.content, Some("Test document".to_string()));
}
#[test]
fn test_candidate_with_metadata() {
let candidate = ScoredCandidate::new("doc1", 0.85, 0)
.with_metadata("source", "wikipedia")
.with_metadata("lang", "en");
assert_eq!(
candidate.metadata.get("source"),
Some(&"wikipedia".to_string())
);
assert_eq!(candidate.metadata.get("lang"), Some(&"en".to_string()));
}
#[test]
fn test_effective_score() {
let mut candidate = ScoredCandidate::new("doc1", 0.85, 0);
assert_eq!(candidate.effective_score(), 0.85);
candidate.reranking_score = Some(0.92);
assert_eq!(candidate.effective_score(), 0.92);
}
#[test]
fn test_error_display() {
let err = RerankingError::ModelNotLoaded {
model_name: "cross-encoder-ms-marco".to_string(),
};
assert!(err.to_string().contains("cross-encoder-ms-marco"));
let err = RerankingError::BatchSizeExceeded { size: 100, max: 50 };
assert!(err.to_string().contains("100"));
assert!(err.to_string().contains("50"));
let err = RerankingError::Timeout { timeout_ms: 5000 };
assert!(err.to_string().contains("5000"));
}
}