1#![deny(missing_docs)]
2#![doc = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/README.md"))]
3#[cfg(any(feature = "tokio", feature = "ureq"))]
4use rand::{distributions::Alphanumeric, Rng};
5use std::io::Write;
6use std::path::PathBuf;
7
8#[cfg(any(feature = "tokio", feature = "ureq"))]
10pub mod api;
11
12#[derive(Debug, Clone, Copy)]
14pub enum RepoType {
15 Model,
18 Dataset,
20 Space,
22}
23
24#[derive(Clone, Debug)]
26pub struct Cache {
27 path: PathBuf,
28}
29
30impl Cache {
31 pub fn new(path: PathBuf) -> Self {
33 Self { path }
34 }
35
36 pub fn path(&self) -> &PathBuf {
38 &self.path
39 }
40
41 pub fn token_path(&self) -> PathBuf {
43 let mut path = self.path.clone();
44 path.pop();
46 path.push("token");
47 path
48 }
49
50 pub fn token(&self) -> Option<String> {
53 let token_filename = self.token_path();
54 if !token_filename.exists() {
55 log::info!("Token file not found {token_filename:?}");
56 }
57 match std::fs::read_to_string(token_filename) {
58 Ok(token_content) => {
59 let token_content = token_content.trim();
60 if token_content.is_empty() {
61 None
62 } else {
63 Some(token_content.to_string())
64 }
65 }
66 Err(_) => None,
67 }
68 }
69
70 pub fn repo(&self, repo: Repo) -> CacheRepo {
73 CacheRepo::new(self.clone(), repo)
74 }
75
76 pub fn model(&self, model_id: String) -> CacheRepo {
84 self.repo(Repo::new(model_id, RepoType::Model))
85 }
86
87 pub fn dataset(&self, model_id: String) -> CacheRepo {
95 self.repo(Repo::new(model_id, RepoType::Dataset))
96 }
97
98 pub fn space(&self, model_id: String) -> CacheRepo {
106 self.repo(Repo::new(model_id, RepoType::Space))
107 }
108
109 #[cfg(any(feature = "tokio", feature = "ureq"))]
110 pub(crate) fn temp_path(&self) -> PathBuf {
111 let mut path = self.path().clone();
112 path.push("tmp");
113 std::fs::create_dir_all(&path).ok();
114
115 let s: String = rand::thread_rng()
116 .sample_iter(&Alphanumeric)
117 .take(7)
118 .map(char::from)
119 .collect();
120 path.push(s);
121 path.to_path_buf()
122 }
123}
124
125#[derive(Debug)]
127pub struct CacheRepo {
128 cache: Cache,
129 repo: Repo,
130}
131
132impl CacheRepo {
133 fn new(cache: Cache, repo: Repo) -> Self {
134 Self { cache, repo }
135 }
136 pub fn get(&self, filename: &str) -> Option<PathBuf> {
139 let commit_path = self.ref_path();
140 let commit_hash = std::fs::read_to_string(commit_path).ok()?;
141 let mut pointer_path = self.pointer_path(&commit_hash);
142 pointer_path.push(filename);
143 if pointer_path.exists() {
144 Some(pointer_path)
145 } else {
146 None
147 }
148 }
149
150 fn path(&self) -> PathBuf {
151 let mut ref_path = self.cache.path.clone();
152 ref_path.push(self.repo.folder_name());
153 ref_path
154 }
155
156 fn ref_path(&self) -> PathBuf {
157 let mut ref_path = self.path();
158 ref_path.push("refs");
159 ref_path.push(self.repo.revision());
160 ref_path
161 }
162
163 pub fn create_ref(&self, commit_hash: &str) -> Result<(), std::io::Error> {
166 let ref_path = self.ref_path();
167 std::fs::create_dir_all(ref_path.parent().unwrap())?;
169 let mut file = std::fs::OpenOptions::new()
170 .write(true)
171 .create(true)
172 .truncate(true)
173 .open(&ref_path)?;
174 file.write_all(commit_hash.trim().as_bytes())?;
175 Ok(())
176 }
177
178 #[cfg(any(feature = "tokio", feature = "ureq"))]
179 pub(crate) fn blob_path(&self, etag: &str) -> PathBuf {
180 let mut blob_path = self.path();
181 blob_path.push("blobs");
182 blob_path.push(etag);
183 blob_path
184 }
185
186 pub(crate) fn pointer_path(&self, commit_hash: &str) -> PathBuf {
187 let mut pointer_path = self.path();
188 pointer_path.push("snapshots");
189 pointer_path.push(commit_hash);
190 pointer_path
191 }
192}
193
194impl Default for Cache {
195 fn default() -> Self {
196 let mut path = match std::env::var("HF_HOME") {
197 Ok(home) => home.into(),
198 Err(_) => {
199 let mut cache = dirs::home_dir().expect("Cache directory cannot be found");
200 cache.push(".cache");
201 cache.push("huggingface");
202 cache
203 }
204 };
205 path.push("hub");
206 Self::new(path)
207 }
208}
209
210#[derive(Clone, Debug)]
212pub struct Repo {
213 repo_id: String,
214 repo_type: RepoType,
215 revision: String,
216}
217
218impl Repo {
219 pub fn new(repo_id: String, repo_type: RepoType) -> Self {
221 Self::with_revision(repo_id, repo_type, "main".to_string())
222 }
223
224 pub fn with_revision(repo_id: String, repo_type: RepoType, revision: String) -> Self {
226 Self {
227 repo_id,
228 repo_type,
229 revision,
230 }
231 }
232
233 pub fn model(repo_id: String) -> Self {
235 Self::new(repo_id, RepoType::Model)
236 }
237
238 pub fn dataset(repo_id: String) -> Self {
240 Self::new(repo_id, RepoType::Dataset)
241 }
242
243 pub fn space(repo_id: String) -> Self {
245 Self::new(repo_id, RepoType::Space)
246 }
247
248 pub fn folder_name(&self) -> String {
250 let prefix = match self.repo_type {
251 RepoType::Model => "models",
252 RepoType::Dataset => "datasets",
253 RepoType::Space => "spaces",
254 };
255 format!("{prefix}--{}", self.repo_id).replace('/', "--")
256 }
257
258 pub fn revision(&self) -> &str {
260 &self.revision
261 }
262
263 #[cfg(any(feature = "tokio", feature = "ureq"))]
265 pub fn url(&self) -> String {
266 match self.repo_type {
267 RepoType::Model => self.repo_id.to_string(),
268 RepoType::Dataset => {
269 format!("datasets/{}", self.repo_id)
270 }
271 RepoType::Space => {
272 format!("spaces/{}", self.repo_id)
273 }
274 }
275 }
276
277 #[cfg(any(feature = "tokio", feature = "ureq"))]
279 pub fn url_revision(&self) -> String {
280 self.revision.replace('/', "%2F")
281 }
282
283 #[cfg(any(feature = "tokio", feature = "ureq"))]
285 pub fn api_url(&self) -> String {
286 let prefix = match self.repo_type {
287 RepoType::Model => "models",
288 RepoType::Dataset => "datasets",
289 RepoType::Space => "spaces",
290 };
291 format!("{prefix}/{}/revision/{}", self.repo_id, self.url_revision())
292 }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 #[cfg(not(target_os = "windows"))]
301 fn token_path() {
302 let cache = Cache::default();
303 let token_path = cache.token_path().to_str().unwrap().to_string();
304 if let Ok(hf_home) = std::env::var("HF_HOME") {
305 assert_eq!(token_path, format!("{hf_home}/token"));
306 } else {
307 let n = "huggingface/token".len();
308 assert_eq!(&token_path[token_path.len() - n..], "huggingface/token");
309 }
310 }
311
312 #[test]
313 #[cfg(target_os = "windows")]
314 fn token_path() {
315 let cache = Cache::default();
316 let token_path = cache.token_path().to_str().unwrap().to_string();
317 if let Ok(hf_home) = std::env::var("HF_HOME") {
318 assert_eq!(token_path, format!("{hf_home}\\token"));
319 } else {
320 let n = "huggingface/token".len();
321 assert_eq!(&token_path[token_path.len() - n..], "huggingface\\token");
322 }
323 }
324}