1use std::path::Path;
7
8pub struct Tokenizer {
10 inner: tokenizers::Tokenizer,
11}
12
13#[derive(Debug, thiserror::Error)]
15pub enum TokenizerError {
16 #[error("failed to load tokenizer: {0}")]
17 Load(String),
18
19 #[error("encoding failed: {0}")]
20 Encode(String),
21
22 #[error("decoding failed: {0}")]
23 Decode(String),
24}
25
26impl Tokenizer {
27 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, TokenizerError> {
29 let inner = tokenizers::Tokenizer::from_file(path.as_ref())
30 .map_err(|e| TokenizerError::Load(e.to_string()))?;
31 Ok(Self { inner })
32 }
33
34 pub fn from_json(json: &str) -> Result<Self, TokenizerError> {
36 let inner: tokenizers::Tokenizer =
37 json.parse()
38 .map_err(|e: Box<dyn std::error::Error + Send + Sync>| {
39 TokenizerError::Load(e.to_string())
40 })?;
41 Ok(Self { inner })
42 }
43
44 pub fn encode(&self, text: &str) -> Result<Vec<u32>, TokenizerError> {
46 let encoding = self
47 .inner
48 .encode(text, false)
49 .map_err(|e| TokenizerError::Encode(e.to_string()))?;
50 Ok(encoding.get_ids().to_vec())
51 }
52
53 pub fn encode_with_special(&self, text: &str) -> Result<Vec<u32>, TokenizerError> {
55 let encoding = self
56 .inner
57 .encode(text, true)
58 .map_err(|e| TokenizerError::Encode(e.to_string()))?;
59 Ok(encoding.get_ids().to_vec())
60 }
61
62 pub fn decode(&self, ids: &[u32]) -> Result<String, TokenizerError> {
64 self.inner
65 .decode(ids, true)
66 .map_err(|e| TokenizerError::Decode(e.to_string()))
67 }
68
69 pub fn decode_one(&self, id: u32) -> Result<String, TokenizerError> {
71 self.decode(&[id])
72 }
73
74 pub fn vocab_size(&self) -> usize {
76 self.inner.get_vocab_size(true)
77 }
78
79 pub fn token_to_id(&self, token: &str) -> Option<u32> {
81 self.inner.token_to_id(token)
82 }
83
84 pub fn bos_token_id(&self) -> Option<u32> {
86 self.token_to_id("<s>")
88 .or_else(|| self.token_to_id("<|begin_of_text|>"))
89 .or_else(|| self.token_to_id("<|startoftext|>"))
90 }
91
92 pub fn eos_token_id(&self) -> Option<u32> {
94 self.token_to_id("</s>")
95 .or_else(|| self.token_to_id("<|end_of_text|>"))
96 .or_else(|| self.token_to_id("<|endoftext|>"))
97 .or_else(|| self.token_to_id("<|im_end|>"))
98 }
99
100 pub fn stop_token_ids(&self) -> Vec<u32> {
103 let candidates = [
104 "</s>",
105 "<|end_of_text|>",
106 "<|endoftext|>",
107 "<|im_end|>",
108 "<|eot_id|>",
109 "<|end|>",
110 ];
111 candidates
112 .iter()
113 .filter_map(|&token| self.token_to_id(token))
114 .collect()
115 }
116
117 pub fn is_stop_token(&self, token_id: u32) -> bool {
119 self.stop_token_ids().contains(&token_id)
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 fn minimal_tokenizer_json() -> String {
130 r#"{
131 "version": "1.0",
132 "truncation": null,
133 "padding": null,
134 "added_tokens": [
135 {"id": 0, "content": "<s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
136 {"id": 1, "content": "</s>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true},
137 {"id": 2, "content": "<unk>", "single_word": false, "lstrip": false, "rstrip": false, "normalized": false, "special": true}
138 ],
139 "normalizer": null,
140 "pre_tokenizer": {"type": "ByteLevel", "add_prefix_space": false, "trim_offsets": true, "use_regex": true},
141 "post_processor": null,
142 "decoder": {"type": "ByteLevel", "add_prefix_space": true, "trim_offsets": true, "use_regex": true},
143 "model": {
144 "type": "BPE",
145 "dropout": null,
146 "unk_token": "<unk>",
147 "continuing_subword_prefix": null,
148 "end_of_word_suffix": null,
149 "fuse_unk": false,
150 "byte_fallback": false,
151 "ignore_merges": false,
152 "vocab": {
153 "<s>": 0, "</s>": 1, "<unk>": 2,
154 "h": 3, "e": 4, "l": 5, "o": 6,
155 "he": 7, "ll": 8, "lo": 9,
156 "hel": 10, "llo": 11
157 },
158 "merges": [
159 "h e", "l l", "l o", "he l", "ll o"
160 ]
161 }
162 }"#
163 .to_string()
164 }
165
166 #[test]
167 fn load_from_json() {
168 let json = minimal_tokenizer_json();
169 let tok = Tokenizer::from_json(&json).unwrap();
170 assert!(tok.vocab_size() > 0);
171 }
172
173 #[test]
174 fn encode_decode_roundtrip() {
175 let json = minimal_tokenizer_json();
176 let tok = Tokenizer::from_json(&json).unwrap();
177
178 let ids = tok.encode("hello").unwrap();
179 assert!(!ids.is_empty());
180
181 let text = tok.decode(&ids).unwrap();
182 assert_eq!(text, "hello");
183 }
184
185 #[test]
186 fn special_tokens() {
187 let json = minimal_tokenizer_json();
188 let tok = Tokenizer::from_json(&json).unwrap();
189
190 assert_eq!(tok.bos_token_id(), Some(0));
191 assert_eq!(tok.eos_token_id(), Some(1));
192 }
193
194 #[test]
195 fn decode_single_token() {
196 let json = minimal_tokenizer_json();
197 let tok = Tokenizer::from_json(&json).unwrap();
198
199 let text = tok.decode_one(3).unwrap();
201 assert!(!text.is_empty());
202 }
203}