Skip to main content

entrenar/tokenizer/
hf.rs

1//! HuggingFace tokenizer integration via aprender.
2
3use super::error::{Result, TokenizerError};
4
5/// GPT-2 vocabulary size (50257 BPE tokens, with ID 50256 used as both pad and EOS).
6const GPT2_VOCAB_SIZE: u32 = 50256;
7
8// Re-export aprender BPE types for HuggingFace compatibility
9pub use aprender::text::bpe::{
10    bytes_to_unicode, load_from_files as load_hf_from_files, load_from_json as load_hf_from_json,
11    BpeConfig as HfBpeConfig, BpeTokenizer as HfBpeTokenizer, MergeRule, Qwen2BpeTokenizer,
12};
13
14/// HuggingFace-compatible tokenizer wrapper
15///
16/// Wraps aprender's BPE tokenizer to provide training batch utilities.
17#[derive(Debug, Clone)]
18pub struct HfTokenizer {
19    inner: HfBpeTokenizer,
20    pad_id: u32,
21    eos_id: Option<u32>,
22    bos_id: Option<u32>,
23}
24
25impl HfTokenizer {
26    /// Create a GPT-2 tokenizer with base vocabulary
27    #[must_use]
28    pub fn gpt2() -> Self {
29        Self {
30            inner: HfBpeTokenizer::gpt2_base(),
31            pad_id: GPT2_VOCAB_SIZE,
32            eos_id: Some(GPT2_VOCAB_SIZE),
33            bos_id: None,
34        }
35    }
36
37    /// Create a Qwen2 tokenizer
38    #[must_use]
39    pub fn qwen2() -> Self {
40        Self {
41            inner: HfBpeTokenizer::new(HfBpeConfig::qwen2()),
42            pad_id: Qwen2BpeTokenizer::ENDOFTEXT_ID,
43            eos_id: Some(Qwen2BpeTokenizer::IM_END_ID),
44            bos_id: Some(Qwen2BpeTokenizer::IM_START_ID),
45        }
46    }
47
48    /// Load tokenizer from HuggingFace tokenizer.json file
49    ///
50    /// # Errors
51    /// Returns error if file cannot be read or parsed.
52    pub fn from_file(path: impl AsRef<std::path::Path>) -> Result<Self> {
53        let json = std::fs::read_to_string(path.as_ref())?;
54        Self::from_json(&json)
55    }
56
57    /// Load tokenizer from JSON string
58    ///
59    /// # Errors
60    /// Returns error if JSON parsing fails.
61    pub fn from_json(json: &str) -> Result<Self> {
62        let inner = load_hf_from_json(json).map_err(|e| {
63            TokenizerError::Serialization(format!("Failed to parse tokenizer JSON: {e}"))
64        })?;
65
66        // Detect special tokens from vocab
67        let pad_id =
68            inner.token_to_id("<pad>").or_else(|| inner.token_to_id("<|endoftext|>")).unwrap_or(0);
69        let eos_id = inner
70            .token_to_id("</s>")
71            .or_else(|| inner.token_to_id("<|im_end|>"))
72            .or_else(|| inner.token_to_id("<|endoftext|>"));
73        let bos_id = inner.token_to_id("<s>").or_else(|| inner.token_to_id("<|im_start|>"));
74
75        Ok(Self { inner, pad_id, eos_id, bos_id })
76    }
77
78    /// Get vocabulary size
79    #[must_use]
80    pub fn vocab_size(&self) -> usize {
81        self.inner.vocab_size()
82    }
83
84    /// Encode text to token IDs
85    #[must_use]
86    pub fn encode(&self, text: &str) -> Vec<u32> {
87        self.inner.encode(text)
88    }
89
90    /// Encode text with special tokens (BOS/EOS)
91    #[must_use]
92    pub fn encode_with_special(&self, text: &str) -> Vec<u32> {
93        let mut tokens = Vec::new();
94        if let Some(bos) = self.bos_id {
95            tokens.push(bos);
96        }
97        tokens.extend(self.inner.encode(text));
98        if let Some(eos) = self.eos_id {
99            tokens.push(eos);
100        }
101        tokens
102    }
103
104    /// Decode token IDs to text
105    #[must_use]
106    pub fn decode(&self, ids: &[u32]) -> String {
107        self.inner.decode(ids)
108    }
109
110    /// Get padding token ID
111    #[must_use]
112    pub fn pad_id(&self) -> u32 {
113        self.pad_id
114    }
115
116    /// Get EOS token ID
117    #[must_use]
118    pub fn eos_id(&self) -> Option<u32> {
119        self.eos_id
120    }
121
122    /// Get BOS token ID
123    #[must_use]
124    pub fn bos_id(&self) -> Option<u32> {
125        self.bos_id
126    }
127
128    /// Batch encode texts with padding
129    #[must_use]
130    pub fn batch_encode(&self, texts: &[&str], max_len: usize) -> Vec<Vec<u32>> {
131        let mut encoded: Vec<Vec<u32>> = texts
132            .iter()
133            .map(|text| {
134                let mut tokens = self.encode_with_special(text);
135                tokens.truncate(max_len);
136                tokens
137            })
138            .collect();
139
140        let batch_max = encoded.iter().map(Vec::len).max().unwrap_or(0);
141        let pad_to = batch_max.min(max_len);
142
143        for tokens in &mut encoded {
144            while tokens.len() < pad_to {
145                tokens.push(self.pad_id);
146            }
147        }
148
149        encoded
150    }
151
152    /// Create training batches from text pairs
153    pub fn create_batches(
154        &self,
155        pairs: &[(&str, &str)],
156        max_len: usize,
157        batch_size: usize,
158    ) -> Vec<crate::train::Batch> {
159        use crate::Tensor;
160
161        pairs
162            .chunks(batch_size)
163            .map(|chunk| {
164                let inputs: Vec<&str> = chunk.iter().map(|(i, _)| *i).collect();
165                let targets: Vec<&str> = chunk.iter().map(|(_, t)| *t).collect();
166
167                let input_tokens = self.batch_encode(&inputs, max_len);
168                let target_tokens = self.batch_encode(&targets, max_len);
169
170                let input_data: Vec<f32> =
171                    input_tokens.into_iter().flatten().map(|t| t as f32).collect();
172                let target_data: Vec<f32> =
173                    target_tokens.into_iter().flatten().map(|t| t as f32).collect();
174
175                crate::train::Batch::new(
176                    Tensor::from_vec(input_data, false),
177                    Tensor::from_vec(target_data, false),
178                )
179            })
180            .collect()
181    }
182
183    /// Create causal LM batches (target = shifted input)
184    pub fn create_causal_batches(
185        &self,
186        texts: &[&str],
187        max_len: usize,
188        batch_size: usize,
189    ) -> Vec<crate::train::Batch> {
190        use crate::Tensor;
191
192        texts
193            .chunks(batch_size)
194            .map(|chunk| {
195                let encoded = self.batch_encode(chunk, max_len);
196
197                let mut input_data: Vec<f32> = Vec::new();
198                let mut target_data: Vec<f32> = Vec::new();
199
200                for tokens in &encoded {
201                    if tokens.len() > 1 {
202                        input_data.extend(tokens[..tokens.len() - 1].iter().map(|&t| t as f32));
203                        target_data.extend(tokens[1..].iter().map(|&t| t as f32));
204                    }
205                }
206
207                crate::train::Batch::new(
208                    Tensor::from_vec(input_data, false),
209                    Tensor::from_vec(target_data, false),
210                )
211            })
212            .collect()
213    }
214}
215
216impl Default for HfTokenizer {
217    fn default() -> Self {
218        Self::gpt2()
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225
226    #[test]
227    fn test_hf_tokenizer_gpt2() {
228        let tokenizer = HfTokenizer::gpt2();
229        assert!(tokenizer.vocab_size() > 0);
230        assert_eq!(tokenizer.pad_id(), GPT2_VOCAB_SIZE);
231    }
232
233    #[test]
234    fn test_hf_tokenizer_qwen2() {
235        let tokenizer = HfTokenizer::qwen2();
236        assert_eq!(tokenizer.eos_id(), Some(Qwen2BpeTokenizer::IM_END_ID));
237    }
238
239    #[test]
240    fn test_hf_tokenizer_encode() {
241        let tokenizer = HfTokenizer::gpt2();
242        let tokens = tokenizer.encode("Hello");
243        assert!(!tokens.is_empty());
244    }
245
246    #[test]
247    fn test_hf_tokenizer_encode_with_special() {
248        let tokenizer = HfTokenizer::gpt2();
249        let tokens = tokenizer.encode_with_special("Hi");
250        assert!(tokens.last() == tokenizer.eos_id().as_ref());
251    }
252
253    #[test]
254    fn test_hf_tokenizer_batch_encode() {
255        let tokenizer = HfTokenizer::gpt2();
256        let texts = vec!["Hello", "Hi there"];
257        let encoded = tokenizer.batch_encode(&texts, 32);
258
259        assert_eq!(encoded.len(), 2);
260        assert_eq!(encoded[0].len(), encoded[1].len());
261    }
262
263    #[test]
264    fn test_hf_tokenizer_create_batches() {
265        let tokenizer = HfTokenizer::gpt2();
266        let pairs = vec![("Hello", "World"), ("How are", "you")];
267        let batches = tokenizer.create_batches(&pairs, 16, 2);
268
269        assert_eq!(batches.len(), 1);
270        assert!(!batches[0].inputs.is_empty());
271    }
272
273    #[test]
274    fn test_hf_tokenizer_create_causal_batches() {
275        let tokenizer = HfTokenizer::gpt2();
276        let texts = vec!["Hello world", "Test text"];
277        let batches = tokenizer.create_causal_batches(&texts, 16, 2);
278
279        assert_eq!(batches.len(), 1);
280        assert_eq!(batches[0].inputs.len(), batches[0].targets.len());
281    }
282
283    #[test]
284    fn test_hf_tokenizer_from_json() {
285        let json = r#"{
286            "model": {
287                "vocab": {"hello": 0, "world": 1, "<|endoftext|>": 2},
288                "merges": []
289            },
290            "added_tokens": []
291        }"#;
292
293        let result = HfTokenizer::from_json(json);
294        assert!(result.is_ok());
295        assert_eq!(result.expect("operation should succeed").vocab_size(), 3);
296    }
297
298    #[test]
299    fn test_hf_tokenizer_from_json_invalid() {
300        let result = HfTokenizer::from_json("invalid json");
301        assert!(result.is_err());
302    }
303}