1use super::error::{Result, TokenizerError};
4
5const GPT2_VOCAB_SIZE: u32 = 50256;
7
8pub use aprender::text::bpe::{
10 bytes_to_unicode, load_from_files as load_hf_from_files, load_from_json as load_hf_from_json,
11 BpeConfig as HfBpeConfig, BpeTokenizer as HfBpeTokenizer, MergeRule, Qwen2BpeTokenizer,
12};
13
14#[derive(Debug, Clone)]
18pub struct HfTokenizer {
19 inner: HfBpeTokenizer,
20 pad_id: u32,
21 eos_id: Option<u32>,
22 bos_id: Option<u32>,
23}
24
25impl HfTokenizer {
26 #[must_use]
28 pub fn gpt2() -> Self {
29 Self {
30 inner: HfBpeTokenizer::gpt2_base(),
31 pad_id: GPT2_VOCAB_SIZE,
32 eos_id: Some(GPT2_VOCAB_SIZE),
33 bos_id: None,
34 }
35 }
36
37 #[must_use]
39 pub fn qwen2() -> Self {
40 Self {
41 inner: HfBpeTokenizer::new(HfBpeConfig::qwen2()),
42 pad_id: Qwen2BpeTokenizer::ENDOFTEXT_ID,
43 eos_id: Some(Qwen2BpeTokenizer::IM_END_ID),
44 bos_id: Some(Qwen2BpeTokenizer::IM_START_ID),
45 }
46 }
47
48 pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
53 let json = std::fs::read_to_string(path.as_ref())?;
54 Self::from_json(&json)
55 }
56
57 pub fn from_json(json: &str) -> Result<Self> {
62 let inner = load_hf_from_json(json).map_err(|e| {
63 TokenizerError::Serialization(format!("Failed to parse tokenizer JSON: {e}"))
64 })?;
65
66 let pad_id =
68 inner.token_to_id("<pad>").or_else(|| inner.token_to_id("<|endoftext|>")).unwrap_or(0);
69 let eos_id = inner
70 .token_to_id("</s>")
71 .or_else(|| inner.token_to_id("<|im_end|>"))
72 .or_else(|| inner.token_to_id("<|endoftext|>"));
73 let bos_id = inner.token_to_id("<s>").or_else(|| inner.token_to_id("<|im_start|>"));
74
75 Ok(Self { inner, pad_id, eos_id, bos_id })
76 }
77
78 #[must_use]
80 pub fn vocab_size(&self) -> usize {
81 self.inner.vocab_size()
82 }
83
84 #[must_use]
86 pub fn encode(&self, text: &str) -> Vec<u32> {
87 self.inner.encode(text)
88 }
89
90 #[must_use]
92 pub fn encode_with_special(&self, text: &str) -> Vec<u32> {
93 let mut tokens = Vec::new();
94 if let Some(bos) = self.bos_id {
95 tokens.push(bos);
96 }
97 tokens.extend(self.inner.encode(text));
98 if let Some(eos) = self.eos_id {
99 tokens.push(eos);
100 }
101 tokens
102 }
103
104 #[must_use]
106 pub fn decode(&self, ids: &[u32]) -> String {
107 self.inner.decode(ids)
108 }
109
110 #[must_use]
112 pub fn pad_id(&self) -> u32 {
113 self.pad_id
114 }
115
116 #[must_use]
118 pub fn eos_id(&self) -> Option<u32> {
119 self.eos_id
120 }
121
122 #[must_use]
124 pub fn bos_id(&self) -> Option<u32> {
125 self.bos_id
126 }
127
128 #[must_use]
130 pub fn batch_encode(&self, texts: &[&str], max_len: usize) -> Vec<Vec<u32>> {
131 let mut encoded: Vec<Vec<u32>> = texts
132 .iter()
133 .map(|text| {
134 let mut tokens = self.encode_with_special(text);
135 tokens.truncate(max_len);
136 tokens
137 })
138 .collect();
139
140 let batch_max = encoded.iter().map(Vec::len).max().unwrap_or(0);
141 let pad_to = batch_max.min(max_len);
142
143 for tokens in &mut encoded {
144 while tokens.len() < pad_to {
145 tokens.push(self.pad_id);
146 }
147 }
148
149 encoded
150 }
151
152 pub fn create_batches(
154 &self,
155 pairs: &[(&str, &str)],
156 max_len: usize,
157 batch_size: usize,
158 ) -> Vec<crate::train::Batch> {
159 use crate::Tensor;
160
161 pairs
162 .chunks(batch_size)
163 .map(|chunk| {
164 let inputs: Vec<&str> = chunk.iter().map(|(i, _)| *i).collect();
165 let targets: Vec<&str> = chunk.iter().map(|(_, t)| *t).collect();
166
167 let input_tokens = self.batch_encode(&inputs, max_len);
168 let target_tokens = self.batch_encode(&targets, max_len);
169
170 let input_data: Vec<f32> =
171 input_tokens.into_iter().flatten().map(|t| t as f32).collect();
172 let target_data: Vec<f32> =
173 target_tokens.into_iter().flatten().map(|t| t as f32).collect();
174
175 crate::train::Batch::new(
176 Tensor::from_vec(input_data, false),
177 Tensor::from_vec(target_data, false),
178 )
179 })
180 .collect()
181 }
182
183 pub fn create_causal_batches(
185 &self,
186 texts: &[&str],
187 max_len: usize,
188 batch_size: usize,
189 ) -> Vec<crate::train::Batch> {
190 use crate::Tensor;
191
192 texts
193 .chunks(batch_size)
194 .map(|chunk| {
195 let encoded = self.batch_encode(chunk, max_len);
196
197 let mut input_data: Vec<f32> = Vec::new();
198 let mut target_data: Vec<f32> = Vec::new();
199
200 for tokens in &encoded {
201 if tokens.len() > 1 {
202 input_data.extend(tokens[..tokens.len() - 1].iter().map(|&t| t as f32));
203 target_data.extend(tokens[1..].iter().map(|&t| t as f32));
204 }
205 }
206
207 crate::train::Batch::new(
208 Tensor::from_vec(input_data, false),
209 Tensor::from_vec(target_data, false),
210 )
211 })
212 .collect()
213 }
214}
215
216impl Default for HfTokenizer {
217 fn default() -> Self {
218 Self::gpt2()
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225
226 #[test]
227 fn test_hf_tokenizer_gpt2() {
228 let tokenizer = HfTokenizer::gpt2();
229 assert!(tokenizer.vocab_size() > 0);
230 assert_eq!(tokenizer.pad_id(), GPT2_VOCAB_SIZE);
231 }
232
233 #[test]
234 fn test_hf_tokenizer_qwen2() {
235 let tokenizer = HfTokenizer::qwen2();
236 assert_eq!(tokenizer.eos_id(), Some(Qwen2BpeTokenizer::IM_END_ID));
237 }
238
239 #[test]
240 fn test_hf_tokenizer_encode() {
241 let tokenizer = HfTokenizer::gpt2();
242 let tokens = tokenizer.encode("Hello");
243 assert!(!tokens.is_empty());
244 }
245
246 #[test]
247 fn test_hf_tokenizer_encode_with_special() {
248 let tokenizer = HfTokenizer::gpt2();
249 let tokens = tokenizer.encode_with_special("Hi");
250 assert!(tokens.last() == tokenizer.eos_id().as_ref());
251 }
252
253 #[test]
254 fn test_hf_tokenizer_batch_encode() {
255 let tokenizer = HfTokenizer::gpt2();
256 let texts = vec!["Hello", "Hi there"];
257 let encoded = tokenizer.batch_encode(&texts, 32);
258
259 assert_eq!(encoded.len(), 2);
260 assert_eq!(encoded[0].len(), encoded[1].len());
261 }
262
263 #[test]
264 fn test_hf_tokenizer_create_batches() {
265 let tokenizer = HfTokenizer::gpt2();
266 let pairs = vec![("Hello", "World"), ("How are", "you")];
267 let batches = tokenizer.create_batches(&pairs, 16, 2);
268
269 assert_eq!(batches.len(), 1);
270 assert!(!batches[0].inputs.is_empty());
271 }
272
273 #[test]
274 fn test_hf_tokenizer_create_causal_batches() {
275 let tokenizer = HfTokenizer::gpt2();
276 let texts = vec!["Hello world", "Test text"];
277 let batches = tokenizer.create_causal_batches(&texts, 16, 2);
278
279 assert_eq!(batches.len(), 1);
280 assert_eq!(batches[0].inputs.len(), batches[0].targets.len());
281 }
282
283 #[test]
284 fn test_hf_tokenizer_from_json() {
285 let json = r#"{
286 "model": {
287 "vocab": {"hello": 0, "world": 1, "<|endoftext|>": 2},
288 "merges": []
289 },
290 "added_tokens": []
291 }"#;
292
293 let result = HfTokenizer::from_json(json);
294 assert!(result.is_ok());
295 assert_eq!(result.expect("operation should succeed").vocab_size(), 3);
296 }
297
298 #[test]
299 fn test_hf_tokenizer_from_json_invalid() {
300 let result = HfTokenizer::from_json("invalid json");
301 assert!(result.is_err());
302 }
303}