use std::path::PathBuf;
use hf_hub::Repo;
use crate::hf::JsonLoadError;
use crate::serde::SerdeError;
use crate::Tokenizer;
#[derive(Debug)]
pub enum HubError {
ApiInit(hf_hub::api::sync::ApiError),
Download(hf_hub::api::sync::ApiError),
Load(JsonLoadError),
LoadBinary(SerdeError),
NotFound(String),
}
impl std::fmt::Display for HubError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HubError::ApiInit(e) => write!(f, "failed to initialize HuggingFace Hub API: {}", e),
HubError::Download(e) => write!(f, "failed to download tokenizer: {}", e),
HubError::Load(e) => write!(f, "failed to load tokenizer: {}", e),
HubError::LoadBinary(e) => write!(f, "failed to load .tkz tokenizer: {}", e),
HubError::NotFound(repo) => {
write!(f, "tokenizer not found in repository '{}'", repo)
}
}
}
}
impl std::error::Error for HubError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
HubError::ApiInit(e) => Some(e),
HubError::Download(e) => Some(e),
HubError::Load(e) => Some(e),
HubError::LoadBinary(e) => Some(e),
HubError::NotFound(_) => None,
}
}
}
impl From<JsonLoadError> for HubError {
fn from(e: JsonLoadError) -> Self {
HubError::Load(e)
}
}
#[derive(Debug, Clone, Default)]
pub struct FromPretrainedOptions {
pub revision: Option<String>,
pub cache_dir: Option<PathBuf>,
pub token: Option<String>,
}
impl FromPretrainedOptions {
pub fn revision(mut self, revision: impl Into<String>) -> Self {
self.revision = Some(revision.into());
self
}
pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
self.cache_dir = Some(path.into());
self
}
pub fn token(mut self, token: impl Into<String>) -> Self {
self.token = Some(token.into());
self
}
}
impl Tokenizer {
pub fn from_pretrained(repo_id: impl AsRef<str>) -> Result<Self, HubError> {
Self::from_pretrained_with_options(repo_id, FromPretrainedOptions::default())
}
pub fn from_pretrained_with_options(
repo_id: impl AsRef<str>,
options: FromPretrainedOptions,
) -> Result<Self, HubError> {
let repo_id = repo_id.as_ref();
let mut api_builder = hf_hub::api::sync::ApiBuilder::new();
if let Some(cache_dir) = options.cache_dir {
api_builder = api_builder.with_cache_dir(cache_dir);
}
if let Some(token) = options.token {
api_builder = api_builder.with_token(Some(token));
}
let api = api_builder.build().map_err(HubError::ApiInit)?;
let repo = if let Some(revision) = options.revision {
Repo::with_revision(repo_id.to_string(), hf_hub::RepoType::Model, revision)
} else {
Repo::model(repo_id.to_string())
};
let repo_api = api.repo(repo);
if let Ok(tkz_path) = repo_api.get("tokenizer.tkz") {
let mut tokenizer = Self::from_file(tkz_path).map_err(HubError::LoadBinary)?;
load_added_tokens_from_json(&mut tokenizer, &repo_api);
return Ok(tokenizer);
}
if let Some(tokiers_name) = tokiers_repo_name(repo_id) {
let tokiers_repo = Repo::model(format!("tokiers/{tokiers_name}"));
let tokiers_api = api.repo(tokiers_repo);
if let Ok(tkz_path) = tokiers_api.get("tokenizer.tkz") {
let mut tokenizer = Self::from_file(tkz_path).map_err(HubError::LoadBinary)?;
load_added_tokens_from_json(&mut tokenizer, &repo_api);
return Ok(tokenizer);
}
}
let tokenizer_path = repo_api.get("tokenizer.json").map_err(HubError::Download)?;
Self::from_json(tokenizer_path).map_err(HubError::Load)
}
}
fn load_added_tokens_from_json(tokenizer: &mut Tokenizer, repo_api: &hf_hub::api::sync::ApiRepo) {
let Ok(json_path) = repo_api.get("tokenizer.json") else { return };
let Ok(json_bytes) = std::fs::read(&json_path) else { return };
let Ok(data) = serde_json::from_slice::<serde_json::Value>(&json_bytes) else { return };
let Some(added) = data["added_tokens"].as_array() else { return };
let tokens: Vec<(crate::types::TokenId, Vec<u8>)> = added.iter().filter_map(|token| {
let id = token["id"].as_u64()? as crate::types::TokenId;
let content = token["content"].as_str()?;
if content.len() < 2 {
return None;
}
Some((id, content.as_bytes().to_vec()))
}).collect();
if !tokens.is_empty() {
tokenizer.set_added_tokens(&tokens);
}
let special: Vec<(String, crate::types::TokenId)> = added.iter().filter_map(|token| {
let special = token["special"].as_bool().unwrap_or(false);
if !special { return None; }
let id = token["id"].as_u64()? as crate::types::TokenId;
let content = token["content"].as_str()?;
Some((content.to_string(), id))
}).collect();
if !special.is_empty() {
tokenizer.set_special_tokens(special);
}
}
fn tokiers_repo_name(repo_id: &str) -> Option<&'static str> {
let key = repo_id.to_ascii_lowercase();
match key.as_str() {
"alibaba-nlp/gte-qwen2-7b-instruct" => Some("gte-Qwen2-7B-instruct"),
"baai/bge-base-en-v1.5" => Some("bge-base-en-v1.5"),
"baai/bge-en-icl" => Some("bge-en-icl"),
"baai/bge-large-en-v1.5" => Some("bge-large-en-v1.5"),
"baai/bge-small-en-v1.5" => Some("bge-small-en-v1.5"),
"cohere/cohere-embed-english-v3.0" => Some("Cohere-embed-english-v3.0"),
"cohere/cohere-embed-english-light-v3.0" => Some("Cohere-embed-english-light-v3.0"),
"cohere/cohere-embed-multilingual-v3.0" => Some("Cohere-embed-multilingual-v3.0"),
"cohere/cohere-embed-multilingual-light-v3.0" => Some("Cohere-embed-multilingual-light-v3.0"),
"intfloat/e5-small-v2" => Some("e5-small-v2"),
"intfloat/e5-base-v2" => Some("e5-base-v2"),
"intfloat/e5-large-v2" => Some("e5-large-v2"),
"jinaai/jina-embeddings-v2-base-en" => Some("jina-embeddings-v2-base-en"),
"jinaai/jina-embeddings-v2-base-code" => Some("jina-embeddings-v2-base-code"),
"jinaai/jina-embeddings-v3" => Some("jina-embeddings-v3"),
"jinaai/jina-embeddings-v4" => Some("jina-embeddings-v4"),
"mixedbread-ai/mxbai-embed-large-v1" => Some("mxbai-embed-large-v1"),
"mixedbread-ai/mxbai-embed-2d-large-v1" => Some("mxbai-embed-2d-large-v1"),
"mixedbread-ai/mxbai-embed-xsmall-v1" => Some("mxbai-embed-xsmall-v1"),
"mixedbread-ai/deepset-mxbai-embed-de-large-v1" => Some("deepset-mxbai-embed-de-large-v1"),
"nomic-ai/nomic-embed-text-v1" => Some("nomic-embed-text-v1"),
"qwen/qwen3-embedding-0.6b" => Some("Qwen3-Embedding-0.6B"),
"qwen/qwen3-embedding-4b" => Some("Qwen3-Embedding-4B"),
"qwen/qwen3-embedding-8b" => Some("Qwen3-Embedding-8B"),
"sentence-transformers/all-minilm-l6-v2" => Some("all-MiniLM-L6-v2"),
"sentence-transformers/all-minilm-l12-v2" => Some("all-MiniLM-L12-v2"),
"sentence-transformers/all-mpnet-base-v2" => Some("all-mpnet-base-v2"),
"thenlper/gte-small" => Some("gte-small"),
"thenlper/gte-base" => Some("gte-base"),
"thenlper/gte-large" => Some("gte-large"),
"voyageai/voyage-3" => Some("voyage-3"),
"voyageai/voyage-3-lite" => Some("voyage-3-lite"),
"voyageai/voyage-3-large" => Some("voyage-3-large"),
"voyageai/voyage-3.5" => Some("voyage-3.5"),
"voyageai/voyage-3.5-lite" => Some("voyage-3.5-lite"),
"voyageai/voyage-code-2" => Some("voyage-code-2"),
"voyageai/voyage-code-3" => Some("voyage-code-3"),
"voyageai/voyage-finance-2" => Some("voyage-finance-2"),
"voyageai/voyage-law-2" => Some("voyage-law-2"),
"voyageai/voyage-multilingual-2" => Some("voyage-multilingual-2"),
"voyageai/voyage-multimodal-3" => Some("voyage-multimodal-3"),
"cross-encoder/ms-marco-minilm-l-4-v2" => Some("ms-marco-MiniLM-L-4-v2"),
"cross-encoder/ms-marco-minilm-l-6-v2" => Some("ms-marco-MiniLM-L-6-v2"),
"bert-base-uncased" => Some("bert-base-uncased"),
"facebookai/roberta-base" => Some("roberta-base"),
"answerdotai/modernbert-base" => Some("ModernBERT-base"),
"openai-community/gpt2" => Some("gpt2"),
"xenova/gpt-4" => Some("cl100k"),
"xenova/gpt-4o" => Some("o200k"),
"meta-llama/llama-3.2-1b" => Some("Llama-3.2-1B"),
"meta-llama/llama-4-scout-17b-16e" => Some("Llama-4-Scout-17B-16E"),
"codellama/codellama-7b-hf" => Some("CodeLlama-7b-hf"),
"mistralai/mistral-7b-v0.1" => Some("Mistral-7B-v0.1"),
"mistralai/mistral-nemo-base-2407" => Some("Mistral-Nemo-Base-2407"),
"mistralai/mixtral-8x7b-v0.1" => Some("Mixtral-8x7B-v0.1"),
"microsoft/phi-2" => Some("phi-2"),
"microsoft/phi-3-mini-4k-instruct" => Some("Phi-3-mini-4k-instruct"),
"qwen/qwen2-7b" => Some("Qwen2-7B"),
"google-t5/t5-base" => Some("t5-base"),
"facebookai/xlm-roberta-base" => Some("xlm-roberta-base"),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokiers_repo_name() {
assert_eq!(tokiers_repo_name("BAAI/bge-base-en-v1.5"), Some("bge-base-en-v1.5"));
assert_eq!(tokiers_repo_name("baai/bge-base-en-v1.5"), Some("bge-base-en-v1.5"));
assert_eq!(tokiers_repo_name("sentence-transformers/all-MiniLM-L6-v2"), Some("all-MiniLM-L6-v2"));
assert_eq!(tokiers_repo_name("openai-community/gpt2"), Some("gpt2"));
assert_eq!(tokiers_repo_name("meta-llama/Llama-3.2-1B"), Some("Llama-3.2-1B"));
assert_eq!(tokiers_repo_name("some-random/model"), None);
}
#[test]
#[ignore] fn test_from_pretrained_gpt2() {
let tokenizer = Tokenizer::from_pretrained("gpt2").expect("Failed to load GPT-2");
let tokens = tokenizer.encode("Hello, world!", false);
assert!(!tokens.ids.is_empty());
let decoded = tokenizer.decode(&tokens.ids).unwrap();
assert_eq!(decoded, "Hello, world!");
}
#[test]
#[ignore] fn test_from_pretrained_with_revision() {
let tokenizer = Tokenizer::from_pretrained_with_options(
"gpt2",
FromPretrainedOptions::default().revision("main"),
)
.expect("Failed to load GPT-2");
let tokens = tokenizer.encode("Test", false);
assert!(!tokens.is_empty());
}
}