oxirs_vec/reranking/
types.rs

1//! Core types for re-ranking
2
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6/// Result type for re-ranking operations
7pub type RerankingResult<T> = std::result::Result<T, RerankingError>;
8
9/// Errors that can occur during re-ranking
10#[derive(Debug, Error, Clone, Serialize, Deserialize)]
11pub enum RerankingError {
12    #[error("Model not loaded: {model_name}")]
13    ModelNotLoaded { model_name: String },
14
15    #[error("Model error: {message}")]
16    ModelError { message: String },
17
18    #[error("Invalid configuration: {message}")]
19    InvalidConfiguration { message: String },
20
21    #[error("Batch size exceeded: {size} > {max}")]
22    BatchSizeExceeded { size: usize, max: usize },
23
24    #[error("Cache error: {message}")]
25    CacheError { message: String },
26
27    #[error("Score fusion error: {message}")]
28    FusionError { message: String },
29
30    #[error("API error: {message}")]
31    ApiError { message: String },
32
33    #[error("Timeout: operation took longer than {timeout_ms}ms")]
34    Timeout { timeout_ms: u64 },
35
36    #[error("Backend error: {message}")]
37    BackendError { message: String },
38
39    #[error("Internal error: {message}")]
40    InternalError { message: String },
41}
42
43/// Candidate document with score
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ScoredCandidate {
46    /// Document ID
47    pub id: String,
48
49    /// Original retrieval score (from bi-encoder)
50    pub retrieval_score: f32,
51
52    /// Re-ranking score (from cross-encoder)
53    pub reranking_score: Option<f32>,
54
55    /// Fused final score
56    pub final_score: f32,
57
58    /// Document content (optional)
59    pub content: Option<String>,
60
61    /// Metadata
62    pub metadata: std::collections::HashMap<String, String>,
63
64    /// Original rank in retrieval results
65    pub original_rank: usize,
66}
67
68impl ScoredCandidate {
69    /// Create new candidate
70    pub fn new(id: impl Into<String>, retrieval_score: f32, original_rank: usize) -> Self {
71        Self {
72            id: id.into(),
73            retrieval_score,
74            reranking_score: None,
75            final_score: retrieval_score,
76            content: None,
77            metadata: std::collections::HashMap::new(),
78            original_rank,
79        }
80    }
81
82    /// Set re-ranking score
83    pub fn with_reranking_score(mut self, score: f32) -> Self {
84        self.reranking_score = Some(score);
85        self.final_score = score;
86        self
87    }
88
89    /// Set content
90    pub fn with_content(mut self, content: impl Into<String>) -> Self {
91        self.content = Some(content.into());
92        self
93    }
94
95    /// Add metadata
96    pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
97        self.metadata.insert(key.into(), value.into());
98        self
99    }
100
101    /// Get effective score (reranking if available, otherwise retrieval)
102    pub fn effective_score(&self) -> f32 {
103        self.reranking_score.unwrap_or(self.retrieval_score)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[test]
112    fn test_scored_candidate_creation() {
113        let candidate = ScoredCandidate::new("doc1", 0.85, 0);
114        assert_eq!(candidate.id, "doc1");
115        assert_eq!(candidate.retrieval_score, 0.85);
116        assert_eq!(candidate.final_score, 0.85);
117        assert_eq!(candidate.original_rank, 0);
118        assert!(candidate.reranking_score.is_none());
119    }
120
121    #[test]
122    fn test_candidate_with_reranking() {
123        let candidate = ScoredCandidate::new("doc1", 0.85, 0).with_reranking_score(0.92);
124
125        assert_eq!(candidate.retrieval_score, 0.85);
126        assert_eq!(candidate.reranking_score, Some(0.92));
127        assert_eq!(candidate.final_score, 0.92);
128        assert_eq!(candidate.effective_score(), 0.92);
129    }
130
131    #[test]
132    fn test_candidate_with_content() {
133        let candidate = ScoredCandidate::new("doc1", 0.85, 0).with_content("Test document");
134
135        assert_eq!(candidate.content, Some("Test document".to_string()));
136    }
137
138    #[test]
139    fn test_candidate_with_metadata() {
140        let candidate = ScoredCandidate::new("doc1", 0.85, 0)
141            .with_metadata("source", "wikipedia")
142            .with_metadata("lang", "en");
143
144        assert_eq!(
145            candidate.metadata.get("source"),
146            Some(&"wikipedia".to_string())
147        );
148        assert_eq!(candidate.metadata.get("lang"), Some(&"en".to_string()));
149    }
150
151    #[test]
152    fn test_effective_score() {
153        let mut candidate = ScoredCandidate::new("doc1", 0.85, 0);
154        assert_eq!(candidate.effective_score(), 0.85);
155
156        candidate.reranking_score = Some(0.92);
157        assert_eq!(candidate.effective_score(), 0.92);
158    }
159
160    #[test]
161    fn test_error_display() {
162        let err = RerankingError::ModelNotLoaded {
163            model_name: "cross-encoder-ms-marco".to_string(),
164        };
165        assert!(err.to_string().contains("cross-encoder-ms-marco"));
166
167        let err = RerankingError::BatchSizeExceeded { size: 100, max: 50 };
168        assert!(err.to_string().contains("100"));
169        assert!(err.to_string().contains("50"));
170
171        let err = RerankingError::Timeout { timeout_ms: 5000 };
172        assert!(err.to_string().contains("5000"));
173    }
174}