1use crate::decoders::wordpiece;
2use crate::tokenizer::{Decoder, Result};
3
4use itertools::Itertools;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(tag = "type")]
13#[non_exhaustive]
14pub struct CTC {
15 pub pad_token: String,
17 pub word_delimiter_token: String,
19 pub cleanup: bool,
22}
23
24impl CTC {
25 pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> Self {
26 Self {
27 pad_token,
28 word_delimiter_token,
29 cleanup,
30 }
31 }
32}
33
34impl Default for CTC {
35 fn default() -> Self {
36 Self {
37 pad_token: "<pad>".to_string(),
38 word_delimiter_token: "|".to_string(),
39 cleanup: true,
40 }
41 }
42}
43
44impl Decoder for CTC {
45 fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
46 Ok(tokens
47 .into_iter()
48 .dedup()
49 .filter_map(|token| {
50 let mut replaced = token.replace(&self.pad_token, "");
51 if self.cleanup {
52 replaced =
53 wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " ");
54 }
55 if replaced.is_empty() {
56 None
57 } else {
58 Some(replaced)
59 }
60 })
61 .collect())
62 }
63}
64
65#[cfg(test)]
66mod tests {
67 use super::*;
68 #[test]
69 fn handmade_sample() {
70 let ctc_decoder = CTC::default();
71 let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad>"
72 .split(' ')
73 .map(|s| s.to_string())
74 .collect();
75 assert_eq!(
76 ctc_decoder.decode_chain(id_to_string_result).unwrap(),
77 vec!["h", "e", "l", "l", "o"]
78 );
79 }
80 #[test]
81 fn handmade_with_delimiter_sample() {
82 let ctc_decoder = CTC::default();
83 let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad> <pad> | <pad> w o o o r <pad> <pad> l l d <pad> <pad> <pad> <pad>"
84 .split(' ')
85 .map(|s| s.to_string())
86 .collect();
87 assert_eq!(
88 ctc_decoder.decode_chain(id_to_string_result).unwrap(),
89 vec!["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"]
90 );
91 }
92 #[test]
93 fn librispeech_sample() {
94 let ctc_decoder = CTC::default();
95 let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> A | | <pad> M <pad> <pad> <pad> <pad> A <pad> <pad> N <pad> <pad> <pad> | | | <pad> <pad> <pad> <pad> S <pad> <pad> <pad> A I <pad> D D | | T T <pad> O <pad> | | T H E E | | | <pad> U U <pad> N N <pad> I <pad> <pad> V <pad> <pad> <pad> E R R <pad> <pad> <pad> S E E | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> S S <pad> <pad> <pad> <pad> I <pad> R R <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> <pad> <pad> | <pad> <pad> <pad> E X <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> S <pad> <pad> T <pad> <pad> | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
96 assert_eq!(
97 ctc_decoder.decode_chain(id_to_string_result).unwrap(),
98 vec![
99 "A", " ", "M", "A", "N", " ", "S", "A", "I", "D", " ", "T", "O", " ", "T", "H",
100 "E", " ", "U", "N", "I", "V", "E", "R", "S", "E", " ", "S", "I", "R", " ", "I",
101 " ", "E", "X", "I", "S", "T", " "
102 ]
103 );
104 }
105 #[test]
106 fn another_librispeech_sample() {
107 let ctc_decoder = CTC::default();
108 let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> I <pad> S S | | <pad> <pad> <pad> I N <pad> <pad> S <pad> T T <pad> <pad> A N C C T <pad> | | | | | <pad> <pad> <pad> <pad> P <pad> <pad> <pad> <pad> A <pad> <pad> N N N <pad> <pad> I <pad> C <pad> <pad> | | <pad> W <pad> <pad> A S <pad> | | <pad> <pad> <pad> F <pad> <pad> O L <pad> <pad> L L O O W E E D | | <pad> B <pad> <pad> <pad> Y <pad> | | | A | | <pad> S S S <pad> M M <pad> <pad> <pad> A L L <pad> <pad> <pad> <pad> L <pad> | | | <pad> <pad> <pad> <pad> S H H <pad> <pad> <pad> <pad> A R R <pad> <pad> P <pad> <pad> | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> B <pad> <pad> L L <pad> <pad> <pad> <pad> <pad> O W W <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> <pad> <pad> <pad> <pad> <pad> <pad> I G H H | | <pad> <pad> O N <pad> | | H <pad> I S S | | <pad> <pad> C H H <pad> <pad> <pad> E <pad> S S <pad> T T <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
109 assert_eq!(
110 ctc_decoder.decode_chain(id_to_string_result).unwrap(),
111 vec![
112 "H", "I", "S", " ", "I", "N", "S", "T", "A", "N", "C", "T", " ", "P", "A", "N",
113 "I", "C", " ", "W", "A", "S", " ", "F", "O", "L", "L", "O", "W", "E", "D", " ",
114 "B", "Y", " ", "A", " ", "S", "M", "A", "L", "L", " ", "S", "H", "A", "R", "P",
115 " ", "B", "L", "O", "W", " ", "H", "I", "G", "H", " ", "O", "N", " ", "H", "I",
116 "S", " ", "C", "H", "E", "S", "T", " "
117 ]
118 );
119 }
120}