use crate::Result;
use hf_hub::{api::sync::ApiBuilder, Repo, RepoType};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct FromPretrainedParameters {
pub revision: String,
pub user_agent: HashMap<String, String>,
pub token: Option<String>,
}
impl Default for FromPretrainedParameters {
fn default() -> Self {
Self {
revision: "main".into(),
user_agent: HashMap::new(),
token: None,
}
}
}
pub fn from_pretrained<S: AsRef<str>>(
identifier: S,
params: Option<FromPretrainedParameters>,
) -> Result<PathBuf> {
let identifier: String = identifier.as_ref().to_string();
let valid_chars = ['-', '_', '.', '/'];
let is_valid_char = |x: char| x.is_alphanumeric() || valid_chars.contains(&x);
let valid = identifier.chars().all(is_valid_char);
let valid_chars_stringified = valid_chars
.iter()
.fold(vec![], |mut buf, x| {
buf.push(format!("'{x}'"));
buf
})
.join(", "); if !valid {
return Err(format!(
"Model \"{identifier}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}"
)
.into());
}
let params = params.unwrap_or_default();
let revision = ¶ms.revision;
let valid_revision = revision.chars().all(is_valid_char);
if !valid_revision {
return Err(format!(
"Revision \"{revision}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}"
)
.into());
}
let mut builder = ApiBuilder::from_env();
if let Some(token) = params.token {
builder = builder.with_token(Some(token));
}
let api = builder.build()?;
let repo = Repo::with_revision(identifier, RepoType::Model, params.revision);
let api = api.repo(repo);
Ok(api.get("tokenizer.json")?)
}