tokenizers/decoders/
wordpiece.rs1use crate::tokenizer::{Decoder, Result};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Deserialize, Clone, Debug, Serialize)]
6#[serde(tag = "type")]
9#[non_exhaustive]
10pub struct WordPiece {
11 pub prefix: String,
13 pub cleanup: bool,
15}
16
17impl WordPiece {
18 pub fn new(prefix: String, cleanup: bool) -> Self {
19 Self { prefix, cleanup }
20 }
21}
22
23impl Default for WordPiece {
24 fn default() -> Self {
25 Self {
26 prefix: "##".to_owned(),
27 cleanup: true,
28 }
29 }
30}
31pub fn cleanup(dirty_input: &str) -> String {
32 dirty_input
33 .replace(" .", ".")
34 .replace(" ?", "?")
35 .replace(" !", "!")
36 .replace(" ,", ",")
37 .replace(" ' ", "'")
38 .replace(" n't", "n't")
39 .replace(" 'm", "'m")
40 .replace(" do not", " don't")
41 .replace(" 's", "'s")
42 .replace(" 've", "'ve")
43 .replace(" 're", "'re")
44}
45
46impl Decoder for WordPiece {
47 fn decode_chain(&self, mut tokens: Vec<String>) -> Result<Vec<String>> {
48 tokens
49 .iter_mut()
50 .enumerate()
51 .map(|(i, token)| {
52 if i != 0 {
53 if token.starts_with(&self.prefix) {
54 *token = token.replacen(&self.prefix, "", 1);
55 } else {
56 *token = format!(" {token}");
57 }
58 }
59 if self.cleanup {
60 *token = cleanup(token);
61 }
62 Ok(token.to_string())
63 })
64 .collect::<Result<_>>()
65 }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71
72 #[test]
73 fn wordpiece_decoder() {
74 let decoder = WordPiece::new("##".to_string(), false);
75
76 assert_eq!(
77 decoder
78 .decode(vec![
79 "##uelo".to_string(),
80 "Ara".to_string(),
81 "##új".to_string(),
82 "##o".to_string(),
83 "No".to_string(),
84 "##guera".to_string()
85 ])
86 .unwrap(),
87 "##uelo Araújo Noguera"
88 );
89 }
90}