1use hf_hub::{Repo, RepoType};
2use httpdate::parse_http_date;
3use kalosm_model_types::{FileLoadingProgress, FileSource};
4use reqwest::{
5 header::{HeaderValue, CONTENT_LENGTH, LAST_MODIFIED, RANGE},
6 IntoUrl,
7};
8use reqwest::{Response, StatusCode};
9use std::path::PathBuf;
10use std::str::FromStr;
11use tokio::fs::{File, OpenOptions};
12use tokio::io::AsyncWriteExt;
13
14#[derive(Debug, thiserror::Error)]
15pub enum CacheError {
16 #[error("Hugging Face API error: {0}")]
17 HuggingFaceApi(#[from] hf_hub::api::sync::ApiError),
18 #[error("Unable to get file metadata for {0}: {1}")]
19 UnableToGetFileMetadata(PathBuf, #[source] tokio::io::Error),
20 #[error("IO error: {0}")]
21 Io(#[from] std::io::Error),
22 #[error("HTTP error: {0}")]
23 Http(#[from] reqwest::Error),
24 #[error("Unexpected status code: {0}")]
25 UnexpectedStatusCode(StatusCode),
26}
27
28#[derive(Debug, Clone)]
29pub struct Cache {
30 location: PathBuf,
31 huggingface_token: Option<String>,
33}
34
35impl Cache {
36 pub fn new(location: PathBuf) -> Self {
38 Self {
39 location,
40 huggingface_token: None,
41 }
42 }
43
44 pub fn with_huggingface_token(mut self, token: Option<String>) -> Self {
46 self.huggingface_token = token;
47 self
48 }
49
50 pub fn exists(&self, source: &FileSource) -> bool {
52 match source {
53 FileSource::HuggingFace {
54 model_id,
55 revision,
56 file,
57 ..
58 } => {
59 let path = self.location.join(model_id).join(revision);
60 let complete_download = path.join(file);
61 complete_download.exists()
62 }
63 FileSource::Local(path) => path.exists(),
64 }
65 }
66
67 pub async fn get(
69 &self,
70 source: &FileSource,
71 progress: impl FnMut(FileLoadingProgress),
72 ) -> Result<PathBuf, CacheError> {
73 match source {
74 FileSource::HuggingFace {
75 model_id,
76 revision,
77 file,
78 } => {
79 let token = self.huggingface_token.clone().or_else(huggingface_token);
80
81 let path = self.location.join(model_id).join(revision);
82 let complete_download = path.join(file);
83
84 let repo = Repo::with_revision(
85 model_id.to_string(),
86 RepoType::Model,
87 revision.to_string(),
88 );
89 let api = hf_hub::api::sync::Api::new()?.repo(repo);
90 let url = api.url(file);
91 let client = reqwest::Client::new();
92 tracing::trace!("Fetching metadata for {file} from {url}");
93 let response = client
94 .head(&url)
95 .with_authorization_header(token.clone())
96 .send()
97 .await;
98
99 if complete_download.exists() {
100 let metadata = tokio::fs::metadata(&complete_download).await.map_err(|e| {
101 CacheError::UnableToGetFileMetadata(complete_download.clone(), e)
102 })?;
103 let file_last_modified = metadata.modified()?;
104 if let Some(last_updated) = response
106 .as_ref()
107 .ok()
108 .and_then(|response| response.headers().get(LAST_MODIFIED))
109 .and_then(|last_updated| last_updated.to_str().ok())
110 .and_then(|s| parse_http_date(s).ok())
111 {
112 if last_updated <= file_last_modified {
113 return Ok(complete_download);
114 }
115 } else {
116 return Ok(complete_download);
118 }
119 }
120 let incomplete_download = path.join(format!("{}.partial", file));
121
122 tracing::trace!("Downloading into {:?}", incomplete_download);
123
124 download_into(
125 url,
126 &incomplete_download,
127 response?,
128 client,
129 token,
130 progress,
131 )
132 .await?;
133
134 tokio::fs::rename(&incomplete_download, &complete_download).await?;
136
137 Ok(complete_download)
138 }
139 FileSource::Local(path) => Ok(path.clone()),
140 }
141 }
142}
143
144impl Default for Cache {
145 fn default() -> Self {
146 Self {
147 location: dirs::data_dir().unwrap().join("kalosm").join("cache"),
148 huggingface_token: None,
149 }
150 }
151}
152
153async fn download_into<U: IntoUrl>(
154 url: U,
155 file: &PathBuf,
156 head: Response,
157 client: reqwest::Client,
158 token: Option<String>,
159 mut progress: impl FnMut(FileLoadingProgress),
160) -> Result<(), CacheError> {
161 let length = head
162 .headers()
163 .get(CONTENT_LENGTH)
164 .ok_or("response doesn't include the content length")
165 .unwrap();
166 let length = length.to_str().ok().and_then(|s| u64::from_str(s).ok());
167
168 let (start, mut output_file) = if let Ok(metadata) = tokio::fs::metadata(file).await {
169 let start = metadata.len();
170 let output_file = OpenOptions::new().append(true).open(file).await.unwrap();
171 (start, output_file)
172 } else {
173 tokio::fs::create_dir_all(file.parent().unwrap()).await?;
174 (0, File::create(file).await.unwrap())
175 };
176
177 if let Some(length) = length {
178 progress(FileLoadingProgress {
179 progress: start,
180 cached_size: start,
181 size: length,
182 start_time: std::time::Instant::now(),
183 });
184 }
185
186 if Some(start) == length {
187 tracing::trace!("File {} already downloaded", file.display());
188 progress(FileLoadingProgress {
189 progress: start,
190 cached_size: start,
191 size: length.unwrap_or(0),
192 start_time: std::time::Instant::now(),
193 });
194 return Ok(());
195 }
196
197 let range = length
198 .and_then(|length| HeaderValue::from_str(&format!("bytes={}-{}", start, length - 1)).ok());
199
200 tracing::trace!("Fetching range {:?}", range);
201 let mut request = client.get(url).with_authorization_header(token);
202 if let Some(range) = range {
203 request = request.header(RANGE, range);
204 }
205 let mut response = request.send().await?;
206
207 let status = response.status();
208 if !(status == StatusCode::OK || status == StatusCode::PARTIAL_CONTENT) {
209 return Err(CacheError::UnexpectedStatusCode(status));
210 }
211
212 let mut current_progress = start;
213
214 while let Some(chunk) = response.chunk().await? {
215 output_file.write_all(&chunk).await?;
216 tracing::trace!("wrote chunk of size {}", chunk.len());
217 current_progress += chunk.len() as u64;
218 if let Some(length) = length {
219 progress(FileLoadingProgress {
220 progress: current_progress,
221 cached_size: start,
222 size: length,
223 start_time: std::time::Instant::now(),
224 });
225 }
226 }
227
228 tracing::trace!("Download of {} complete", file.display());
229
230 Ok(())
231}
232
233trait RequestBuilderExt {
234 fn with_authorization_header(self, token: Option<String>) -> Self;
235}
236
237impl RequestBuilderExt for reqwest::RequestBuilder {
238 fn with_authorization_header(self, token: Option<String>) -> Self {
239 if let Some(token) = token {
240 self.header(reqwest::header::AUTHORIZATION, format!("Bearer {token}"))
241 } else {
242 self
243 }
244 }
245}
246
247#[cfg(test)]
248#[tokio::test]
249async fn downloads_work() {
250 let url = "https://httpbin.org/range/102400?duration=2";
251 let file = PathBuf::from("download.bin");
252 let progress = |p| {
253 println!("Progress: {:?}", p);
254 };
255 let client = reqwest::Client::new();
256 let response = client.head(url).send().await.unwrap();
257 download_into(url, &file, response, client, None, progress)
258 .await
259 .unwrap();
260 assert!(file.exists());
261 tokio::fs::remove_file(file).await.unwrap();
262}
263
264fn huggingface_token() -> Option<String> {
265 let cache = hf_hub::Cache::default();
266 cache.token().or_else(|| std::env::var("HF_TOKEN").ok())
267}