oxirs_vec/reranking/
types.rs1use serde::{Deserialize, Serialize};
4use thiserror::Error;
5
6pub type RerankingResult<T> = std::result::Result<T, RerankingError>;
8
9#[derive(Debug, Error, Clone, Serialize, Deserialize)]
11pub enum RerankingError {
12 #[error("Model not loaded: {model_name}")]
13 ModelNotLoaded { model_name: String },
14
15 #[error("Model error: {message}")]
16 ModelError { message: String },
17
18 #[error("Invalid configuration: {message}")]
19 InvalidConfiguration { message: String },
20
21 #[error("Batch size exceeded: {size} > {max}")]
22 BatchSizeExceeded { size: usize, max: usize },
23
24 #[error("Cache error: {message}")]
25 CacheError { message: String },
26
27 #[error("Score fusion error: {message}")]
28 FusionError { message: String },
29
30 #[error("API error: {message}")]
31 ApiError { message: String },
32
33 #[error("Timeout: operation took longer than {timeout_ms}ms")]
34 Timeout { timeout_ms: u64 },
35
36 #[error("Backend error: {message}")]
37 BackendError { message: String },
38
39 #[error("Internal error: {message}")]
40 InternalError { message: String },
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ScoredCandidate {
46 pub id: String,
48
49 pub retrieval_score: f32,
51
52 pub reranking_score: Option<f32>,
54
55 pub final_score: f32,
57
58 pub content: Option<String>,
60
61 pub metadata: std::collections::HashMap<String, String>,
63
64 pub original_rank: usize,
66}
67
68impl ScoredCandidate {
69 pub fn new(id: impl Into<String>, retrieval_score: f32, original_rank: usize) -> Self {
71 Self {
72 id: id.into(),
73 retrieval_score,
74 reranking_score: None,
75 final_score: retrieval_score,
76 content: None,
77 metadata: std::collections::HashMap::new(),
78 original_rank,
79 }
80 }
81
82 pub fn with_reranking_score(mut self, score: f32) -> Self {
84 self.reranking_score = Some(score);
85 self.final_score = score;
86 self
87 }
88
89 pub fn with_content(mut self, content: impl Into<String>) -> Self {
91 self.content = Some(content.into());
92 self
93 }
94
95 pub fn with_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
97 self.metadata.insert(key.into(), value.into());
98 self
99 }
100
101 pub fn effective_score(&self) -> f32 {
103 self.reranking_score.unwrap_or(self.retrieval_score)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[test]
112 fn test_scored_candidate_creation() {
113 let candidate = ScoredCandidate::new("doc1", 0.85, 0);
114 assert_eq!(candidate.id, "doc1");
115 assert_eq!(candidate.retrieval_score, 0.85);
116 assert_eq!(candidate.final_score, 0.85);
117 assert_eq!(candidate.original_rank, 0);
118 assert!(candidate.reranking_score.is_none());
119 }
120
121 #[test]
122 fn test_candidate_with_reranking() {
123 let candidate = ScoredCandidate::new("doc1", 0.85, 0).with_reranking_score(0.92);
124
125 assert_eq!(candidate.retrieval_score, 0.85);
126 assert_eq!(candidate.reranking_score, Some(0.92));
127 assert_eq!(candidate.final_score, 0.92);
128 assert_eq!(candidate.effective_score(), 0.92);
129 }
130
131 #[test]
132 fn test_candidate_with_content() {
133 let candidate = ScoredCandidate::new("doc1", 0.85, 0).with_content("Test document");
134
135 assert_eq!(candidate.content, Some("Test document".to_string()));
136 }
137
138 #[test]
139 fn test_candidate_with_metadata() {
140 let candidate = ScoredCandidate::new("doc1", 0.85, 0)
141 .with_metadata("source", "wikipedia")
142 .with_metadata("lang", "en");
143
144 assert_eq!(
145 candidate.metadata.get("source"),
146 Some(&"wikipedia".to_string())
147 );
148 assert_eq!(candidate.metadata.get("lang"), Some(&"en".to_string()));
149 }
150
151 #[test]
152 fn test_effective_score() {
153 let mut candidate = ScoredCandidate::new("doc1", 0.85, 0);
154 assert_eq!(candidate.effective_score(), 0.85);
155
156 candidate.reranking_score = Some(0.92);
157 assert_eq!(candidate.effective_score(), 0.92);
158 }
159
160 #[test]
161 fn test_error_display() {
162 let err = RerankingError::ModelNotLoaded {
163 model_name: "cross-encoder-ms-marco".to_string(),
164 };
165 assert!(err.to_string().contains("cross-encoder-ms-marco"));
166
167 let err = RerankingError::BatchSizeExceeded { size: 100, max: 50 };
168 assert!(err.to_string().contains("100"));
169 assert!(err.to_string().contains("50"));
170
171 let err = RerankingError::Timeout { timeout_ms: 5000 };
172 assert!(err.to_string().contains("5000"));
173 }
174}