oar_ocr/utils/topk.rs
1//! Top-k classification result processing.
2
3use std::collections::HashMap;
4
5/// Result structure for top-k classification processing.
6///
7/// Contains the top-k class indexes and their corresponding confidence scores
8/// for each prediction in a batch.
9#[derive(Debug, Clone)]
10pub struct TopkResult {
11 /// Vector of vectors containing the class indexes for each prediction.
12 /// Each inner vector contains the top-k class indexes for one prediction.
13 pub indexes: Vec<Vec<usize>>,
14 /// Vector of vectors containing the confidence scores for each prediction.
15 /// Each inner vector contains the top-k scores corresponding to the indexes.
16 pub scores: Vec<Vec<f32>>,
17 /// Optional vector of vectors containing class names for each prediction.
18 /// Only populated if class name mapping is provided.
19 pub class_names: Option<Vec<Vec<String>>>,
20}
21
22/// A processor for extracting top-k results from classification outputs.
23///
24/// The `Topk` struct processes classification model outputs to extract the
25/// top-k most confident predictions along with their class names (if available).
26#[derive(Debug)]
27pub struct Topk {
28 /// Optional mapping from class IDs to class names.
29 class_id_map: Option<HashMap<usize, String>>,
30}
31
32impl Topk {
33 /// Creates a new Topk processor with optional class name mapping.
34 ///
35 /// # Arguments
36 ///
37 /// * `class_id_map` - Optional mapping from class IDs to human-readable class names.
38 ///
39 /// # Examples
40 ///
41 /// ```rust,no_run
42 /// use std::collections::HashMap;
43 /// use oar_ocr::utils::topk::Topk;
44 ///
45 /// let mut class_map = HashMap::new();
46 /// class_map.insert(0, "cat".to_string());
47 /// class_map.insert(1, "dog".to_string());
48 ///
49 /// let topk = Topk::new(Some(class_map));
50 /// ```
51 pub fn new(class_id_map: Option<HashMap<usize, String>>) -> Self {
52 Self { class_id_map }
53 }
54
55 /// Creates a new Topk processor without class name mapping.
56 ///
57 /// # Examples
58 ///
59 /// ```rust,no_run
60 /// use oar_ocr::utils::topk::Topk;
61 ///
62 /// let topk = Topk::without_class_names();
63 /// ```
64 pub fn without_class_names() -> Self {
65 Self::new(None)
66 }
67
68 /// Creates a new Topk processor with class names from a vector.
69 ///
70 /// The vector index corresponds to the class ID.
71 ///
72 /// # Arguments
73 ///
74 /// * `class_names` - Vector of class names where index = class ID.
75 ///
76 /// # Examples
77 ///
78 /// ```rust,no_run
79 /// use oar_ocr::utils::topk::Topk;
80 ///
81 /// let class_names = vec!["cat".to_string(), "dog".to_string(), "bird".to_string()];
82 /// let topk = Topk::from_class_names(class_names);
83 /// ```
84 pub fn from_class_names(class_names: Vec<String>) -> Self {
85 let class_id_map: HashMap<usize, String> = class_names.into_iter().enumerate().collect();
86 Self::new(Some(class_id_map))
87 }
88
89 /// Processes classification outputs to extract top-k results.
90 ///
91 /// # Arguments
92 ///
93 /// * `predictions` - 2D vector where each inner vector contains the confidence
94 /// scores for all classes for one prediction.
95 /// * `k` - Number of top predictions to extract (must be > 0).
96 ///
97 /// # Returns
98 ///
99 /// * `Ok(TopkResult)` - The top-k results with indexes, scores, and optional class names.
100 /// * `Err(String)` - If k is 0 or if the input is invalid.
101 ///
102 /// # Examples
103 ///
104 /// ```rust,no_run
105 /// use oar_ocr::utils::topk::Topk;
106 ///
107 /// let topk = Topk::without_class_names();
108 /// let predictions = vec![
109 /// vec![0.1, 0.8, 0.1], // Prediction 1: class 1 has highest score
110 /// vec![0.7, 0.2, 0.1], // Prediction 2: class 0 has highest score
111 /// ];
112 /// let result = topk.process(&predictions, 2).unwrap();
113 /// ```
114 pub fn process(&self, predictions: &[Vec<f32>], k: usize) -> Result<TopkResult, String> {
115 if k == 0 {
116 return Err("k must be greater than 0".to_string());
117 }
118
119 if predictions.is_empty() {
120 return Ok(TopkResult {
121 indexes: vec![],
122 scores: vec![],
123 class_names: None,
124 });
125 }
126
127 let mut all_indexes = Vec::new();
128 let mut all_scores = Vec::new();
129 let mut all_class_names = if self.class_id_map.is_some() {
130 Some(Vec::new())
131 } else {
132 None
133 };
134
135 for prediction in predictions {
136 if prediction.is_empty() {
137 return Err("Empty prediction vector".to_string());
138 }
139
140 let effective_k = k.min(prediction.len());
141 let (top_indexes, top_scores) =
142 self.extract_topk_from_prediction(prediction, effective_k);
143
144 all_indexes.push(top_indexes.clone());
145 all_scores.push(top_scores);
146
147 // Add class names if mapping is available
148 if let Some(ref mut class_names_vec) = all_class_names {
149 let names = self.map_indexes_to_names(&top_indexes);
150 class_names_vec.push(names);
151 }
152 }
153
154 Ok(TopkResult {
155 indexes: all_indexes,
156 scores: all_scores,
157 class_names: all_class_names,
158 })
159 }
160
161 /// Extracts top-k indexes and scores from a single prediction.
162 ///
163 /// # Arguments
164 ///
165 /// * `prediction` - Vector of confidence scores for all classes.
166 /// * `k` - Number of top predictions to extract.
167 ///
168 /// # Returns
169 ///
170 /// * `(Vec<usize>, Vec<f32>)` - Tuple of (top_indexes, top_scores).
171 fn extract_topk_from_prediction(&self, prediction: &[f32], k: usize) -> (Vec<usize>, Vec<f32>) {
172 // Create pairs of (index, score) and sort by score in descending order
173 let mut indexed_scores: Vec<(usize, f32)> = prediction
174 .iter()
175 .enumerate()
176 .map(|(idx, &score)| (idx, score))
177 .collect();
178
179 // Sort by score in descending order
180 indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
181
182 // Take top k
183 let top_k: Vec<(usize, f32)> = indexed_scores.into_iter().take(k).collect();
184
185 // Separate indexes and scores
186 let (indexes, scores): (Vec<usize>, Vec<f32>) = top_k.into_iter().unzip();
187
188 (indexes, scores)
189 }
190
191 /// Maps class indexes to class names using the internal mapping.
192 ///
193 /// # Arguments
194 ///
195 /// * `indexes` - Vector of class indexes.
196 ///
197 /// # Returns
198 ///
199 /// * `Vec<String>` - Vector of class names. Unknown indexes are mapped to "Unknown".
200 fn map_indexes_to_names(&self, indexes: &[usize]) -> Vec<String> {
201 if let Some(ref class_map) = self.class_id_map {
202 indexes
203 .iter()
204 .map(|&idx| {
205 class_map
206 .get(&idx)
207 .cloned()
208 .unwrap_or_else(|| format!("Unknown({})", idx))
209 })
210 .collect()
211 } else {
212 indexes.iter().map(|&idx| idx.to_string()).collect()
213 }
214 }
215
216 /// Gets the class name for a given class ID.
217 ///
218 /// # Arguments
219 ///
220 /// * `class_id` - The class ID to look up.
221 ///
222 /// # Returns
223 ///
224 /// * `Option<&String>` - The class name if available.
225 pub fn get_class_name(&self, class_id: usize) -> Option<&String> {
226 self.class_id_map.as_ref()?.get(&class_id)
227 }
228
229 /// Checks if class name mapping is available.
230 ///
231 /// # Returns
232 ///
233 /// * `true` - If class name mapping is available.
234 /// * `false` - If no class name mapping is available.
235 pub fn has_class_names(&self) -> bool {
236 self.class_id_map.is_some()
237 }
238
239 /// Gets the number of classes in the mapping.
240 ///
241 /// # Returns
242 ///
243 /// * `Option<usize>` - Number of classes if mapping is available.
244 pub fn num_classes(&self) -> Option<usize> {
245 self.class_id_map.as_ref().map(|map| map.len())
246 }
247
248 /// Updates the class name mapping.
249 ///
250 /// # Arguments
251 ///
252 /// * `class_id_map` - New class ID to name mapping.
253 pub fn set_class_mapping(&mut self, class_id_map: Option<HashMap<usize, String>>) {
254 self.class_id_map = class_id_map;
255 }
256
257 /// Processes a single prediction vector.
258 ///
259 /// # Arguments
260 ///
261 /// * `prediction` - Vector of confidence scores for all classes.
262 /// * `k` - Number of top predictions to extract.
263 ///
264 /// # Returns
265 ///
266 /// * `Ok(TopkResult)` - The top-k results for the single prediction.
267 /// * `Err(String)` - If k is 0 or if the input is invalid.
268 pub fn process_single(&self, prediction: &[f32], k: usize) -> Result<TopkResult, String> {
269 self.process(&[prediction.to_vec()], k)
270 }
271}
272
273impl Default for Topk {
274 /// Creates a default Topk processor without class name mapping.
275 fn default() -> Self {
276 Self::without_class_names()
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 #[test]
285 fn test_topk_without_class_names() {
286 let topk = Topk::without_class_names();
287 let predictions = vec![vec![0.1, 0.8, 0.1], vec![0.7, 0.2, 0.1]];
288
289 let result = topk.process(&predictions, 2).unwrap();
290 assert_eq!(result.indexes.len(), 2);
291 assert_eq!(result.indexes[0], vec![1, 0]); // Class 1 (0.8), Class 0 (0.1)
292 assert_eq!(result.indexes[1], vec![0, 1]); // Class 0 (0.7), Class 1 (0.2)
293 assert!(result.class_names.is_none());
294 }
295
296 #[test]
297 fn test_topk_with_class_names() {
298 let mut class_map = HashMap::new();
299 class_map.insert(0, "cat".to_string());
300 class_map.insert(1, "dog".to_string());
301 class_map.insert(2, "bird".to_string());
302
303 let topk = Topk::new(Some(class_map));
304 let predictions = vec![vec![0.1, 0.8, 0.1]];
305
306 let result = topk.process(&predictions, 2).unwrap();
307 assert_eq!(result.indexes[0], vec![1, 0]);
308 assert_eq!(result.class_names.as_ref().unwrap()[0], vec!["dog", "cat"]);
309 }
310
311 #[test]
312 fn test_topk_from_class_names() {
313 let class_names = vec!["cat".to_string(), "dog".to_string(), "bird".to_string()];
314 let topk = Topk::from_class_names(class_names);
315
316 assert!(topk.has_class_names());
317 assert_eq!(topk.num_classes(), Some(3));
318 assert_eq!(topk.get_class_name(0), Some(&"cat".to_string()));
319 }
320
321 #[test]
322 fn test_topk_k_larger_than_classes() {
323 let topk = Topk::without_class_names();
324 let predictions = vec![vec![0.1, 0.8]]; // Only 2 classes
325
326 let result = topk.process(&predictions, 5).unwrap(); // Ask for 5
327 assert_eq!(result.indexes[0].len(), 2); // Should only get 2
328 }
329
330 #[test]
331 fn test_topk_invalid_k() {
332 let topk = Topk::without_class_names();
333 let predictions = vec![vec![0.1, 0.8, 0.1]];
334
335 assert!(topk.process(&predictions, 0).is_err());
336 }
337
338 #[test]
339 fn test_topk_empty_predictions() {
340 let topk = Topk::without_class_names();
341 let predictions = vec![];
342
343 let result = topk.process(&predictions, 2).unwrap();
344 assert!(result.indexes.is_empty());
345 assert!(result.scores.is_empty());
346 }
347
348 #[test]
349 fn test_process_single() {
350 let topk = Topk::without_class_names();
351 let prediction = vec![0.1, 0.8, 0.1];
352
353 let result = topk.process_single(&prediction, 2).unwrap();
354 assert_eq!(result.indexes.len(), 1);
355 assert_eq!(result.indexes[0], vec![1, 0]);
356 }
357}