use crate::Result;
use cached_path::CacheBuilder;
use itertools::Itertools;
use reqwest::{blocking::Client, header};
use std::borrow::Cow;
use std::collections::HashMap;
use std::path::PathBuf;
fn cache_dir() -> PathBuf {
if let Ok(path) = std::env::var("TOKENIZERS_CACHE") {
PathBuf::from(path)
} else {
let mut dir = dirs::cache_dir().unwrap_or_else(std::env::temp_dir);
dir.push("huggingface");
dir.push("tokenizers");
dir
}
}
fn ensure_cache_dir() -> std::io::Result<PathBuf> {
let dir = cache_dir();
std::fs::create_dir_all(&dir)?;
Ok(dir)
}
fn sanitize_user_agent(item: &str) -> Cow<str> {
let mut sanitized = Cow::Borrowed(item);
if sanitized.contains('/') {
sanitized = Cow::Owned(sanitized.replace('/', "-"));
}
if sanitized.contains(';') {
sanitized = Cow::Owned(sanitized.replace(';', "-"));
}
sanitized
}
const VERSION: &str = env!("CARGO_PKG_VERSION");
#[allow(unstable_name_collisions)]
fn user_agent(additional_info: HashMap<String, String>) -> String {
let additional_str: String = additional_info
.iter()
.map(|(k, v)| format!("{}/{}", sanitize_user_agent(k), sanitize_user_agent(v)))
.intersperse("; ".to_string())
.collect();
let user_agent = format!(
"tokenizers/{}{}",
VERSION,
if !additional_str.is_empty() {
format!("; {}", additional_str)
} else {
String::new()
}
);
user_agent
}
#[derive(Debug, Clone)]
pub struct FromPretrainedParameters {
pub revision: String,
pub user_agent: HashMap<String, String>,
pub auth_token: Option<String>,
}
impl Default for FromPretrainedParameters {
fn default() -> Self {
Self {
revision: "main".into(),
user_agent: HashMap::new(),
auth_token: None,
}
}
}
pub fn from_pretrained<S: AsRef<str>>(
identifier: S,
params: Option<FromPretrainedParameters>,
) -> Result<PathBuf> {
let params = params.unwrap_or_default();
let cache_dir = ensure_cache_dir()?;
let mut headers = header::HeaderMap::new();
if let Some(ref token) = params.auth_token {
headers.insert(
"Authorization",
header::HeaderValue::from_str(&format!("Bearer {}", token))?,
);
}
let client_builder = Client::builder()
.user_agent(user_agent(params.user_agent))
.default_headers(headers);
let cache = CacheBuilder::with_client_builder(client_builder)
.dir(cache_dir)
.build()?;
let url_to_download = format!(
"https://huggingface.co/{}/resolve/{}/tokenizer.json",
identifier.as_ref(),
params.revision,
);
match cache.cached_path(&url_to_download) {
Err(_) => Err(format!(
"Model \"{}\" on the Hub doesn't have a tokenizer",
identifier.as_ref()
)
.into()),
Ok(path) => Ok(path),
}
}