1use std::{fs::File, io::Read, path::Path, sync::Arc};
2
3use anyhow::{Error, Result};
4use tracing::debug;
5
6use crate::{
7 hub::download_tokenizer_from_hf,
8 huggingface::HuggingFaceTokenizer,
9 tiktoken::{has_tiktoken_file, is_tiktoken_file, TiktokenTokenizer},
10 traits,
11};
12
13#[derive(Debug, Clone)]
15pub enum TokenizerType {
16 HuggingFace(String),
17 Mock,
18 Tiktoken(String),
19 }
21
22pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
28 create_tokenizer_with_chat_template(file_path, None)
29}
30
31pub fn create_tokenizer_with_chat_template(
33 file_path: &str,
34 chat_template_path: Option<&str>,
35) -> Result<Arc<dyn traits::Tokenizer>> {
36 if file_path == "mock" || file_path == "test" {
38 return Ok(Arc::new(super::mock::MockTokenizer::new()));
39 }
40
41 let path = Path::new(file_path);
42
43 if !path.exists() {
45 return Err(Error::msg(format!("File not found: {file_path}")));
46 }
47
48 if path.is_dir() {
50 let tokenizer_json = path.join("tokenizer.json");
51 if tokenizer_json.exists() {
52 let final_chat_template =
54 resolve_and_log_chat_template(chat_template_path, path, file_path);
55 let tokenizer_path_str = tokenizer_json.to_str().ok_or_else(|| {
56 Error::msg(format!(
57 "Tokenizer path is not valid UTF-8: {tokenizer_json:?}"
58 ))
59 })?;
60 return create_tokenizer_with_chat_template(
61 tokenizer_path_str,
62 final_chat_template.as_deref(),
63 );
64 }
65
66 if has_tiktoken_file(path) {
70 return Ok(Arc::new(TiktokenTokenizer::from_dir_with_chat_template(
71 path,
72 chat_template_path,
73 )?));
74 }
75
76 return Err(Error::msg(format!(
77 "Directory '{file_path}' does not contain a valid tokenizer file (tokenizer.json, tiktoken.model, *.tiktoken, or vocab.json)"
78 )));
79 }
80
81 let extension = path
83 .extension()
84 .and_then(std::ffi::OsStr::to_str)
85 .map(|s| s.to_lowercase());
86
87 let result = match extension.as_deref() {
88 Some("json") => {
89 let tokenizer =
90 HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
91
92 Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
93 }
94 Some("model") | Some("tiktoken") => {
95 if is_tiktoken_file(path) {
97 Ok(Arc::new(TiktokenTokenizer::from_file_with_chat_template(
98 path,
99 chat_template_path,
100 )?) as Arc<dyn traits::Tokenizer>)
101 } else {
102 Err(Error::msg("SentencePiece models not yet supported"))
103 }
104 }
105 Some("gguf") => {
106 Err(Error::msg("GGUF format not yet supported"))
108 }
109 _ => {
110 auto_detect_tokenizer(file_path)
112 }
113 };
114
115 result
116}
117
118fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
120 let mut file = File::open(file_path)?;
121 let mut buffer = vec![0u8; 512]; let bytes_read = file.read(&mut buffer)?;
123 buffer.truncate(bytes_read);
124
125 if is_likely_json(&buffer) {
127 let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
128 return Ok(Arc::new(tokenizer));
129 }
130
131 if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
133 return Err(Error::msg("GGUF format detected but not yet supported"));
134 }
135
136 if is_likely_sentencepiece(&buffer) {
138 return Err(Error::msg(
139 "SentencePiece model detected but not yet supported",
140 ));
141 }
142
143 Err(Error::msg(format!(
144 "Unable to determine tokenizer type for file: {file_path}"
145 )))
146}
147
148fn is_likely_json(buffer: &[u8]) -> bool {
150 let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
152 &buffer[3..]
153 } else {
154 buffer
155 };
156
157 if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
159 *first_byte == b'{' || *first_byte == b'['
160 } else {
161 false
162 }
163}
164
165fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
167 if buffer.len() < 12 {
170 return false;
171 }
172
173 if buffer.starts_with(b"\x0a\x09") || buffer.starts_with(b"\x08\x00") {
175 return true;
176 }
177
178 let patterns: &[&[u8]] = &[b"<unk", b"<s>", b"</s>"];
181 for window in buffer.windows(4) {
182 for pattern in patterns {
183 if window.starts_with(pattern) {
184 return true;
185 }
186 }
187 }
188 false
189}
190
191pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> {
193 use std::fs;
194
195 let json_template_path = dir.join("chat_template.json");
197 if json_template_path.exists() {
198 return json_template_path.to_str().map(|s| s.to_string());
199 }
200
201 let jinja_path = dir.join("chat_template.jinja");
203 if jinja_path.exists() {
204 return jinja_path.to_str().map(|s| s.to_string());
205 }
206
207 if let Ok(entries) = fs::read_dir(dir) {
209 for entry in entries.flatten() {
210 if let Some(name) = entry.file_name().to_str() {
211 if name.ends_with(".jinja") && name != "chat_template.jinja" {
212 return entry.path().to_str().map(|s| s.to_string());
213 }
214 }
215 }
216 }
217
218 None
219}
220
221fn resolve_and_log_chat_template(
226 provided_path: Option<&str>,
227 discovery_dir: &Path,
228 model_name: &str,
229) -> Option<String> {
230 let final_chat_template = provided_path
231 .map(|s| s.to_string())
232 .or_else(|| discover_chat_template_in_dir(discovery_dir));
233
234 match (&provided_path, &final_chat_template) {
235 (Some(provided), _) => {
236 debug!("Using provided chat template: {}", provided);
237 }
238 (None, Some(discovered)) => {
239 debug!(
240 "Auto-discovered chat template in '{}': {}",
241 discovery_dir.display(),
242 discovered
243 );
244 }
245 (None, None) => {
246 debug!(
247 "No chat template provided or discovered for model: {}",
248 model_name
249 );
250 }
251 }
252
253 final_chat_template
254}
255
256pub async fn create_tokenizer_async(
258 model_name_or_path: &str,
259) -> Result<Arc<dyn traits::Tokenizer>> {
260 create_tokenizer_async_with_chat_template(model_name_or_path, None).await
261}
262
263fn is_likely_openai_model(name: &str) -> bool {
279 let bare = name.rsplit('/').next().unwrap_or(name);
280
281 if bare.starts_with("gpt-") && bare.as_bytes().get(4).is_some_and(|b| b.is_ascii_digit()) {
284 return true;
285 }
286 if bare.starts_with("chatgpt-") {
287 return true;
288 }
289
290 if bare.starts_with('o')
293 && bare.as_bytes().get(1).is_some_and(|b| b.is_ascii_digit())
294 && bare.as_bytes().get(2).is_none_or(|b| *b == b'-')
295 {
296 return true;
297 }
298
299 matches!(bare, "davinci" | "curie" | "babbage" | "ada")
304 || bare.starts_with("text-davinci")
305 || bare.starts_with("code-davinci")
306 || bare.starts_with("text-curie")
307 || bare.starts_with("text-babbage")
308 || bare.starts_with("text-ada")
309 || bare.starts_with("text-embedding-ada")
310 || bare.starts_with("code-cushman")
311}
312
313pub async fn create_tokenizer_async_with_chat_template(
315 model_name_or_path: &str,
316 chat_template_path: Option<&str>,
317) -> Result<Arc<dyn traits::Tokenizer>> {
318 let path = Path::new(model_name_or_path);
320 if path.exists() {
321 return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
322 }
323
324 if is_likely_openai_model(model_name_or_path) {
326 match TiktokenTokenizer::from_model_name(model_name_or_path) {
328 Ok(tokenizer) => return Ok(Arc::new(tokenizer)),
329 Err(e) => {
330 debug!(
331 "Tiktoken failed for '{}': {}, falling back to HuggingFace",
332 model_name_or_path, e
333 );
334 }
335 }
336 }
337
338 match download_tokenizer_from_hf(model_name_or_path).await {
340 Ok(cache_dir) => {
341 let tokenizer_path = cache_dir.join("tokenizer.json");
343 if tokenizer_path.exists() {
344 let final_chat_template = resolve_and_log_chat_template(
346 chat_template_path,
347 &cache_dir,
348 model_name_or_path,
349 );
350
351 let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| {
352 Error::msg(format!(
353 "Tokenizer path is not valid UTF-8: {tokenizer_path:?}"
354 ))
355 })?;
356 create_tokenizer_with_chat_template(
357 tokenizer_path_str,
358 final_chat_template.as_deref(),
359 )
360 } else if has_tiktoken_file(&cache_dir) {
361 Ok(Arc::new(TiktokenTokenizer::from_dir_with_chat_template(
362 &cache_dir,
363 chat_template_path,
364 )?))
365 } else {
366 let possible_files = ["tokenizer_config.json", "vocab.json"];
368 for file_name in &possible_files {
369 let file_path = cache_dir.join(file_name);
370 if file_path.exists() {
371 let final_chat_template = resolve_and_log_chat_template(
373 chat_template_path,
374 &cache_dir,
375 model_name_or_path,
376 );
377
378 let file_path_str = file_path.to_str().ok_or_else(|| {
379 Error::msg(format!("File path is not valid UTF-8: {file_path:?}"))
380 })?;
381 return create_tokenizer_with_chat_template(
382 file_path_str,
383 final_chat_template.as_deref(),
384 );
385 }
386 }
387 Err(Error::msg(format!(
388 "Downloaded model '{model_name_or_path}' but couldn't find a suitable tokenizer file"
389 )))
390 }
391 }
392 Err(e) => Err(Error::msg(format!(
393 "Failed to download tokenizer from HuggingFace: {e}"
394 ))),
395 }
396}
397
398pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
403 create_tokenizer_with_chat_template_blocking(model_name_or_path, None)
404}
405
406pub fn create_tokenizer_with_chat_template_blocking(
408 model_name_or_path: &str,
409 chat_template_path: Option<&str>,
410) -> Result<Arc<dyn traits::Tokenizer>> {
411 let path = Path::new(model_name_or_path);
413 if path.exists() {
414 return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
415 }
416
417 if is_likely_openai_model(model_name_or_path) {
420 match TiktokenTokenizer::from_model_name(model_name_or_path) {
421 Ok(tokenizer) => return Ok(Arc::new(tokenizer)),
422 Err(e) => {
423 debug!(
424 "Tiktoken failed for '{}': {}, falling back to HuggingFace",
425 model_name_or_path, e
426 );
427 }
428 }
429 }
430
431 if let Ok(handle) = tokio::runtime::Handle::try_current() {
433 tokio::task::block_in_place(|| {
434 handle.block_on(create_tokenizer_async_with_chat_template(
435 model_name_or_path,
436 chat_template_path,
437 ))
438 })
439 } else {
440 let rt = tokio::runtime::Runtime::new()?;
441 rt.block_on(create_tokenizer_async_with_chat_template(
442 model_name_or_path,
443 chat_template_path,
444 ))
445 }
446}
447
448pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
450 let path = Path::new(file_path);
451
452 if !path.exists() {
453 return Err(Error::msg(format!("File not found: {file_path}")));
454 }
455
456 let extension = path
457 .extension()
458 .and_then(std::ffi::OsStr::to_str)
459 .map(|s| s.to_lowercase());
460
461 match extension.as_deref() {
462 Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
463 _ => {
464 use std::{fs::File, io::Read};
466
467 let mut file = File::open(file_path)?;
468 let mut buffer = vec![0u8; 512];
469 let bytes_read = file.read(&mut buffer)?;
470 buffer.truncate(bytes_read);
471
472 if is_likely_json(&buffer) {
473 Ok(TokenizerType::HuggingFace(file_path.to_string()))
474 } else {
475 Err(Error::msg("Unknown tokenizer type"))
476 }
477 }
478 }
479}
480
481#[cfg(test)]
482#[expect(
483 clippy::print_stdout,
484 reason = "diagnostic output in tests for CI skip messages and download results"
485)]
486mod tests {
487 use super::{
488 create_tokenizer, create_tokenizer_async, create_tokenizer_from_file, is_likely_json,
489 is_likely_openai_model,
490 };
491
492 #[test]
493 fn test_json_detection() {
494 assert!(is_likely_json(b"{\"test\": \"value\"}"));
495 assert!(is_likely_json(b" \n\t{\"test\": \"value\"}"));
496 assert!(is_likely_json(b"[1, 2, 3]"));
497 assert!(!is_likely_json(b"not json"));
498 assert!(!is_likely_json(b""));
499 }
500
501 #[test]
502 fn test_mock_tokenizer_creation() {
503 let tokenizer = create_tokenizer_from_file("mock").unwrap();
504 assert_eq!(tokenizer.vocab_size(), 14); }
506
507 #[test]
508 fn test_file_not_found() {
509 let result = create_tokenizer_from_file("/nonexistent/file.json");
510 assert!(result.is_err());
511 if let Err(e) = result {
512 assert!(e.to_string().contains("File not found"));
513 }
514 }
515
516 #[test]
517 fn test_create_tiktoken_tokenizer() {
518 let tokenizer = create_tokenizer("gpt-4").unwrap();
519 assert!(tokenizer.vocab_size() > 0);
520
521 let text = "Hello, world!";
522 let encoding = tokenizer.encode(text, false).unwrap();
523 let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
524 assert_eq!(decoded, text);
525 }
526
527 #[tokio::test]
528 async fn test_download_tokenizer_from_hf() {
529 if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
531 println!("Skipping HF download test in CI without HF_TOKEN");
532 return;
533 }
534
535 let result = create_tokenizer_async("bert-base-uncased").await;
537
538 match result {
541 Ok(tokenizer) => {
542 assert!(tokenizer.vocab_size() > 0);
543 println!("Successfully downloaded and created tokenizer");
544 }
545 Err(e) => {
546 println!("Download failed (this might be expected): {e}");
547 }
549 }
550 }
551
552 #[test]
553 fn test_is_likely_openai_model_positives() {
554 assert!(is_likely_openai_model("gpt-4"));
556 assert!(is_likely_openai_model("gpt-4o"));
557 assert!(is_likely_openai_model("gpt-4o-mini"));
558 assert!(is_likely_openai_model("gpt-4-turbo"));
559 assert!(is_likely_openai_model("gpt-4-32k"));
560 assert!(is_likely_openai_model("gpt-4.5-preview"));
561
562 assert!(is_likely_openai_model("gpt-3.5-turbo"));
564 assert!(is_likely_openai_model("gpt-3.5-turbo-16k"));
565 assert!(is_likely_openai_model("gpt-3.5-turbo-instruct"));
566
567 assert!(is_likely_openai_model("chatgpt-4o-latest"));
569
570 assert!(is_likely_openai_model("o1"));
572 assert!(is_likely_openai_model("o1-mini"));
573 assert!(is_likely_openai_model("o1-preview"));
574 assert!(is_likely_openai_model("o3"));
575 assert!(is_likely_openai_model("o3-mini"));
576 assert!(is_likely_openai_model("o3-pro"));
577 assert!(is_likely_openai_model("o4-mini"));
578
579 assert!(is_likely_openai_model("davinci"));
581 assert!(is_likely_openai_model("text-davinci-003"));
582 assert!(is_likely_openai_model("code-davinci-002"));
583 assert!(is_likely_openai_model("curie"));
584 assert!(is_likely_openai_model("text-curie-001"));
585 assert!(is_likely_openai_model("babbage"));
586 assert!(is_likely_openai_model("text-babbage-001"));
587 assert!(is_likely_openai_model("ada"));
588 assert!(is_likely_openai_model("text-ada-001"));
589 assert!(is_likely_openai_model("text-embedding-ada-002"));
590 assert!(is_likely_openai_model("code-cushman-001"));
591
592 assert!(is_likely_openai_model("openai/gpt-4"));
594 assert!(is_likely_openai_model("openai/o1-mini"));
595 assert!(is_likely_openai_model("openai/davinci"));
596 }
597
598 #[test]
599 fn test_is_likely_openai_model_negatives() {
600 assert!(!is_likely_openai_model("openai/gpt-oss-20b"));
602 assert!(!is_likely_openai_model("meta-llama/Llama-3-8B"));
603 assert!(!is_likely_openai_model("mistralai/Mistral-7B"));
604 assert!(!is_likely_openai_model("bert-base-uncased"));
605
606 assert!(!is_likely_openai_model("turbo-llama"));
608 assert!(!is_likely_openai_model("adapter-v2"));
609 assert!(!is_likely_openai_model("oracle-7b"));
610 assert!(!is_likely_openai_model("open-llama"));
611 }
612}