candle_hf_hub/
lib.rs

1#![deny(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
3#[cfg(any(feature = "tokio", feature = "ureq"))]
4use rand::{distributions::Alphanumeric, Rng};
5use std::io::Write;
6use std::path::PathBuf;
7
8/// The actual Api to interact with the hub.
9#[cfg(any(feature = "tokio", feature = "ureq"))]
10pub mod api;
11
12/// The type of repo to interact with
13#[derive(Debug, Clone, Copy)]
14pub enum RepoType {
15    /// This is a model, usually it consists of weight files and some configuration
16    /// files
17    Model,
18    /// This is a dataset, usually contains data within parquet files
19    Dataset,
20    /// This is a space, usually a demo showcashing a given model or dataset
21    Space,
22}
23
24/// A local struct used to fetch information from the cache folder.
25#[derive(Clone, Debug)]
26pub struct Cache {
27    path: PathBuf,
28}
29
30impl Cache {
31    /// Creates a new cache object location
32    pub fn new(path: PathBuf) -> Self {
33        Self { path }
34    }
35
36    /// Creates a new cache object location
37    pub fn path(&self) -> &PathBuf {
38        &self.path
39    }
40
41    /// Returns the location of the token file
42    pub fn token_path(&self) -> PathBuf {
43        let mut path = self.path.clone();
44        // Remove `"hub"`
45        path.pop();
46        path.push("token");
47        path
48    }
49
50    /// Returns the token value if it exists in the cache
51    /// Use `huggingface-cli login` to set it up.
52    pub fn token(&self) -> Option<String> {
53        let token_filename = self.token_path();
54        if !token_filename.exists() {
55            log::info!("Token file not found {token_filename:?}");
56        }
57        match std::fs::read_to_string(token_filename) {
58            Ok(token_content) => {
59                let token_content = token_content.trim();
60                if token_content.is_empty() {
61                    None
62                } else {
63                    Some(token_content.to_string())
64                }
65            }
66            Err(_) => None,
67        }
68    }
69
70    /// Creates a new handle [`CacheRepo`] which contains operations
71    /// on a particular [`Repo`]
72    pub fn repo(&self, repo: Repo) -> CacheRepo {
73        CacheRepo::new(self.clone(), repo)
74    }
75
76    /// Simple wrapper over
77    /// ```
78    /// # use hf_hub::{Cache, Repo, RepoType};
79    /// # let model_id = "gpt2".to_string();
80    /// let cache = Cache::new("/tmp/".into());
81    /// let cache = cache.repo(Repo::new(model_id, RepoType::Model));
82    /// ```
83    pub fn model(&self, model_id: String) -> CacheRepo {
84        self.repo(Repo::new(model_id, RepoType::Model))
85    }
86
87    /// Simple wrapper over
88    /// ```
89    /// # use hf_hub::{Cache, Repo, RepoType};
90    /// # let model_id = "gpt2".to_string();
91    /// let cache = Cache::new("/tmp/".into());
92    /// let cache = cache.repo(Repo::new(model_id, RepoType::Dataset));
93    /// ```
94    pub fn dataset(&self, model_id: String) -> CacheRepo {
95        self.repo(Repo::new(model_id, RepoType::Dataset))
96    }
97
98    /// Simple wrapper over
99    /// ```
100    /// # use hf_hub::{Cache, Repo, RepoType};
101    /// # let model_id = "gpt2".to_string();
102    /// let cache = Cache::new("/tmp/".into());
103    /// let cache = cache.repo(Repo::new(model_id, RepoType::Space));
104    /// ```
105    pub fn space(&self, model_id: String) -> CacheRepo {
106        self.repo(Repo::new(model_id, RepoType::Space))
107    }
108
109    #[cfg(any(feature = "tokio", feature = "ureq"))]
110    pub(crate) fn temp_path(&self) -> PathBuf {
111        let mut path = self.path().clone();
112        path.push("tmp");
113        std::fs::create_dir_all(&path).ok();
114
115        let s: String = rand::thread_rng()
116            .sample_iter(&Alphanumeric)
117            .take(7)
118            .map(char::from)
119            .collect();
120        path.push(s);
121        path.to_path_buf()
122    }
123}
124
125/// Shorthand for accessing things within a particular repo
126#[derive(Debug)]
127pub struct CacheRepo {
128    cache: Cache,
129    repo: Repo,
130}
131
132impl CacheRepo {
133    fn new(cache: Cache, repo: Repo) -> Self {
134        Self { cache, repo }
135    }
136    /// This will get the location of the file within the cache for the remote
137    /// `filename`. Will return `None` if file is not already present in cache.
138    pub fn get(&self, filename: &str) -> Option<PathBuf> {
139        let commit_path = self.ref_path();
140        let commit_hash = std::fs::read_to_string(commit_path).ok()?;
141        let mut pointer_path = self.pointer_path(&commit_hash);
142        pointer_path.push(filename);
143        if pointer_path.exists() {
144            Some(pointer_path)
145        } else {
146            None
147        }
148    }
149
150    fn path(&self) -> PathBuf {
151        let mut ref_path = self.cache.path.clone();
152        ref_path.push(self.repo.folder_name());
153        ref_path
154    }
155
156    fn ref_path(&self) -> PathBuf {
157        let mut ref_path = self.path();
158        ref_path.push("refs");
159        ref_path.push(self.repo.revision());
160        ref_path
161    }
162
163    /// Creates a reference in the cache directory that points branches to the correct
164    /// commits within the blobs.
165    pub fn create_ref(&self, commit_hash: &str) -> Result<(), std::io::Error> {
166        let ref_path = self.ref_path();
167        // Needs to be done like this because revision might contain `/` creating subfolders here.
168        std::fs::create_dir_all(ref_path.parent().unwrap())?;
169        let mut file = std::fs::OpenOptions::new()
170            .write(true)
171            .create(true)
172            .truncate(true)
173            .open(&ref_path)?;
174        file.write_all(commit_hash.trim().as_bytes())?;
175        Ok(())
176    }
177
178    #[cfg(any(feature = "tokio", feature = "ureq"))]
179    pub(crate) fn blob_path(&self, etag: &str) -> PathBuf {
180        let mut blob_path = self.path();
181        blob_path.push("blobs");
182        blob_path.push(etag);
183        blob_path
184    }
185
186    pub(crate) fn pointer_path(&self, commit_hash: &str) -> PathBuf {
187        let mut pointer_path = self.path();
188        pointer_path.push("snapshots");
189        pointer_path.push(commit_hash);
190        pointer_path
191    }
192}
193
194impl Default for Cache {
195    fn default() -> Self {
196        let mut path = match std::env::var("HF_HOME") {
197            Ok(home) => home.into(),
198            Err(_) => {
199                let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
200                cache.push(".cache");
201                cache.push("huggingface");
202                cache
203            }
204        };
205        path.push("hub");
206        Self::new(path)
207    }
208}
209
210/// The representation of a repo on the hub.
211#[derive(Clone, Debug)]
212pub struct Repo {
213    repo_id: String,
214    repo_type: RepoType,
215    revision: String,
216}
217
218impl Repo {
219    /// Repo with the default branch ("main").
220    pub fn new(repo_id: String, repo_type: RepoType) -> Self {
221        Self::with_revision(repo_id, repo_type, "main".to_string())
222    }
223
224    /// fully qualified Repo
225    pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self {
226        Self {
227            repo_id,
228            repo_type,
229            revision,
230        }
231    }
232
233    /// Shortcut for [`Repo::new`] with [`RepoType::Model`]
234    pub fn model(repo_id: String) -> Self {
235        Self::new(repo_id, RepoType::Model)
236    }
237
238    /// Shortcut for [`Repo::new`] with [`RepoType::Dataset`]
239    pub fn dataset(repo_id: String) -> Self {
240        Self::new(repo_id, RepoType::Dataset)
241    }
242
243    /// Shortcut for [`Repo::new`] with [`RepoType::Space`]
244    pub fn space(repo_id: String) -> Self {
245        Self::new(repo_id, RepoType::Space)
246    }
247
248    /// The normalized folder nameof the repo within the cache directory
249    pub fn folder_name(&self) -> String {
250        let prefix = match self.repo_type {
251            RepoType::Model => "models",
252            RepoType::Dataset => "datasets",
253            RepoType::Space => "spaces",
254        };
255        format!("{prefix}--{}", self.repo_id).replace('/', "--")
256    }
257
258    /// The revision
259    pub fn revision(&self) -> &str {
260        &self.revision
261    }
262
263    /// The actual URL part of the repo
264    #[cfg(any(feature = "tokio", feature = "ureq"))]
265    pub fn url(&self) -> String {
266        match self.repo_type {
267            RepoType::Model => self.repo_id.to_string(),
268            RepoType::Dataset => {
269                format!("datasets/{}", self.repo_id)
270            }
271            RepoType::Space => {
272                format!("spaces/{}", self.repo_id)
273            }
274        }
275    }
276
277    /// Revision needs to be url escaped before being used in a URL
278    #[cfg(any(feature = "tokio", feature = "ureq"))]
279    pub fn url_revision(&self) -> String {
280        self.revision.replace('/', "%2F")
281    }
282
283    /// Used to compute the repo's url part when accessing the metadata of the repo
284    #[cfg(any(feature = "tokio", feature = "ureq"))]
285    pub fn api_url(&self) -> String {
286        let prefix = match self.repo_type {
287            RepoType::Model => "models",
288            RepoType::Dataset => "datasets",
289            RepoType::Space => "spaces",
290        };
291        format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision())
292    }
293}
294
295#[cfg(test)]
296mod tests {
297    use super::*;
298
299    #[test]
300    #[cfg(not(target_os = "windows"))]
301    fn token_path() {
302        let cache = Cache::default();
303        let token_path = cache.token_path().to_str().unwrap().to_string();
304        if let Ok(hf_home) = std::env::var("HF_HOME") {
305            assert_eq!(token_path, format!("{hf_home}/token"));
306        } else {
307            let n = "huggingface/token".len();
308            assert_eq!(&token_path[token_path.len() - n..], "huggingface/token");
309        }
310    }
311
312    #[test]
313    #[cfg(target_os = "windows")]
314    fn token_path() {
315        let cache = Cache::default();
316        let token_path = cache.token_path().to_str().unwrap().to_string();
317        if let Ok(hf_home) = std::env::var("HF_HOME") {
318            assert_eq!(token_path, format!("{hf_home}\\token"));
319        } else {
320            let n = "huggingface/token".len();
321            assert_eq!(&token_path[token_path.len() - n..], "huggingface\\token");
322        }
323    }
324}