Skip to main content

alith_models/local_model/
hf_loader.rs

1//! Downloads to Path: "/root/.cache/huggingface/hub/
2use anyhow::{Result, anyhow};
3use dotenvy::dotenv;
4use hf_hub::api::sync::{Api, ApiBuilder};
5use std::{cell::OnceCell, path::PathBuf};
6
7const DEFAULT_ENV_VAR: &str = "HUGGING_FACE_TOKEN";
8
9#[derive(Clone)]
10pub struct HuggingFaceLoader {
11    pub hf_token: Option<String>,
12    pub hf_token_env_var: String,
13    pub hf_api: OnceCell<Api>,
14}
15
16impl Default for HuggingFaceLoader {
17    fn default() -> Self {
18        Self::new()
19    }
20}
21
22impl HuggingFaceLoader {
23    pub fn new() -> Self {
24        Self {
25            hf_token: None,
26            hf_token_env_var: DEFAULT_ENV_VAR.to_string(),
27            hf_api: OnceCell::new(),
28        }
29    }
30
31    #[inline]
32    pub fn hf_api(&self) -> &Api {
33        self.hf_api.get_or_init(|| {
34            ApiBuilder::from_env()
35                .with_progress(true)
36                .with_token(self.load_hf_token())
37                .build()
38                .expect("Failed to build Hugging Face API")
39        })
40    }
41
42    fn load_hf_token(&self) -> Option<String> {
43        if let Some(hf_token) = &self.hf_token {
44            return Some(hf_token.to_owned());
45        }
46
47        dotenv().ok();
48
49        match dotenvy::var(&self.hf_token_env_var) {
50            Ok(hf_token) => Some(hf_token),
51            Err(_) => {
52                eprintln!(
53                    "{} not found in dotenv, nor was it set manually",
54                    self.hf_token_env_var
55                );
56                None
57            }
58        }
59    }
60
61    #[inline]
62    pub fn load_file<T: AsRef<str>, S: Into<String>>(
63        &self,
64        file_name: T,
65        repo_id: S,
66    ) -> Result<PathBuf> {
67        self.hf_api()
68            .model(repo_id.into())
69            .get(file_name.as_ref())
70            .map_err(|e| anyhow!(e))
71    }
72
73    pub fn load_model_safe_tensors<S: Into<String>>(&self, repo_id: S) -> Result<Vec<PathBuf>> {
74        let repo_id = repo_id.into();
75        let mut safe_tensor_filenames = vec![];
76        let siblings = self.hf_api().model(repo_id.clone()).info()?.siblings;
77        for sib in siblings {
78            if sib.rfilename.ends_with(".safetensors") {
79                safe_tensor_filenames.push(sib.rfilename);
80            }
81        }
82        let mut safe_tensor_paths = vec![];
83        for safe_tensor_filename in &safe_tensor_filenames {
84            let safe_tensor_path = self
85                .hf_api()
86                .model(repo_id.clone())
87                .get(safe_tensor_filename)
88                .map_err(|e| anyhow!(e))?;
89            let safe_tensor_path = Self::canonicalize_local_path(safe_tensor_path)?;
90            safe_tensor_paths.push(safe_tensor_path);
91        }
92        Ok(safe_tensor_paths)
93    }
94
95    #[inline]
96    pub fn canonicalize_local_path(local_path: PathBuf) -> Result<PathBuf> {
97        local_path.canonicalize().map_err(|e| anyhow!(e))
98    }
99
100    pub fn parse_full_model_url(model_url: &str) -> (String, String, String) {
101        if !model_url.starts_with("https://huggingface.co") {
102            panic!(
103                "URL does not start with https://huggingface.co\n Format should be like: https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/blob/main/zephyr-7b-alpha.Q8_0.gguf"
104            );
105        } else if !model_url.ends_with(".gguf") {
106            panic!(
107                "URL does not end with .gguf\n Format should be like: https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/blob/main/zephyr-7b-alpha.Q8_0.gguf"
108            );
109        } else {
110            let parts: Vec<&str> = model_url.split('/').collect();
111            if parts.len() < 5 {
112                panic!(
113                    "URL does not have enough parts\n Format should be like: https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/blob/main/zephyr-7b-alpha.Q8_0.gguf"
114                );
115            }
116            let model_id = parts[4].to_string();
117            let repo_id = format!("{}/{}", parts[3], parts[4]);
118            let gguf_model_filename = parts.last().unwrap_or(&"").to_string();
119            (model_id, repo_id, gguf_model_filename)
120        }
121    }
122
123    pub fn model_url_from_repo_and_local_filename(
124        repo_id: &str,
125        local_model_filename: &str,
126    ) -> String {
127        let filename = std::path::Path::new(local_model_filename)
128            .file_name()
129            .and_then(|os_str| os_str.to_str())
130            .unwrap_or(local_model_filename);
131
132        format!("https://huggingface.co/{}/blob/main/{}", repo_id, filename)
133    }
134
135    pub fn model_url_from_repo(repo_id: &str) -> String {
136        format!("https://huggingface.co/{}", repo_id)
137    }
138
139    pub fn model_id_from_url(model_url: &str) -> String {
140        let parts = Self::parse_full_model_url(model_url);
141        parts.0
142    }
143}
144
145impl HfTokenTrait for HuggingFaceLoader {
146    fn hf_token_mut(&mut self) -> &mut Option<String> {
147        &mut self.hf_token
148    }
149
150    fn hf_token_env_var_mut(&mut self) -> &mut String {
151        &mut self.hf_token_env_var
152    }
153}
154
155pub trait HfTokenTrait {
156    fn hf_token_mut(&mut self) -> &mut Option<String>;
157
158    fn hf_token_env_var_mut(&mut self) -> &mut String;
159
160    /// Set the API key for the client. Otherwise it will attempt to load it from the .env file.
161    fn hf_token<S: Into<String>>(mut self, hf_token: S) -> Self
162    where
163        Self: Sized,
164    {
165        *self.hf_token_mut() = Some(hf_token.into());
166        self
167    }
168
169    /// Set the environment variable name for the API key. Default is "HUGGING_FACE_TOKEN".
170    fn hf_token_env_var<S: Into<String>>(mut self, hf_token_env_var: S) -> Self
171    where
172        Self: Sized,
173    {
174        *self.hf_token_env_var_mut() = hf_token_env_var.into();
175        self
176    }
177}