1use anyhow::Result;
2use std::path::Path;
3use tokenizers::Tokenizer as HfTokenizer;
4
5pub struct Tokenizer {
6 inner: HfTokenizer,
7 pad_token_id: u32,
8 bos_token_id: u32,
9 eos_token_id: u32,
10}
11
12impl Tokenizer {
13 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
14 let inner = HfTokenizer::from_file(path).map_err(|e| anyhow::anyhow!("{}", e))?;
15 let vocab_size = inner.get_vocab_size(true) as u32;
16 Ok(Self {
17 inner,
18 pad_token_id: 0,
19 bos_token_id: 1,
20 eos_token_id: 2.min(vocab_size - 1),
21 })
22 }
23
24 pub fn from_pretrained(identifier: &str) -> Result<Self> {
25 let api = hf_hub::api::sync::Api::new()?;
27 let repo = api.model(identifier.to_string());
28 let tokenizer_path = repo
29 .get("tokenizer.json")
30 .map_err(|e| anyhow::anyhow!("Failed to download tokenizer: {}", e))?;
31 Self::from_file(tokenizer_path)
32 }
33
34 pub fn encode(&self, text: &str, add_special_tokens: bool) -> Result<Vec<u32>> {
35 let encoding = self
36 .inner
37 .encode(text, add_special_tokens)
38 .map_err(|e| anyhow::anyhow!("{}", e))?;
39 Ok(encoding.get_ids().to_vec())
40 }
41
42 pub fn encode_batch(&self, texts: &[&str], add_special_tokens: bool) -> Result<Vec<Vec<u32>>> {
43 let encodings = self
44 .inner
45 .encode_batch(texts.to_vec(), add_special_tokens)
46 .map_err(|e| anyhow::anyhow!("{}", e))?;
47 Ok(encodings
48 .into_iter()
49 .map(|e| e.get_ids().to_vec())
50 .collect())
51 }
52
53 pub fn decode(&self, ids: &[u32], skip_special_tokens: bool) -> Result<String> {
54 self.inner
55 .decode(ids, skip_special_tokens)
56 .map_err(|e| anyhow::anyhow!("{}", e))
57 }
58
59 pub fn vocab_size(&self) -> usize {
60 self.inner.get_vocab_size(true)
61 }
62
63 pub fn pad_token_id(&self) -> u32 {
64 self.pad_token_id
65 }
66
67 pub fn bos_token_id(&self) -> u32 {
68 self.bos_token_id
69 }
70
71 pub fn eos_token_id(&self) -> u32 {
72 self.eos_token_id
73 }
74
75 pub fn set_pad_token_id(&mut self, id: u32) {
76 self.pad_token_id = id;
77 }
78
79 pub fn set_bos_token_id(&mut self, id: u32) {
80 self.bos_token_id = id;
81 }
82
83 pub fn set_eos_token_id(&mut self, id: u32) {
84 self.eos_token_id = id;
85 }
86}
87
88pub struct BPETrainer {
89 vocab_size: usize,
90 min_frequency: u32,
91 special_tokens: Vec<String>,
92}
93
94impl BPETrainer {
95 pub fn new(vocab_size: usize) -> Self {
96 Self {
97 vocab_size,
98 min_frequency: 2,
99 special_tokens: vec![
100 "<pad>".to_string(),
101 "<bos>".to_string(),
102 "<eos>".to_string(),
103 "<unk>".to_string(),
104 ],
105 }
106 }
107
108 pub fn with_min_frequency(mut self, freq: u32) -> Self {
109 self.min_frequency = freq;
110 self
111 }
112
113 pub fn with_special_tokens(mut self, tokens: Vec<String>) -> Self {
114 self.special_tokens = tokens;
115 self
116 }
117
118 pub fn train_from_files(&self, files: &[&str], output_path: &str) -> Result<Tokenizer> {
121 use std::io::BufRead;
122 use tokenizers::models::bpe::{BPE, BpeTrainerBuilder};
123 use tokenizers::pre_tokenizers::byte_level::ByteLevel;
124 use tokenizers::tokenizer::Trainer;
125
126 let special_tokens: Vec<tokenizers::AddedToken> = self
127 .special_tokens
128 .iter()
129 .map(|s| tokenizers::AddedToken::from(s.as_str(), true))
130 .collect();
131
132 let mut trainer = BpeTrainerBuilder::default()
133 .vocab_size(self.vocab_size)
134 .min_frequency(self.min_frequency as u64)
135 .special_tokens(special_tokens.clone())
136 .build();
137
138 let mut model = BPE::default();
139
140 const BATCH_SIZE: usize = 10000;
141
142 for file in files {
143 let reader = crate::io::open_file(file)?;
144 let mut batch = Vec::with_capacity(BATCH_SIZE);
145
146 for line in reader.lines() {
147 let line = line?;
148 if !line.is_empty() {
149 batch.push(line);
150 }
151
152 if batch.len() >= BATCH_SIZE {
153 trainer
154 .feed(batch.iter().map(|s| s.as_str()), |s| Ok(vec![s.to_owned()]))
155 .map_err(|e| anyhow::anyhow!("{}", e))?;
156 batch.clear();
157 }
158 }
159
160 if !batch.is_empty() {
162 trainer
163 .feed(batch.iter().map(|s| s.as_str()), |s| Ok(vec![s.to_owned()]))
164 .map_err(|e| anyhow::anyhow!("{}", e))?;
165 }
166 }
167
168 trainer
169 .train(&mut model)
170 .map_err(|e| anyhow::anyhow!("{}", e))?;
171
172 let mut tokenizer = HfTokenizer::new(model);
173 tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
174 tokenizer.add_special_tokens(&special_tokens);
175
176 tokenizer
177 .save(output_path, true)
178 .map_err(|e| anyhow::anyhow!("{}", e))?;
179
180 Tokenizer::from_file(output_path)
181 }
182
183 pub fn train_from_texts(&self, texts: &[&str], output_path: &str) -> Result<Tokenizer> {
184 use tokenizers::models::bpe::{BPE, BpeTrainerBuilder};
185 use tokenizers::pre_tokenizers::byte_level::ByteLevel;
186 use tokenizers::tokenizer::Trainer;
187
188 let special_tokens: Vec<tokenizers::AddedToken> = self
189 .special_tokens
190 .iter()
191 .map(|s| tokenizers::AddedToken::from(s.as_str(), true))
192 .collect();
193
194 let mut trainer = BpeTrainerBuilder::default()
195 .vocab_size(self.vocab_size)
196 .min_frequency(self.min_frequency as u64)
197 .special_tokens(special_tokens.clone())
198 .build();
199
200 let mut model = BPE::default();
201
202 trainer
203 .feed(texts.iter().copied(), |s| Ok(vec![s.to_owned()]))
204 .map_err(|e| anyhow::anyhow!("{}", e))?;
205
206 trainer
207 .train(&mut model)
208 .map_err(|e| anyhow::anyhow!("{}", e))?;
209
210 let mut tokenizer = HfTokenizer::new(model);
211 tokenizer.with_pre_tokenizer(Some(ByteLevel::default()));
212 tokenizer.add_special_tokens(&special_tokens);
213
214 tokenizer
215 .save(output_path, true)
216 .map_err(|e| anyhow::anyhow!("{}", e))?;
217
218 Tokenizer::from_file(output_path)
219 }
220}