1use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::HashMap;
11
12pub type PositionedDecodeResult = (
14 Vec<String>,
15 Vec<f32>,
16 Vec<Vec<f32>>,
17 Vec<Vec<usize>>,
18 Vec<usize>,
19);
20
21static ALPHANUMERIC_REGEX: Lazy<Regex> = Lazy::new(|| {
22 Regex::new(r"[a-zA-Z0-9 :*./%+-]")
23 .unwrap_or_else(|e| panic!("Failed to compile regex pattern: {e}"))
24});
25
26pub struct BaseRecLabelDecode {
37 reverse: bool,
38 dict: HashMap<char, usize>,
39 character: Vec<char>,
40}
41
42impl BaseRecLabelDecode {
43 pub fn new(character_str: Option<&str>, use_space_char: bool) -> Self {
53 let mut character_list: Vec<char> = if let Some(chars) = character_str {
54 chars.chars().collect()
55 } else {
56 "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
57 };
58
59 if use_space_char {
60 character_list.push(' ');
61 }
62
63 character_list = Self::add_special_char(character_list);
64
65 let mut dict = HashMap::new();
66 for (i, &char) in character_list.iter().enumerate() {
67 dict.insert(char, i);
68 }
69
70 Self {
71 reverse: false,
72 dict,
73 character: character_list,
74 }
75 }
76
77 pub fn from_string_list(character_list: Option<&[String]>, use_space_char: bool) -> Self {
88 let mut chars: Vec<char> = if let Some(list) = character_list {
89 list.iter().filter_map(|s| s.chars().next()).collect()
90 } else {
91 "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
92 };
93
94 if use_space_char {
95 chars.push(' ');
96 }
97
98 chars = Self::add_special_char(chars);
99
100 let mut dict = HashMap::new();
101 for (i, &char) in chars.iter().enumerate() {
102 dict.insert(char, i);
103 }
104
105 Self {
106 reverse: false,
107 dict,
108 character: chars,
109 }
110 }
111
112 fn pred_reverse(&self, pred: &str) -> String {
120 let mut pred_re = Vec::new();
121 let mut c_current = String::new();
122
123 for c in pred.chars() {
124 if !ALPHANUMERIC_REGEX.is_match(&c.to_string()) {
125 if !c_current.is_empty() {
126 pred_re.push(c_current.clone());
127 c_current.clear();
128 }
129 pred_re.push(c.to_string());
130 } else {
131 c_current.push(c);
132 }
133 }
134
135 if !c_current.is_empty() {
136 pred_re.push(c_current);
137 }
138
139 pred_re.reverse();
140 pred_re.join("")
141 }
142
143 fn add_special_char(character_list: Vec<char>) -> Vec<char> {
154 character_list
155 }
156
157 fn get_ignored_tokens(&self) -> Vec<usize> {
162 vec![self.get_blank_idx()]
163 }
164
165 pub fn decode(
175 &self,
176 text_index: &[Vec<usize>],
177 text_prob: Option<&[Vec<f32>]>,
178 is_remove_duplicate: bool,
179 ) -> Vec<(String, f32)> {
180 let mut result_list = Vec::new();
181 let ignored_tokens = self.get_ignored_tokens();
182
183 for (batch_idx, indices) in text_index.iter().enumerate() {
184 let mut selection = vec![true; indices.len()];
185
186 if is_remove_duplicate && indices.len() > 1 {
187 for i in 1..indices.len() {
188 if indices[i] == indices[i - 1] {
189 selection[i] = false;
190 }
191 }
192 }
193
194 for &ignored_token in &ignored_tokens {
195 for (i, &idx) in indices.iter().enumerate() {
196 if idx == ignored_token {
197 selection[i] = false;
198 }
199 }
200 }
201
202 let char_list: Vec<char> = indices
203 .iter()
204 .enumerate()
205 .filter(|(i, _)| selection[*i])
206 .filter_map(|(_, &text_id)| self.character.get(text_id).copied())
207 .collect();
208
209 let conf_list: Vec<f32> = if let Some(probs) = text_prob {
210 if batch_idx < probs.len() {
211 probs[batch_idx]
212 .iter()
213 .enumerate()
214 .filter(|(i, _)| *i < selection.len() && selection[*i])
215 .map(|(_, &prob)| prob)
216 .collect()
217 } else {
218 vec![1.0; char_list.len()]
219 }
220 } else {
221 vec![1.0; char_list.len()]
222 };
223
224 let conf_list = if conf_list.is_empty() {
225 vec![0.0]
226 } else {
227 conf_list
228 };
229
230 let mut text: String = char_list.iter().collect();
231
232 if self.reverse {
233 text = self.pred_reverse(&text);
234 }
235
236 let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
237 result_list.push((text, mean_conf));
238 }
239
240 result_list
241 }
242
243 pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
253 if pred.is_empty() {
254 return (Vec::new(), Vec::new());
255 }
256
257 let batch_size = pred.shape()[0];
258 let mut all_texts = Vec::new();
259 let mut all_scores = Vec::new();
260
261 for batch_idx in 0..batch_size {
262 let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
263
264 let mut sequence_idx = Vec::new();
265 let mut sequence_prob = Vec::new();
266
267 for row in preds.outer_iter() {
268 if let Some((idx, &prob)) = row
269 .iter()
270 .enumerate()
271 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
272 {
273 sequence_idx.push(idx);
274 sequence_prob.push(prob);
275 } else {
276 sequence_idx.push(0);
277 sequence_prob.push(0.0);
278 }
279 }
280
281 let text = self.decode(&[sequence_idx], Some(&[sequence_prob]), true);
282
283 for (t, score) in text {
284 all_texts.push(t);
285 all_scores.push(score);
286 }
287 }
288
289 (all_texts, all_scores)
290 }
291
292 fn get_blank_idx(&self) -> usize {
297 0
298 }
299}
300
301pub struct CTCLabelDecode {
310 base: BaseRecLabelDecode,
311 blank_index: usize,
312}
313
314impl std::fmt::Debug for CTCLabelDecode {
315 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
316 f.debug_struct("CTCLabelDecode")
317 .field("character_count", &self.base.character.len())
318 .field("reverse", &self.base.reverse)
319 .finish()
320 }
321}
322
323impl CTCLabelDecode {
324 pub fn new(character_list: Option<&str>, use_space_char: bool) -> Self {
334 let mut base = BaseRecLabelDecode::new(character_list, use_space_char);
335
336 let mut new_character = vec!['\0'];
338 new_character.extend(base.character);
339
340 let mut new_dict = HashMap::new();
341 for (i, &char) in new_character.iter().enumerate() {
342 new_dict.insert(char, i);
343 }
344
345 base.character = new_character;
346 base.dict = new_dict;
347
348 let blank_index = 0;
349
350 Self { base, blank_index }
351 }
352
353 pub fn from_string_list(
365 character_list: Option<&[String]>,
366 use_space_char: bool,
367 has_explicit_blank: bool,
368 ) -> Self {
369 if has_explicit_blank {
370 let base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
371 Self {
372 base,
373 blank_index: 0,
374 }
375 } else {
376 let mut base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
377
378 let mut new_character = vec!['\0'];
380 new_character.extend(base.character);
381
382 let mut new_dict = HashMap::new();
383 for (i, &char) in new_character.iter().enumerate() {
384 new_dict.insert(char, i);
385 }
386
387 base.character = new_character;
388 base.dict = new_dict;
389
390 Self {
391 base,
392 blank_index: 0,
393 }
394 }
395 }
396
397 pub fn get_blank_index(&self) -> usize {
402 self.blank_index
403 }
404
405 pub fn get_character_list(&self) -> &[char] {
410 &self.base.character
411 }
412
413 pub fn get_character_count(&self) -> usize {
418 self.base.character.len()
419 }
420
421 pub fn apply_with_positions(&self, pred: &crate::core::Tensor3D) -> PositionedDecodeResult {
437 if pred.is_empty() {
438 return (Vec::new(), Vec::new(), Vec::new(), Vec::new(), Vec::new());
439 }
440
441 let batch_size = pred.shape()[0];
442 let mut all_texts = Vec::new();
443 let mut all_scores = Vec::new();
444 let mut all_positions = Vec::new();
445 let mut all_col_indices = Vec::new();
446 let mut all_seq_lengths = Vec::new();
447
448 for batch_idx in 0..batch_size {
449 let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
450 let seq_len = preds.shape()[0] as f32;
451
452 let mut sequence_idx = Vec::new();
453 let mut sequence_prob = Vec::new();
454 let mut sequence_timesteps = Vec::new();
455
456 for (timestep, row) in preds.outer_iter().enumerate() {
457 if let Some((idx, &prob)) = row
458 .iter()
459 .enumerate()
460 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
461 {
462 sequence_idx.push(idx);
463 sequence_prob.push(prob);
464 sequence_timesteps.push(timestep);
465 } else {
466 sequence_idx.push(self.blank_index);
467 sequence_prob.push(0.0);
468 sequence_timesteps.push(timestep);
469 }
470 }
471
472 let mut filtered_idx = Vec::new();
473 let mut filtered_prob = Vec::new();
474 let mut filtered_timesteps = Vec::new();
475 let mut selection = vec![true; sequence_idx.len()];
476
477 if sequence_idx.len() > 1 {
479 for i in 1..sequence_idx.len() {
480 if sequence_idx[i] == sequence_idx[i - 1] {
481 selection[i] = false;
482 }
483 }
484 }
485
486 for (i, &idx) in sequence_idx.iter().enumerate() {
488 if idx == self.blank_index {
489 selection[i] = false;
490 }
491 }
492
493 for (i, &idx) in sequence_idx.iter().enumerate() {
495 if selection[i] {
496 filtered_idx.push(idx);
497 filtered_prob.push(sequence_prob[i]);
498 filtered_timesteps.push(sequence_timesteps[i]);
499 }
500 }
501
502 let char_list: Vec<char> = filtered_idx
503 .iter()
504 .filter_map(|&text_id| self.base.character.get(text_id).copied())
505 .collect();
506
507 let conf_list = if filtered_prob.is_empty() {
508 vec![0.0]
509 } else {
510 filtered_prob
511 };
512
513 let char_positions: Vec<f32> = filtered_timesteps
515 .iter()
516 .map(|×tep| timestep as f32 / seq_len)
517 .collect();
518
519 let col_indices: Vec<usize> = filtered_timesteps.clone();
521
522 let text: String = char_list.iter().collect();
523 let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
524
525 all_texts.push(text);
526 all_scores.push(mean_conf);
527 all_positions.push(char_positions);
528 all_col_indices.push(col_indices);
529 all_seq_lengths.push(seq_len as usize);
530 }
531
532 (
533 all_texts,
534 all_scores,
535 all_positions,
536 all_col_indices,
537 all_seq_lengths,
538 )
539 }
540
541 pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
557 if pred.is_empty() {
558 return (Vec::new(), Vec::new());
559 }
560
561 let batch_size = pred.shape()[0];
562 let mut all_texts = Vec::new();
563 let mut all_scores = Vec::new();
564 let mut batches_with_text = 0;
565
566 for batch_idx in 0..batch_size {
567 let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
568
569 let mut sequence_idx = Vec::new();
570 let mut sequence_prob = Vec::new();
571
572 for row in preds.outer_iter() {
573 if let Some((idx, &prob)) = row
574 .iter()
575 .enumerate()
576 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
577 {
578 sequence_idx.push(idx);
579 sequence_prob.push(prob);
580 } else {
581 sequence_idx.push(self.blank_index);
582 sequence_prob.push(0.0);
583 }
584 }
585
586 let mut filtered_idx = Vec::new();
587 let mut filtered_prob = Vec::new();
588 let mut selection = vec![true; sequence_idx.len()];
589
590 if sequence_idx.len() > 1 {
591 for i in 1..sequence_idx.len() {
592 if sequence_idx[i] == sequence_idx[i - 1] {
593 selection[i] = false;
594 }
595 }
596 }
597
598 for (i, &idx) in sequence_idx.iter().enumerate() {
599 if idx == self.blank_index {
600 selection[i] = false;
601 }
602 }
603
604 for (i, &idx) in sequence_idx.iter().enumerate() {
605 if selection[i] {
606 filtered_idx.push(idx);
607 filtered_prob.push(sequence_prob[i]);
608 }
609 }
610
611 let char_list: Vec<char> = filtered_idx
612 .iter()
613 .filter_map(|&text_id| self.base.character.get(text_id).copied())
614 .collect();
615
616 let conf_list = if filtered_prob.is_empty() {
617 vec![0.0]
618 } else {
619 filtered_prob
620 };
621
622 let text: String = char_list.iter().collect();
623 let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
624
625 if !text.is_empty() {
626 batches_with_text += 1;
627 }
628
629 all_texts.push(text);
630 all_scores.push(mean_conf);
631 }
632
633 tracing::debug!(
635 "CTC decode summary: batch_size={}, batches_with_text={}, empty_batches={}",
636 batch_size,
637 batches_with_text,
638 batch_size - batches_with_text
639 );
640
641 (all_texts, all_scores)
642 }
643}