tokenizers/decoders/
ctc.rs

1use crate::decoders::wordpiece;
2use crate::tokenizer::{Decoder, Result};
3
4use itertools::Itertools;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8/// The CTC (Connectionist Temporal Classification) decoder takes care
9/// of sanitizing a list of inputs token.
10/// Due to some alignement problem the output of some models can come
11/// with duplicated token.
12#[serde(tag = "type")]
13#[non_exhaustive]
14pub struct CTC {
15    /// The pad token used by CTC to delimit a new token.
16    pub pad_token: String,
17    /// The word delimiter token. It will be replaced by a `<space>`.
18    pub word_delimiter_token: String,
19    /// Whether to cleanup some tokenization artifacts.
20    /// Mainly spaces before punctuation, and some abbreviated english forms.
21    pub cleanup: bool,
22}
23
24impl CTC {
25    pub fn new(pad_token: String, word_delimiter_token: String, cleanup: bool) -> Self {
26        Self {
27            pad_token,
28            word_delimiter_token,
29            cleanup,
30        }
31    }
32}
33
34impl Default for CTC {
35    fn default() -> Self {
36        Self {
37            pad_token: "<pad>".to_string(),
38            word_delimiter_token: "|".to_string(),
39            cleanup: true,
40        }
41    }
42}
43
44impl Decoder for CTC {
45    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
46        Ok(tokens
47            .into_iter()
48            .dedup()
49            .filter_map(|token| {
50                let mut replaced = token.replace(&self.pad_token, "");
51                if self.cleanup {
52                    replaced =
53                        wordpiece::cleanup(&replaced).replace(&self.word_delimiter_token, " ");
54                }
55                if replaced.is_empty() {
56                    None
57                } else {
58                    Some(replaced)
59                }
60            })
61            .collect())
62    }
63}
64
65#[cfg(test)]
66mod tests {
67    use super::*;
68    #[test]
69    fn handmade_sample() {
70        let ctc_decoder = CTC::default();
71        let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad>"
72            .split(' ')
73            .map(|s| s.to_string())
74            .collect();
75        assert_eq!(
76            ctc_decoder.decode_chain(id_to_string_result).unwrap(),
77            vec!["h", "e", "l", "l", "o"]
78        );
79    }
80    #[test]
81    fn handmade_with_delimiter_sample() {
82        let ctc_decoder = CTC::default();
83        let id_to_string_result = "<pad> <pad> h e e l l <pad> l o o o <pad> <pad> | <pad> w o o o r <pad> <pad> l l d <pad> <pad> <pad> <pad>"
84            .split(' ')
85            .map(|s| s.to_string())
86            .collect();
87        assert_eq!(
88            ctc_decoder.decode_chain(id_to_string_result).unwrap(),
89            vec!["h", "e", "l", "l", "o", " ", "w", "o", "r", "l", "d"]
90        );
91    }
92    #[test]
93    fn librispeech_sample() {
94        let ctc_decoder = CTC::default();
95        let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> A | | <pad> M <pad> <pad> <pad> <pad> A <pad> <pad> N <pad> <pad> <pad> | | | <pad> <pad> <pad> <pad> S <pad> <pad> <pad> A I <pad> D D | | T T <pad> O <pad> | | T H E E | | | <pad> U U <pad> N N <pad> I <pad> <pad> V <pad> <pad> <pad> E R R <pad> <pad> <pad> S E E | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> S S <pad> <pad> <pad> <pad> I <pad> R R <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> <pad> <pad> | <pad> <pad> <pad> E X <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> I <pad> S <pad> <pad> T <pad> <pad> | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
96        assert_eq!(
97            ctc_decoder.decode_chain(id_to_string_result).unwrap(),
98            vec![
99                "A", " ", "M", "A", "N", " ", "S", "A", "I", "D", " ", "T", "O", " ", "T", "H",
100                "E", " ", "U", "N", "I", "V", "E", "R", "S", "E", " ", "S", "I", "R", " ", "I",
101                " ", "E", "X", "I", "S", "T", " "
102            ]
103        );
104    }
105    #[test]
106    fn another_librispeech_sample() {
107        let ctc_decoder = CTC::default();
108        let id_to_string_result = "<pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> I <pad> S S | | <pad> <pad> <pad> I N <pad> <pad> S <pad> T T <pad> <pad> A N C C T <pad> | | | | | <pad> <pad> <pad> <pad> P <pad> <pad> <pad> <pad> A <pad> <pad> N N N <pad> <pad> I <pad> C <pad> <pad> | | <pad> W <pad> <pad> A S <pad> | | <pad> <pad> <pad> F <pad> <pad> O L <pad> <pad> L L O O W E E D | | <pad> B <pad> <pad> <pad> Y <pad> | | | A | | <pad> S S S <pad> M M <pad> <pad> <pad> A L L <pad> <pad> <pad> <pad> L <pad> | | | <pad> <pad> <pad> <pad> S H H <pad> <pad> <pad> <pad> A R R <pad> <pad> P <pad> <pad> | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> B <pad> <pad> L L <pad> <pad> <pad> <pad> <pad> O W W <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> H <pad> <pad> <pad> <pad> <pad> <pad> <pad> I G H H | | <pad> <pad> O N <pad> | | H <pad> I S S | | <pad> <pad> C H H <pad> <pad> <pad> E <pad> S S <pad> T T <pad> <pad> | | | <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad> <pad>".split(' ').map(|s| s.to_string()).collect();
109        assert_eq!(
110            ctc_decoder.decode_chain(id_to_string_result).unwrap(),
111            vec![
112                "H", "I", "S", " ", "I", "N", "S", "T", "A", "N", "C", "T", " ", "P", "A", "N",
113                "I", "C", " ", "W", "A", "S", " ", "F", "O", "L", "L", "O", "W", "E", "D", " ",
114                "B", "Y", " ", "A", " ", "S", "M", "A", "L", "L", " ", "S", "H", "A", "R", "P",
115                " ", "B", "L", "O", "W", " ", "H", "I", "G", "H", " ", "O", "N", " ", "H", "I",
116                "S", " ", "C", "H", "E", "S", "T", " "
117            ]
118        );
119    }
120}