1use once_cell::sync::Lazy;
9use regex::Regex;
10use std::collections::HashMap;
11
12pub type PositionedDecodeResult = (
14 Vec<String>,
15 Vec<f32>,
16 Vec<Vec<f32>>,
17 Vec<Vec<usize>>,
18 Vec<usize>,
19);
20
21static ALPHANUMERIC_REGEX: Lazy<Regex> =
22 Lazy::new(|| Regex::new(r"[a-zA-Z0-9 :*./%+-]").expect("Failed to compile regex pattern"));
23
24pub struct BaseRecLabelDecode {
35 reverse: bool,
36 dict: HashMap<char, usize>,
37 character: Vec<char>,
38}
39
40impl BaseRecLabelDecode {
41 pub fn new(character_str: Option<&str>, use_space_char: bool) -> Self {
51 let mut character_list: Vec<char> = if let Some(chars) = character_str {
52 chars.chars().collect()
53 } else {
54 "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
55 };
56
57 if use_space_char {
58 character_list.push(' ');
59 }
60
61 character_list = Self::add_special_char(character_list);
62
63 let mut dict = HashMap::new();
64 for (i, &char) in character_list.iter().enumerate() {
65 dict.insert(char, i);
66 }
67
68 Self {
69 reverse: false,
70 dict,
71 character: character_list,
72 }
73 }
74
75 pub fn from_string_list(character_list: Option<&[String]>, use_space_char: bool) -> Self {
86 let mut chars: Vec<char> = if let Some(list) = character_list {
87 list.iter().filter_map(|s| s.chars().next()).collect()
88 } else {
89 "0123456789abcdefghijklmnopqrstuvwxyz".chars().collect()
90 };
91
92 if use_space_char {
93 chars.push(' ');
94 }
95
96 chars = Self::add_special_char(chars);
97
98 let mut dict = HashMap::new();
99 for (i, &char) in chars.iter().enumerate() {
100 dict.insert(char, i);
101 }
102
103 Self {
104 reverse: false,
105 dict,
106 character: chars,
107 }
108 }
109
110 fn pred_reverse(&self, pred: &str) -> String {
118 let mut pred_re = Vec::new();
119 let mut c_current = String::new();
120
121 for c in pred.chars() {
122 if !ALPHANUMERIC_REGEX.is_match(&c.to_string()) {
123 if !c_current.is_empty() {
124 pred_re.push(c_current.clone());
125 c_current.clear();
126 }
127 pred_re.push(c.to_string());
128 } else {
129 c_current.push(c);
130 }
131 }
132
133 if !c_current.is_empty() {
134 pred_re.push(c_current);
135 }
136
137 pred_re.reverse();
138 pred_re.join("")
139 }
140
141 fn add_special_char(character_list: Vec<char>) -> Vec<char> {
152 character_list
153 }
154
155 fn get_ignored_tokens(&self) -> Vec<usize> {
160 vec![self.get_blank_idx()]
161 }
162
163 pub fn decode(
173 &self,
174 text_index: &[Vec<usize>],
175 text_prob: Option<&[Vec<f32>]>,
176 is_remove_duplicate: bool,
177 ) -> Vec<(String, f32)> {
178 let mut result_list = Vec::new();
179 let ignored_tokens = self.get_ignored_tokens();
180
181 for (batch_idx, indices) in text_index.iter().enumerate() {
182 let mut selection = vec![true; indices.len()];
183
184 if is_remove_duplicate && indices.len() > 1 {
185 for i in 1..indices.len() {
186 if indices[i] == indices[i - 1] {
187 selection[i] = false;
188 }
189 }
190 }
191
192 for &ignored_token in &ignored_tokens {
193 for (i, &idx) in indices.iter().enumerate() {
194 if idx == ignored_token {
195 selection[i] = false;
196 }
197 }
198 }
199
200 let char_list: Vec<char> = indices
201 .iter()
202 .enumerate()
203 .filter(|(i, _)| selection[*i])
204 .filter_map(|(_, &text_id)| self.character.get(text_id).copied())
205 .collect();
206
207 let conf_list: Vec<f32> = if let Some(probs) = text_prob {
208 if batch_idx < probs.len() {
209 probs[batch_idx]
210 .iter()
211 .enumerate()
212 .filter(|(i, _)| *i < selection.len() && selection[*i])
213 .map(|(_, &prob)| prob)
214 .collect()
215 } else {
216 vec![1.0; char_list.len()]
217 }
218 } else {
219 vec![1.0; char_list.len()]
220 };
221
222 let conf_list = if conf_list.is_empty() {
223 vec![0.0]
224 } else {
225 conf_list
226 };
227
228 let mut text: String = char_list.iter().collect();
229
230 if self.reverse {
231 text = self.pred_reverse(&text);
232 }
233
234 let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
235 result_list.push((text, mean_conf));
236 }
237
238 result_list
239 }
240
241 pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
251 if pred.is_empty() {
252 return (Vec::new(), Vec::new());
253 }
254
255 let batch_size = pred.shape()[0];
256 let mut all_texts = Vec::new();
257 let mut all_scores = Vec::new();
258
259 for batch_idx in 0..batch_size {
260 let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
261
262 let mut sequence_idx = Vec::new();
263 let mut sequence_prob = Vec::new();
264
265 for row in preds.outer_iter() {
266 if let Some((idx, &prob)) = row
267 .iter()
268 .enumerate()
269 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
270 {
271 sequence_idx.push(idx);
272 sequence_prob.push(prob);
273 } else {
274 sequence_idx.push(0);
275 sequence_prob.push(0.0);
276 }
277 }
278
279 let text = self.decode(&[sequence_idx], Some(&[sequence_prob]), true);
280
281 for (t, score) in text {
282 all_texts.push(t);
283 all_scores.push(score);
284 }
285 }
286
287 (all_texts, all_scores)
288 }
289
290 fn get_blank_idx(&self) -> usize {
295 0
296 }
297}
298
299pub struct CTCLabelDecode {
308 base: BaseRecLabelDecode,
309 blank_index: usize,
310}
311
312impl std::fmt::Debug for CTCLabelDecode {
313 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
314 f.debug_struct("CTCLabelDecode")
315 .field("character_count", &self.base.character.len())
316 .field("reverse", &self.base.reverse)
317 .finish()
318 }
319}
320
321impl CTCLabelDecode {
322 pub fn new(character_list: Option<&str>, use_space_char: bool) -> Self {
332 let mut base = BaseRecLabelDecode::new(character_list, use_space_char);
333
334 let mut new_character = vec!['\0'];
336 new_character.extend(base.character);
337
338 let mut new_dict = HashMap::new();
339 for (i, &char) in new_character.iter().enumerate() {
340 new_dict.insert(char, i);
341 }
342
343 base.character = new_character;
344 base.dict = new_dict;
345
346 let blank_index = 0;
347
348 Self { base, blank_index }
349 }
350
351 pub fn from_string_list(
363 character_list: Option<&[String]>,
364 use_space_char: bool,
365 has_explicit_blank: bool,
366 ) -> Self {
367 if has_explicit_blank {
368 let base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
369 Self {
370 base,
371 blank_index: 0,
372 }
373 } else {
374 let mut base = BaseRecLabelDecode::from_string_list(character_list, use_space_char);
375
376 let mut new_character = vec!['\0'];
378 new_character.extend(base.character);
379
380 let mut new_dict = HashMap::new();
381 for (i, &char) in new_character.iter().enumerate() {
382 new_dict.insert(char, i);
383 }
384
385 base.character = new_character;
386 base.dict = new_dict;
387
388 Self {
389 base,
390 blank_index: 0,
391 }
392 }
393 }
394
395 pub fn get_blank_index(&self) -> usize {
400 self.blank_index
401 }
402
403 pub fn get_character_list(&self) -> &[char] {
408 &self.base.character
409 }
410
411 pub fn get_character_count(&self) -> usize {
416 self.base.character.len()
417 }
418
419 pub fn apply_with_positions(&self, pred: &crate::core::Tensor3D) -> PositionedDecodeResult {
435 if pred.is_empty() {
436 return (Vec::new(), Vec::new(), Vec::new(), Vec::new(), Vec::new());
437 }
438
439 let batch_size = pred.shape()[0];
440 let mut all_texts = Vec::new();
441 let mut all_scores = Vec::new();
442 let mut all_positions = Vec::new();
443 let mut all_col_indices = Vec::new();
444 let mut all_seq_lengths = Vec::new();
445
446 for batch_idx in 0..batch_size {
447 let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
448 let seq_len = preds.shape()[0] as f32;
449
450 let mut sequence_idx = Vec::new();
451 let mut sequence_prob = Vec::new();
452 let mut sequence_timesteps = Vec::new();
453
454 for (timestep, row) in preds.outer_iter().enumerate() {
455 if let Some((idx, &prob)) = row
456 .iter()
457 .enumerate()
458 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
459 {
460 sequence_idx.push(idx);
461 sequence_prob.push(prob);
462 sequence_timesteps.push(timestep);
463 } else {
464 sequence_idx.push(self.blank_index);
465 sequence_prob.push(0.0);
466 sequence_timesteps.push(timestep);
467 }
468 }
469
470 let mut filtered_idx = Vec::new();
471 let mut filtered_prob = Vec::new();
472 let mut filtered_timesteps = Vec::new();
473 let mut selection = vec![true; sequence_idx.len()];
474
475 if sequence_idx.len() > 1 {
477 for i in 1..sequence_idx.len() {
478 if sequence_idx[i] == sequence_idx[i - 1] {
479 selection[i] = false;
480 }
481 }
482 }
483
484 for (i, &idx) in sequence_idx.iter().enumerate() {
486 if idx == self.blank_index {
487 selection[i] = false;
488 }
489 }
490
491 for (i, &idx) in sequence_idx.iter().enumerate() {
493 if selection[i] {
494 filtered_idx.push(idx);
495 filtered_prob.push(sequence_prob[i]);
496 filtered_timesteps.push(sequence_timesteps[i]);
497 }
498 }
499
500 let char_list: Vec<char> = filtered_idx
501 .iter()
502 .filter_map(|&text_id| self.base.character.get(text_id).copied())
503 .collect();
504
505 let conf_list = if filtered_prob.is_empty() {
506 vec![0.0]
507 } else {
508 filtered_prob
509 };
510
511 let char_positions: Vec<f32> = filtered_timesteps
513 .iter()
514 .map(|×tep| timestep as f32 / seq_len)
515 .collect();
516
517 let col_indices: Vec<usize> = filtered_timesteps.clone();
519
520 let text: String = char_list.iter().collect();
521 let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
522
523 all_texts.push(text);
524 all_scores.push(mean_conf);
525 all_positions.push(char_positions);
526 all_col_indices.push(col_indices);
527 all_seq_lengths.push(seq_len as usize);
528 }
529
530 (
531 all_texts,
532 all_scores,
533 all_positions,
534 all_col_indices,
535 all_seq_lengths,
536 )
537 }
538
539 pub fn apply(&self, pred: &crate::core::Tensor3D) -> (Vec<String>, Vec<f32>) {
555 if pred.is_empty() {
556 return (Vec::new(), Vec::new());
557 }
558
559 let batch_size = pred.shape()[0];
560 let mut all_texts = Vec::new();
561 let mut all_scores = Vec::new();
562 let mut batches_with_text = 0;
563
564 for batch_idx in 0..batch_size {
565 let preds = pred.index_axis(ndarray::Axis(0), batch_idx);
566
567 let mut sequence_idx = Vec::new();
568 let mut sequence_prob = Vec::new();
569
570 for row in preds.outer_iter() {
571 if let Some((idx, &prob)) = row
572 .iter()
573 .enumerate()
574 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
575 {
576 sequence_idx.push(idx);
577 sequence_prob.push(prob);
578 } else {
579 sequence_idx.push(self.blank_index);
580 sequence_prob.push(0.0);
581 }
582 }
583
584 let mut filtered_idx = Vec::new();
585 let mut filtered_prob = Vec::new();
586 let mut selection = vec![true; sequence_idx.len()];
587
588 if sequence_idx.len() > 1 {
589 for i in 1..sequence_idx.len() {
590 if sequence_idx[i] == sequence_idx[i - 1] {
591 selection[i] = false;
592 }
593 }
594 }
595
596 for (i, &idx) in sequence_idx.iter().enumerate() {
597 if idx == self.blank_index {
598 selection[i] = false;
599 }
600 }
601
602 for (i, &idx) in sequence_idx.iter().enumerate() {
603 if selection[i] {
604 filtered_idx.push(idx);
605 filtered_prob.push(sequence_prob[i]);
606 }
607 }
608
609 let char_list: Vec<char> = filtered_idx
610 .iter()
611 .filter_map(|&text_id| self.base.character.get(text_id).copied())
612 .collect();
613
614 let conf_list = if filtered_prob.is_empty() {
615 vec![0.0]
616 } else {
617 filtered_prob
618 };
619
620 let text: String = char_list.iter().collect();
621 let mean_conf = conf_list.iter().sum::<f32>() / conf_list.len() as f32;
622
623 if !text.is_empty() {
624 batches_with_text += 1;
625 }
626
627 all_texts.push(text);
628 all_scores.push(mean_conf);
629 }
630
631 tracing::debug!(
633 "CTC decode summary: batch_size={}, batches_with_text={}, empty_batches={}",
634 batch_size,
635 batches_with_text,
636 batch_size - batches_with_text
637 );
638
639 (all_texts, all_scores)
640 }
641}