use std::{fs::File, io::Read, path::Path, sync::Arc};
use anyhow::{Error, Result};
use tracing::debug;
use crate::{
hub::download_tokenizer_from_hf,
huggingface::HuggingFaceTokenizer,
tiktoken::{has_tiktoken_file, is_tiktoken_file, TiktokenTokenizer},
traits,
};
#[derive(Debug, Clone)]
pub enum TokenizerType {
HuggingFace(String),
Mock,
Tiktoken(String),
}
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
create_tokenizer_with_chat_template(file_path, None)
}
pub fn create_tokenizer_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> {
if file_path == "mock" || file_path == "test" {
return Ok(Arc::new(super::mock::MockTokenizer::new()));
}
let path = Path::new(file_path);
if !path.exists() {
return Err(Error::msg(format!("File not found: {file_path}")));
}
if path.is_dir() {
let tokenizer_json = path.join("tokenizer.json");
if tokenizer_json.exists() {
let final_chat_template =
resolve_and_log_chat_template(chat_template_path, path, file_path);
let tokenizer_path_str = tokenizer_json.to_str().ok_or_else(|| {
Error::msg(format!(
"Tokenizer path is not valid UTF-8: {tokenizer_json:?}"
))
})?;
return create_tokenizer_with_chat_template(
tokenizer_path_str,
final_chat_template.as_deref(),
);
}
if has_tiktoken_file(path) {
return Ok(Arc::new(TiktokenTokenizer::from_dir_with_chat_template(
path,
chat_template_path,
)?));
}
return Err(Error::msg(format!(
"Directory '{file_path}' does not contain a valid tokenizer file (tokenizer.json, tiktoken.model, *.tiktoken, or vocab.json)"
)));
}
let extension = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.map(|s| s.to_lowercase());
let result = match extension.as_deref() {
Some("json") => {
let tokenizer =
HuggingFaceTokenizer::from_file_with_chat_template(file_path, chat_template_path)?;
Ok(Arc::new(tokenizer) as Arc<dyn traits::Tokenizer>)
}
Some("model") | Some("tiktoken") => {
if is_tiktoken_file(path) {
Ok(Arc::new(TiktokenTokenizer::from_file_with_chat_template(
path,
chat_template_path,
)?) as Arc<dyn traits::Tokenizer>)
} else {
Err(Error::msg("SentencePiece models not yet supported"))
}
}
Some("gguf") => {
Err(Error::msg("GGUF format not yet supported"))
}
_ => {
auto_detect_tokenizer(file_path)
}
};
result
}
fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512]; let bytes_read = file.read(&mut buffer)?;
buffer.truncate(bytes_read);
if is_likely_json(&buffer) {
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
return Ok(Arc::new(tokenizer));
}
if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
return Err(Error::msg("GGUF format detected but not yet supported"));
}
if is_likely_sentencepiece(&buffer) {
return Err(Error::msg(
"SentencePiece model detected but not yet supported",
));
}
Err(Error::msg(format!(
"Unable to determine tokenizer type for file: {file_path}"
)))
}
fn is_likely_json(buffer: &[u8]) -> bool {
let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
&buffer[3..]
} else {
buffer
};
if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
*first_byte == b'{' || *first_byte == b'['
} else {
false
}
}
fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
if buffer.len() < 12 {
return false;
}
if buffer.starts_with(b"\x0a\x09") || buffer.starts_with(b"\x08\x00") {
return true;
}
let patterns: &[&[u8]] = &[b"<unk", b"<s>", b"</s>"];
for window in buffer.windows(4) {
for pattern in patterns {
if window.starts_with(pattern) {
return true;
}
}
}
false
}
pub fn discover_chat_template_in_dir(dir: &Path) -> Option<String> {
use std::fs;
let json_template_path = dir.join("chat_template.json");
if json_template_path.exists() {
return json_template_path.to_str().map(|s| s.to_string());
}
let jinja_path = dir.join("chat_template.jinja");
if jinja_path.exists() {
return jinja_path.to_str().map(|s| s.to_string());
}
if let Ok(entries) = fs::read_dir(dir) {
for entry in entries.flatten() {
if let Some(name) = entry.file_name().to_str() {
if name.ends_with(".jinja") && name != "chat_template.jinja" {
return entry.path().to_str().map(|s| s.to_string());
}
}
}
}
None
}
fn resolve_and_log_chat_template(
provided_path: Option<&str>,
discovery_dir: &Path,
model_name: &str,
) -> Option<String> {
let final_chat_template = provided_path
.map(|s| s.to_string())
.or_else(|| discover_chat_template_in_dir(discovery_dir));
match (&provided_path, &final_chat_template) {
(Some(provided), _) => {
debug!("Using provided chat template: {}", provided);
}
(None, Some(discovered)) => {
debug!(
"Auto-discovered chat template in '{}': {}",
discovery_dir.display(),
discovered
);
}
(None, None) => {
debug!(
"No chat template provided or discovered for model: {}",
model_name
);
}
}
final_chat_template
}
pub async fn create_tokenizer_async(
model_name_or_path: &str,
) -> Result<Arc<dyn traits::Tokenizer>> {
create_tokenizer_async_with_chat_template(model_name_or_path, None).await
}
fn is_likely_openai_model(name: &str) -> bool {
let bare = name.rsplit('/').next().unwrap_or(name);
if bare.starts_with("gpt-") && bare.as_bytes().get(4).is_some_and(|b| b.is_ascii_digit()) {
return true;
}
if bare.starts_with("chatgpt-") {
return true;
}
if bare.starts_with('o')
&& bare.as_bytes().get(1).is_some_and(|b| b.is_ascii_digit())
&& bare.as_bytes().get(2).is_none_or(|b| *b == b'-')
{
return true;
}
matches!(bare, "davinci" | "curie" | "babbage" | "ada")
|| bare.starts_with("text-davinci")
|| bare.starts_with("code-davinci")
|| bare.starts_with("text-curie")
|| bare.starts_with("text-babbage")
|| bare.starts_with("text-ada")
|| bare.starts_with("text-embedding-ada")
|| bare.starts_with("code-cushman")
}
pub async fn create_tokenizer_async_with_chat_template(
model_name_or_path: &str,
chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> {
let path = Path::new(model_name_or_path);
if path.exists() {
return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
}
if is_likely_openai_model(model_name_or_path) {
match TiktokenTokenizer::from_model_name(model_name_or_path) {
Ok(tokenizer) => return Ok(Arc::new(tokenizer)),
Err(e) => {
debug!(
"Tiktoken failed for '{}': {}, falling back to HuggingFace",
model_name_or_path, e
);
}
}
}
match download_tokenizer_from_hf(model_name_or_path).await {
Ok(cache_dir) => {
let tokenizer_path = cache_dir.join("tokenizer.json");
if tokenizer_path.exists() {
let final_chat_template = resolve_and_log_chat_template(
chat_template_path,
&cache_dir,
model_name_or_path,
);
let tokenizer_path_str = tokenizer_path.to_str().ok_or_else(|| {
Error::msg(format!(
"Tokenizer path is not valid UTF-8: {tokenizer_path:?}"
))
})?;
create_tokenizer_with_chat_template(
tokenizer_path_str,
final_chat_template.as_deref(),
)
} else if has_tiktoken_file(&cache_dir) {
Ok(Arc::new(TiktokenTokenizer::from_dir_with_chat_template(
&cache_dir,
chat_template_path,
)?))
} else {
let possible_files = ["tokenizer_config.json", "vocab.json"];
for file_name in &possible_files {
let file_path = cache_dir.join(file_name);
if file_path.exists() {
let final_chat_template = resolve_and_log_chat_template(
chat_template_path,
&cache_dir,
model_name_or_path,
);
let file_path_str = file_path.to_str().ok_or_else(|| {
Error::msg(format!("File path is not valid UTF-8: {file_path:?}"))
})?;
return create_tokenizer_with_chat_template(
file_path_str,
final_chat_template.as_deref(),
);
}
}
Err(Error::msg(format!(
"Downloaded model '{model_name_or_path}' but couldn't find a suitable tokenizer file"
)))
}
}
Err(e) => Err(Error::msg(format!(
"Failed to download tokenizer from HuggingFace: {e}"
))),
}
}
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
create_tokenizer_with_chat_template_blocking(model_name_or_path, None)
}
pub fn create_tokenizer_with_chat_template_blocking(
model_name_or_path: &str,
chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> {
let path = Path::new(model_name_or_path);
if path.exists() {
return create_tokenizer_with_chat_template(model_name_or_path, chat_template_path);
}
if is_likely_openai_model(model_name_or_path) {
match TiktokenTokenizer::from_model_name(model_name_or_path) {
Ok(tokenizer) => return Ok(Arc::new(tokenizer)),
Err(e) => {
debug!(
"Tiktoken failed for '{}': {}, falling back to HuggingFace",
model_name_or_path, e
);
}
}
}
if let Ok(handle) = tokio::runtime::Handle::try_current() {
tokio::task::block_in_place(|| {
handle.block_on(create_tokenizer_async_with_chat_template(
model_name_or_path,
chat_template_path,
))
})
} else {
let rt = tokio::runtime::Runtime::new()?;
rt.block_on(create_tokenizer_async_with_chat_template(
model_name_or_path,
chat_template_path,
))
}
}
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
let path = Path::new(file_path);
if !path.exists() {
return Err(Error::msg(format!("File not found: {file_path}")));
}
let extension = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.map(|s| s.to_lowercase());
match extension.as_deref() {
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
_ => {
use std::{fs::File, io::Read};
let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512];
let bytes_read = file.read(&mut buffer)?;
buffer.truncate(bytes_read);
if is_likely_json(&buffer) {
Ok(TokenizerType::HuggingFace(file_path.to_string()))
} else {
Err(Error::msg("Unknown tokenizer type"))
}
}
}
}
#[cfg(test)]
#[expect(
clippy::print_stdout,
reason = "diagnostic output in tests for CI skip messages and download results"
)]
mod tests {
use super::{
create_tokenizer, create_tokenizer_async, create_tokenizer_from_file, is_likely_json,
is_likely_openai_model,
};
#[test]
fn test_json_detection() {
assert!(is_likely_json(b"{\"test\": \"value\"}"));
assert!(is_likely_json(b" \n\t{\"test\": \"value\"}"));
assert!(is_likely_json(b"[1, 2, 3]"));
assert!(!is_likely_json(b"not json"));
assert!(!is_likely_json(b""));
}
#[test]
fn test_mock_tokenizer_creation() {
let tokenizer = create_tokenizer_from_file("mock").unwrap();
assert_eq!(tokenizer.vocab_size(), 14); }
#[test]
fn test_file_not_found() {
let result = create_tokenizer_from_file("/nonexistent/file.json");
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("File not found"));
}
}
#[test]
fn test_create_tiktoken_tokenizer() {
let tokenizer = create_tokenizer("gpt-4").unwrap();
assert!(tokenizer.vocab_size() > 0);
let text = "Hello, world!";
let encoding = tokenizer.encode(text, false).unwrap();
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap();
assert_eq!(decoded, text);
}
#[tokio::test]
async fn test_download_tokenizer_from_hf() {
if std::env::var("CI").is_ok() && std::env::var("HF_TOKEN").is_err() {
println!("Skipping HF download test in CI without HF_TOKEN");
return;
}
let result = create_tokenizer_async("bert-base-uncased").await;
match result {
Ok(tokenizer) => {
assert!(tokenizer.vocab_size() > 0);
println!("Successfully downloaded and created tokenizer");
}
Err(e) => {
println!("Download failed (this might be expected): {e}");
}
}
}
#[test]
fn test_is_likely_openai_model_positives() {
assert!(is_likely_openai_model("gpt-4"));
assert!(is_likely_openai_model("gpt-4o"));
assert!(is_likely_openai_model("gpt-4o-mini"));
assert!(is_likely_openai_model("gpt-4-turbo"));
assert!(is_likely_openai_model("gpt-4-32k"));
assert!(is_likely_openai_model("gpt-4.5-preview"));
assert!(is_likely_openai_model("gpt-3.5-turbo"));
assert!(is_likely_openai_model("gpt-3.5-turbo-16k"));
assert!(is_likely_openai_model("gpt-3.5-turbo-instruct"));
assert!(is_likely_openai_model("chatgpt-4o-latest"));
assert!(is_likely_openai_model("o1"));
assert!(is_likely_openai_model("o1-mini"));
assert!(is_likely_openai_model("o1-preview"));
assert!(is_likely_openai_model("o3"));
assert!(is_likely_openai_model("o3-mini"));
assert!(is_likely_openai_model("o3-pro"));
assert!(is_likely_openai_model("o4-mini"));
assert!(is_likely_openai_model("davinci"));
assert!(is_likely_openai_model("text-davinci-003"));
assert!(is_likely_openai_model("code-davinci-002"));
assert!(is_likely_openai_model("curie"));
assert!(is_likely_openai_model("text-curie-001"));
assert!(is_likely_openai_model("babbage"));
assert!(is_likely_openai_model("text-babbage-001"));
assert!(is_likely_openai_model("ada"));
assert!(is_likely_openai_model("text-ada-001"));
assert!(is_likely_openai_model("text-embedding-ada-002"));
assert!(is_likely_openai_model("code-cushman-001"));
assert!(is_likely_openai_model("openai/gpt-4"));
assert!(is_likely_openai_model("openai/o1-mini"));
assert!(is_likely_openai_model("openai/davinci"));
}
#[test]
fn test_is_likely_openai_model_negatives() {
assert!(!is_likely_openai_model("openai/gpt-oss-20b"));
assert!(!is_likely_openai_model("meta-llama/Llama-3-8B"));
assert!(!is_likely_openai_model("mistralai/Mistral-7B"));
assert!(!is_likely_openai_model("bert-base-uncased"));
assert!(!is_likely_openai_model("turbo-llama"));
assert!(!is_likely_openai_model("adapter-v2"));
assert!(!is_likely_openai_model("oracle-7b"));
assert!(!is_likely_openai_model("open-llama"));
}
}