oar_ocr/utils/
topk.rs

1//! Top-k classification result processing.
2
3use std::collections::HashMap;
4
5/// Result structure for top-k classification processing.
6///
7/// Contains the top-k class indexes and their corresponding confidence scores
8/// for each prediction in a batch.
9#[derive(Debug, Clone)]
10pub struct TopkResult {
11    /// Vector of vectors containing the class indexes for each prediction.
12    /// Each inner vector contains the top-k class indexes for one prediction.
13    pub indexes: Vec<Vec<usize>>,
14    /// Vector of vectors containing the confidence scores for each prediction.
15    /// Each inner vector contains the top-k scores corresponding to the indexes.
16    pub scores: Vec<Vec<f32>>,
17    /// Optional vector of vectors containing class names for each prediction.
18    /// Only populated if class name mapping is provided.
19    pub class_names: Option<Vec<Vec<String>>>,
20}
21
22/// A processor for extracting top-k results from classification outputs.
23///
24/// The `Topk` struct processes classification model outputs to extract the
25/// top-k most confident predictions along with their class names (if available).
26#[derive(Debug)]
27pub struct Topk {
28    /// Optional mapping from class IDs to class names.
29    class_id_map: Option<HashMap<usize, String>>,
30}
31
32impl Topk {
33    /// Creates a new Topk processor with optional class name mapping.
34    ///
35    /// # Arguments
36    ///
37    /// * `class_id_map` - Optional mapping from class IDs to human-readable class names.
38    ///
39    /// # Examples
40    ///
41    /// ```rust,no_run
42    /// use std::collections::HashMap;
43    /// use oar_ocr::utils::topk::Topk;
44    ///
45    /// let mut class_map = HashMap::new();
46    /// class_map.insert(0, "cat".to_string());
47    /// class_map.insert(1, "dog".to_string());
48    ///
49    /// let topk = Topk::new(Some(class_map));
50    /// ```
51    pub fn new(class_id_map: Option<HashMap<usize, String>>) -> Self {
52        Self { class_id_map }
53    }
54
55    /// Creates a new Topk processor without class name mapping.
56    ///
57    /// # Examples
58    ///
59    /// ```rust,no_run
60    /// use oar_ocr::utils::topk::Topk;
61    ///
62    /// let topk = Topk::without_class_names();
63    /// ```
64    pub fn without_class_names() -> Self {
65        Self::new(None)
66    }
67
68    /// Creates a new Topk processor with class names from a vector.
69    ///
70    /// The vector index corresponds to the class ID.
71    ///
72    /// # Arguments
73    ///
74    /// * `class_names` - Vector of class names where index = class ID.
75    ///
76    /// # Examples
77    ///
78    /// ```rust,no_run
79    /// use oar_ocr::utils::topk::Topk;
80    ///
81    /// let class_names = vec!["cat".to_string(), "dog".to_string(), "bird".to_string()];
82    /// let topk = Topk::from_class_names(class_names);
83    /// ```
84    pub fn from_class_names(class_names: Vec<String>) -> Self {
85        let class_id_map: HashMap<usize, String> = class_names.into_iter().enumerate().collect();
86        Self::new(Some(class_id_map))
87    }
88
89    /// Processes classification outputs to extract top-k results.
90    ///
91    /// # Arguments
92    ///
93    /// * `predictions` - 2D vector where each inner vector contains the confidence
94    ///   scores for all classes for one prediction.
95    /// * `k` - Number of top predictions to extract (must be > 0).
96    ///
97    /// # Returns
98    ///
99    /// * `Ok(TopkResult)` - The top-k results with indexes, scores, and optional class names.
100    /// * `Err(String)` - If k is 0 or if the input is invalid.
101    ///
102    /// # Examples
103    ///
104    /// ```rust,no_run
105    /// use oar_ocr::utils::topk::Topk;
106    ///
107    /// let topk = Topk::without_class_names();
108    /// let predictions = vec![
109    ///     vec![0.1, 0.8, 0.1],  // Prediction 1: class 1 has highest score
110    ///     vec![0.7, 0.2, 0.1],  // Prediction 2: class 0 has highest score
111    /// ];
112    /// let result = topk.process(&predictions, 2).unwrap();
113    /// ```
114    pub fn process(&self, predictions: &[Vec<f32>], k: usize) -> Result<TopkResult, String> {
115        if k == 0 {
116            return Err("k must be greater than 0".to_string());
117        }
118
119        if predictions.is_empty() {
120            return Ok(TopkResult {
121                indexes: vec![],
122                scores: vec![],
123                class_names: None,
124            });
125        }
126
127        let mut all_indexes = Vec::new();
128        let mut all_scores = Vec::new();
129        let mut all_class_names = if self.class_id_map.is_some() {
130            Some(Vec::new())
131        } else {
132            None
133        };
134
135        for prediction in predictions {
136            if prediction.is_empty() {
137                return Err("Empty prediction vector".to_string());
138            }
139
140            let effective_k = k.min(prediction.len());
141            let (top_indexes, top_scores) =
142                self.extract_topk_from_prediction(prediction, effective_k);
143
144            all_indexes.push(top_indexes.clone());
145            all_scores.push(top_scores);
146
147            // Add class names if mapping is available
148            if let Some(ref mut class_names_vec) = all_class_names {
149                let names = self.map_indexes_to_names(&top_indexes);
150                class_names_vec.push(names);
151            }
152        }
153
154        Ok(TopkResult {
155            indexes: all_indexes,
156            scores: all_scores,
157            class_names: all_class_names,
158        })
159    }
160
161    /// Extracts top-k indexes and scores from a single prediction.
162    ///
163    /// # Arguments
164    ///
165    /// * `prediction` - Vector of confidence scores for all classes.
166    /// * `k` - Number of top predictions to extract.
167    ///
168    /// # Returns
169    ///
170    /// * `(Vec<usize>, Vec<f32>)` - Tuple of (top_indexes, top_scores).
171    fn extract_topk_from_prediction(&self, prediction: &[f32], k: usize) -> (Vec<usize>, Vec<f32>) {
172        // Create pairs of (index, score) and sort by score in descending order
173        let mut indexed_scores: Vec<(usize, f32)> = prediction
174            .iter()
175            .enumerate()
176            .map(|(idx, &score)| (idx, score))
177            .collect();
178
179        // Sort by score in descending order
180        indexed_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
181
182        // Take top k
183        let top_k: Vec<(usize, f32)> = indexed_scores.into_iter().take(k).collect();
184
185        // Separate indexes and scores
186        let (indexes, scores): (Vec<usize>, Vec<f32>) = top_k.into_iter().unzip();
187
188        (indexes, scores)
189    }
190
191    /// Maps class indexes to class names using the internal mapping.
192    ///
193    /// # Arguments
194    ///
195    /// * `indexes` - Vector of class indexes.
196    ///
197    /// # Returns
198    ///
199    /// * `Vec<String>` - Vector of class names. Unknown indexes are mapped to "Unknown".
200    fn map_indexes_to_names(&self, indexes: &[usize]) -> Vec<String> {
201        if let Some(ref class_map) = self.class_id_map {
202            indexes
203                .iter()
204                .map(|&idx| {
205                    class_map
206                        .get(&idx)
207                        .cloned()
208                        .unwrap_or_else(|| format!("Unknown({})", idx))
209                })
210                .collect()
211        } else {
212            indexes.iter().map(|&idx| idx.to_string()).collect()
213        }
214    }
215
216    /// Gets the class name for a given class ID.
217    ///
218    /// # Arguments
219    ///
220    /// * `class_id` - The class ID to look up.
221    ///
222    /// # Returns
223    ///
224    /// * `Option<&String>` - The class name if available.
225    pub fn get_class_name(&self, class_id: usize) -> Option<&String> {
226        self.class_id_map.as_ref()?.get(&class_id)
227    }
228
229    /// Checks if class name mapping is available.
230    ///
231    /// # Returns
232    ///
233    /// * `true` - If class name mapping is available.
234    /// * `false` - If no class name mapping is available.
235    pub fn has_class_names(&self) -> bool {
236        self.class_id_map.is_some()
237    }
238
239    /// Gets the number of classes in the mapping.
240    ///
241    /// # Returns
242    ///
243    /// * `Option<usize>` - Number of classes if mapping is available.
244    pub fn num_classes(&self) -> Option<usize> {
245        self.class_id_map.as_ref().map(|map| map.len())
246    }
247
248    /// Updates the class name mapping.
249    ///
250    /// # Arguments
251    ///
252    /// * `class_id_map` - New class ID to name mapping.
253    pub fn set_class_mapping(&mut self, class_id_map: Option<HashMap<usize, String>>) {
254        self.class_id_map = class_id_map;
255    }
256
257    /// Processes a single prediction vector.
258    ///
259    /// # Arguments
260    ///
261    /// * `prediction` - Vector of confidence scores for all classes.
262    /// * `k` - Number of top predictions to extract.
263    ///
264    /// # Returns
265    ///
266    /// * `Ok(TopkResult)` - The top-k results for the single prediction.
267    /// * `Err(String)` - If k is 0 or if the input is invalid.
268    pub fn process_single(&self, prediction: &[f32], k: usize) -> Result<TopkResult, String> {
269        self.process(&[prediction.to_vec()], k)
270    }
271}
272
273impl Default for Topk {
274    /// Creates a default Topk processor without class name mapping.
275    fn default() -> Self {
276        Self::without_class_names()
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    #[test]
285    fn test_topk_without_class_names() {
286        let topk = Topk::without_class_names();
287        let predictions = vec![vec![0.1, 0.8, 0.1], vec![0.7, 0.2, 0.1]];
288
289        let result = topk.process(&predictions, 2).unwrap();
290        assert_eq!(result.indexes.len(), 2);
291        assert_eq!(result.indexes[0], vec![1, 0]); // Class 1 (0.8), Class 0 (0.1)
292        assert_eq!(result.indexes[1], vec![0, 1]); // Class 0 (0.7), Class 1 (0.2)
293        assert!(result.class_names.is_none());
294    }
295
296    #[test]
297    fn test_topk_with_class_names() {
298        let mut class_map = HashMap::new();
299        class_map.insert(0, "cat".to_string());
300        class_map.insert(1, "dog".to_string());
301        class_map.insert(2, "bird".to_string());
302
303        let topk = Topk::new(Some(class_map));
304        let predictions = vec![vec![0.1, 0.8, 0.1]];
305
306        let result = topk.process(&predictions, 2).unwrap();
307        assert_eq!(result.indexes[0], vec![1, 0]);
308        assert_eq!(result.class_names.as_ref().unwrap()[0], vec!["dog", "cat"]);
309    }
310
311    #[test]
312    fn test_topk_from_class_names() {
313        let class_names = vec!["cat".to_string(), "dog".to_string(), "bird".to_string()];
314        let topk = Topk::from_class_names(class_names);
315
316        assert!(topk.has_class_names());
317        assert_eq!(topk.num_classes(), Some(3));
318        assert_eq!(topk.get_class_name(0), Some(&"cat".to_string()));
319    }
320
321    #[test]
322    fn test_topk_k_larger_than_classes() {
323        let topk = Topk::without_class_names();
324        let predictions = vec![vec![0.1, 0.8]]; // Only 2 classes
325
326        let result = topk.process(&predictions, 5).unwrap(); // Ask for 5
327        assert_eq!(result.indexes[0].len(), 2); // Should only get 2
328    }
329
330    #[test]
331    fn test_topk_invalid_k() {
332        let topk = Topk::without_class_names();
333        let predictions = vec![vec![0.1, 0.8, 0.1]];
334
335        assert!(topk.process(&predictions, 0).is_err());
336    }
337
338    #[test]
339    fn test_topk_empty_predictions() {
340        let topk = Topk::without_class_names();
341        let predictions = vec![];
342
343        let result = topk.process(&predictions, 2).unwrap();
344        assert!(result.indexes.is_empty());
345        assert!(result.scores.is_empty());
346    }
347
348    #[test]
349    fn test_process_single() {
350        let topk = Topk::without_class_names();
351        let prediction = vec![0.1, 0.8, 0.1];
352
353        let result = topk.process_single(&prediction, 2).unwrap();
354        assert_eq!(result.indexes.len(), 1);
355        assert_eq!(result.indexes[0], vec![1, 0]);
356    }
357}