1use std::collections::HashMap;
2
3use anyhow::{Error, Result};
4use tokenizers::{processors::template::TemplateProcessing, tokenizer::Tokenizer as HfTokenizer};
5use tracing::debug;
6
7use crate::{
8 chat_template::{
9 load_chat_template_from_file, ChatTemplateContentFormat, ChatTemplateParams,
10 ChatTemplateState,
11 },
12 traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait},
13};
14
15pub struct HuggingFaceTokenizer {
17 tokenizer: HfTokenizer,
18 special_tokens: SpecialTokens,
19 vocab: HashMap<String, TokenIdType>,
20 reverse_vocab: HashMap<TokenIdType, String>,
21 chat_template: ChatTemplateState,
22}
23
24impl HuggingFaceTokenizer {
25 pub fn from_file(file_path: &str) -> Result<Self> {
27 let path = std::path::Path::new(file_path);
29 let chat_template_path = path
30 .parent()
31 .and_then(crate::factory::discover_chat_template_in_dir);
32 Self::from_file_with_chat_template(file_path, chat_template_path.as_deref())
33 }
34
35 pub fn from_file_with_chat_template(
37 file_path: &str,
38 chat_template_path: Option<&str>,
39 ) -> Result<Self> {
40 let mut tokenizer = HfTokenizer::from_file(file_path)
41 .map_err(|e| Error::msg(format!("Failed to load tokenizer: {e}")))?;
42
43 let special_tokens = Self::extract_special_tokens(&tokenizer);
45
46 let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
49 .iter()
50 .map(|(token, &id)| (id, token.clone()))
51 .collect();
52
53 let (mut chat_template_str, add_bos_token, add_eos_token) =
56 Self::load_chat_template_and_config(file_path);
57
58 if let Some(template_path) = chat_template_path {
59 chat_template_str = load_chat_template_from_file(template_path)?;
60 }
61
62 let needs_eos = add_eos_token == Some(true);
65 let needs_bos = match add_bos_token {
66 Some(true) => true,
67 Some(false) => false,
68 None => needs_eos && Self::tokenizer_adds_special_tokens(&tokenizer),
70 };
71
72 if needs_bos || needs_eos {
73 if let Some(post_processor) =
74 Self::build_post_processor(needs_bos, needs_eos, &special_tokens, &vocab)
75 {
76 debug!(needs_bos, needs_eos, "Configured post_processor");
77 tokenizer.with_post_processor(Some(post_processor));
78 }
79 }
80
81 Ok(HuggingFaceTokenizer {
82 tokenizer,
83 special_tokens,
84 vocab,
85 reverse_vocab,
86 chat_template: ChatTemplateState::new(chat_template_str)?,
87 })
88 }
89
90 fn tokenizer_adds_special_tokens(tokenizer: &HfTokenizer) -> bool {
92 tokenizer
93 .encode("", true)
94 .map(|enc| !enc.get_ids().is_empty())
95 .unwrap_or(false)
96 }
97
98 fn build_post_processor(
101 add_bos_token: bool,
102 add_eos_token: bool,
103 special_tokens: &SpecialTokens,
104 vocab: &HashMap<String, TokenIdType>,
105 ) -> Option<TemplateProcessing> {
106 let mut template = String::with_capacity(32);
109 let mut tokens = Vec::with_capacity(2);
110
111 if add_bos_token {
112 let bos = special_tokens.bos_token.as_ref()?;
113 let bos_id = vocab.get(bos).copied()?;
114 template.push_str(bos);
115 template.push_str(":0 ");
116 tokens.push((bos.clone(), bos_id));
117 }
118
119 template.push_str("$A:0");
120
121 if add_eos_token {
122 let eos = special_tokens.eos_token.as_ref()?;
123 let eos_id = vocab.get(eos).copied()?;
124 template.push(' ');
125 template.push_str(eos);
126 template.push_str(":0");
127 tokens.push((eos.clone(), eos_id));
128 }
129
130 TemplateProcessing::builder()
131 .try_single(template.as_str())
132 .ok()?
133 .special_tokens(tokens)
134 .build()
135 .ok()
136 }
137
138 pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
140 let special_tokens = Self::extract_special_tokens(&tokenizer);
141 let vocab = tokenizer.get_vocab(true); let reverse_vocab: HashMap<TokenIdType, String> = vocab
143 .iter()
144 .map(|(token, &id)| (id, token.clone()))
145 .collect();
146
147 HuggingFaceTokenizer {
148 tokenizer,
149 special_tokens,
150 vocab,
151 reverse_vocab,
152 chat_template: ChatTemplateState::empty(),
153 }
154 }
155
156 fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
158 let vocab = tokenizer.get_vocab(true);
160
161 let find_token = |patterns: &[&str]| -> Option<String> {
162 for pattern in patterns {
163 if vocab.contains_key(*pattern) {
164 return Some((*pattern).to_string());
165 }
166 }
167 None
168 };
169
170 let additional_special_tokens: Vec<String> = tokenizer
172 .get_added_tokens_decoder()
173 .iter()
174 .filter(|(_id, token)| token.special) .map(|(_id, token)| token.content.clone())
176 .collect();
177
178 SpecialTokens {
179 bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
180 eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
181 unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
182 sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
183 pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
184 cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
185 mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
186 additional_special_tokens,
187 }
188 }
189
190 fn load_chat_template_and_config(
193 tokenizer_path: &str,
194 ) -> (Option<String>, Option<bool>, Option<bool>) {
195 (|| {
196 let path = std::path::Path::new(tokenizer_path);
197 let config_path = path.parent()?.join("tokenizer_config.json");
198
199 if !config_path.exists() {
200 return None;
201 }
202
203 let config_str = config_path.to_str()?;
204 let content = std::fs::read_to_string(&config_path).ok()?;
205 let config: serde_json::Value = serde_json::from_str(&content).ok()?;
206
207 let chat_template = super::chat_template::load_chat_template_from_config(config_str)
208 .ok()
209 .flatten();
210
211 let add_bos_token = config.get("add_bos_token").and_then(|v| v.as_bool());
212 let add_eos_token = config.get("add_eos_token").and_then(|v| v.as_bool());
213
214 Some((chat_template, add_bos_token, add_eos_token))
215 })()
216 .unwrap_or((None, None, None))
217 }
218}
219
220impl Encoder for HuggingFaceTokenizer {
221 fn encode(&self, input: &str, add_special_tokens: bool) -> Result<Encoding> {
222 self.tokenizer
223 .encode(input, add_special_tokens)
224 .map_err(|e| Error::msg(format!("Encoding failed: {e}")))
225 .map(|encoding| Encoding::Hf(Box::new(encoding)))
226 }
227
228 fn encode_batch(&self, inputs: &[&str], add_special_tokens: bool) -> Result<Vec<Encoding>> {
229 self.tokenizer
230 .encode_batch(inputs.to_vec(), add_special_tokens)
231 .map_err(|e| Error::msg(format!("Batch encoding failed: {e}")))
232 .map(|encodings| {
233 encodings
234 .into_iter()
235 .map(|e| Encoding::Hf(Box::new(e)))
236 .collect()
237 })
238 }
239}
240
241impl Decoder for HuggingFaceTokenizer {
242 fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
243 self.tokenizer
244 .decode(token_ids, skip_special_tokens)
245 .map_err(|e| Error::msg(format!("Decoding failed: {e}")))
246 }
247}
248
249impl TokenizerTrait for HuggingFaceTokenizer {
250 fn vocab_size(&self) -> usize {
251 self.tokenizer.get_vocab_size(false)
252 }
253
254 fn get_special_tokens(&self) -> &SpecialTokens {
255 &self.special_tokens
256 }
257
258 fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
259 self.vocab.get(token).copied()
260 }
261
262 fn id_to_token(&self, id: TokenIdType) -> Option<String> {
263 self.reverse_vocab.get(&id).cloned()
264 }
265
266 fn as_any(&self) -> &dyn std::any::Any {
267 self
268 }
269
270 fn apply_chat_template(
271 &self,
272 messages: &[serde_json::Value],
273 params: ChatTemplateParams,
274 ) -> Result<String> {
275 self.chat_template.apply(messages, params)
276 }
277
278 fn chat_template_content_format(&self) -> ChatTemplateContentFormat {
279 self.chat_template.content_format()
280 }
281
282 fn set_chat_template(&mut self, template: String) -> Result<()> {
283 self.chat_template.set(template)
284 }
285}