Skip to main content

hermes_llm/
tokenizer.rs

1use anyhow::Result;
2use std::path::Path;
3use tokenizers::Tokenizer as HfTokenizer;
4
5pub struct Tokenizer {
6    inner: HfTokenizer,
7    pad_token_id: u32,
8    bos_token_id: u32,
9    eos_token_id: u32,
10}
11
12impl Tokenizer {
13    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
14        let inner = HfTokenizer::from_file(path).map_err(|e| anyhow::anyhow!("{}", e))?;
15        let vocab_size = inner.get_vocab_size(true) as u32;
16        Ok(Self {
17            inner,
18            pad_token_id: 0,
19            bos_token_id: 1,
20            eos_token_id: 2.min(vocab_size - 1),
21        })
22    }
23
24    pub fn from_pretrained(identifier: &str) -> Result<Self> {
25        // Use hf-hub to download the tokenizer file (uses rustls, no openssl)
26        let api = hf_hub::api::sync::Api::new()?;
27        let repo = api.model(identifier.to_string());
28        let tokenizer_path = repo
29            .get("tokenizer.json")
30            .map_err(|e| anyhow::anyhow!("Failed to download tokenizer: {}", e))?;
31        Self::from_file(tokenizer_path)
32    }
33
34    pub fn encode(&self, text: &str, add_special_tokens: bool) -> Result<Vec<u32>> {
35        let encoding = self
36            .inner
37            .encode(text, add_special_tokens)
38            .map_err(|e| anyhow::anyhow!("{}", e))?;
39        Ok(encoding.get_ids().to_vec())
40    }
41
42    pub fn encode_batch(&self, texts: &[&str], add_special_tokens: bool) -> Result<Vec<Vec<u32>>> {
43        let encodings = self
44            .inner
45            .encode_batch(texts.to_vec(), add_special_tokens)
46            .map_err(|e| anyhow::anyhow!("{}", e))?;
47        Ok(encodings
48            .into_iter()
49            .map(|e| e.get_ids().to_vec())
50            .collect())
51    }
52
53    pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
54        self.inner
55            .decode(ids, skip_special_tokens)
56            .map_err(|e| anyhow::anyhow!("{}", e))
57    }
58
59    pub fn vocab_size(&self) -> usize {
60        self.inner.get_vocab_size(true)
61    }
62
63    pub fn pad_token_id(&self) -> u32 {
64        self.pad_token_id
65    }
66
67    pub fn bos_token_id(&self) -> u32 {
68        self.bos_token_id
69    }
70
71    pub fn eos_token_id(&self) -> u32 {
72        self.eos_token_id
73    }
74
75    pub fn set_pad_token_id(&mut self, id: u32) {
76        self.pad_token_id = id;
77    }
78
79    pub fn set_bos_token_id(&mut self, id: u32) {
80        self.bos_token_id = id;
81    }
82
83    pub fn set_eos_token_id(&mut self, id: u32) {
84        self.eos_token_id = id;
85    }
86}
87
88pub struct BPETrainer {
89    vocab_size: usize,
90    min_frequency: u32,
91    special_tokens: Vec<String>,
92}
93
94impl BPETrainer {
95    pub fn new(vocab_size: usize) -> Self {
96        Self {
97            vocab_size,
98            min_frequency: 2,
99            special_tokens: vec![
100                "<pad>".to_string(),
101                "<bos>".to_string(),
102                "<eos>".to_string(),
103                "<unk>".to_string(),
104            ],
105        }
106    }
107
108    pub fn with_min_frequency(mut self, freq: u32) -> Self {
109        self.min_frequency = freq;
110        self
111    }
112
113    pub fn with_special_tokens(mut self, tokens: Vec<String>) -> Self {
114        self.special_tokens = tokens;
115        self
116    }
117
118    /// Train tokenizer from files. Supports .gz and .zst/.zstd compressed files.
119    /// Uses streaming to avoid loading entire files into memory.
120    pub fn train_from_files(&self, files: &[&str], output_path: &str) -> Result<Tokenizer> {
121        use std::io::BufRead;
122        use tokenizers::models::bpe::{BPE, BpeTrainerBuilder};
123        use tokenizers::pre_tokenizers::byte_level::ByteLevel;
124        use tokenizers::tokenizer::Trainer;
125
126        let special_tokens: Vec<tokenizers::AddedToken> = self
127            .special_tokens
128            .iter()
129            .map(|s| tokenizers::AddedToken::from(s.as_str(), true))
130            .collect();
131
132        let mut trainer = BpeTrainerBuilder::default()
133            .vocab_size(self.vocab_size)
134            .min_frequency(self.min_frequency as u64)
135            .special_tokens(special_tokens.clone())
136            .build();
137
138        let mut model = BPE::default();
139
140        const BATCH_SIZE: usize = 10000;
141
142        for file in files {
143            let reader = crate::io::open_file(file)?;
144            let mut batch = Vec::with_capacity(BATCH_SIZE);
145
146            for line in reader.lines() {
147                let line = line?;
148                if !line.is_empty() {
149                    batch.push(line);
150                }
151
152                if batch.len() >= BATCH_SIZE {
153                    trainer
154                        .feed(batch.iter().map(|s| s.as_str()), |s| Ok(vec![s.to_owned()]))
155                        .map_err(|e| anyhow::anyhow!("{}", e))?;
156                    batch.clear();
157                }
158            }
159
160            // Feed remaining lines
161            if !batch.is_empty() {
162                trainer
163                    .feed(batch.iter().map(|s| s.as_str()), |s| Ok(vec![s.to_owned()]))
164                    .map_err(|e| anyhow::anyhow!("{}", e))?;
165            }
166        }
167
168        trainer
169            .train(&mut model)
170            .map_err(|e| anyhow::anyhow!("{}", e))?;
171
172        let mut tokenizer = HfTokenizer::new(model);
173        tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
174        tokenizer.add_special_tokens(&special_tokens);
175
176        tokenizer
177            .save(output_path, true)
178            .map_err(|e| anyhow::anyhow!("{}", e))?;
179
180        Tokenizer::from_file(output_path)
181    }
182
183    pub fn train_from_texts(&self, texts: &[&str], output_path: &str) -> Result<Tokenizer> {
184        use tokenizers::models::bpe::{BPE, BpeTrainerBuilder};
185        use tokenizers::pre_tokenizers::byte_level::ByteLevel;
186        use tokenizers::tokenizer::Trainer;
187
188        let special_tokens: Vec<tokenizers::AddedToken> = self
189            .special_tokens
190            .iter()
191            .map(|s| tokenizers::AddedToken::from(s.as_str(), true))
192            .collect();
193
194        let mut trainer = BpeTrainerBuilder::default()
195            .vocab_size(self.vocab_size)
196            .min_frequency(self.min_frequency as u64)
197            .special_tokens(special_tokens.clone())
198            .build();
199
200        let mut model = BPE::default();
201
202        trainer
203            .feed(texts.iter().copied(), |s| Ok(vec![s.to_owned()]))
204            .map_err(|e| anyhow::anyhow!("{}", e))?;
205
206        trainer
207            .train(&mut model)
208            .map_err(|e| anyhow::anyhow!("{}", e))?;
209
210        let mut tokenizer = HfTokenizer::new(model);
211        tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
212        tokenizer.add_special_tokens(&special_tokens);
213
214        tokenizer
215            .save(output_path, true)
216            .map_err(|e| anyhow::anyhow!("{}", e))?;
217
218        Tokenizer::from_file(output_path)
219    }
220}