1use std::collections::HashMap;
6use std::sync::OnceLock;
7
8pub const METASPACE: char = '\u{2581}';
10
11struct Tables {
12 byte_to_char: [char; 256],
13 char_to_byte: HashMap<u32, u8>,
14}
15
16fn tables() -> &'static Tables {
17 static T: OnceLock<Tables> = OnceLock::new();
18 T.get_or_init(|| {
19 let mut bs: Vec<u32> = Vec::with_capacity(256);
23 for i in 33..=126u32 {
24 bs.push(i);
25 }
26 for i in 161..=172u32 {
27 bs.push(i);
28 }
29 for i in 174..=255u32 {
30 bs.push(i);
31 }
32 let mut cs: Vec<u32> = bs.clone();
33 let mut n: u32 = 0;
34 for b in 0u32..256 {
35 if !bs.contains(&b) {
36 bs.push(b);
37 cs.push(256 + n);
38 n += 1;
39 }
40 }
41 let mut byte_to_char = ['\0'; 256];
42 let mut char_to_byte: HashMap<u32, u8> = HashMap::with_capacity(256);
43 for (i, &b) in bs.iter().enumerate() {
44 let c = char::from_u32(cs[i]).expect("GPT-2 char point should be valid");
45 byte_to_char[b as usize] = c;
46 char_to_byte.insert(cs[i], b as u8);
47 }
48 Tables { byte_to_char, char_to_byte }
49 })
50}
51
52pub fn byte_to_char(b: u8) -> char {
54 tables().byte_to_char[b as usize]
55}
56
57pub fn char_to_byte(c: char) -> Option<u8> {
60 tables().char_to_byte.get(&(c as u32)).copied()
61}
62
63pub fn decode_byte_level_token(raw_token: &str) -> Vec<u8> {
68 let mut buf: Vec<u8> = Vec::with_capacity(raw_token.len());
69 for c in raw_token.chars() {
70 match char_to_byte(c) {
71 Some(b) => buf.push(b),
72 None => {
73 let mut tmp = [0u8; 4];
75 let s = c.encode_utf8(&mut tmp);
76 buf.extend_from_slice(s.as_bytes());
77 }
78 }
79 }
80 buf
81}
82
83pub fn encode_byte_level_chars(bytes: &[u8]) -> String {
86 let mut s = String::with_capacity(bytes.len());
87 for &b in bytes {
88 s.push(byte_to_char(b));
89 }
90 s
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 #[test]
98 fn space_maps_to_capital_g_with_dot() {
99 assert_eq!(byte_to_char(0x20), '\u{0120}');
101 assert_eq!(char_to_byte('\u{0120}'), Some(0x20));
102 }
103
104 #[test]
105 fn round_trip_all_bytes() {
106 for b in 0u8..=255 {
107 let c = byte_to_char(b);
108 assert_eq!(char_to_byte(c), Some(b));
109 }
110 }
111}