llguidance/
tokenizer_json.rs1use crate::HashMap;
2use anyhow::{anyhow, bail, Result};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use toktrie::TokTrie;
6
7#[derive(Debug, Deserialize, Serialize)]
8struct AddedToken {
9 id: usize,
10 content: String,
11 special: bool,
12}
13
14fn add_bytes(tokens: &mut Vec<Vec<u8>>, idx: usize, bytes: Vec<u8>) {
15 if tokens.len() <= idx {
16 tokens.resize(idx + 1, vec![]);
17 }
18 tokens[idx] = bytes;
19}
20
21fn is_self_mapped(c: char) -> bool {
24 matches!(c, '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}')
25}
26
27fn build_char_map() -> HashMap<char, u8> {
28 let mut res = HashMap::default();
29 let mut k = 0x100u32;
30 for byte in 0..=255u8 {
31 let c = byte as char;
32 if is_self_mapped(c) {
33 res.insert(c, byte);
34 } else {
35 res.insert(char::from_u32(k).unwrap(), byte);
36 k += 1;
37 }
38 }
39 res
40}
41
42pub fn token_bytes_from_tokenizer_json(tokenizer_json: &Value) -> Result<Vec<Vec<u8>>> {
44 let mut is_byte_level = false;
45 let mut is_byte_fallback = false;
46 let mut space_ch = ' ';
47
48 let decoder = &tokenizer_json["decoder"];
49 if decoder["type"].as_str() == Some("ByteLevel") {
50 is_byte_level = true;
51 } else if decoder["type"].as_str() == Some("Sequence") {
52 if let Some(decoders) = decoder["decoders"].as_array() {
53 for decoder in decoders {
54 if decoder["type"].as_str() == Some("ByteFallback") {
55 is_byte_fallback = true;
56 } else if decoder["type"].as_str() == Some("Replace")
57 && decoder["content"].as_str() == Some(" ")
58 {
59 if let Some(s) = decoder["pattern"]["String"].as_str() {
60 let s: Vec<char> = s.chars().collect();
61 if s.len() == 1 {
62 space_ch = s[0];
63 }
64 }
65 }
66 }
67 }
68 }
69
70 if !is_byte_fallback && !is_byte_level {
71 bail!("can't determine decoder type: {:?}", decoder["type"]);
72 }
73
74 let mut token_bytes = vec![];
75 let added_tokens: Vec<AddedToken> =
76 serde_json::from_value(tokenizer_json["added_tokens"].clone())
77 .map_err(|e| anyhow!("error parsing added_tokens: {}", e))?;
78
79 for info in added_tokens.iter() {
80 let mut bytes = info.content.as_bytes().to_vec();
81 if info.special {
82 bytes.insert(0, TokTrie::SPECIAL_TOKEN_MARKER);
83 }
84 add_bytes(&mut token_bytes, info.id, bytes);
85 }
86
87 let char_map = build_char_map();
88
89 let vocab: HashMap<String, usize> =
90 serde_json::from_value(tokenizer_json["model"]["vocab"].clone())
91 .map_err(|e| anyhow!("error parsing vocab: {}", e))?;
92
93 for (tok_name, &tok_id) in vocab.iter() {
94 if tok_id < token_bytes.len() && !token_bytes[tok_id].is_empty() {
95 continue; }
97
98 let bytes = if is_byte_fallback {
99 if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") {
100 let hex_str = &tok_name[3..5];
102 let byte = u8::from_str_radix(hex_str, 16).unwrap();
103 vec![byte]
104 } else {
105 assert!(!tok_name.starts_with("<0x"));
106 let tok_name = tok_name.replace(space_ch, " ");
107 tok_name.as_bytes().to_vec()
108 }
109 } else if is_byte_level {
110 let bytes: Result<Vec<u8>> = tok_name
111 .chars()
112 .map(|c| {
113 char_map
114 .get(&c)
115 .copied()
116 .ok_or_else(|| anyhow!("missing char: {}", c))
117 })
118 .collect();
119 match bytes {
120 Ok(b) => b,
121 Err(e) => {
122 bail!("error: {} decoding {:?}", e, tok_name);
123 }
124 }
125 } else {
126 panic!();
127 };
128 add_bytes(&mut token_bytes, tok_id, bytes);
129 }
130
131 Ok(token_bytes)
132}