oar_ocr/processors/decode.rs
1//! Text decoding utilities for OCR (Optical Character Recognition) systems.
2//!
3//! This module provides implementations for decoding text recognition results,
4//! particularly focused on CTC (Connectionist Temporal Classification) decoding.
5//! It includes structures and methods for converting model predictions into
6//! readable text strings with confidence scores.
7
8use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::HashMap;
11
12static ALPHANUMERIC_REGEX: Lazy<Regex> =
13 Lazy::new(|| Regex::new(r"[a-zA-Z0-9 :*./%+-]").expect("Failed to compile regex pattern"));
14
15/// A base decoder for text recognition that handles character mapping and basic decoding operations.
16///
17/// This struct is responsible for converting model predictions into readable text strings.
18/// It maintains a character dictionary for mapping indices to characters and provides
19/// methods for decoding text with optional duplicate removal and confidence scoring.
20///
21/// # Fields
22/// * `reverse` - Flag indicating whether to reverse the text output
23/// * `dict` - A mapping from characters to their indices in the character list
24/// * `character` - A list of characters in the vocabulary, indexed by their position
25pub struct BaseRecLabelDecode {
26 reverse: bool,
27 dict: HashMap<char, usize>,
28 character: Vec<char>,
29}
30
31impl BaseRecLabelDecode {
32 /// Creates a new `BaseRecLabelDecode` instance.
33 ///
34 /// # Arguments
35 /// * `character_str` - An optional string containing the character vocabulary.
36 /// If None, a default alphanumeric character set is used.
37 /// * `use_space_char` - Whether to include a space character in the vocabulary.
38 ///
39 /// # Returns
40 /// A new `BaseRecLabelDecode` instance.
41 pub fn new(character_str: Option<&str>, use_space_char: bool) -> Self {
42 let mut character_list: Vec<char> = if let Some(chars) = character_str {
43 chars.chars().collect()
44 } else {
45 "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
46 };
47
48 if use_space_char {
49 character_list.push(' ');
50 }
51
52 character_list = Self::add_special_char(character_list);
53
54 let mut dict = HashMap::new();
55 for (i, &char) in character_list.iter().enumerate() {
56 dict.insert(char, i);
57 }
58
59 Self {
60 reverse: false,
61 dict,
62 character: character_list,
63 }
64 }
65
66 /// Creates a new `BaseRecLabelDecode` instance from a list of strings.
67 ///
68 /// # Arguments
69 /// * `character_list` - An optional slice of strings containing the character vocabulary.
70 /// Only the first character of each string is used. If None, a default alphanumeric
71 /// character set is used.
72 /// * `use_space_char` - Whether to include a space character in the vocabulary.
73 ///
74 /// # Returns
75 /// A new `BaseRecLabelDecode` instance.
76 pub fn from_string_list(character_list: Option<&[String]>, use_space_char: bool) -> Self {
77 let mut chars: Vec<char> = if let Some(list) = character_list {
78 list.iter().filter_map(|s| s.chars().next()).collect()
79 } else {
80 "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
81 };
82
83 if use_space_char {
84 chars.push(' ');
85 }
86
87 chars = Self::add_special_char(chars);
88
89 let mut dict = HashMap::new();
90 for (i, &char) in chars.iter().enumerate() {
91 dict.insert(char, i);
92 }
93
94 Self {
95 reverse: false,
96 dict,
97 character: chars,
98 }
99 }
100
101 /// Reverses the alphanumeric parts of a string while keeping non-alphanumeric parts in place.
102 ///
103 /// # Arguments
104 /// * `pred` - The input string to process.
105 ///
106 /// # Returns
107 /// A new string with alphanumeric parts reversed.
108 fn pred_reverse(&self, pred: &str) -> String {
109 let mut pred_re = Vec::new();
110 let mut c_current = String::new();
111
112 for c in pred.chars() {
113 if !ALPHANUMERIC_REGEX.is_match(&c.to_string()) {
114 if !c_current.is_empty() {
115 pred_re.push(c_current.clone());
116 c_current.clear();
117 }
118 pred_re.push(c.to_string());
119 } else {
120 c_current.push(c);
121 }
122 }
123
124 if !c_current.is_empty() {
125 pred_re.push(c_current);
126 }
127
128 pred_re.reverse();
129 pred_re.join("")
130 }
131
132 /// Adds special characters to the character list.
133 ///
134 /// This is a placeholder method that currently just returns the input list unchanged.
135 /// It can be overridden in subclasses to add special characters.
136 ///
137 /// # Arguments
138 /// * `character_list` - The input character list.
139 ///
140 /// # Returns
141 /// The character list with any special characters added.
142 fn add_special_char(character_list: Vec<char>) -> Vec<char> {
143 character_list
144 }
145
146 /// Gets a list of token indices that should be ignored during decoding.
147 ///
148 /// # Returns
149 /// A vector containing the indices of tokens to ignore.
150 fn get_ignored_tokens(&self) -> Vec<usize> {
151 vec![self.get_blank_idx()]
152 }
153
154 /// Decodes model predictions into text strings with confidence scores.
155 ///
156 /// # Arguments
157 /// * `text_index` - A slice of vectors containing the predicted character indices.
158 /// * `text_prob` - An optional slice of vectors containing the prediction probabilities.
159 /// * `is_remove_duplicate` - Whether to remove consecutive duplicate characters.
160 ///
161 /// # Returns
162 /// A vector of tuples, each containing a decoded text string and its confidence score.
163 pub fn decode(
164 &self,
165 text_index: &[Vec<usize>],
166 text_prob: Option<&[Vec<f32>]>,
167 is_remove_duplicate: bool,
168 ) -> Vec<(String, f32)> {
169 let mut result_list = Vec::new();
170 let ignored_tokens = self.get_ignored_tokens();
171
172 for (batch_idx, indices) in text_index.iter().enumerate() {
173 let mut selection = vec![true; indices.len()];
174
175 if is_remove_duplicate && indices.len() > 1 {
176 for i in 1..indices.len() {
177 if indices[i] == indices[i - 1] {
178 selection[i] = false;
179 }
180 }
181 }
182
183 for &ignored_token in &ignored_tokens {
184 for (i, &idx) in indices.iter().enumerate() {
185 if idx == ignored_token {
186 selection[i] = false;
187 }
188 }
189 }
190
191 let char_list: Vec<char> = indices
192 .iter()
193 .enumerate()
194 .filter(|(i, _)| selection[*i])
195 .filter_map(|(_, &text_id)| self.character.get(text_id).copied())
196 .collect();
197
198 let conf_list: Vec<f32> = if let Some(probs) = text_prob {
199 if batch_idx < probs.len() {
200 probs[batch_idx]
201 .iter()
202 .enumerate()
203 .filter(|(i, _)| *i < selection.len() && selection[*i])
204 .map(|(_, &prob)| prob)
205 .collect()
206 } else {
207 vec![1.0; char_list.len()]
208 }
209 } else {
210 vec![1.0; char_list.len()]
211 };
212
213 let conf_list = if conf_list.is_empty() {
214 vec![0.0]
215 } else {
216 conf_list
217 };
218
219 let mut text: String = char_list.iter().collect();
220
221 if self.reverse {
222 text = self.pred_reverse(&text);
223 }
224
225 let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
226 result_list.push((text, mean_conf));
227 }
228
229 result_list
230 }
231
232 /// Applies the decoder to a tensor of model predictions.
233 ///
234 /// # Arguments
235 /// * `pred` - A 3D tensor containing the model predictions.
236 ///
237 /// # Returns
238 /// A tuple containing:
239 /// * A vector of decoded text strings
240 /// * A vector of confidence scores for each text string
241 pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
242 if pred.is_empty() {
243 return (Vec::new(), Vec::new());
244 }
245
246 let batch_size = pred.shape()[0];
247 let mut all_texts = Vec::new();
248 let mut all_scores = Vec::new();
249
250 for batch_idx in 0..batch_size {
251 let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
252
253 let mut sequence_idx = Vec::new();
254 let mut sequence_prob = Vec::new();
255
256 for row in preds.outer_iter() {
257 if let Some((idx, &prob)) = row
258 .iter()
259 .enumerate()
260 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
261 {
262 sequence_idx.push(idx);
263 sequence_prob.push(prob);
264 } else {
265 sequence_idx.push(0);
266 sequence_prob.push(0.0);
267 }
268 }
269
270 let text = self.decode(&[sequence_idx], Some(&[sequence_prob]), true);
271
272 for (t, score) in text {
273 all_texts.push(t);
274 all_scores.push(score);
275 }
276 }
277
278 (all_texts, all_scores)
279 }
280
281 /// Gets the index of the blank token.
282 ///
283 /// # Returns
284 /// The index of the blank token (always 0 in this base implementation).
285 fn get_blank_idx(&self) -> usize {
286 0
287 }
288}
289
290/// A decoder for CTC (Connectionist Temporal Classification) based text recognition models.
291///
292/// This struct extends `BaseRecLabelDecode` to provide specialized decoding for CTC models,
293/// which include a blank token that needs to be handled specially during decoding.
294///
295/// # Fields
296/// * `base` - The base decoder that handles character mapping and basic decoding operations
297/// * `blank_index` - The index of the blank token in the character vocabulary
298pub struct CTCLabelDecode {
299 base: BaseRecLabelDecode,
300 blank_index: usize,
301}
302
303impl std::fmt::Debug for CTCLabelDecode {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 f.debug_struct("CTCLabelDecode")
306 .field("character_count", &self.base.character.len())
307 .field("reverse", &self.base.reverse)
308 .finish()
309 }
310}
311
312impl CTCLabelDecode {
313 /// Creates a new `CTCLabelDecode` instance.
314 ///
315 /// # Arguments
316 /// * `character_list` - An optional string containing the character vocabulary.
317 /// If None, a default alphanumeric character set is used.
318 /// * `use_space_char` - Whether to include a space character in the vocabulary.
319 ///
320 /// # Returns
321 /// A new `CTCLabelDecode` instance.
322 pub fn new(character_list: Option<&str>, use_space_char: bool) -> Self {
323 let mut base = BaseRecLabelDecode::new(character_list, use_space_char);
324
325 let mut new_character = vec![' '];
326 new_character.extend(base.character);
327
328 let mut new_dict = HashMap::new();
329 for (i, &char) in new_character.iter().enumerate() {
330 new_dict.insert(char, i);
331 }
332
333 base.character = new_character;
334 base.dict = new_dict;
335
336 let blank_index = 0;
337
338 Self { base, blank_index }
339 }
340
341 /// Creates a new `CTCLabelDecode` instance from a list of strings.
342 ///
343 /// # Arguments
344 /// * `character_list` - An optional slice of strings containing the character vocabulary.
345 /// Only the first character of each string is used. If None, a default alphanumeric
346 /// character set is used.
347 /// * `use_space_char` - Whether to include a space character in the vocabulary.
348 /// * `has_explicit_blank` - Whether the character list already includes a blank token.
349 ///
350 /// # Returns
351 /// A new `CTCLabelDecode` instance.
352 pub fn from_string_list(
353 character_list: Option<&[String]>,
354 use_space_char: bool,
355 has_explicit_blank: bool,
356 ) -> Self {
357 if has_explicit_blank {
358 let base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
359 Self {
360 base,
361 blank_index: 0,
362 }
363 } else {
364 let mut base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
365
366 let mut new_character = vec![' '];
367 new_character.extend(base.character);
368
369 let mut new_dict = HashMap::new();
370 for (i, &char) in new_character.iter().enumerate() {
371 new_dict.insert(char, i);
372 }
373
374 base.character = new_character;
375 base.dict = new_dict;
376
377 Self {
378 base,
379 blank_index: 0,
380 }
381 }
382 }
383
384 /// Gets the index of the blank token.
385 ///
386 /// # Returns
387 /// The index of the blank token.
388 pub fn get_blank_index(&self) -> usize {
389 self.blank_index
390 }
391
392 /// Gets the character list used by this decoder.
393 ///
394 /// # Returns
395 /// A slice containing the characters in the vocabulary.
396 pub fn get_character_list(&self) -> &[char] {
397 &self.base.character
398 }
399
400 /// Gets the number of characters in the vocabulary.
401 ///
402 /// # Returns
403 /// The number of characters in the vocabulary.
404 pub fn get_character_count(&self) -> usize {
405 self.base.character.len()
406 }
407
408 /// Applies the CTC decoder to a tensor of model predictions.
409 ///
410 /// This method handles the special requirements of CTC decoding:
411 /// 1. Removing blank tokens
412 /// 2. Removing consecutive duplicate characters
413 /// 3. Converting indices to characters
414 /// 4. Calculating confidence scores
415 ///
416 /// # Arguments
417 /// * `pred` - A 3D tensor containing the model predictions.
418 ///
419 /// # Returns
420 /// A tuple containing:
421 /// * A vector of decoded text strings
422 /// * A vector of confidence scores for each text string
423 pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
424 if pred.is_empty() {
425 return (Vec::new(), Vec::new());
426 }
427
428 let batch_size = pred.shape()[0];
429 let mut all_texts = Vec::new();
430 let mut all_scores = Vec::new();
431
432 for batch_idx in 0..batch_size {
433 let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
434
435 let mut sequence_idx = Vec::new();
436 let mut sequence_prob = Vec::new();
437
438 for row in preds.outer_iter() {
439 if let Some((idx, &prob)) = row
440 .iter()
441 .enumerate()
442 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
443 {
444 sequence_idx.push(idx);
445 sequence_prob.push(prob);
446 } else {
447 sequence_idx.push(self.blank_index);
448 sequence_prob.push(0.0);
449 }
450 }
451
452 let mut filtered_idx = Vec::new();
453 let mut filtered_prob = Vec::new();
454 let mut selection = vec![true; sequence_idx.len()];
455
456 if sequence_idx.len() > 1 {
457 for i in 1..sequence_idx.len() {
458 if sequence_idx[i] == sequence_idx[i - 1] {
459 selection[i] = false;
460 }
461 }
462 }
463
464 for (i, &idx) in sequence_idx.iter().enumerate() {
465 if idx == self.blank_index {
466 selection[i] = false;
467 }
468 }
469
470 for (i, &idx) in sequence_idx.iter().enumerate() {
471 if selection[i] {
472 filtered_idx.push(idx);
473 filtered_prob.push(sequence_prob[i]);
474 }
475 }
476
477 let char_list: Vec<char> = filtered_idx
478 .iter()
479 .filter_map(|&text_id| self.base.character.get(text_id).copied())
480 .collect();
481
482 let conf_list = if filtered_prob.is_empty() {
483 vec![0.0]
484 } else {
485 filtered_prob
486 };
487
488 let text: String = char_list.iter().collect();
489 let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
490
491 all_texts.push(text);
492 all_scores.push(mean_conf);
493 }
494
495 (all_texts, all_scores)
496 }
497}