tokenizers/utils/
from_pretrained.rs

1use crate::Result;
2use hf_hub_enfer::{api::sync::ApiBuilder, Repo, RepoType};
3use std::collections::HashMap;
4use std::path::PathBuf;
5
6/// Defines the aditional parameters available for the `from_pretrained` function
7#[derive(Debug, Clone)]
8pub struct FromPretrainedParameters {
9    pub revision: String,
10    pub user_agent: HashMap<String, String>,
11    pub token: Option<String>,
12}
13
14impl Default for FromPretrainedParameters {
15    fn default() -> Self {
16        Self {
17            revision: "main".into(),
18            user_agent: HashMap::new(),
19            token: None,
20        }
21    }
22}
23
24/// Downloads and cache the identified tokenizer if it exists on
25/// the Hugging Face Hub, and returns a local path to the file
26pub fn from_pretrained<S: AsRef<str>>(
27    identifier: S,
28    params: Option<FromPretrainedParameters>,
29) -> Result<PathBuf> {
30    let identifier: String = identifier.as_ref().to_string();
31
32    let valid_chars = ['-', '_', '.', '/'];
33    let is_valid_char = |x: char| x.is_alphanumeric() || valid_chars.contains(&x);
34
35    let valid = identifier.chars().all(is_valid_char);
36    let valid_chars_stringified = valid_chars
37        .iter()
38        .fold(vec![], |mut buf, x| {
39            buf.push(format!("'{}'", x));
40            buf
41        })
42        .join(", "); // "'/', '-', '_', '.'"
43    if !valid {
44        return Err(format!(
45            "Model \"{}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}",
46            identifier
47        )
48        .into());
49    }
50    let params = params.unwrap_or_default();
51
52    let revision = &params.revision;
53    let valid_revision = revision.chars().all(is_valid_char);
54    if !valid_revision {
55        return Err(format!(
56            "Revision \"{}\" contains invalid characters, expected only alphanumeric or {valid_chars_stringified}",
57            revision
58        )
59        .into());
60    }
61
62    let mut builder = ApiBuilder::new();
63    if let Some(token) = params.token {
64        builder = builder.with_token(Some(token));
65    }
66    let api = builder.build()?;
67    let repo = Repo::with_revision(identifier, RepoType::Model, params.revision);
68    let api = api.repo(repo);
69    Ok(api.get("tokenizer.json")?)
70}