Skip to main content

llguidance/
tokenizer_json.rs

1use crate::HashMap;
2use anyhow::{anyhow, bail, Result};
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use toktrie::TokTrie;
6
7#[derive(Debug, Deserialize, Serialize)]
8struct AddedToken {
9    id: usize,
10    content: String,
11    special: bool,
12}
13
14fn add_bytes(tokens: &mut Vec<Vec<u8>>, idx: usize, bytes: Vec<u8>) {
15    if tokens.len() <= idx {
16        tokens.resize(idx + 1, vec![]);
17    }
18    tokens[idx] = bytes;
19}
20
21// useful when debugging this: https://www.cogsci.ed.ac.uk/~richard/utf-8.cgi
22
23fn is_self_mapped(c: char) -> bool {
24    matches!(c, '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}')
25}
26
27fn build_char_map() -> HashMap<char, u8> {
28    let mut res = HashMap::default();
29    let mut k = 0x100u32;
30    for byte in 0..=255u8 {
31        let c = byte as char;
32        if is_self_mapped(c) {
33            res.insert(c, byte);
34        } else {
35            res.insert(char::from_u32(k).unwrap(), byte);
36            k += 1;
37        }
38    }
39    res
40}
41
42/// Parse HF tokenizer.json file and return bytes for every token
43pub fn token_bytes_from_tokenizer_json(tokenizer_json: &Value) -> Result<Vec<Vec<u8>>> {
44    let mut is_byte_level = false;
45    let mut is_byte_fallback = false;
46    let mut space_ch = ' ';
47
48    let decoder = &tokenizer_json["decoder"];
49    if decoder["type"].as_str() == Some("ByteLevel") {
50        is_byte_level = true;
51    } else if decoder["type"].as_str() == Some("Sequence") {
52        if let Some(decoders) = decoder["decoders"].as_array() {
53            for decoder in decoders {
54                if decoder["type"].as_str() == Some("ByteFallback") {
55                    is_byte_fallback = true;
56                } else if decoder["type"].as_str() == Some("Replace")
57                    && decoder["content"].as_str() == Some(" ")
58                {
59                    if let Some(s) = decoder["pattern"]["String"].as_str() {
60                        let s: Vec<char> = s.chars().collect();
61                        if s.len() == 1 {
62                            space_ch = s[0];
63                        }
64                    }
65                }
66            }
67        }
68    }
69
70    if !is_byte_fallback && !is_byte_level {
71        bail!("can't determine decoder type: {:?}", decoder["type"]);
72    }
73
74    let mut token_bytes = vec![];
75    let added_tokens: Vec<AddedToken> =
76        serde_json::from_value(tokenizer_json["added_tokens"].clone())
77            .map_err(|e| anyhow!("error parsing added_tokens: {}", e))?;
78
79    for info in added_tokens.iter() {
80        let mut bytes = info.content.as_bytes().to_vec();
81        if info.special {
82            bytes.insert(0, TokTrie::SPECIAL_TOKEN_MARKER);
83        }
84        add_bytes(&mut token_bytes, info.id, bytes);
85    }
86
87    let char_map = build_char_map();
88
89    let vocab: HashMap<String, usize> =
90        serde_json::from_value(tokenizer_json["model"]["vocab"].clone())
91            .map_err(|e| anyhow!("error parsing vocab: {}", e))?;
92
93    for (tok_name, &tok_id) in vocab.iter() {
94        if tok_id < token_bytes.len() && !token_bytes[tok_id].is_empty() {
95            continue; // skip specials already added
96        }
97
98        let bytes = if is_byte_fallback {
99            if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with(">") {
100                // parse hex number from tok_name
101                let hex_str = &tok_name[3..5];
102                let byte = u8::from_str_radix(hex_str, 16).unwrap();
103                vec![byte]
104            } else {
105                assert!(!tok_name.starts_with("<0x"));
106                let tok_name = tok_name.replace(space_ch, " ");
107                tok_name.as_bytes().to_vec()
108            }
109        } else if is_byte_level {
110            let bytes: Result<Vec<u8>> = tok_name
111                .chars()
112                .map(|c| {
113                    char_map
114                        .get(&c)
115                        .copied()
116                        .ok_or_else(|| anyhow!("missing char: {}", c))
117                })
118                .collect();
119            match bytes {
120                Ok(b) => b,
121                Err(e) => {
122                    bail!("error: {} decoding {:?}", e, tok_name);
123                }
124            }
125        } else {
126            panic!();
127        };
128        add_bytes(&mut token_bytes, tok_id, bytes);
129    }
130
131    Ok(token_bytes)
132}