tokenizers/utils/
from_pretrained.rs1use crate::Result;
2use hf_hub_enfer::{api::sync::ApiBuilder, Repo, RepoType};
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6#[derive(Debug, Clone)]
8pub struct FromPretrainedParameters {
9 pub revision: String,
10 pub user_agent: HashMap<String, String>,
11 pub token: Option<String>,
12}
13
14impl Default for FromPretrainedParameters {
15 fn default() -> Self {
16 Self {
17 revision: "main".into(),
18 user_agent: HashMap::new(),
19 token: None,
20 }
21 }
22}
23
24pub fn from_pretrained<S: AsRef<str>>(
27 identifier: S,
28 params: Option<FromPretrainedParameters>,
29) -> Result<PathBuf> {
30 let identifier: String = identifier.as_ref().to_string();
31
32 let valid_chars = ['-', '_', '.', '/'];
33 let is_valid_char = |x: char| x.is_alphanumeric() || valid_chars.contains(&x);
34
35 let valid = identifier.chars().all(is_valid_char);
36 let valid_chars_stringified = valid_chars
37 .iter()
38 .fold(vec![], |mut buf, x| {
39 buf.push(format!("'{}'", x));
40 buf
41 })
42 .join(", "); if !valid {
44 return Err(format!(
45 "Model \"{}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}",
46 identifier
47 )
48 .into());
49 }
50 let params = params.unwrap_or_default();
51
52 let revision = ¶ms.revision;
53 let valid_revision = revision.chars().all(is_valid_char);
54 if !valid_revision {
55 return Err(format!(
56 "Revision \"{}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}",
57 revision
58 )
59 .into());
60 }
61
62 let mut builder = ApiBuilder::new();
63 if let Some(token) = params.token {
64 builder = builder.with_token(Some(token));
65 }
66 let api = builder.build()?;
67 let repo = Repo::with_revision(identifier, RepoType::Model, params.revision);
68 let api = api.repo(repo);
69 Ok(api.get("tokenizer.json")?)
70}