Skip to main content

hf_hub/
lib.rs

1#![deny(missing_docs)]
2#![cfg_attr(feature="ureq", doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md")))]
3#![cfg_attr(
4    not(feature = "ureq"),
5    doc = "Documentation is meant to be compiled with default features (at least ureq)"
6)]
7use std::io::Write;
8use std::path::PathBuf;
9
10/// The actual Api to interact with the hub.
11#[cfg(any(feature = "tokio", feature = "ureq"))]
12pub mod api;
13
14const HF_HOME: &str = "HF_HOME";
15
16/// The type of repo to interact with
17#[derive(Debug, Clone, Copy)]
18pub enum RepoType {
19    /// This is a model, usually it consists of weight files and some configuration
20    /// files
21    Model,
22    /// This is a dataset, usually contains data within parquet files
23    Dataset,
24    /// This is a space, usually a demo showcashing a given model or dataset
25    Space,
26}
27
28/// A local struct used to fetch information from the cache folder.
29#[derive(Clone, Debug)]
30pub struct Cache {
31    path: PathBuf,
32}
33
34impl Cache {
35    /// Creates a new cache object location
36    pub fn new(path: PathBuf) -> Self {
37        Self { path }
38    }
39
40    /// Creates cache from environment variable HF_HOME (if defined) otherwise
41    /// defaults to [`home_dir`]/.cache/huggingface/
42    pub fn from_env() -> Self {
43        match std::env::var(HF_HOME) {
44            Ok(home) => {
45                let mut path: PathBuf = home.into();
46                path.push("hub");
47                Self::new(path)
48            }
49            Err(_) => Self::default(),
50        }
51    }
52
53    /// Creates a new cache object location
54    pub fn path(&self) -> &PathBuf {
55        &self.path
56    }
57
58    /// Returns the location of the token file
59    pub fn token_path(&self) -> PathBuf {
60        let mut path = self.path.clone();
61        // Remove `"hub"`
62        path.pop();
63        path.push("token");
64        path
65    }
66
67    /// Returns the token value if it exists in the cache
68    /// Use `huggingface-cli login` to set it up.
69    pub fn token(&self) -> Option<String> {
70        let token_filename = self.token_path();
71        if token_filename.exists() {
72            log::info!("Using token file found {token_filename:?}");
73        }
74        match std::fs::read_to_string(token_filename) {
75            Ok(token_content) => {
76                let token_content = token_content.trim();
77                if token_content.is_empty() {
78                    None
79                } else {
80                    Some(token_content.to_string())
81                }
82            }
83            Err(_) => None,
84        }
85    }
86
87    /// Creates a new handle [`CacheRepo`] which contains operations
88    /// on a particular [`Repo`]
89    pub fn repo(&self, repo: Repo) -> CacheRepo {
90        CacheRepo::new(self.clone(), repo)
91    }
92
93    /// Simple wrapper over
94    /// ```
95    /// # use hf_hub::{Cache, Repo, RepoType};
96    /// # let model_id = "gpt2".to_string();
97    /// let cache = Cache::new("/tmp/".into());
98    /// let cache = cache.repo(Repo::new(model_id, RepoType::Model));
99    /// ```
100    pub fn model(&self, model_id: String) -> CacheRepo {
101        self.repo(Repo::new(model_id, RepoType::Model))
102    }
103
104    /// Simple wrapper over
105    /// ```
106    /// # use hf_hub::{Cache, Repo, RepoType};
107    /// # let model_id = "gpt2".to_string();
108    /// let cache = Cache::new("/tmp/".into());
109    /// let cache = cache.repo(Repo::new(model_id, RepoType::Dataset));
110    /// ```
111    pub fn dataset(&self, model_id: String) -> CacheRepo {
112        self.repo(Repo::new(model_id, RepoType::Dataset))
113    }
114
115    /// Simple wrapper over
116    /// ```
117    /// # use hf_hub::{Cache, Repo, RepoType};
118    /// # let model_id = "gpt2".to_string();
119    /// let cache = Cache::new("/tmp/".into());
120    /// let cache = cache.repo(Repo::new(model_id, RepoType::Space));
121    /// ```
122    pub fn space(&self, model_id: String) -> CacheRepo {
123        self.repo(Repo::new(model_id, RepoType::Space))
124    }
125}
126
127/// Shorthand for accessing things within a particular repo
128#[derive(Debug)]
129pub struct CacheRepo {
130    cache: Cache,
131    repo: Repo,
132}
133
134impl CacheRepo {
135    fn new(cache: Cache, repo: Repo) -> Self {
136        Self { cache, repo }
137    }
138
139    /// This will get the location of the file within the cache for the remote
140    /// `filename`. Will return `None` if file is not already present in cache.
141    pub fn get(&self, filename: &str) -> Option<PathBuf> {
142        let commit_path = self.ref_path();
143        let commit_hash = std::fs::read_to_string(commit_path).ok()?;
144        let mut pointer_path = self.pointer_path(&commit_hash);
145        pointer_path.push(filename);
146        if pointer_path.exists() {
147            Some(pointer_path)
148        } else {
149            None
150        }
151    }
152
153    fn path(&self) -> PathBuf {
154        let mut ref_path = self.cache.path.clone();
155        ref_path.push(self.repo.folder_name());
156        ref_path
157    }
158
159    fn ref_path(&self) -> PathBuf {
160        let mut ref_path = self.path();
161        ref_path.push("refs");
162        ref_path.push(self.repo.revision());
163        ref_path
164    }
165
166    /// Creates a reference in the cache directory that points branches to the correct
167    /// commits within the blobs.
168    pub fn create_ref(&self, commit_hash: &str) -> Result<(), std::io::Error> {
169        let ref_path = self.ref_path();
170        // Needs to be done like this because revision might contain `/` creating subfolders here.
171        std::fs::create_dir_all(ref_path.parent().unwrap())?;
172        let mut file = std::fs::OpenOptions::new()
173            .write(true)
174            .create(true)
175            .truncate(true)
176            .open(&ref_path)?;
177        file.write_all(commit_hash.trim().as_bytes())?;
178        Ok(())
179    }
180
181    /// Get the path of the blob with the given etag.
182    #[cfg(any(feature = "tokio", feature = "ureq"))]
183    pub fn blob_path(&self, etag: &str) -> PathBuf {
184        let mut blob_path = self.path();
185        blob_path.push("blobs");
186        blob_path.push(etag);
187        blob_path
188    }
189
190
191    /// Get the path of the snapshot with the given commit hash.
192    ///
193    /// This path contains symlink pointers to the files for this commit.
194    pub fn pointer_path(&self, commit_hash: &str) -> PathBuf {
195        let mut pointer_path = self.path();
196        pointer_path.push("snapshots");
197        pointer_path.push(commit_hash);
198        pointer_path
199    }
200}
201
202impl Default for Cache {
203    fn default() -> Self {
204        let mut path = dirs::home_dir().expect("Cache directory cannot be found");
205        path.push(".cache");
206        path.push("huggingface");
207        path.push("hub");
208        Self::new(path)
209    }
210}
211
212/// The representation of a repo on the hub.
213#[derive(Clone, Debug)]
214pub struct Repo {
215    repo_id: String,
216    repo_type: RepoType,
217    revision: String,
218}
219
220impl Repo {
221    /// Repo with the default branch ("main").
222    pub fn new(repo_id: String, repo_type: RepoType) -> Self {
223        Self::with_revision(repo_id, repo_type, "main".to_string())
224    }
225
226    /// fully qualified Repo
227    pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self {
228        Self {
229            repo_id,
230            repo_type,
231            revision,
232        }
233    }
234
235    /// Shortcut for [`Repo::new`] with [`RepoType::Model`]
236    pub fn model(repo_id: String) -> Self {
237        Self::new(repo_id, RepoType::Model)
238    }
239
240    /// Shortcut for [`Repo::new`] with [`RepoType::Dataset`]
241    pub fn dataset(repo_id: String) -> Self {
242        Self::new(repo_id, RepoType::Dataset)
243    }
244
245    /// Shortcut for [`Repo::new`] with [`RepoType::Space`]
246    pub fn space(repo_id: String) -> Self {
247        Self::new(repo_id, RepoType::Space)
248    }
249
250    /// The normalized folder nameof the repo within the cache directory
251    pub fn folder_name(&self) -> String {
252        let prefix = match self.repo_type {
253            RepoType::Model => "models",
254            RepoType::Dataset => "datasets",
255            RepoType::Space => "spaces",
256        };
257        format!("{prefix}--{}", self.repo_id).replace('/', "--")
258    }
259
260    /// The revision
261    pub fn revision(&self) -> &str {
262        &self.revision
263    }
264
265    /// The actual URL part of the repo
266    #[cfg(any(feature = "tokio", feature = "ureq"))]
267    pub fn url(&self) -> String {
268        match self.repo_type {
269            RepoType::Model => self.repo_id.to_string(),
270            RepoType::Dataset => {
271                format!("datasets/{}", self.repo_id)
272            }
273            RepoType::Space => {
274                format!("spaces/{}", self.repo_id)
275            }
276        }
277    }
278
279    /// Revision needs to be url escaped before being used in a URL
280    #[cfg(any(feature = "tokio", feature = "ureq"))]
281    pub fn url_revision(&self) -> String {
282        self.revision.replace('/', "%2F")
283    }
284
285    /// Used to compute the repo's url part when accessing the metadata of the repo
286    #[cfg(any(feature = "tokio", feature = "ureq"))]
287    pub fn api_url(&self) -> String {
288        let prefix = match self.repo_type {
289            RepoType::Model => "models",
290            RepoType::Dataset => "datasets",
291            RepoType::Space => "spaces",
292        };
293        format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision())
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    /// Internal macro used to show cleaners errors
302    /// on the payloads received from the hub.
303    #[macro_export]
304    macro_rules! assert_no_diff {
305        ($left: expr, $right: expr) => {
306            let left = serde_json::to_string_pretty(&$left).unwrap();
307            let right = serde_json::to_string_pretty(&$right).unwrap();
308            if left != right {
309                use rand::Rng;
310                use std::io::Write;
311                use std::process::Command;
312                let rand_string: String = rand::rng()
313                    .sample_iter(&rand::distr::Alphanumeric)
314                    .take(6)
315                    .map(char::from)
316                    .collect();
317                let left_filename = format!("/tmp/left-{rand_string}.txt");
318                let mut file = std::fs::File::create(&left_filename).unwrap();
319                file.write_all(left.as_bytes()).unwrap();
320                let right_filename = format!("/tmp/right-{rand_string}.txt");
321                let mut file = std::fs::File::create(&right_filename).unwrap();
322                file.write_all(right.as_bytes()).unwrap();
323                let output = Command::new("diff")
324                    // Reverse order seems to be more appropriate for how we set up the tests.
325                    .args(["-U5", &right_filename, &left_filename])
326                    .output()
327                    .expect("Failed to diff")
328                    .stdout;
329                let diff = String::from_utf8(output).expect("Invalid utf-8 diff output");
330                // eprintln!("assertion `left == right` failed\n{diff}");
331                assert!(false, "{diff}")
332            };
333        };
334    }
335
336    #[test]
337    #[cfg(not(target_os = "windows"))]
338    fn token_path() {
339        let cache = Cache::from_env();
340        let token_path = cache.token_path().to_str().unwrap().to_string();
341        if let Ok(hf_home) = std::env::var(HF_HOME) {
342            assert_eq!(token_path, format!("{hf_home}/token"));
343        } else {
344            let n = "huggingface/token".len();
345            assert_eq!(&token_path[token_path.len() - n..], "huggingface/token");
346        }
347    }
348
349    #[test]
350    #[cfg(target_os = "windows")]
351    fn token_path() {
352        let cache = Cache::from_env();
353        let token_path = cache.token_path().to_str().unwrap().to_string();
354        if let Ok(hf_home) = std::env::var(HF_HOME) {
355            assert_eq!(token_path, format!("{hf_home}\\token"));
356        } else {
357            let n = "huggingface/token".len();
358            assert_eq!(&token_path[token_path.len() - n..], "huggingface\\token");
359        }
360    }
361}