1use regex::Regex;
23use serde_json::{from_str, Value};
24use std::collections::{HashMap, HashSet};
25
26fn bytes_to_unicode() -> HashMap<u8, char> {
27 let mut bs: Vec<u8> = (b'!'..=b'~').collect();
28 bs.extend(&(b'\xA1'..=b'\xAC').collect::<Vec<u8>>());
29 bs.extend(&(b'\xAE'..=b'\xFF').collect::<Vec<u8>>());
30
31 let mut cs: Vec<u32> = bs.iter().map(|&x| x as u32).collect();
32 let mut n = 0;
33 for b in 0..u8::MAX {
34 if !bs.contains(&b) {
35 bs.push(b);
36 cs.push(2_u32.pow(8) + n);
37 n += 1;
38 }
39 }
40
41 let cs: Vec<char> = cs.iter().map(|&x| char::from_u32(x).unwrap()).collect();
42 let result: HashMap<_, _> = bs.into_iter().zip(cs).collect();
43 result
44}
45
46fn get_pairs(word: &Vec<String>) -> HashSet<(String, String)> {
47 let mut pairs = HashSet::new();
48 let mut prev_char = word.iter().next().unwrap();
49 for char in word.iter().skip(1) {
50 pairs.insert((prev_char.clone(), char.clone()));
51 prev_char = char;
52 }
53 pairs
54}
55
56pub struct Encoder {
57 pat: Regex,
58 byte_encoder: HashMap<u8, char>,
59 byte_decoder: HashMap<char, u8>,
60 encoder: HashMap<String, u64>,
61 decoder: HashMap<u64, String>,
62 bpe_ranks: HashMap<(String, String), usize>,
63 cache: HashMap<String, String>,
64}
65
66impl Encoder {
67 pub fn new() -> Self {
69 let pat = Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap();
70 let byte_encoder = bytes_to_unicode();
71 let byte_decoder = byte_encoder.iter().map(|(&k, &v)| (v, k)).collect::<HashMap<_, _>>();
72
73 let mut encoder = HashMap::new();
74 let mut decoder = HashMap::new();
75
76 let encoder_str = include_str!("encoder.json");
77 let encoder_json: Value = from_str(&encoder_str).expect("Unable to parse JSON");
78 for (key, value) in encoder_json.as_object().unwrap() {
79 encoder.insert(key.clone(), value.as_u64().unwrap());
80 decoder.insert(value.as_u64().unwrap(), key.clone());
81 }
82
83 let vocab_bpe = include_str!("vocab.bpe");
84 let bpe_merges = vocab_bpe
85 .split('\n')
86 .skip(1)
87 .take_while(|line| !line.is_empty())
88 .map(|line| {
89 let merge_str: Vec<&str> = line.split(' ').collect();
90 (merge_str[0].to_owned(), merge_str[1].to_owned())
91 })
92 .collect::<Vec<(String, String)>>();
93
94 let idx = 0..bpe_merges.len();
95 let bpe_ranks = bpe_merges.into_iter().zip(idx).collect::<HashMap<_, _>>();
96
97 let cache = HashMap::new();
98 Self {
99 pat,
100 byte_encoder,
101 byte_decoder,
102 encoder,
103 decoder,
104 bpe_ranks,
105 cache,
106 }
107 }
108
109 fn bpe(&mut self, token: String) -> String {
110 if let Some(cached_word) = self.cache.get(&token) {
111 return cached_word.to_string();
112 }
113
114 let mut word = token.chars().map(|c| c.to_string()).collect::<Vec<_>>();
115 let mut pairs = get_pairs(&word);
116 if pairs.is_empty() {
117 return token;
118 }
119
120 loop {
121 let bigram = pairs
122 .iter()
123 .min_by_key(|pair| self.bpe_ranks.get(pair).unwrap_or(&usize::MAX))
124 .unwrap();
125
126 if !self.bpe_ranks.contains_key(bigram) {
127 break;
128 }
129
130 let (first, second) = bigram;
131 let mut new_word: Vec<String> = vec![];
132 let mut i = 0;
133 while i < word.len() {
134 if let Some(j) = word[i..].iter().position(|c| c == first) {
135 new_word.extend(word[i..i+j].iter().map(|c| c.to_string()));
136 i += j;
137
138 if i < word.len() - 1 && &word[i] == first && &word[i + 1] == second {
139 new_word.push(first.to_string() + &second.to_string());
140 i += 2;
141 } else {
142 new_word.push(word[i].to_string());
143 i += 1;
144 }
145 } else {
146 new_word.extend(word[i..].iter().map(|c| c.to_string()));
147 break;
148 }
149 }
150
151 word = new_word;
152 if word.len() == 1 {
153 break;
154 } else {
155 pairs = get_pairs(&word);
156 }
157 }
158
159 let word = word.join(" ");
160 self.cache.insert(token, word.clone());
161 word
162 }
163
164 pub fn encode(&mut self, text: String) -> Vec<u64> {
173 let mut bpe_tokens = vec![];
174
175 let matches: Vec<&str> = self.pat
176 .find_iter(text.as_str())
177 .map(|m| m.as_str())
178 .filter(|s| !s.is_empty())
179 .collect();
180
181 for token in matches {
182 let token = token
183 .bytes()
184 .map(|x| self.byte_encoder.get(&x).unwrap().to_string())
185 .collect::<Vec<_>>()
186 .join("");
187
188 let mut new_tokens = self.bpe(token)
189 .split(' ')
190 .map(|x| self.encoder.get(&x.to_string()))
191 .filter(|x| x.is_some())
192 .map(|x| x.unwrap().clone())
193 .collect::<Vec<_>>();
194 bpe_tokens.append(&mut new_tokens);
195 }
196 bpe_tokens
197 }
198
199 pub fn decode(&self, token: Vec<u64>) -> String {
208 let text: String = token.iter().map(|t| self.decoder.get(t).unwrap().clone()).collect::<Vec<_>>().join("");
209 let text: Vec<u8> = text.chars().map(|c| self.byte_decoder.get(&c).unwrap().clone()).collect::<Vec<_>>();
210 String::from_utf8_lossy(&text).to_string()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test() -> Result<(), String> {
220 let mut encoder = Encoder::new();
221
222 test_encoder(&mut encoder,"Space", " ".to_string(), vec![220])?;
223 test_encoder(&mut encoder,"Tab", "\t".to_string(), vec![197])?;
224 test_encoder(&mut encoder,"Simple text", "This is some text".to_string(), vec![1212, 318, 617, 2420])?;
225 test_encoder(&mut encoder,"indivisible", "indivisible".to_string(), vec![521, 452, 12843])?;
226 test_encoder(&mut encoder,"emojis", "hello 👋 world 🌍".to_string(), vec![31373, 50169, 233, 995, 12520, 234, 235])?;
227 test_encoder(&mut encoder,"properties of Object", "toString constructor hasOwnProperty valueOf".to_string(), vec![1462, 10100, 23772, 468, 23858, 21746, 1988, 5189])?;
228
229 Ok(())
230 }
231
232 fn test_encoder(encoder: &mut Encoder, title: &str, text: String, expected: Vec<u64>) -> Result<(), String> {
233 let encoded = encoder.encode(text.clone());
234 if encoded != expected {
235 return Err(format!("{}: encoded output did not match the expected output", title))
236 }
237
238 let decoded = encoder.decode(encoded);
239 if decoded != text {
240 return Err(format!("{}: decoded output did not match the input text", title))
241 }
242
243 Ok(())
244 }
245}