oar_ocr/processors/
decode.rs

1//! Text decoding utilities for OCR (Optical Character Recognition) systems.
2//!
3//! This module provides implementations for decoding text recognition results,
4//! particularly focused on CTC (Connectionist Temporal Classification) decoding.
5//! It includes structures and methods for converting model predictions into
6//! readable text strings with confidence scores.
7
8use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::HashMap;
11
12static ALPHANUMERIC_REGEX: Lazy<Regex> =
13    Lazy::new(|| Regex::new(r"[a-zA-Z0-9 :*./%+-]").expect("Failed to compile regex pattern"));
14
15/// A base decoder for text recognition that handles character mapping and basic decoding operations.
16///
17/// This struct is responsible for converting model predictions into readable text strings.
18/// It maintains a character dictionary for mapping indices to characters and provides
19/// methods for decoding text with optional duplicate removal and confidence scoring.
20///
21/// # Fields
22/// * `reverse` - Flag indicating whether to reverse the text output
23/// * `dict` - A mapping from characters to their indices in the character list
24/// * `character` - A list of characters in the vocabulary, indexed by their position
25pub struct BaseRecLabelDecode {
26    reverse: bool,
27    dict: HashMap<char, usize>,
28    character: Vec<char>,
29}
30
31impl BaseRecLabelDecode {
32    /// Creates a new `BaseRecLabelDecode` instance.
33    ///
34    /// # Arguments
35    /// * `character_str` - An optional string containing the character vocabulary.
36    ///   If None, a default alphanumeric character set is used.
37    /// * `use_space_char` - Whether to include a space character in the vocabulary.
38    ///
39    /// # Returns
40    /// A new `BaseRecLabelDecode` instance.
41    pub fn new(character_str: Option<&str>, use_space_char: bool) -> Self {
42        let mut character_list: Vec<char> = if let Some(chars) = character_str {
43            chars.chars().collect()
44        } else {
45            "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
46        };
47
48        if use_space_char {
49            character_list.push(' ');
50        }
51
52        character_list = Self::add_special_char(character_list);
53
54        let mut dict = HashMap::new();
55        for (i, &char) in character_list.iter().enumerate() {
56            dict.insert(char, i);
57        }
58
59        Self {
60            reverse: false,
61            dict,
62            character: character_list,
63        }
64    }
65
66    /// Creates a new `BaseRecLabelDecode` instance from a list of strings.
67    ///
68    /// # Arguments
69    /// * `character_list` - An optional slice of strings containing the character vocabulary.
70    ///   Only the first character of each string is used. If None, a default alphanumeric
71    ///   character set is used.
72    /// * `use_space_char` - Whether to include a space character in the vocabulary.
73    ///
74    /// # Returns
75    /// A new `BaseRecLabelDecode` instance.
76    pub fn from_string_list(character_list: Option<&[String]>, use_space_char: bool) -> Self {
77        let mut chars: Vec<char> = if let Some(list) = character_list {
78            list.iter().filter_map(|s| s.chars().next()).collect()
79        } else {
80            "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
81        };
82
83        if use_space_char {
84            chars.push(' ');
85        }
86
87        chars = Self::add_special_char(chars);
88
89        let mut dict = HashMap::new();
90        for (i, &char) in chars.iter().enumerate() {
91            dict.insert(char, i);
92        }
93
94        Self {
95            reverse: false,
96            dict,
97            character: chars,
98        }
99    }
100
101    /// Reverses the alphanumeric parts of a string while keeping non-alphanumeric parts in place.
102    ///
103    /// # Arguments
104    /// * `pred` - The input string to process.
105    ///
106    /// # Returns
107    /// A new string with alphanumeric parts reversed.
108    fn pred_reverse(&self, pred: &str) -> String {
109        let mut pred_re = Vec::new();
110        let mut c_current = String::new();
111
112        for c in pred.chars() {
113            if !ALPHANUMERIC_REGEX.is_match(&c.to_string()) {
114                if !c_current.is_empty() {
115                    pred_re.push(c_current.clone());
116                    c_current.clear();
117                }
118                pred_re.push(c.to_string());
119            } else {
120                c_current.push(c);
121            }
122        }
123
124        if !c_current.is_empty() {
125            pred_re.push(c_current);
126        }
127
128        pred_re.reverse();
129        pred_re.join("")
130    }
131
132    /// Adds special characters to the character list.
133    ///
134    /// This is a placeholder method that currently just returns the input list unchanged.
135    /// It can be overridden in subclasses to add special characters.
136    ///
137    /// # Arguments
138    /// * `character_list` - The input character list.
139    ///
140    /// # Returns
141    /// The character list with any special characters added.
142    fn add_special_char(character_list: Vec<char>) -> Vec<char> {
143        character_list
144    }
145
146    /// Gets a list of token indices that should be ignored during decoding.
147    ///
148    /// # Returns
149    /// A vector containing the indices of tokens to ignore.
150    fn get_ignored_tokens(&self) -> Vec<usize> {
151        vec![self.get_blank_idx()]
152    }
153
154    /// Decodes model predictions into text strings with confidence scores.
155    ///
156    /// # Arguments
157    /// * `text_index` - A slice of vectors containing the predicted character indices.
158    /// * `text_prob` - An optional slice of vectors containing the prediction probabilities.
159    /// * `is_remove_duplicate` - Whether to remove consecutive duplicate characters.
160    ///
161    /// # Returns
162    /// A vector of tuples, each containing a decoded text string and its confidence score.
163    pub fn decode(
164        &self,
165        text_index: &[Vec<usize>],
166        text_prob: Option<&[Vec<f32>]>,
167        is_remove_duplicate: bool,
168    ) -> Vec<(String, f32)> {
169        let mut result_list = Vec::new();
170        let ignored_tokens = self.get_ignored_tokens();
171
172        for (batch_idx, indices) in text_index.iter().enumerate() {
173            let mut selection = vec![true; indices.len()];
174
175            if is_remove_duplicate && indices.len() > 1 {
176                for i in 1..indices.len() {
177                    if indices[i] == indices[i - 1] {
178                        selection[i] = false;
179                    }
180                }
181            }
182
183            for &ignored_token in &ignored_tokens {
184                for (i, &idx) in indices.iter().enumerate() {
185                    if idx == ignored_token {
186                        selection[i] = false;
187                    }
188                }
189            }
190
191            let char_list: Vec<char> = indices
192                .iter()
193                .enumerate()
194                .filter(|(i, _)| selection[*i])
195                .filter_map(|(_, &text_id)| self.character.get(text_id).copied())
196                .collect();
197
198            let conf_list: Vec<f32> = if let Some(probs) = text_prob {
199                if batch_idx < probs.len() {
200                    probs[batch_idx]
201                        .iter()
202                        .enumerate()
203                        .filter(|(i, _)| *i < selection.len() && selection[*i])
204                        .map(|(_, &prob)| prob)
205                        .collect()
206                } else {
207                    vec![1.0; char_list.len()]
208                }
209            } else {
210                vec![1.0; char_list.len()]
211            };
212
213            let conf_list = if conf_list.is_empty() {
214                vec![0.0]
215            } else {
216                conf_list
217            };
218
219            let mut text: String = char_list.iter().collect();
220
221            if self.reverse {
222                text = self.pred_reverse(&text);
223            }
224
225            let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
226            result_list.push((text, mean_conf));
227        }
228
229        result_list
230    }
231
232    /// Applies the decoder to a tensor of model predictions.
233    ///
234    /// # Arguments
235    /// * `pred` - A 3D tensor containing the model predictions.
236    ///
237    /// # Returns
238    /// A tuple containing:
239    /// * A vector of decoded text strings
240    /// * A vector of confidence scores for each text string
241    pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
242        if pred.is_empty() {
243            return (Vec::new(), Vec::new());
244        }
245
246        let batch_size = pred.shape()[0];
247        let mut all_texts = Vec::new();
248        let mut all_scores = Vec::new();
249
250        for batch_idx in 0..batch_size {
251            let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
252
253            let mut sequence_idx = Vec::new();
254            let mut sequence_prob = Vec::new();
255
256            for row in preds.outer_iter() {
257                if let Some((idx, &prob)) = row
258                    .iter()
259                    .enumerate()
260                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
261                {
262                    sequence_idx.push(idx);
263                    sequence_prob.push(prob);
264                } else {
265                    sequence_idx.push(0);
266                    sequence_prob.push(0.0);
267                }
268            }
269
270            let text = self.decode(&[sequence_idx], Some(&[sequence_prob]), true);
271
272            for (t, score) in text {
273                all_texts.push(t);
274                all_scores.push(score);
275            }
276        }
277
278        (all_texts, all_scores)
279    }
280
281    /// Gets the index of the blank token.
282    ///
283    /// # Returns
284    /// The index of the blank token (always 0 in this base implementation).
285    fn get_blank_idx(&self) -> usize {
286        0
287    }
288}
289
290/// A decoder for CTC (Connectionist Temporal Classification) based text recognition models.
291///
292/// This struct extends `BaseRecLabelDecode` to provide specialized decoding for CTC models,
293/// which include a blank token that needs to be handled specially during decoding.
294///
295/// # Fields
296/// * `base` - The base decoder that handles character mapping and basic decoding operations
297/// * `blank_index` - The index of the blank token in the character vocabulary
298pub struct CTCLabelDecode {
299    base: BaseRecLabelDecode,
300    blank_index: usize,
301}
302
303impl std::fmt::Debug for CTCLabelDecode {
304    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305        f.debug_struct("CTCLabelDecode")
306            .field("character_count", &self.base.character.len())
307            .field("reverse", &self.base.reverse)
308            .finish()
309    }
310}
311
312impl CTCLabelDecode {
313    /// Creates a new `CTCLabelDecode` instance.
314    ///
315    /// # Arguments
316    /// * `character_list` - An optional string containing the character vocabulary.
317    ///   If None, a default alphanumeric character set is used.
318    /// * `use_space_char` - Whether to include a space character in the vocabulary.
319    ///
320    /// # Returns
321    /// A new `CTCLabelDecode` instance.
322    pub fn new(character_list: Option<&str>, use_space_char: bool) -> Self {
323        let mut base = BaseRecLabelDecode::new(character_list, use_space_char);
324
325        let mut new_character = vec![' '];
326        new_character.extend(base.character);
327
328        let mut new_dict = HashMap::new();
329        for (i, &char) in new_character.iter().enumerate() {
330            new_dict.insert(char, i);
331        }
332
333        base.character = new_character;
334        base.dict = new_dict;
335
336        let blank_index = 0;
337
338        Self { base, blank_index }
339    }
340
341    /// Creates a new `CTCLabelDecode` instance from a list of strings.
342    ///
343    /// # Arguments
344    /// * `character_list` - An optional slice of strings containing the character vocabulary.
345    ///   Only the first character of each string is used. If None, a default alphanumeric
346    ///   character set is used.
347    /// * `use_space_char` - Whether to include a space character in the vocabulary.
348    /// * `has_explicit_blank` - Whether the character list already includes a blank token.
349    ///
350    /// # Returns
351    /// A new `CTCLabelDecode` instance.
352    pub fn from_string_list(
353        character_list: Option<&[String]>,
354        use_space_char: bool,
355        has_explicit_blank: bool,
356    ) -> Self {
357        if has_explicit_blank {
358            let base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
359            Self {
360                base,
361                blank_index: 0,
362            }
363        } else {
364            let mut base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
365
366            let mut new_character = vec![' '];
367            new_character.extend(base.character);
368
369            let mut new_dict = HashMap::new();
370            for (i, &char) in new_character.iter().enumerate() {
371                new_dict.insert(char, i);
372            }
373
374            base.character = new_character;
375            base.dict = new_dict;
376
377            Self {
378                base,
379                blank_index: 0,
380            }
381        }
382    }
383
384    /// Gets the index of the blank token.
385    ///
386    /// # Returns
387    /// The index of the blank token.
388    pub fn get_blank_index(&self) -> usize {
389        self.blank_index
390    }
391
392    /// Gets the character list used by this decoder.
393    ///
394    /// # Returns
395    /// A slice containing the characters in the vocabulary.
396    pub fn get_character_list(&self) -> &[char] {
397        &self.base.character
398    }
399
400    /// Gets the number of characters in the vocabulary.
401    ///
402    /// # Returns
403    /// The number of characters in the vocabulary.
404    pub fn get_character_count(&self) -> usize {
405        self.base.character.len()
406    }
407
408    /// Applies the CTC decoder to a tensor of model predictions.
409    ///
410    /// This method handles the special requirements of CTC decoding:
411    /// 1. Removing blank tokens
412    /// 2. Removing consecutive duplicate characters
413    /// 3. Converting indices to characters
414    /// 4. Calculating confidence scores
415    ///
416    /// # Arguments
417    /// * `pred` - A 3D tensor containing the model predictions.
418    ///
419    /// # Returns
420    /// A tuple containing:
421    /// * A vector of decoded text strings
422    /// * A vector of confidence scores for each text string
423    pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
424        if pred.is_empty() {
425            return (Vec::new(), Vec::new());
426        }
427
428        let batch_size = pred.shape()[0];
429        let mut all_texts = Vec::new();
430        let mut all_scores = Vec::new();
431
432        for batch_idx in 0..batch_size {
433            let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
434
435            let mut sequence_idx = Vec::new();
436            let mut sequence_prob = Vec::new();
437
438            for row in preds.outer_iter() {
439                if let Some((idx, &prob)) = row
440                    .iter()
441                    .enumerate()
442                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
443                {
444                    sequence_idx.push(idx);
445                    sequence_prob.push(prob);
446                } else {
447                    sequence_idx.push(self.blank_index);
448                    sequence_prob.push(0.0);
449                }
450            }
451
452            let mut filtered_idx = Vec::new();
453            let mut filtered_prob = Vec::new();
454            let mut selection = vec![true; sequence_idx.len()];
455
456            if sequence_idx.len() > 1 {
457                for i in 1..sequence_idx.len() {
458                    if sequence_idx[i] == sequence_idx[i - 1] {
459                        selection[i] = false;
460                    }
461                }
462            }
463
464            for (i, &idx) in sequence_idx.iter().enumerate() {
465                if idx == self.blank_index {
466                    selection[i] = false;
467                }
468            }
469
470            for (i, &idx) in sequence_idx.iter().enumerate() {
471                if selection[i] {
472                    filtered_idx.push(idx);
473                    filtered_prob.push(sequence_prob[i]);
474                }
475            }
476
477            let char_list: Vec<char> = filtered_idx
478                .iter()
479                .filter_map(|&text_id| self.base.character.get(text_id).copied())
480                .collect();
481
482            let conf_list = if filtered_prob.is_empty() {
483                vec![0.0]
484            } else {
485                filtered_prob
486            };
487
488            let text: String = char_list.iter().collect();
489            let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
490
491            all_texts.push(text);
492            all_scores.push(mean_conf);
493        }
494
495        (all_texts, all_scores)
496    }
497}