1#![deny(missing_docs)]
2#![cfg_attr(feature="ureq", doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md")))]
3#![cfg_attr(
4 not(feature = "ureq"),
5 doc = "Documentation is meant to be compiled with default features (at least ureq)"
6)]
7use std::io::Write;
8use std::path::PathBuf;
9
10#[cfg(any(feature = "tokio", feature = "ureq"))]
12pub mod api;
13
14const HF_HOME: &str = "HF_HOME";
15
16#[derive(Debug, Clone, Copy)]
18pub enum RepoType {
19 Model,
22 Dataset,
24 Space,
26}
27
28#[derive(Clone, Debug)]
30pub struct Cache {
31 path: PathBuf,
32}
33
34impl Cache {
35 pub fn new(path: PathBuf) -> Self {
37 Self { path }
38 }
39
40 pub fn from_env() -> Self {
43 match std::env::var(HF_HOME) {
44 Ok(home) => {
45 let mut path: PathBuf = home.into();
46 path.push("hub");
47 Self::new(path)
48 }
49 Err(_) => Self::default(),
50 }
51 }
52
53 pub fn path(&self) -> &PathBuf {
55 &self.path
56 }
57
58 pub fn token_path(&self) -> PathBuf {
60 let mut path = self.path.clone();
61 path.pop();
63 path.push("token");
64 path
65 }
66
67 pub fn token(&self) -> Option<String> {
70 let token_filename = self.token_path();
71 if token_filename.exists() {
72 log::info!("Using token file found {token_filename:?}");
73 }
74 match std::fs::read_to_string(token_filename) {
75 Ok(token_content) => {
76 let token_content = token_content.trim();
77 if token_content.is_empty() {
78 None
79 } else {
80 Some(token_content.to_string())
81 }
82 }
83 Err(_) => None,
84 }
85 }
86
87 pub fn repo(&self, repo: Repo) -> CacheRepo {
90 CacheRepo::new(self.clone(), repo)
91 }
92
93 pub fn model(&self, model_id: String) -> CacheRepo {
101 self.repo(Repo::new(model_id, RepoType::Model))
102 }
103
104 pub fn dataset(&self, model_id: String) -> CacheRepo {
112 self.repo(Repo::new(model_id, RepoType::Dataset))
113 }
114
115 pub fn space(&self, model_id: String) -> CacheRepo {
123 self.repo(Repo::new(model_id, RepoType::Space))
124 }
125}
126
127#[derive(Debug)]
129pub struct CacheRepo {
130 cache: Cache,
131 repo: Repo,
132}
133
134impl CacheRepo {
135 fn new(cache: Cache, repo: Repo) -> Self {
136 Self { cache, repo }
137 }
138
139 pub fn get(&self, filename: &str) -> Option<PathBuf> {
142 let commit_path = self.ref_path();
143 let commit_hash = std::fs::read_to_string(commit_path).ok()?;
144 let mut pointer_path = self.pointer_path(&commit_hash);
145 pointer_path.push(filename);
146 if pointer_path.exists() {
147 Some(pointer_path)
148 } else {
149 None
150 }
151 }
152
153 fn path(&self) -> PathBuf {
154 let mut ref_path = self.cache.path.clone();
155 ref_path.push(self.repo.folder_name());
156 ref_path
157 }
158
159 fn ref_path(&self) -> PathBuf {
160 let mut ref_path = self.path();
161 ref_path.push("refs");
162 ref_path.push(self.repo.revision());
163 ref_path
164 }
165
166 pub fn create_ref(&self, commit_hash: &str) -> Result<(), std::io::Error> {
169 let ref_path = self.ref_path();
170 std::fs::create_dir_all(ref_path.parent().unwrap())?;
172 let mut file = std::fs::OpenOptions::new()
173 .write(true)
174 .create(true)
175 .truncate(true)
176 .open(&ref_path)?;
177 file.write_all(commit_hash.trim().as_bytes())?;
178 Ok(())
179 }
180
181 #[cfg(any(feature = "tokio", feature = "ureq"))]
183 pub fn blob_path(&self, etag: &str) -> PathBuf {
184 let mut blob_path = self.path();
185 blob_path.push("blobs");
186 blob_path.push(etag);
187 blob_path
188 }
189
190
191 pub fn pointer_path(&self, commit_hash: &str) -> PathBuf {
195 let mut pointer_path = self.path();
196 pointer_path.push("snapshots");
197 pointer_path.push(commit_hash);
198 pointer_path
199 }
200}
201
202impl Default for Cache {
203 fn default() -> Self {
204 let mut path = dirs::home_dir().expect("Cache directory cannot be found");
205 path.push(".cache");
206 path.push("huggingface");
207 path.push("hub");
208 Self::new(path)
209 }
210}
211
212#[derive(Clone, Debug)]
214pub struct Repo {
215 repo_id: String,
216 repo_type: RepoType,
217 revision: String,
218}
219
220impl Repo {
221 pub fn new(repo_id: String, repo_type: RepoType) -> Self {
223 Self::with_revision(repo_id, repo_type, "main".to_string())
224 }
225
226 pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self {
228 Self {
229 repo_id,
230 repo_type,
231 revision,
232 }
233 }
234
235 pub fn model(repo_id: String) -> Self {
237 Self::new(repo_id, RepoType::Model)
238 }
239
240 pub fn dataset(repo_id: String) -> Self {
242 Self::new(repo_id, RepoType::Dataset)
243 }
244
245 pub fn space(repo_id: String) -> Self {
247 Self::new(repo_id, RepoType::Space)
248 }
249
250 pub fn folder_name(&self) -> String {
252 let prefix = match self.repo_type {
253 RepoType::Model => "models",
254 RepoType::Dataset => "datasets",
255 RepoType::Space => "spaces",
256 };
257 format!("{prefix}--{}", self.repo_id).replace('/', "--")
258 }
259
260 pub fn revision(&self) -> &str {
262 &self.revision
263 }
264
265 #[cfg(any(feature = "tokio", feature = "ureq"))]
267 pub fn url(&self) -> String {
268 match self.repo_type {
269 RepoType::Model => self.repo_id.to_string(),
270 RepoType::Dataset => {
271 format!("datasets/{}", self.repo_id)
272 }
273 RepoType::Space => {
274 format!("spaces/{}", self.repo_id)
275 }
276 }
277 }
278
279 #[cfg(any(feature = "tokio", feature = "ureq"))]
281 pub fn url_revision(&self) -> String {
282 self.revision.replace('/', "%2F")
283 }
284
285 #[cfg(any(feature = "tokio", feature = "ureq"))]
287 pub fn api_url(&self) -> String {
288 let prefix = match self.repo_type {
289 RepoType::Model => "models",
290 RepoType::Dataset => "datasets",
291 RepoType::Space => "spaces",
292 };
293 format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision())
294 }
295}
296
297#[cfg(test)]
298mod tests {
299 use super::*;
300
301 #[macro_export]
304 macro_rules! assert_no_diff {
305 ($left: expr, $right: expr) => {
306 let left = serde_json::to_string_pretty(&$left).unwrap();
307 let right = serde_json::to_string_pretty(&$right).unwrap();
308 if left != right {
309 use rand::Rng;
310 use std::io::Write;
311 use std::process::Command;
312 let rand_string: String = rand::rng()
313 .sample_iter(&rand::distr::Alphanumeric)
314 .take(6)
315 .map(char::from)
316 .collect();
317 let left_filename = format!("/tmp/left-{rand_string}.txt");
318 let mut file = std::fs::File::create(&left_filename).unwrap();
319 file.write_all(left.as_bytes()).unwrap();
320 let right_filename = format!("/tmp/right-{rand_string}.txt");
321 let mut file = std::fs::File::create(&right_filename).unwrap();
322 file.write_all(right.as_bytes()).unwrap();
323 let output = Command::new("diff")
324 .args(["-U5", &right_filename, &left_filename])
326 .output()
327 .expect("Failed to diff")
328 .stdout;
329 let diff = String::from_utf8(output).expect("Invalid utf-8 diff output");
330 assert!(false, "{diff}")
332 };
333 };
334 }
335
336 #[test]
337 #[cfg(not(target_os = "windows"))]
338 fn token_path() {
339 let cache = Cache::from_env();
340 let token_path = cache.token_path().to_str().unwrap().to_string();
341 if let Ok(hf_home) = std::env::var(HF_HOME) {
342 assert_eq!(token_path, format!("{hf_home}/token"));
343 } else {
344 let n = "huggingface/token".len();
345 assert_eq!(&token_path[token_path.len() - n..], "huggingface/token");
346 }
347 }
348
349 #[test]
350 #[cfg(target_os = "windows")]
351 fn token_path() {
352 let cache = Cache::from_env();
353 let token_path = cache.token_path().to_str().unwrap().to_string();
354 if let Ok(hf_home) = std::env::var(HF_HOME) {
355 assert_eq!(token_path, format!("{hf_home}\\token"));
356 } else {
357 let n = "huggingface/token".len();
358 assert_eq!(&token_path[token_path.len() - n..], "huggingface\\token");
359 }
360 }
361}