tokenizers/decoders/
wordpiece.rs

1use crate::tokenizer::{Decoder, Result};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Deserialize, Clone, Debug, Serialize)]
6/// The WordPiece decoder takes care of decoding a list of wordpiece tokens
7/// back into a readable string.
8#[serde(tag = "type")]
9#[non_exhaustive]
10pub struct WordPiece {
11    /// The prefix to be used for continuing subwords
12    pub prefix: String,
13    /// Whether to cleanup some tokenization artifacts (spaces before punctuation, ...)
14    pub cleanup: bool,
15}
16
17impl WordPiece {
18    pub fn new(prefix: String, cleanup: bool) -> Self {
19        Self { prefix, cleanup }
20    }
21}
22
23impl Default for WordPiece {
24    fn default() -> Self {
25        Self {
26            prefix: "##".to_owned(),
27            cleanup: true,
28        }
29    }
30}
31pub fn cleanup(dirty_input: &str) -> String {
32    dirty_input
33        .replace(" .", ".")
34        .replace(" ?", "?")
35        .replace(" !", "!")
36        .replace(" ,", ",")
37        .replace(" ' ", "'")
38        .replace(" n't", "n't")
39        .replace(" 'm", "'m")
40        .replace(" do not", " don't")
41        .replace(" 's", "'s")
42        .replace(" 've", "'ve")
43        .replace(" 're", "'re")
44}
45
46impl Decoder for WordPiece {
47    fn decode_chain(&self, mut tokens: Vec<String>) -> Result<Vec<String>> {
48        tokens
49            .iter_mut()
50            .enumerate()
51            .map(|(i, token)| {
52                if i != 0 {
53                    if token.starts_with(&self.prefix) {
54                        *token = token.replacen(&self.prefix, "", 1);
55                    } else {
56                        *token = format!(" {token}");
57                    }
58                }
59                if self.cleanup {
60                    *token = cleanup(token);
61                }
62                Ok(token.to_string())
63            })
64            .collect::<Result<_>>()
65    }
66}
67
68#[cfg(test)]
69mod tests {
70    use super::*;
71
72    #[test]
73    fn wordpiece_decoder() {
74        let decoder = WordPiece::new("##".to_string(), false);
75
76        assert_eq!(
77            decoder
78                .decode(vec![
79                    "##uelo".to_string(),
80                    "Ara".to_string(),
81                    "##új".to_string(),
82                    "##o".to_string(),
83                    "No".to_string(),
84                    "##guera".to_string()
85                ])
86                .unwrap(),
87            "##uelo Araújo Noguera"
88        );
89    }
90}