1use std::{fs::File, io::Read, path::Path, sync::Arc};
2
3use anyhow::{Error, Result};
4use tracing::debug;
5
6use crate::{
7 hub::download_tokenizer_from_hf, huggingface::HuggingFaceTokenizer,
8 tiktoken::TiktokenTokenizer, traits,
9};
10
11#[derive(Debug, Clone)]
13pub enum TokenizerType {
14 HuggingFace(String),
15 Mock,
16 Tiktoken(String),
17 }
19
20pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
26 create_tokenizer_with_chat_template(file_path, None)
27}
28
29pub fn create_tokenizer_with_chat_template(
31 file_path: &str,
32 chat_template_path: Option<&str>,
33) -> Result<Arc<dyn traits::Tokenizer>> {
34 if file_path == "mock" || file_path == "test" {
36 return Ok(Arc::new(super::mock::MockTokenizer::new()));
37 }
38
39 let path = Path::new(file_path);
40
41 if !path.exists() {
43 return Err(Error::msg(format!("File not found: {}", file_path)));
44 }
45
46 if path.is_dir() {
48 let tokenizer_json = path.join("tokenizer.json");
49 if tokenizer_json.exists() {
50 let final_chat_template =
52 resolve_and_log_chat_template(chat_template_path, path, file_path);
53 let tokenizer_path_str = tokenizer_json.to_str().ok_or_else(|| {
54 Error::msg(format!(
55 "Tokenizer path is not valid UTF-8: {:?}",
56 tokenizer_json
57 ))
58 })?;
59 return create_tokenizer_with_chat_template(
60 tokenizer_path_str,
61 final_chat_template.as_deref(),
62 );
63 }
64
65 return Err(Error::msg(format!(
66 "Directory '{}' does not contain a valid tokenizer file (tokenizer.json, tokenizer_config.json, or vocab.json)",
67 file_path
68 )));
69 }
70
71 let extension = path
73 .extension()
74 .and_then(std::ffi::OsStr::to_str)
75 .map(|s| s.to_lowercase());
76
77 let result = match extension.as_deref() {
78 Some("json") => {
79 let tokenizer =
80 HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
81
82 Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
83 }
84 Some("model") => {
85 Err(Error::msg("SentencePiece models not yet supported"))
87 }
88 Some("gguf") => {
89 Err(Error::msg("GGUF format not yet supported"))
91 }
92 _ => {
93 auto_detect_tokenizer(file_path)
95 }
96 };
97
98 result
99}
100
101fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
103 let mut file = File::open(file_path)?;
104 let mut buffer = vec![0u8; 512]; let bytes_read = file.read(&mut buffer)?;
106 buffer.truncate(bytes_read);
107
108 if is_likely_json(&buffer) {
110 let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
111 return Ok(Arc::new(tokenizer));
112 }
113
114 if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
116 return Err(Error::msg("GGUF format detected but not yet supported"));
117 }
118
119 if is_likely_sentencepiece(&buffer) {
121 return Err(Error::msg(
122 "SentencePiece model detected but not yet supported",
123 ));
124 }
125
126 Err(Error::msg(format!(
127 "Unable to determine tokenizer type for file: {}",
128 file_path
129 )))
130}
131
132fn is_likely_json(buffer: &[u8]) -> bool {
134 let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
136 &buffer[3..]
137 } else {
138 buffer
139 };
140
141 if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
143 *first_byte == b'{' || *first_byte == b'['
144 } else {
145 false
146 }
147}
148
149fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
151 if buffer.len() < 12 {
154 return false;
155 }
156
157 if buffer.starts_with(b"\x0a\x09") || buffer.starts_with(b"\x08\x00") {
159 return true;
160 }
161
162 let patterns: &[&[u8]] = &[b"<unk", b"<s>", b"</s>"];
165 for window in buffer.windows(4) {
166 for pattern in patterns {
167 if window.starts_with(pattern) {
168 return true;
169 }
170 }
171 }
172 false
173}
174
175pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> {
177 use std::fs;
178
179 let json_template_path = dir.join("chat_template.json");
181 if json_template_path.exists() {
182 return json_template_path.to_str().map(|s| s.to_string());
183 }
184
185 let jinja_path = dir.join("chat_template.jinja");
187 if jinja_path.exists() {
188 return jinja_path.to_str().map(|s| s.to_string());
189 }
190
191 if let Ok(entries) = fs::read_dir(dir) {
193 for entry in entries.flatten() {
194 if let Some(name) = entry.file_name().to_str() {
195 if name.ends_with(".jinja") && name != "chat_template.jinja" {
196 return entry.path().to_str().map(|s| s.to_string());
197 }
198 }
199 }
200 }
201
202 None
203}
204
205fn resolve_and_log_chat_template(
210 provided_path: Option<&str>,
211 discovery_dir: &Path,
212 model_name: &str,
213) -> Option<String> {
214 let final_chat_template = provided_path
215 .map(|s| s.to_string())
216 .or_else(|| discover_chat_template_in_dir(discovery_dir));
217
218 match (&provided_path, &final_chat_template) {
219 (Some(provided), _) => {
220 debug!("Using provided chat template: {}", provided);
221 }
222 (None, Some(discovered)) => {
223 debug!(
224 "Auto-discovered chat template in '{}': {}",
225 discovery_dir.display(),
226 discovered
227 );
228 }
229 (None, None) => {
230 debug!(
231 "No chat template provided or discovered for model: {}",
232 model_name
233 );
234 }
235 }
236
237 final_chat_template
238}
239
240pub async fn create_tokenizer_async(
242 model_name_or_path: &str,
243) -> Result<Arc<dyn traits::Tokenizer>> {
244 create_tokenizer_async_with_chat_template(model_name_or_path, None).await
245}
246
247pub async fn create_tokenizer_async_with_chat_template(
249 model_name_or_path: &str,
250 chat_template_path: Option<&str>,
251) -> Result<Arc<dyn traits::Tokenizer>> {
252 let path = Path::new(model_name_or_path);
254 if path.exists() {
255 return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
256 }
257
258 if model_name_or_path.contains("gpt-4")
261 || model_name_or_path.contains("gpt-3.5")
262 || model_name_or_path.contains("gpt-3")
263 || model_name_or_path.contains("turbo")
264 || model_name_or_path.contains("davinci")
265 || model_name_or_path.contains("curie")
266 || model_name_or_path.contains("babbage")
267 || model_name_or_path.contains("ada")
268 || model_name_or_path.contains("codex")
269 {
270 match TiktokenTokenizer::from_model_name(model_name_or_path) {
272 Ok(tokenizer) => return Ok(Arc::new(tokenizer)),
273 Err(e) => {
274 debug!(
275 "Tiktoken failed for '{}': {}, falling back to HuggingFace",
276 model_name_or_path, e
277 );
278 }
279 }
280 }
281
282 match download_tokenizer_from_hf(model_name_or_path).await {
284 Ok(cache_dir) => {
285 let tokenizer_path = cache_dir.join("tokenizer.json");
287 if tokenizer_path.exists() {
288 let final_chat_template = resolve_and_log_chat_template(
290 chat_template_path,
291 &cache_dir,
292 model_name_or_path,
293 );
294
295 let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| {
296 Error::msg(format!(
297 "Tokenizer path is not valid UTF-8: {:?}",
298 tokenizer_path
299 ))
300 })?;
301 create_tokenizer_with_chat_template(
302 tokenizer_path_str,
303 final_chat_template.as_deref(),
304 )
305 } else {
306 let possible_files = ["tokenizer_config.json", "vocab.json"];
308 for file_name in &possible_files {
309 let file_path = cache_dir.join(file_name);
310 if file_path.exists() {
311 let final_chat_template = resolve_and_log_chat_template(
313 chat_template_path,
314 &cache_dir,
315 model_name_or_path,
316 );
317
318 let file_path_str = file_path.to_str().ok_or_else(|| {
319 Error::msg(format!("File path is not valid UTF-8: {:?}", file_path))
320 })?;
321 return create_tokenizer_with_chat_template(
322 file_path_str,
323 final_chat_template.as_deref(),
324 );
325 }
326 }
327 Err(Error::msg(format!(
328 "Downloaded model '{}' but couldn't find a suitable tokenizer file",
329 model_name_or_path
330 )))
331 }
332 }
333 Err(e) => Err(Error::msg(format!(
334 "Failed to download tokenizer from HuggingFace: {}",
335 e
336 ))),
337 }
338}
339
340pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
345 create_tokenizer_with_chat_template_blocking(model_name_or_path, None)
346}
347
348pub fn create_tokenizer_with_chat_template_blocking(
350 model_name_or_path: &str,
351 chat_template_path: Option<&str>,
352) -> Result<Arc<dyn traits::Tokenizer>> {
353 let path = Path::new(model_name_or_path);
355 if path.exists() {
356 return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
357 }
358
359 if model_name_or_path.contains("gpt-")
361 || model_name_or_path.contains("davinci")
362 || model_name_or_path.contains("curie")
363 || model_name_or_path.contains("babbage")
364 || model_name_or_path.contains("ada")
365 {
366 let tokenizer = TiktokenTokenizer::from_model_name(model_name_or_path)?;
367 return Ok(Arc::new(tokenizer));
368 }
369
370 if let Ok(handle) = tokio::runtime::Handle::try_current() {
373 tokio::task::block_in_place(|| {
375 handle.block_on(create_tokenizer_async_with_chat_template(
376 model_name_or_path,
377 chat_template_path,
378 ))
379 })
380 } else {
381 let rt = tokio::runtime::Runtime::new()?;
383 rt.block_on(create_tokenizer_async_with_chat_template(
384 model_name_or_path,
385 chat_template_path,
386 ))
387 }
388}
389
390pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
392 let path = Path::new(file_path);
393
394 if !path.exists() {
395 return Err(Error::msg(format!("File not found: {}", file_path)));
396 }
397
398 let extension = path
399 .extension()
400 .and_then(std::ffi::OsStr::to_str)
401 .map(|s| s.to_lowercase());
402
403 match extension.as_deref() {
404 Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
405 _ => {
406 use std::{fs::File, io::Read};
408
409 let mut file = File::open(file_path)?;
410 let mut buffer = vec![0u8; 512];
411 let bytes_read = file.read(&mut buffer)?;
412 buffer.truncate(bytes_read);
413
414 if is_likely_json(&buffer) {
415 Ok(TokenizerType::HuggingFace(file_path.to_string()))
416 } else {
417 Err(Error::msg("Unknown tokenizer type"))
418 }
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::{
426 create_tokenizer, create_tokenizer_async, create_tokenizer_from_file, is_likely_json,
427 };
428
429 #[test]
430 fn test_json_detection() {
431 assert!(is_likely_json(b"{\"test\": \"value\"}"));
432 assert!(is_likely_json(b" \n\t{\"test\": \"value\"}"));
433 assert!(is_likely_json(b"[1, 2, 3]"));
434 assert!(!is_likely_json(b"not json"));
435 assert!(!is_likely_json(b""));
436 }
437
438 #[test]
439 fn test_mock_tokenizer_creation() {
440 let tokenizer = create_tokenizer_from_file("mock").unwrap();
441 assert_eq!(tokenizer.vocab_size(), 14); }
443
444 #[test]
445 fn test_file_not_found() {
446 let result = create_tokenizer_from_file("/nonexistent/file.json");
447 assert!(result.is_err());
448 if let Err(e) = result {
449 assert!(e.to_string().contains("File not found"));
450 }
451 }
452
453 #[test]
454 fn test_create_tiktoken_tokenizer() {
455 let tokenizer = create_tokenizer("gpt-4").unwrap();
456 assert!(tokenizer.vocab_size() > 0);
457
458 let text = "Hello, world!";
459 let encoding = tokenizer.encode(text, false).unwrap();
460 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
461 assert_eq!(decoded, text);
462 }
463
464 #[tokio::test]
465 async fn test_download_tokenizer_from_hf() {
466 if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
468 println!("Skipping HF download test in CI without HF_TOKEN");
469 return;
470 }
471
472 let result = create_tokenizer_async("bert-base-uncased").await;
474
475 match result {
478 Ok(tokenizer) => {
479 assert!(tokenizer.vocab_size() > 0);
480 println!("Successfully downloaded and created tokenizer");
481 }
482 Err(e) => {
483 println!("Download failed (this might be expected): {}", e);
484 }
486 }
487 }
488}