oar_ocr_core/utils/topk.rs
1//! Top-k classification result processing.
2
3use crate::core::OcrResult;
4use crate::core::errors::OCRError;
5use std::collections::HashMap;
6
7/// Result structure for top-k classification processing.
8///
9/// Contains the top-k class indexes and their corresponding confidence scores
10/// for each prediction in a batch.
11#[derive(Debug, Clone)]
12pub struct TopkResult {
13 /// Vector of vectors containing the class indexes for each prediction.
14 /// Each inner vector contains the top-k class indexes for one prediction.
15 pub indexes: Vec<Vec<usize>>,
16 /// Vector of vectors containing the confidence scores for each prediction.
17 /// Each inner vector contains the top-k scores corresponding to the indexes.
18 pub scores: Vec<Vec<f32>>,
19 /// Optional vector of vectors containing class names for each prediction.
20 /// Only populated if class name mapping is provided.
21 pub class_names: Option<Vec<Vec<String>>>,
22}
23
24/// A processor for extracting top-k results from classification outputs.
25///
26/// The `Topk` struct processes classification model outputs to extract the
27/// top-k most confident predictions along with their class names (if available).
28#[derive(Debug)]
29pub struct Topk {
30 /// Optional mapping from class IDs to class names.
31 class_id_map: Option<HashMap<usize, String>>,
32}
33
34impl Topk {
35 /// Creates a new Topk processor with optional class name mapping.
36 ///
37 /// # Arguments
38 ///
39 /// * `class_id_map` - Optional mapping from class IDs to human-readable class names.
40 ///
41 /// # Examples
42 ///
43 /// ```rust,no_run
44 /// use std::collections::HashMap;
45 /// use oar_ocr_core::utils::topk::Topk;
46 ///
47 /// let mut class_map = HashMap::new();
48 /// class_map.insert(0, "cat".to_string());
49 /// class_map.insert(1, "dog".to_string());
50 ///
51 /// let topk = Topk::new(Some(class_map));
52 /// ```
53 pub fn new(class_id_map: Option<HashMap<usize, String>>) -> Self {
54 Self { class_id_map }
55 }
56
57 /// Creates a new Topk processor without class name mapping.
58 ///
59 /// # Examples
60 ///
61 /// ```rust,no_run
62 /// use oar_ocr_core::utils::topk::Topk;
63 ///
64 /// let topk = Topk::without_class_names();
65 /// ```
66 pub fn without_class_names() -> Self {
67 Self::new(None)
68 }
69
70 /// Creates a new Topk processor with class names from a vector.
71 ///
72 /// The vector index corresponds to the class ID.
73 ///
74 /// # Arguments
75 ///
76 /// * `class_names` - Vector of class names where index = class ID.
77 ///
78 /// # Examples
79 ///
80 /// ```rust,no_run
81 /// use oar_ocr_core::utils::topk::Topk;
82 ///
83 /// let class_names = vec!["cat".to_string(), "dog".to_string(), "bird".to_string()];
84 /// let topk = Topk::from_class_names(class_names);
85 /// ```
86 pub fn from_class_names(class_names: Vec<String>) -> Self {
87 let class_id_map: HashMap<usize, String> = class_names.into_iter().enumerate().collect();
88 Self::new(Some(class_id_map))
89 }
90
91 /// Processes classification outputs to extract top-k results.
92 ///
93 /// # Arguments
94 ///
95 /// * `predictions` - 2D vector where each inner vector contains the confidence
96 /// scores for all classes for one prediction.
97 /// * `k` - Number of top predictions to extract (must be > 0).
98 ///
99 /// # Returns
100 ///
101 /// * `Ok(TopkResult)` - The top-k results with indexes, scores, and optional class names.
102 /// * `Err(OCRError)` - If k is 0 or if the input is invalid.
103 ///
104 /// # Examples
105 ///
106 /// ```rust,no_run
107 /// use oar_ocr_core::utils::topk::Topk;
108 ///
109 /// # fn main() -> Result<(), oar_ocr_core::core::OCRError> {
110 /// let topk = Topk::without_class_names();
111 /// let predictions = vec![
112 /// vec![0.1, 0.8, 0.1], // Prediction 1: class 1 has highest score
113 /// vec![0.7, 0.2, 0.1], // Prediction 2: class 0 has highest score
114 /// ];
115 /// let result = topk.process(&predictions, 2)?;
116 /// # let _ = result;
117 /// # Ok(())
118 /// # }
119 /// ```
120 pub fn process(&self, predictions: &[Vec<f32>], k: usize) -> OcrResult<TopkResult> {
121 if k == 0 {
122 return Err(OCRError::InvalidInput {
123 message: "k must be greater than 0".to_string(),
124 });
125 }
126
127 if predictions.is_empty() {
128 return Ok(TopkResult {
129 indexes: vec![],
130 scores: vec![],
131 class_names: None,
132 });
133 }
134
135 let mut all_indexes = Vec::new();
136 let mut all_scores = Vec::new();
137 let mut all_class_names = if self.class_id_map.is_some() {
138 Some(Vec::new())
139 } else {
140 None
141 };
142
143 for prediction in predictions {
144 if prediction.is_empty() {
145 return Err(OCRError::InvalidInput {
146 message: "Empty prediction vector".to_string(),
147 });
148 }
149
150 let effective_k = k.min(prediction.len());
151 let (top_indexes, top_scores) =
152 self.extract_topk_from_prediction(prediction, effective_k);
153
154 all_indexes.push(top_indexes.clone());
155 all_scores.push(top_scores);
156
157 // Add class names if mapping is available
158 if let Some(ref mut class_names_vec) = all_class_names {
159 let names = self.map_indexes_to_names(&top_indexes);
160 class_names_vec.push(names);
161 }
162 }
163
164 Ok(TopkResult {
165 indexes: all_indexes,
166 scores: all_scores,
167 class_names: all_class_names,
168 })
169 }
170
171 /// Extracts top-k indexes and scores from a single prediction.
172 ///
173 /// # Arguments
174 ///
175 /// * `prediction` - Vector of confidence scores for all classes.
176 /// * `k` - Number of top predictions to extract.
177 ///
178 /// # Returns
179 ///
180 /// * `(Vec<usize>, Vec<f32>)` - Tuple of (top_indexes, top_scores).
181 fn extract_topk_from_prediction(&self, prediction: &[f32], k: usize) -> (Vec<usize>, Vec<f32>) {
182 // Create pairs of (index, score) and sort by score in descending order
183 let mut indexed_scores: Vec<(usize, f32)> = prediction
184 .iter()
185 .enumerate()
186 .map(|(idx, &score)| (idx, score))
187 .collect();
188
189 // Sort by score in descending order
190 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
191
192 // Take top k
193 let top_k: Vec<(usize, f32)> = indexed_scores.into_iter().take(k).collect();
194
195 // Separate indexes and scores
196 let (indexes, scores): (Vec<usize>, Vec<f32>) = top_k.into_iter().unzip();
197
198 (indexes, scores)
199 }
200
201 /// Maps class indexes to class names using the internal mapping.
202 ///
203 /// # Arguments
204 ///
205 /// * `indexes` - Vector of class indexes.
206 ///
207 /// # Returns
208 ///
209 /// * `Vec<String>` - Vector of class names. Unknown indexes are mapped to "Unknown".
210 fn map_indexes_to_names(&self, indexes: &[usize]) -> Vec<String> {
211 if let Some(ref class_map) = self.class_id_map {
212 indexes
213 .iter()
214 .map(|&idx| {
215 class_map
216 .get(&idx)
217 .cloned()
218 .unwrap_or_else(|| format!("Unknown({})", idx))
219 })
220 .collect()
221 } else {
222 indexes.iter().map(|&idx| idx.to_string()).collect()
223 }
224 }
225
226 /// Gets the class name for a given class ID.
227 ///
228 /// # Arguments
229 ///
230 /// * `class_id` - The class ID to look up.
231 ///
232 /// # Returns
233 ///
234 /// * `Option<&String>` - The class name if available.
235 pub fn get_class_name(&self, class_id: usize) -> Option<&String> {
236 self.class_id_map.as_ref()?.get(&class_id)
237 }
238
239 /// Checks if class name mapping is available.
240 ///
241 /// # Returns
242 ///
243 /// * `true` - If class name mapping is available.
244 /// * `false` - If no class name mapping is available.
245 pub fn has_class_names(&self) -> bool {
246 self.class_id_map.is_some()
247 }
248
249 /// Gets the number of classes in the mapping.
250 ///
251 /// # Returns
252 ///
253 /// * `Option<usize>` - Number of classes if mapping is available.
254 pub fn num_classes(&self) -> Option<usize> {
255 self.class_id_map.as_ref().map(|map| map.len())
256 }
257
258 /// Updates the class name mapping.
259 ///
260 /// # Arguments
261 ///
262 /// * `class_id_map` - New class ID to name mapping.
263 pub fn set_class_mapping(&mut self, class_id_map: Option<HashMap<usize, String>>) {
264 self.class_id_map = class_id_map;
265 }
266
267 /// Processes a single prediction vector.
268 ///
269 /// # Arguments
270 ///
271 /// * `prediction` - Vector of confidence scores for all classes.
272 /// * `k` - Number of top predictions to extract.
273 ///
274 /// # Returns
275 ///
276 /// * `Ok(TopkResult)` - The top-k results for the single prediction.
277 /// * `Err(OCRError)` - If k is 0 or if the input is invalid.
278 pub fn process_single(&self, prediction: &[f32], k: usize) -> OcrResult<TopkResult> {
279 self.process(&[prediction.to_vec()], k)
280 }
281}
282
283impl Default for Topk {
284 /// Creates a default Topk processor without class name mapping.
285 fn default() -> Self {
286 Self::without_class_names()
287 }
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293
294 #[test]
295 fn test_topk_without_class_names() -> Result<(), OCRError> {
296 let topk = Topk::without_class_names();
297 let predictions = vec![vec![0.1, 0.8, 0.1], vec![0.7, 0.2, 0.1]];
298
299 let result = topk.process(&predictions, 2)?;
300 assert_eq!(result.indexes.len(), 2);
301 assert_eq!(result.indexes[0], vec![1, 0]); // Class 1 (0.8), Class 0 (0.1)
302 assert_eq!(result.indexes[1], vec![0, 1]); // Class 0 (0.7), Class 1 (0.2)
303 assert!(result.class_names.is_none());
304 Ok(())
305 }
306
307 #[test]
308 fn test_topk_with_class_names() -> Result<(), OCRError> {
309 let mut class_map = HashMap::new();
310 class_map.insert(0, "cat".to_string());
311 class_map.insert(1, "dog".to_string());
312 class_map.insert(2, "bird".to_string());
313
314 let topk = Topk::new(Some(class_map));
315 let predictions = vec![vec![0.1, 0.8, 0.1]];
316
317 let result = topk.process(&predictions, 2)?;
318 assert_eq!(result.indexes[0], vec![1, 0]);
319 let Some(class_names) = result.class_names.as_ref() else {
320 panic!("expected class_names to be present");
321 };
322 assert_eq!(class_names[0], vec!["dog", "cat"]);
323 Ok(())
324 }
325
326 #[test]
327 fn test_topk_from_class_names() {
328 let class_names = vec!["cat".to_string(), "dog".to_string(), "bird".to_string()];
329 let topk = Topk::from_class_names(class_names);
330
331 assert!(topk.has_class_names());
332 assert_eq!(topk.num_classes(), Some(3));
333 assert_eq!(topk.get_class_name(0), Some(&"cat".to_string()));
334 }
335
336 #[test]
337 fn test_topk_k_larger_than_classes() -> Result<(), OCRError> {
338 let topk = Topk::without_class_names();
339 let predictions = vec![vec![0.1, 0.8]]; // Only 2 classes
340
341 let result = topk.process(&predictions, 5)?; // Ask for 5
342 assert_eq!(result.indexes[0].len(), 2); // Should only get 2
343 Ok(())
344 }
345
346 #[test]
347 fn test_topk_invalid_k() {
348 let topk = Topk::without_class_names();
349 let predictions = vec![vec![0.1, 0.8, 0.1]];
350
351 assert!(topk.process(&predictions, 0).is_err());
352 }
353
354 #[test]
355 fn test_topk_empty_predictions() -> Result<(), OCRError> {
356 let topk = Topk::without_class_names();
357 let predictions = vec![];
358
359 let result = topk.process(&predictions, 2)?;
360 assert!(result.indexes.is_empty());
361 assert!(result.scores.is_empty());
362 Ok(())
363 }
364
365 #[test]
366 fn test_process_single() -> Result<(), OCRError> {
367 let topk = Topk::without_class_names();
368 let prediction = vec![0.1, 0.8, 0.1];
369
370 let result = topk.process_single(&prediction, 2)?;
371 assert_eq!(result.indexes.len(), 1);
372 assert_eq!(result.indexes[0], vec![1, 0]);
373 Ok(())
374 }
375}