oar_ocr_core/processors/
decode.rs

1//! Text decoding utilities for OCR (Optical Character Recognition) systems.
2//!
3//! This module provides implementations for decoding text recognition results,
4//! particularly focused on CTC (Connectionist Temporal Classification) decoding.
5//! It includes structures and methods for converting model predictions into
6//! readable text strings with confidence scores.
7
8use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::HashMap;
11
12/// Decoded batch outputs along with positional metadata.
13pub type PositionedDecodeResult = (
14    Vec<String>,
15    Vec<f32>,
16    Vec<Vec<f32>>,
17    Vec<Vec<usize>>,
18    Vec<usize>,
19);
20
21static ALPHANUMERIC_REGEX: Lazy<Regex> =
22    Lazy::new(|| Regex::new(r"[a-zA-Z0-9 :*./%+-]").expect("Failed to compile regex pattern"));
23
24/// A base decoder for text recognition that handles character mapping and basic decoding operations.
25///
26/// This struct is responsible for converting model predictions into readable text strings.
27/// It maintains a character dictionary for mapping indices to characters and provides
28/// methods for decoding text with optional duplicate removal and confidence scoring.
29///
30/// # Fields
31/// * `reverse` - Flag indicating whether to reverse the text output
32/// * `dict` - A mapping from characters to their indices in the character list
33/// * `character` - A list of characters in the vocabulary, indexed by their position
34pub struct BaseRecLabelDecode {
35    reverse: bool,
36    dict: HashMap<char, usize>,
37    character: Vec<char>,
38}
39
40impl BaseRecLabelDecode {
41    /// Creates a new `BaseRecLabelDecode` instance.
42    ///
43    /// # Arguments
44    /// * `character_str` - An optional string containing the character vocabulary.
45    ///   If None, a default alphanumeric character set is used.
46    /// * `use_space_char` - Whether to include a space character in the vocabulary.
47    ///
48    /// # Returns
49    /// A new `BaseRecLabelDecode` instance.
50    pub fn new(character_str: Option<&str>, use_space_char: bool) -> Self {
51        let mut character_list: Vec<char> = if let Some(chars) = character_str {
52            chars.chars().collect()
53        } else {
54            "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
55        };
56
57        if use_space_char {
58            character_list.push(' ');
59        }
60
61        character_list = Self::add_special_char(character_list);
62
63        let mut dict = HashMap::new();
64        for (i, &char) in character_list.iter().enumerate() {
65            dict.insert(char, i);
66        }
67
68        Self {
69            reverse: false,
70            dict,
71            character: character_list,
72        }
73    }
74
75    /// Creates a new `BaseRecLabelDecode` instance from a list of strings.
76    ///
77    /// # Arguments
78    /// * `character_list` - An optional slice of strings containing the character vocabulary.
79    ///   Only the first character of each string is used. If None, a default alphanumeric
80    ///   character set is used.
81    /// * `use_space_char` - Whether to include a space character in the vocabulary.
82    ///
83    /// # Returns
84    /// A new `BaseRecLabelDecode` instance.
85    pub fn from_string_list(character_list: Option<&[String]>, use_space_char: bool) -> Self {
86        let mut chars: Vec<char> = if let Some(list) = character_list {
87            list.iter().filter_map(|s| s.chars().next()).collect()
88        } else {
89            "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
90        };
91
92        if use_space_char {
93            chars.push(' ');
94        }
95
96        chars = Self::add_special_char(chars);
97
98        let mut dict = HashMap::new();
99        for (i, &char) in chars.iter().enumerate() {
100            dict.insert(char, i);
101        }
102
103        Self {
104            reverse: false,
105            dict,
106            character: chars,
107        }
108    }
109
110    /// Reverses the alphanumeric parts of a string while keeping non-alphanumeric parts in place.
111    ///
112    /// # Arguments
113    /// * `pred` - The input string to process.
114    ///
115    /// # Returns
116    /// A new string with alphanumeric parts reversed.
117    fn pred_reverse(&self, pred: &str) -> String {
118        let mut pred_re = Vec::new();
119        let mut c_current = String::new();
120
121        for c in pred.chars() {
122            if !ALPHANUMERIC_REGEX.is_match(&c.to_string()) {
123                if !c_current.is_empty() {
124                    pred_re.push(c_current.clone());
125                    c_current.clear();
126                }
127                pred_re.push(c.to_string());
128            } else {
129                c_current.push(c);
130            }
131        }
132
133        if !c_current.is_empty() {
134            pred_re.push(c_current);
135        }
136
137        pred_re.reverse();
138        pred_re.join("")
139    }
140
141    /// Adds special characters to the character list.
142    ///
143    /// This is a placeholder method that currently just returns the input list unchanged.
144    /// It can be overridden in subclasses to add special characters.
145    ///
146    /// # Arguments
147    /// * `character_list` - The input character list.
148    ///
149    /// # Returns
150    /// The character list with any special characters added.
151    fn add_special_char(character_list: Vec<char>) -> Vec<char> {
152        character_list
153    }
154
155    /// Gets a list of token indices that should be ignored during decoding.
156    ///
157    /// # Returns
158    /// A vector containing the indices of tokens to ignore.
159    fn get_ignored_tokens(&self) -> Vec<usize> {
160        vec![self.get_blank_idx()]
161    }
162
163    /// Decodes model predictions into text strings with confidence scores.
164    ///
165    /// # Arguments
166    /// * `text_index` - A slice of vectors containing the predicted character indices.
167    /// * `text_prob` - An optional slice of vectors containing the prediction probabilities.
168    /// * `is_remove_duplicate` - Whether to remove consecutive duplicate characters.
169    ///
170    /// # Returns
171    /// A vector of tuples, each containing a decoded text string and its confidence score.
172    pub fn decode(
173        &self,
174        text_index: &[Vec<usize>],
175        text_prob: Option<&[Vec<f32>]>,
176        is_remove_duplicate: bool,
177    ) -> Vec<(String, f32)> {
178        let mut result_list = Vec::new();
179        let ignored_tokens = self.get_ignored_tokens();
180
181        for (batch_idx, indices) in text_index.iter().enumerate() {
182            let mut selection = vec![true; indices.len()];
183
184            if is_remove_duplicate && indices.len() > 1 {
185                for i in 1..indices.len() {
186                    if indices[i] == indices[i - 1] {
187                        selection[i] = false;
188                    }
189                }
190            }
191
192            for &ignored_token in &ignored_tokens {
193                for (i, &idx) in indices.iter().enumerate() {
194                    if idx == ignored_token {
195                        selection[i] = false;
196                    }
197                }
198            }
199
200            let char_list: Vec<char> = indices
201                .iter()
202                .enumerate()
203                .filter(|(i, _)| selection[*i])
204                .filter_map(|(_, &text_id)| self.character.get(text_id).copied())
205                .collect();
206
207            let conf_list: Vec<f32> = if let Some(probs) = text_prob {
208                if batch_idx < probs.len() {
209                    probs[batch_idx]
210                        .iter()
211                        .enumerate()
212                        .filter(|(i, _)| *i < selection.len() && selection[*i])
213                        .map(|(_, &prob)| prob)
214                        .collect()
215                } else {
216                    vec![1.0; char_list.len()]
217                }
218            } else {
219                vec![1.0; char_list.len()]
220            };
221
222            let conf_list = if conf_list.is_empty() {
223                vec![0.0]
224            } else {
225                conf_list
226            };
227
228            let mut text: String = char_list.iter().collect();
229
230            if self.reverse {
231                text = self.pred_reverse(&text);
232            }
233
234            let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
235            result_list.push((text, mean_conf));
236        }
237
238        result_list
239    }
240
241    /// Applies the decoder to a tensor of model predictions.
242    ///
243    /// # Arguments
244    /// * `pred` - A 3D tensor containing the model predictions.
245    ///
246    /// # Returns
247    /// A tuple containing:
248    /// * A vector of decoded text strings
249    /// * A vector of confidence scores for each text string
250    pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
251        if pred.is_empty() {
252            return (Vec::new(), Vec::new());
253        }
254
255        let batch_size = pred.shape()[0];
256        let mut all_texts = Vec::new();
257        let mut all_scores = Vec::new();
258
259        for batch_idx in 0..batch_size {
260            let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
261
262            let mut sequence_idx = Vec::new();
263            let mut sequence_prob = Vec::new();
264
265            for row in preds.outer_iter() {
266                if let Some((idx, &prob)) = row
267                    .iter()
268                    .enumerate()
269                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
270                {
271                    sequence_idx.push(idx);
272                    sequence_prob.push(prob);
273                } else {
274                    sequence_idx.push(0);
275                    sequence_prob.push(0.0);
276                }
277            }
278
279            let text = self.decode(&[sequence_idx], Some(&[sequence_prob]), true);
280
281            for (t, score) in text {
282                all_texts.push(t);
283                all_scores.push(score);
284            }
285        }
286
287        (all_texts, all_scores)
288    }
289
290    /// Gets the index of the blank token.
291    ///
292    /// # Returns
293    /// The index of the blank token (always 0 in this base implementation).
294    fn get_blank_idx(&self) -> usize {
295        0
296    }
297}
298
299/// A decoder for CTC (Connectionist Temporal Classification) based text recognition models.
300///
301/// This struct extends `BaseRecLabelDecode` to provide specialized decoding for CTC models,
302/// which include a blank token that needs to be handled specially during decoding.
303///
304/// # Fields
305/// * `base` - The base decoder that handles character mapping and basic decoding operations
306/// * `blank_index` - The index of the blank token in the character vocabulary
307pub struct CTCLabelDecode {
308    base: BaseRecLabelDecode,
309    blank_index: usize,
310}
311
312impl std::fmt::Debug for CTCLabelDecode {
313    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314        f.debug_struct("CTCLabelDecode")
315            .field("character_count", &self.base.character.len())
316            .field("reverse", &self.base.reverse)
317            .finish()
318    }
319}
320
321impl CTCLabelDecode {
322    /// Creates a new `CTCLabelDecode` instance.
323    ///
324    /// # Arguments
325    /// * `character_list` - An optional string containing the character vocabulary.
326    ///   If None, a default alphanumeric character set is used.
327    /// * `use_space_char` - Whether to include a space character in the vocabulary.
328    ///
329    /// # Returns
330    /// A new `CTCLabelDecode` instance.
331    pub fn new(character_list: Option<&str>, use_space_char: bool) -> Self {
332        let mut base = BaseRecLabelDecode::new(character_list, use_space_char);
333
334        // Use null char for blank to distinguish from actual space
335        let mut new_character = vec!['\0'];
336        new_character.extend(base.character);
337
338        let mut new_dict = HashMap::new();
339        for (i, &char) in new_character.iter().enumerate() {
340            new_dict.insert(char, i);
341        }
342
343        base.character = new_character;
344        base.dict = new_dict;
345
346        let blank_index = 0;
347
348        Self { base, blank_index }
349    }
350
351    /// Creates a new `CTCLabelDecode` instance from a list of strings.
352    ///
353    /// # Arguments
354    /// * `character_list` - An optional slice of strings containing the character vocabulary.
355    ///   Only the first character of each string is used. If None, a default alphanumeric
356    ///   character set is used.
357    /// * `use_space_char` - Whether to include a space character in the vocabulary.
358    /// * `has_explicit_blank` - Whether the character list already includes a blank token.
359    ///
360    /// # Returns
361    /// A new `CTCLabelDecode` instance.
362    pub fn from_string_list(
363        character_list: Option<&[String]>,
364        use_space_char: bool,
365        has_explicit_blank: bool,
366    ) -> Self {
367        if has_explicit_blank {
368            let base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
369            Self {
370                base,
371                blank_index: 0,
372            }
373        } else {
374            let mut base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
375
376            // Use null char for blank to distinguish from actual space
377            let mut new_character = vec!['\0'];
378            new_character.extend(base.character);
379
380            let mut new_dict = HashMap::new();
381            for (i, &char) in new_character.iter().enumerate() {
382                new_dict.insert(char, i);
383            }
384
385            base.character = new_character;
386            base.dict = new_dict;
387
388            Self {
389                base,
390                blank_index: 0,
391            }
392        }
393    }
394
395    /// Gets the index of the blank token.
396    ///
397    /// # Returns
398    /// The index of the blank token.
399    pub fn get_blank_index(&self) -> usize {
400        self.blank_index
401    }
402
403    /// Gets the character list used by this decoder.
404    ///
405    /// # Returns
406    /// A slice containing the characters in the vocabulary.
407    pub fn get_character_list(&self) -> &[char] {
408        &self.base.character
409    }
410
411    /// Gets the number of characters in the vocabulary.
412    ///
413    /// # Returns
414    /// The number of characters in the vocabulary.
415    pub fn get_character_count(&self) -> usize {
416        self.base.character.len()
417    }
418
419    /// Applies the CTC decoder to a tensor of model predictions with character position tracking.
420    ///
421    /// This method handles the special requirements of CTC decoding and additionally tracks
422    /// the timestep positions of each character for word box generation.
423    ///
424    /// # Arguments
425    /// * `pred` - A 3D tensor containing the model predictions.
426    ///
427    /// # Returns
428    /// A tuple containing:
429    /// * A vector of decoded text strings
430    /// * A vector of confidence scores for each text string
431    /// * A vector of character positions (normalized 0.0-1.0) for each text string
432    /// * A vector of column indices for each character in each text string
433    /// * A vector of sequence lengths (total columns) for each text string
434    pub fn apply_with_positions(&self, pred: &crate::core::Tensor3D) -> PositionedDecodeResult {
435        if pred.is_empty() {
436            return (Vec::new(), Vec::new(), Vec::new(), Vec::new(), Vec::new());
437        }
438
439        let batch_size = pred.shape()[0];
440        let mut all_texts = Vec::new();
441        let mut all_scores = Vec::new();
442        let mut all_positions = Vec::new();
443        let mut all_col_indices = Vec::new();
444        let mut all_seq_lengths = Vec::new();
445
446        for batch_idx in 0..batch_size {
447            let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
448            let seq_len = preds.shape()[0] as f32;
449
450            let mut sequence_idx = Vec::new();
451            let mut sequence_prob = Vec::new();
452            let mut sequence_timesteps = Vec::new();
453
454            for (timestep, row) in preds.outer_iter().enumerate() {
455                if let Some((idx, &prob)) = row
456                    .iter()
457                    .enumerate()
458                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
459                {
460                    sequence_idx.push(idx);
461                    sequence_prob.push(prob);
462                    sequence_timesteps.push(timestep);
463                } else {
464                    sequence_idx.push(self.blank_index);
465                    sequence_prob.push(0.0);
466                    sequence_timesteps.push(timestep);
467                }
468            }
469
470            let mut filtered_idx = Vec::new();
471            let mut filtered_prob = Vec::new();
472            let mut filtered_timesteps = Vec::new();
473            let mut selection = vec![true; sequence_idx.len()];
474
475            // Remove consecutive duplicates
476            if sequence_idx.len() > 1 {
477                for i in 1..sequence_idx.len() {
478                    if sequence_idx[i] == sequence_idx[i - 1] {
479                        selection[i] = false;
480                    }
481                }
482            }
483
484            // Remove blanks
485            for (i, &idx) in sequence_idx.iter().enumerate() {
486                if idx == self.blank_index {
487                    selection[i] = false;
488                }
489            }
490
491            // Collect filtered results
492            for (i, &idx) in sequence_idx.iter().enumerate() {
493                if selection[i] {
494                    filtered_idx.push(idx);
495                    filtered_prob.push(sequence_prob[i]);
496                    filtered_timesteps.push(sequence_timesteps[i]);
497                }
498            }
499
500            let char_list: Vec<char> = filtered_idx
501                .iter()
502                .filter_map(|&text_id| self.base.character.get(text_id).copied())
503                .collect();
504
505            let conf_list = if filtered_prob.is_empty() {
506                vec![0.0]
507            } else {
508                filtered_prob
509            };
510
511            // Calculate normalized character positions (0.0 to 1.0)
512            let char_positions: Vec<f32> = filtered_timesteps
513                .iter()
514                .map(|&timestep| timestep as f32 / seq_len)
515                .collect();
516
517            // Store column indices (raw timesteps) for accurate word box generation
518            let col_indices: Vec<usize> = filtered_timesteps.clone();
519
520            let text: String = char_list.iter().collect();
521            let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
522
523            all_texts.push(text);
524            all_scores.push(mean_conf);
525            all_positions.push(char_positions);
526            all_col_indices.push(col_indices);
527            all_seq_lengths.push(seq_len as usize);
528        }
529
530        (
531            all_texts,
532            all_scores,
533            all_positions,
534            all_col_indices,
535            all_seq_lengths,
536        )
537    }
538
539    /// Applies the CTC decoder to a tensor of model predictions.
540    ///
541    /// This method handles the special requirements of CTC decoding:
542    /// 1. Removing blank tokens
543    /// 2. Removing consecutive duplicate characters
544    /// 3. Converting indices to characters
545    /// 4. Calculating confidence scores
546    ///
547    /// # Arguments
548    /// * `pred` - A 3D tensor containing the model predictions.
549    ///
550    /// # Returns
551    /// A tuple containing:
552    /// * A vector of decoded text strings
553    /// * A vector of confidence scores for each text string
554    pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
555        if pred.is_empty() {
556            return (Vec::new(), Vec::new());
557        }
558
559        let batch_size = pred.shape()[0];
560        let mut all_texts = Vec::new();
561        let mut all_scores = Vec::new();
562        let mut batches_with_text = 0;
563
564        for batch_idx in 0..batch_size {
565            let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
566
567            let mut sequence_idx = Vec::new();
568            let mut sequence_prob = Vec::new();
569
570            for row in preds.outer_iter() {
571                if let Some((idx, &prob)) = row
572                    .iter()
573                    .enumerate()
574                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
575                {
576                    sequence_idx.push(idx);
577                    sequence_prob.push(prob);
578                } else {
579                    sequence_idx.push(self.blank_index);
580                    sequence_prob.push(0.0);
581                }
582            }
583
584            let mut filtered_idx = Vec::new();
585            let mut filtered_prob = Vec::new();
586            let mut selection = vec![true; sequence_idx.len()];
587
588            if sequence_idx.len() > 1 {
589                for i in 1..sequence_idx.len() {
590                    if sequence_idx[i] == sequence_idx[i - 1] {
591                        selection[i] = false;
592                    }
593                }
594            }
595
596            for (i, &idx) in sequence_idx.iter().enumerate() {
597                if idx == self.blank_index {
598                    selection[i] = false;
599                }
600            }
601
602            for (i, &idx) in sequence_idx.iter().enumerate() {
603                if selection[i] {
604                    filtered_idx.push(idx);
605                    filtered_prob.push(sequence_prob[i]);
606                }
607            }
608
609            let char_list: Vec<char> = filtered_idx
610                .iter()
611                .filter_map(|&text_id| self.base.character.get(text_id).copied())
612                .collect();
613
614            let conf_list = if filtered_prob.is_empty() {
615                vec![0.0]
616            } else {
617                filtered_prob
618            };
619
620            let text: String = char_list.iter().collect();
621            let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
622
623            if !text.is_empty() {
624                batches_with_text += 1;
625            }
626
627            all_texts.push(text);
628            all_scores.push(mean_conf);
629        }
630
631        // Log summary of decoding results
632        tracing::debug!(
633            "CTC decode summary: batch_size={}, batches_with_text={}, empty_batches={}",
634            batch_size,
635            batches_with_text,
636            batch_size - batches_with_text
637        );
638
639        (all_texts, all_scores)
640    }
641}