alith_models/local_model/
hf_loader.rs1use anyhow::{Result, anyhow};
3use dotenvy::dotenv;
4use hf_hub::api::sync::{Api, ApiBuilder};
5use std::{cell::OnceCell, path::PathBuf};
6
7const DEFAULT_ENV_VAR: &str = "HUGGING_FACE_TOKEN";
8
9#[derive(Clone)]
10pub struct HuggingFaceLoader {
11 pub hf_token: Option<String>,
12 pub hf_token_env_var: String,
13 pub hf_api: OnceCell<Api>,
14}
15
16impl Default for HuggingFaceLoader {
17 fn default() -> Self {
18 Self::new()
19 }
20}
21
22impl HuggingFaceLoader {
23 pub fn new() -> Self {
24 Self {
25 hf_token: None,
26 hf_token_env_var: DEFAULT_ENV_VAR.to_string(),
27 hf_api: OnceCell::new(),
28 }
29 }
30
31 #[inline]
32 pub fn hf_api(&self) -> &Api {
33 self.hf_api.get_or_init(|| {
34 ApiBuilder::from_env()
35 .with_progress(true)
36 .with_token(self.load_hf_token())
37 .build()
38 .expect("Failed to build Hugging Face API")
39 })
40 }
41
42 fn load_hf_token(&self) -> Option<String> {
43 if let Some(hf_token) = &self.hf_token {
44 return Some(hf_token.to_owned());
45 }
46
47 dotenv().ok();
48
49 match dotenvy::var(&self.hf_token_env_var) {
50 Ok(hf_token) => Some(hf_token),
51 Err(_) => {
52 eprintln!(
53 "{} not found in dotenv, nor was it set manually",
54 self.hf_token_env_var
55 );
56 None
57 }
58 }
59 }
60
61 #[inline]
62 pub fn load_file<T: AsRef<str>, S: Into<String>>(
63 &self,
64 file_name: T,
65 repo_id: S,
66 ) -> Result<PathBuf> {
67 self.hf_api()
68 .model(repo_id.into())
69 .get(file_name.as_ref())
70 .map_err(|e| anyhow!(e))
71 }
72
73 pub fn load_model_safe_tensors<S: Into<String>>(&self, repo_id: S) -> Result<Vec<PathBuf>> {
74 let repo_id = repo_id.into();
75 let mut safe_tensor_filenames = vec![];
76 let siblings = self.hf_api().model(repo_id.clone()).info()?.siblings;
77 for sib in siblings {
78 if sib.rfilename.ends_with(".safetensors") {
79 safe_tensor_filenames.push(sib.rfilename);
80 }
81 }
82 let mut safe_tensor_paths = vec![];
83 for safe_tensor_filename in &safe_tensor_filenames {
84 let safe_tensor_path = self
85 .hf_api()
86 .model(repo_id.clone())
87 .get(safe_tensor_filename)
88 .map_err(|e| anyhow!(e))?;
89 let safe_tensor_path = Self::canonicalize_local_path(safe_tensor_path)?;
90 safe_tensor_paths.push(safe_tensor_path);
91 }
92 Ok(safe_tensor_paths)
93 }
94
95 #[inline]
96 pub fn canonicalize_local_path(local_path: PathBuf) -> Result<PathBuf> {
97 local_path.canonicalize().map_err(|e| anyhow!(e))
98 }
99
100 pub fn parse_full_model_url(model_url: &str) -> (String, String, String) {
101 if !model_url.starts_with("https://huggingface.co") {
102 panic!(
103 "URL does not start with https://huggingface.co\n Format should be like: https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/blob/main/zephyr-7b-alpha.Q8_0.gguf"
104 );
105 } else if !model_url.ends_with(".gguf") {
106 panic!(
107 "URL does not end with .gguf\n Format should be like: https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/blob/main/zephyr-7b-alpha.Q8_0.gguf"
108 );
109 } else {
110 let parts: Vec<&str> = model_url.split('/').collect();
111 if parts.len() < 5 {
112 panic!(
113 "URL does not have enough parts\n Format should be like: https://huggingface.co/TheBloke/zephyr-7B-alpha-GGUF/blob/main/zephyr-7b-alpha.Q8_0.gguf"
114 );
115 }
116 let model_id = parts[4].to_string();
117 let repo_id = format!("{}/{}", parts[3], parts[4]);
118 let gguf_model_filename = parts.last().unwrap_or(&"").to_string();
119 (model_id, repo_id, gguf_model_filename)
120 }
121 }
122
123 pub fn model_url_from_repo_and_local_filename(
124 repo_id: &str,
125 local_model_filename: &str,
126 ) -> String {
127 let filename = std::path::Path::new(local_model_filename)
128 .file_name()
129 .and_then(|os_str| os_str.to_str())
130 .unwrap_or(local_model_filename);
131
132 format!("https://huggingface.co/{}/blob/main/{}", repo_id, filename)
133 }
134
135 pub fn model_url_from_repo(repo_id: &str) -> String {
136 format!("https://huggingface.co/{}", repo_id)
137 }
138
139 pub fn model_id_from_url(model_url: &str) -> String {
140 let parts = Self::parse_full_model_url(model_url);
141 parts.0
142 }
143}
144
145impl HfTokenTrait for HuggingFaceLoader {
146 fn hf_token_mut(&mut self) -> &mut Option<String> {
147 &mut self.hf_token
148 }
149
150 fn hf_token_env_var_mut(&mut self) -> &mut String {
151 &mut self.hf_token_env_var
152 }
153}
154
155pub trait HfTokenTrait {
156 fn hf_token_mut(&mut self) -> &mut Option<String>;
157
158 fn hf_token_env_var_mut(&mut self) -> &mut String;
159
160 fn hf_token<S: Into<String>>(mut self, hf_token: S) -> Self
162 where
163 Self: Sized,
164 {
165 *self.hf_token_mut() = Some(hf_token.into());
166 self
167 }
168
169 fn hf_token_env_var<S: Into<String>>(mut self, hf_token_env_var: S) -> Self
171 where
172 Self: Sized,
173 {
174 *self.hf_token_env_var_mut() = hf_token_env_var.into();
175 self
176 }
177}