Skip to main content

oar_ocr_core/processors/
decode.rs

1//! Text decoding utilities for OCR (Optical Character Recognition) systems.
2//!
3//! This module provides implementations for decoding text recognition results,
4//! particularly focused on CTC (Connectionist Temporal Classification) decoding.
5//! It includes structures and methods for converting model predictions into
6//! readable text strings with confidence scores.
7
8use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::HashMap;
11
12/// Decoded batch outputs along with positional metadata.
13pub type PositionedDecodeResult = (
14    Vec<String>,
15    Vec<f32>,
16    Vec<Vec<f32>>,
17    Vec<Vec<usize>>,
18    Vec<usize>,
19);
20
21static ALPHANUMERIC_REGEX: Lazy<Regex> = Lazy::new(|| {
22    Regex::new(r"[a-zA-Z0-9 :*./%+-]")
23        .unwrap_or_else(|e| panic!("Failed to compile regex pattern: {e}"))
24});
25
26/// A base decoder for text recognition that handles character mapping and basic decoding operations.
27///
28/// This struct is responsible for converting model predictions into readable text strings.
29/// It maintains a character dictionary for mapping indices to characters and provides
30/// methods for decoding text with optional duplicate removal and confidence scoring.
31///
32/// # Fields
33/// * `reverse` - Flag indicating whether to reverse the text output
34/// * `dict` - A mapping from characters to their indices in the character list
35/// * `character` - A list of characters in the vocabulary, indexed by their position
36pub struct BaseRecLabelDecode {
37    reverse: bool,
38    dict: HashMap<char, usize>,
39    character: Vec<char>,
40}
41
42impl BaseRecLabelDecode {
43    /// Creates a new `BaseRecLabelDecode` instance.
44    ///
45    /// # Arguments
46    /// * `character_str` - An optional string containing the character vocabulary.
47    ///   If None, a default alphanumeric character set is used.
48    /// * `use_space_char` - Whether to include a space character in the vocabulary.
49    ///
50    /// # Returns
51    /// A new `BaseRecLabelDecode` instance.
52    pub fn new(character_str: Option<&str>, use_space_char: bool) -> Self {
53        let mut character_list: Vec<char> = if let Some(chars) = character_str {
54            chars.chars().collect()
55        } else {
56            "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
57        };
58
59        if use_space_char {
60            character_list.push(' ');
61        }
62
63        character_list = Self::add_special_char(character_list);
64
65        let mut dict = HashMap::new();
66        for (i, &char) in character_list.iter().enumerate() {
67            dict.insert(char, i);
68        }
69
70        Self {
71            reverse: false,
72            dict,
73            character: character_list,
74        }
75    }
76
77    /// Creates a new `BaseRecLabelDecode` instance from a list of strings.
78    ///
79    /// # Arguments
80    /// * `character_list` - An optional slice of strings containing the character vocabulary.
81    ///   Only the first character of each string is used. If None, a default alphanumeric
82    ///   character set is used.
83    /// * `use_space_char` - Whether to include a space character in the vocabulary.
84    ///
85    /// # Returns
86    /// A new `BaseRecLabelDecode` instance.
87    pub fn from_string_list(character_list: Option<&[String]>, use_space_char: bool) -> Self {
88        let mut chars: Vec<char> = if let Some(list) = character_list {
89            list.iter().filter_map(|s| s.chars().next()).collect()
90        } else {
91            "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
92        };
93
94        if use_space_char {
95            chars.push(' ');
96        }
97
98        chars = Self::add_special_char(chars);
99
100        let mut dict = HashMap::new();
101        for (i, &char) in chars.iter().enumerate() {
102            dict.insert(char, i);
103        }
104
105        Self {
106            reverse: false,
107            dict,
108            character: chars,
109        }
110    }
111
112    /// Reverses the alphanumeric parts of a string while keeping non-alphanumeric parts in place.
113    ///
114    /// # Arguments
115    /// * `pred` - The input string to process.
116    ///
117    /// # Returns
118    /// A new string with alphanumeric parts reversed.
119    fn pred_reverse(&self, pred: &str) -> String {
120        let mut pred_re = Vec::new();
121        let mut c_current = String::new();
122
123        for c in pred.chars() {
124            if !ALPHANUMERIC_REGEX.is_match(&c.to_string()) {
125                if !c_current.is_empty() {
126                    pred_re.push(c_current.clone());
127                    c_current.clear();
128                }
129                pred_re.push(c.to_string());
130            } else {
131                c_current.push(c);
132            }
133        }
134
135        if !c_current.is_empty() {
136            pred_re.push(c_current);
137        }
138
139        pred_re.reverse();
140        pred_re.join("")
141    }
142
143    /// Adds special characters to the character list.
144    ///
145    /// This is a placeholder method that currently just returns the input list unchanged.
146    /// It can be overridden in subclasses to add special characters.
147    ///
148    /// # Arguments
149    /// * `character_list` - The input character list.
150    ///
151    /// # Returns
152    /// The character list with any special characters added.
153    fn add_special_char(character_list: Vec<char>) -> Vec<char> {
154        character_list
155    }
156
157    /// Gets a list of token indices that should be ignored during decoding.
158    ///
159    /// # Returns
160    /// A vector containing the indices of tokens to ignore.
161    fn get_ignored_tokens(&self) -> Vec<usize> {
162        vec![self.get_blank_idx()]
163    }
164
165    /// Decodes model predictions into text strings with confidence scores.
166    ///
167    /// # Arguments
168    /// * `text_index` - A slice of vectors containing the predicted character indices.
169    /// * `text_prob` - An optional slice of vectors containing the prediction probabilities.
170    /// * `is_remove_duplicate` - Whether to remove consecutive duplicate characters.
171    ///
172    /// # Returns
173    /// A vector of tuples, each containing a decoded text string and its confidence score.
174    pub fn decode(
175        &self,
176        text_index: &[Vec<usize>],
177        text_prob: Option<&[Vec<f32>]>,
178        is_remove_duplicate: bool,
179    ) -> Vec<(String, f32)> {
180        let mut result_list = Vec::new();
181        let ignored_tokens = self.get_ignored_tokens();
182
183        for (batch_idx, indices) in text_index.iter().enumerate() {
184            let mut selection = vec![true; indices.len()];
185
186            if is_remove_duplicate && indices.len() > 1 {
187                for i in 1..indices.len() {
188                    if indices[i] == indices[i - 1] {
189                        selection[i] = false;
190                    }
191                }
192            }
193
194            for &ignored_token in &ignored_tokens {
195                for (i, &idx) in indices.iter().enumerate() {
196                    if idx == ignored_token {
197                        selection[i] = false;
198                    }
199                }
200            }
201
202            let char_list: Vec<char> = indices
203                .iter()
204                .enumerate()
205                .filter(|(i, _)| selection[*i])
206                .filter_map(|(_, &text_id)| self.character.get(text_id).copied())
207                .collect();
208
209            let conf_list: Vec<f32> = if let Some(probs) = text_prob {
210                if batch_idx < probs.len() {
211                    probs[batch_idx]
212                        .iter()
213                        .enumerate()
214                        .filter(|(i, _)| *i < selection.len() && selection[*i])
215                        .map(|(_, &prob)| prob)
216                        .collect()
217                } else {
218                    vec![1.0; char_list.len()]
219                }
220            } else {
221                vec![1.0; char_list.len()]
222            };
223
224            let conf_list = if conf_list.is_empty() {
225                vec![0.0]
226            } else {
227                conf_list
228            };
229
230            let mut text: String = char_list.iter().collect();
231
232            if self.reverse {
233                text = self.pred_reverse(&text);
234            }
235
236            let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
237            result_list.push((text, mean_conf));
238        }
239
240        result_list
241    }
242
243    /// Applies the decoder to a tensor of model predictions.
244    ///
245    /// # Arguments
246    /// * `pred` - A 3D tensor containing the model predictions.
247    ///
248    /// # Returns
249    /// A tuple containing:
250    /// * A vector of decoded text strings
251    /// * A vector of confidence scores for each text string
252    pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
253        if pred.is_empty() {
254            return (Vec::new(), Vec::new());
255        }
256
257        let batch_size = pred.shape()[0];
258        let mut all_texts = Vec::new();
259        let mut all_scores = Vec::new();
260
261        for batch_idx in 0..batch_size {
262            let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
263
264            let mut sequence_idx = Vec::new();
265            let mut sequence_prob = Vec::new();
266
267            for row in preds.outer_iter() {
268                if let Some((idx, &prob)) = row
269                    .iter()
270                    .enumerate()
271                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
272                {
273                    sequence_idx.push(idx);
274                    sequence_prob.push(prob);
275                } else {
276                    sequence_idx.push(0);
277                    sequence_prob.push(0.0);
278                }
279            }
280
281            let text = self.decode(&[sequence_idx], Some(&[sequence_prob]), true);
282
283            for (t, score) in text {
284                all_texts.push(t);
285                all_scores.push(score);
286            }
287        }
288
289        (all_texts, all_scores)
290    }
291
292    /// Gets the index of the blank token.
293    ///
294    /// # Returns
295    /// The index of the blank token (always 0 in this base implementation).
296    fn get_blank_idx(&self) -> usize {
297        0
298    }
299}
300
301/// A decoder for CTC (Connectionist Temporal Classification) based text recognition models.
302///
303/// This struct extends `BaseRecLabelDecode` to provide specialized decoding for CTC models,
304/// which include a blank token that needs to be handled specially during decoding.
305///
306/// # Fields
307/// * `base` - The base decoder that handles character mapping and basic decoding operations
308/// * `blank_index` - The index of the blank token in the character vocabulary
309pub struct CTCLabelDecode {
310    base: BaseRecLabelDecode,
311    blank_index: usize,
312}
313
314impl std::fmt::Debug for CTCLabelDecode {
315    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316        f.debug_struct("CTCLabelDecode")
317            .field("character_count", &self.base.character.len())
318            .field("reverse", &self.base.reverse)
319            .finish()
320    }
321}
322
323impl CTCLabelDecode {
324    /// Creates a new `CTCLabelDecode` instance.
325    ///
326    /// # Arguments
327    /// * `character_list` - An optional string containing the character vocabulary.
328    ///   If None, a default alphanumeric character set is used.
329    /// * `use_space_char` - Whether to include a space character in the vocabulary.
330    ///
331    /// # Returns
332    /// A new `CTCLabelDecode` instance.
333    pub fn new(character_list: Option<&str>, use_space_char: bool) -> Self {
334        let mut base = BaseRecLabelDecode::new(character_list, use_space_char);
335
336        // Use null char for blank to distinguish from actual space
337        let mut new_character = vec!['\0'];
338        new_character.extend(base.character);
339
340        let mut new_dict = HashMap::new();
341        for (i, &char) in new_character.iter().enumerate() {
342            new_dict.insert(char, i);
343        }
344
345        base.character = new_character;
346        base.dict = new_dict;
347
348        let blank_index = 0;
349
350        Self { base, blank_index }
351    }
352
353    /// Creates a new `CTCLabelDecode` instance from a list of strings.
354    ///
355    /// # Arguments
356    /// * `character_list` - An optional slice of strings containing the character vocabulary.
357    ///   Only the first character of each string is used. If None, a default alphanumeric
358    ///   character set is used.
359    /// * `use_space_char` - Whether to include a space character in the vocabulary.
360    /// * `has_explicit_blank` - Whether the character list already includes a blank token.
361    ///
362    /// # Returns
363    /// A new `CTCLabelDecode` instance.
364    pub fn from_string_list(
365        character_list: Option<&[String]>,
366        use_space_char: bool,
367        has_explicit_blank: bool,
368    ) -> Self {
369        if has_explicit_blank {
370            let base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
371            Self {
372                base,
373                blank_index: 0,
374            }
375        } else {
376            let mut base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
377
378            // Use null char for blank to distinguish from actual space
379            let mut new_character = vec!['\0'];
380            new_character.extend(base.character);
381
382            let mut new_dict = HashMap::new();
383            for (i, &char) in new_character.iter().enumerate() {
384                new_dict.insert(char, i);
385            }
386
387            base.character = new_character;
388            base.dict = new_dict;
389
390            Self {
391                base,
392                blank_index: 0,
393            }
394        }
395    }
396
397    /// Gets the index of the blank token.
398    ///
399    /// # Returns
400    /// The index of the blank token.
401    pub fn get_blank_index(&self) -> usize {
402        self.blank_index
403    }
404
405    /// Gets the character list used by this decoder.
406    ///
407    /// # Returns
408    /// A slice containing the characters in the vocabulary.
409    pub fn get_character_list(&self) -> &[char] {
410        &self.base.character
411    }
412
413    /// Gets the number of characters in the vocabulary.
414    ///
415    /// # Returns
416    /// The number of characters in the vocabulary.
417    pub fn get_character_count(&self) -> usize {
418        self.base.character.len()
419    }
420
421    /// Applies the CTC decoder to a tensor of model predictions with character position tracking.
422    ///
423    /// This method handles the special requirements of CTC decoding and additionally tracks
424    /// the timestep positions of each character for word box generation.
425    ///
426    /// # Arguments
427    /// * `pred` - A 3D tensor containing the model predictions.
428    ///
429    /// # Returns
430    /// A tuple containing:
431    /// * A vector of decoded text strings
432    /// * A vector of confidence scores for each text string
433    /// * A vector of character positions (normalized 0.0-1.0) for each text string
434    /// * A vector of column indices for each character in each text string
435    /// * A vector of sequence lengths (total columns) for each text string
436    pub fn apply_with_positions(&self, pred: &crate::core::Tensor3D) -> PositionedDecodeResult {
437        if pred.is_empty() {
438            return (Vec::new(), Vec::new(), Vec::new(), Vec::new(), Vec::new());
439        }
440
441        let batch_size = pred.shape()[0];
442        let mut all_texts = Vec::new();
443        let mut all_scores = Vec::new();
444        let mut all_positions = Vec::new();
445        let mut all_col_indices = Vec::new();
446        let mut all_seq_lengths = Vec::new();
447
448        for batch_idx in 0..batch_size {
449            let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
450            let seq_len = preds.shape()[0] as f32;
451
452            let mut sequence_idx = Vec::new();
453            let mut sequence_prob = Vec::new();
454            let mut sequence_timesteps = Vec::new();
455
456            for (timestep, row) in preds.outer_iter().enumerate() {
457                if let Some((idx, &prob)) = row
458                    .iter()
459                    .enumerate()
460                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
461                {
462                    sequence_idx.push(idx);
463                    sequence_prob.push(prob);
464                    sequence_timesteps.push(timestep);
465                } else {
466                    sequence_idx.push(self.blank_index);
467                    sequence_prob.push(0.0);
468                    sequence_timesteps.push(timestep);
469                }
470            }
471
472            let mut filtered_idx = Vec::new();
473            let mut filtered_prob = Vec::new();
474            let mut filtered_timesteps = Vec::new();
475            let mut selection = vec![true; sequence_idx.len()];
476
477            // Remove consecutive duplicates
478            if sequence_idx.len() > 1 {
479                for i in 1..sequence_idx.len() {
480                    if sequence_idx[i] == sequence_idx[i - 1] {
481                        selection[i] = false;
482                    }
483                }
484            }
485
486            // Remove blanks
487            for (i, &idx) in sequence_idx.iter().enumerate() {
488                if idx == self.blank_index {
489                    selection[i] = false;
490                }
491            }
492
493            // Collect filtered results
494            for (i, &idx) in sequence_idx.iter().enumerate() {
495                if selection[i] {
496                    filtered_idx.push(idx);
497                    filtered_prob.push(sequence_prob[i]);
498                    filtered_timesteps.push(sequence_timesteps[i]);
499                }
500            }
501
502            let char_list: Vec<char> = filtered_idx
503                .iter()
504                .filter_map(|&text_id| self.base.character.get(text_id).copied())
505                .collect();
506
507            let conf_list = if filtered_prob.is_empty() {
508                vec![0.0]
509            } else {
510                filtered_prob
511            };
512
513            // Calculate normalized character positions (0.0 to 1.0)
514            let char_positions: Vec<f32> = filtered_timesteps
515                .iter()
516                .map(|&timestep| timestep as f32 / seq_len)
517                .collect();
518
519            // Store column indices (raw timesteps) for accurate word box generation
520            let col_indices: Vec<usize> = filtered_timesteps.clone();
521
522            let text: String = char_list.iter().collect();
523            let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
524
525            all_texts.push(text);
526            all_scores.push(mean_conf);
527            all_positions.push(char_positions);
528            all_col_indices.push(col_indices);
529            all_seq_lengths.push(seq_len as usize);
530        }
531
532        (
533            all_texts,
534            all_scores,
535            all_positions,
536            all_col_indices,
537            all_seq_lengths,
538        )
539    }
540
541    /// Applies the CTC decoder to a tensor of model predictions.
542    ///
543    /// This method handles the special requirements of CTC decoding:
544    /// 1. Removing blank tokens
545    /// 2. Removing consecutive duplicate characters
546    /// 3. Converting indices to characters
547    /// 4. Calculating confidence scores
548    ///
549    /// # Arguments
550    /// * `pred` - A 3D tensor containing the model predictions.
551    ///
552    /// # Returns
553    /// A tuple containing:
554    /// * A vector of decoded text strings
555    /// * A vector of confidence scores for each text string
556    pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
557        if pred.is_empty() {
558            return (Vec::new(), Vec::new());
559        }
560
561        let batch_size = pred.shape()[0];
562        let mut all_texts = Vec::new();
563        let mut all_scores = Vec::new();
564        let mut batches_with_text = 0;
565
566        for batch_idx in 0..batch_size {
567            let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
568
569            let mut sequence_idx = Vec::new();
570            let mut sequence_prob = Vec::new();
571
572            for row in preds.outer_iter() {
573                if let Some((idx, &prob)) = row
574                    .iter()
575                    .enumerate()
576                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
577                {
578                    sequence_idx.push(idx);
579                    sequence_prob.push(prob);
580                } else {
581                    sequence_idx.push(self.blank_index);
582                    sequence_prob.push(0.0);
583                }
584            }
585
586            let mut filtered_idx = Vec::new();
587            let mut filtered_prob = Vec::new();
588            let mut selection = vec![true; sequence_idx.len()];
589
590            if sequence_idx.len() > 1 {
591                for i in 1..sequence_idx.len() {
592                    if sequence_idx[i] == sequence_idx[i - 1] {
593                        selection[i] = false;
594                    }
595                }
596            }
597
598            for (i, &idx) in sequence_idx.iter().enumerate() {
599                if idx == self.blank_index {
600                    selection[i] = false;
601                }
602            }
603
604            for (i, &idx) in sequence_idx.iter().enumerate() {
605                if selection[i] {
606                    filtered_idx.push(idx);
607                    filtered_prob.push(sequence_prob[i]);
608                }
609            }
610
611            let char_list: Vec<char> = filtered_idx
612                .iter()
613                .filter_map(|&text_id| self.base.character.get(text_id).copied())
614                .collect();
615
616            let conf_list = if filtered_prob.is_empty() {
617                vec![0.0]
618            } else {
619                filtered_prob
620            };
621
622            let text: String = char_list.iter().collect();
623            let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
624
625            if !text.is_empty() {
626                batches_with_text += 1;
627            }
628
629            all_texts.push(text);
630            all_scores.push(mean_conf);
631        }
632
633        // Log summary of decoding results
634        tracing::debug!(
635            "CTC decode summary: batch_size={}, batches_with_text={}, empty_batches={}",
636            batch_size,
637            batches_with_text,
638            batch_size - batches_with_text
639        );
640
641        (all_texts, all_scores)
642    }
643}