gpt_encoder/
lib.rs

1//! # GPT-Encoder
2//! Rust BPE Encoder Decoder for GPT-2 / GPT-3
3//! 
4//! This is rewrite of [openai's gpt-2 encoder](https://github.com/openai/gpt-2/blob/master/src/encoder.py) and [latitudegames's GPT-3-Encoder](https://github.com/latitudegames/GPT-3-Encoder) in rust
5//! 
6//! # Example
7//! ```
8//! use gpt_encoder::Encoder;
9//! 
10//! fn main() {
11//!     let mut encoder = Encoder::new();
12//!     let encoded = encoder.encode("Hello, World".to_string());
13//!     println!("{:?}", encoded); 
14//!     // prints: [15496, 11, 2159]
15//! 
16//!     let decoded = encoder.decode(encoded);
17//!     println!("{:?}", decoded); 
18//!     // prints: "Hello, World"
19//! }
20//! ```
21
22use regex::Regex;
23use serde_json::{from_str, Value};
24use std::collections::{HashMap, HashSet};
25
26fn bytes_to_unicode() -> HashMap<u8, char> {
27    let mut bs: Vec<u8> = (b'!'..=b'~').collect();
28    bs.extend(&(b'\xA1'..=b'\xAC').collect::<Vec<u8>>());
29    bs.extend(&(b'\xAE'..=b'\xFF').collect::<Vec<u8>>());
30
31    let mut cs: Vec<u32> = bs.iter().map(|&x| x as u32).collect();
32    let mut n = 0;
33    for b in 0..u8::MAX {
34        if !bs.contains(&b) {
35            bs.push(b);
36            cs.push(2_u32.pow(8) + n);
37            n += 1;
38        }
39    }
40
41    let cs: Vec<char> = cs.iter().map(|&x| char::from_u32(x).unwrap()).collect();
42    let result: HashMap<_, _> = bs.into_iter().zip(cs).collect();
43    result
44}
45
46fn get_pairs(word: &Vec<String>) -> HashSet<(String, String)> {
47    let mut pairs = HashSet::new();
48    let mut prev_char = word.iter().next().unwrap();
49    for char in word.iter().skip(1) {
50        pairs.insert((prev_char.clone(), char.clone()));
51        prev_char = char;
52    }
53    pairs
54}
55
56pub struct Encoder {
57    pat: Regex,
58    byte_encoder: HashMap<u8, char>,
59    byte_decoder: HashMap<char, u8>,
60    encoder: HashMap<String, u64>,
61    decoder: HashMap<u64, String>,
62    bpe_ranks: HashMap<(String, String), usize>,
63    cache: HashMap<String, String>,
64}
65
66impl Encoder {
67    /// To create new instance of Encoder
68    pub fn new() -> Self {
69        let pat = Regex::new(r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+").unwrap();
70        let byte_encoder = bytes_to_unicode();
71        let byte_decoder = byte_encoder.iter().map(|(&k, &v)| (v, k)).collect::<HashMap<_, _>>();
72
73        let mut encoder = HashMap::new();
74        let mut decoder = HashMap::new();
75
76        let encoder_str = include_str!("encoder.json");
77        let encoder_json: Value = from_str(&encoder_str).expect("Unable to parse JSON");
78        for (key, value) in encoder_json.as_object().unwrap() {
79            encoder.insert(key.clone(), value.as_u64().unwrap());
80            decoder.insert(value.as_u64().unwrap(), key.clone());
81        }
82
83        let vocab_bpe = include_str!("vocab.bpe");
84        let bpe_merges = vocab_bpe
85            .split('\n')
86            .skip(1)
87            .take_while(|line| !line.is_empty())
88            .map(|line| {
89                let merge_str: Vec<&str> = line.split(' ').collect();
90                (merge_str[0].to_owned(), merge_str[1].to_owned())
91            })
92            .collect::<Vec<(String, String)>>();
93
94        let idx = 0..bpe_merges.len();
95        let bpe_ranks = bpe_merges.into_iter().zip(idx).collect::<HashMap<_, _>>();
96
97        let cache = HashMap::new();
98        Self {
99            pat,
100            byte_encoder,
101            byte_decoder,
102            encoder,
103            decoder,
104            bpe_ranks,
105            cache,
106        }
107    }
108
109    fn bpe(&mut self, token: String) -> String {
110        if let Some(cached_word) = self.cache.get(&token) {
111            return cached_word.to_string();
112        }
113        
114        let mut word = token.chars().map(|c| c.to_string()).collect::<Vec<_>>();
115        let mut pairs = get_pairs(&word);
116        if pairs.is_empty() {
117            return token;
118        }
119    
120        loop {
121            let bigram = pairs
122                .iter()
123                .min_by_key(|pair| self.bpe_ranks.get(pair).unwrap_or(&usize::MAX))
124                .unwrap();
125
126            if !self.bpe_ranks.contains_key(bigram) {
127                break;
128            }
129    
130            let (first, second) = bigram;
131            let mut new_word: Vec<String> = vec![];
132            let mut i = 0;
133            while i < word.len() {
134                if let Some(j) = word[i..].iter().position(|c| c == first) {
135                    new_word.extend(word[i..i+j].iter().map(|c| c.to_string()));
136                    i += j;
137
138                    if i < word.len() - 1 && &word[i] == first && &word[i + 1] == second {
139                        new_word.push(first.to_string() + &second.to_string());
140                        i += 2;
141                    } else {
142                        new_word.push(word[i].to_string());
143                        i += 1;
144                    }
145                } else {
146                    new_word.extend(word[i..].iter().map(|c| c.to_string()));
147                    break;
148                }
149            }
150    
151            word = new_word;
152            if word.len() == 1 {
153                break;
154            } else {
155                pairs = get_pairs(&word);
156            }
157        }
158    
159        let word = word.join(" ");
160        self.cache.insert(token, word.clone());
161        word
162    }
163    
164    /// # Example
165    /// ```
166    /// use gpt_encoder::Encoder;
167    /// 
168    /// let mut encoder = Encoder::new();
169    /// let encoded = encoder.encode("Hello, World".to_string());
170    /// assert_eq!(encoded, vec![15496, 11, 2159]);
171    /// ```
172    pub fn encode(&mut self, text: String) -> Vec<u64> {
173        let mut bpe_tokens = vec![];
174
175        let matches: Vec<&str> = self.pat
176            .find_iter(text.as_str())
177            .map(|m| m.as_str())
178            .filter(|s| !s.is_empty())
179            .collect();
180
181        for token in matches {
182            let token = token
183                .bytes()
184                .map(|x| self.byte_encoder.get(&x).unwrap().to_string())
185                .collect::<Vec<_>>()
186                .join("");
187            
188            let mut new_tokens = self.bpe(token)
189                .split(' ')
190                .map(|x| self.encoder.get(&x.to_string()))
191                .filter(|x| x.is_some())
192                .map(|x| x.unwrap().clone())
193                .collect::<Vec<_>>();
194            bpe_tokens.append(&mut new_tokens);
195        }
196        bpe_tokens
197    }
198
199    /// # Example
200    /// ```
201    /// use gpt_encoder::Encoder;
202    /// 
203    /// let encoder = Encoder::new();
204    /// let decoded = encoder.decode(vec![15496, 11, 2159]);
205    /// assert_eq!(decoded, "Hello, World".to_string());
206    /// ```
207    pub fn decode(&self, token: Vec<u64>) -> String {
208        let text: String = token.iter().map(|t| self.decoder.get(t).unwrap().clone()).collect::<Vec<_>>().join("");
209        let text: Vec<u8> = text.chars().map(|c| self.byte_decoder.get(&c).unwrap().clone()).collect::<Vec<_>>();
210        String::from_utf8_lossy(&text).to_string()
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test() -> Result<(), String> {
220        let mut encoder = Encoder::new();
221
222        test_encoder(&mut encoder,"Space", " ".to_string(), vec![220])?;
223        test_encoder(&mut encoder,"Tab", "\t".to_string(), vec![197])?;
224        test_encoder(&mut encoder,"Simple text", "This is some text".to_string(), vec![1212, 318, 617, 2420])?;
225        test_encoder(&mut encoder,"indivisible", "indivisible".to_string(), vec![521, 452, 12843])?;
226        test_encoder(&mut encoder,"emojis", "hello 👋 world 🌍".to_string(), vec![31373, 50169, 233, 995, 12520, 234, 235])?;
227        test_encoder(&mut encoder,"properties of Object", "toString constructor hasOwnProperty valueOf".to_string(), vec![1462, 10100, 23772, 468, 23858, 21746, 1988, 5189])?;
228
229        Ok(())
230    }
231
232    fn test_encoder(encoder: &mut Encoder, title: &str, text: String, expected: Vec<u64>) -> Result<(), String> {
233        let encoded = encoder.encode(text.clone());
234        if encoded != expected {
235            return Err(format!("{}: encoded output did not match the expected output", title))
236        }
237    
238        let decoded = encoder.decode(encoded);
239        if decoded != text {
240            return Err(format!("{}: decoded output did not match the input text", title))
241        }
242    
243        Ok(())
244    }
245}